Full Code of tracel-ai/burn for AI

main bcaabad860c5 cached
1817 files
9.5 MB
2.6M tokens
15549 symbols
1 requests
Copy disabled (too large) Download .txt
Showing preview only (10,282K chars total). Download the full file to get everything.
Repository: tracel-ai/burn
Branch: main
Commit: bcaabad860c5
Files: 1817
Total size: 9.5 MB

Directory structure:
gitextract_82tcaglc/

├── .cargo/
│   ├── audit.toml
│   └── config.toml
├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   ├── doc_request.md
│   │   └── feature_request.md
│   ├── PULL_REQUEST_TEMPLATE/
│   │   └── template.md
│   ├── dependabot.yml
│   ├── pull_request_template.md
│   └── workflows/
│       ├── combine-dependabot-prs.yml
│       ├── dependencies.yml
│       ├── publish.yml
│       ├── stale-pr.yml
│       ├── test-gpu.yml
│       ├── test.yml
│       ├── valgrind.yml
│       └── vulnerabilities.yml
├── .gitignore
├── CITATION.cff
├── CODE-OF-CONDUCT.md
├── CONTRIBUTING.md
├── Cargo.toml
├── LICENSE-APACHE
├── LICENSE-MIT
├── NOTICES.md
├── POEM.md
├── README.md
├── _typos.toml
├── benchmarks.toml
├── burn-book/
│   ├── .gitignore
│   ├── .prettierrc.json
│   ├── book.toml
│   └── src/
│       ├── SUMMARY.md
│       ├── advanced/
│       │   ├── README.md
│       │   ├── backend-extension/
│       │   │   ├── README.md
│       │   │   ├── custom-cubecl-kernel.md
│       │   │   └── custom-wgpu-kernel.md
│       │   ├── no-std.md
│       │   └── web-assembly.md
│       ├── basic-workflow/
│       │   ├── README.md
│       │   ├── backend.md
│       │   ├── data.md
│       │   ├── inference.md
│       │   ├── model.md
│       │   └── training.md
│       ├── building-blocks/
│       │   ├── README.md
│       │   ├── autodiff.md
│       │   ├── backend.md
│       │   ├── config.md
│       │   ├── dataset.md
│       │   ├── learner.md
│       │   ├── metric.md
│       │   ├── module.md
│       │   ├── record.md
│       │   └── tensor.md
│       ├── custom-training-loop.md
│       ├── distributed-computing.md
│       ├── examples.md
│       ├── getting-started.md
│       ├── models-and-pretrained-weights.md
│       ├── motivation.md
│       ├── onnx-import.md
│       ├── overview.md
│       ├── performance/
│       │   ├── README.md
│       │   ├── distributed-computing.md
│       │   ├── good-practices/
│       │   │   ├── README.md
│       │   │   ├── asynchronous-execution.md
│       │   │   ├── kernel-fusion.md
│       │   │   └── kernel-selection.md
│       │   └── quantization.md
│       └── saving-and-loading.md
├── codecov.yml
├── contributor-book/
│   ├── .gitignore
│   ├── .prettierrc.json
│   ├── book.toml
│   └── src/
│       ├── SUMMARY.md
│       ├── frequently-encountered-issues/
│       │   ├── README.md
│       │   └── issues-while-adding-ops.md
│       ├── getting-started/
│       │   ├── README.md
│       │   ├── configuring-your-editor.md
│       │   ├── setting-up-the-environment.md
│       │   └── testing.md
│       ├── guides/
│       │   ├── README.md
│       │   ├── adding-a-new-operation-to-burn.md
│       │   └── submitting-examples.md
│       ├── how-to-read-this-book.md
│       ├── overview.md
│       └── project-architecture/
│           ├── README.md
│           ├── backend.md
│           ├── module.md
│           ├── serialization.md
│           └── tensor.md
├── crates/
│   ├── burn/
│   │   ├── Cargo.toml
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── collective.rs
│   │       └── lib.rs
│   ├── burn-autodiff/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── checkpoint/
│   │       │   ├── base.rs
│   │       │   ├── builder.rs
│   │       │   ├── mod.rs
│   │       │   ├── retro_forward.rs
│   │       │   ├── state.rs
│   │       │   └── strategy.rs
│   │       ├── grads.rs
│   │       ├── graph/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── node.rs
│   │       │   ├── requirement.rs
│   │       │   └── traversal.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── backward.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── maxmin.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── sort.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── runtime/
│   │       │   ├── client.rs
│   │       │   ├── graph.rs
│   │       │   ├── memory_management.rs
│   │       │   ├── mod.rs
│   │       │   └── server.rs
│   │       ├── tensor.rs
│   │       └── utils.rs
│   ├── burn-backend/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend/
│   │       │   ├── base.rs
│   │       │   ├── device.rs
│   │       │   ├── mod.rs
│   │       │   ├── ops/
│   │       │   │   ├── activation.rs
│   │       │   │   ├── argwhere.rs
│   │       │   │   ├── bool_tensor.rs
│   │       │   │   ├── cat.rs
│   │       │   │   ├── int_tensor.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── modules/
│   │       │   │   │   ├── attention.rs
│   │       │   │   │   ├── base.rs
│   │       │   │   │   ├── conv.rs
│   │       │   │   │   ├── grid_sample.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── pool.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   ├── qtensor.rs
│   │       │   │   ├── repeat_dim.rs
│   │       │   │   ├── sort.rs
│   │       │   │   ├── tensor.rs
│   │       │   │   └── transaction.rs
│   │       │   └── primitive.rs
│   │       ├── data/
│   │       │   ├── compare.rs
│   │       │   ├── mod.rs
│   │       │   └── tensor.rs
│   │       ├── distribution.rs
│   │       ├── element/
│   │       │   ├── base.rs
│   │       │   ├── cast.rs
│   │       │   ├── mod.rs
│   │       │   └── scalar.rs
│   │       ├── lib.rs
│   │       └── tensor/
│   │           ├── alias.rs
│   │           ├── container.rs
│   │           ├── kind.rs
│   │           ├── mod.rs
│   │           ├── ops/
│   │           │   ├── autodiff.rs
│   │           │   ├── base.rs
│   │           │   ├── bool.rs
│   │           │   ├── float.rs
│   │           │   ├── int.rs
│   │           │   ├── mod.rs
│   │           │   ├── numeric.rs
│   │           │   └── ordered.rs
│   │           └── quantization/
│   │               ├── calibration.rs
│   │               ├── mod.rs
│   │               ├── parameters.rs
│   │               └── scheme.rs
│   ├── burn-backend-tests/
│   │   ├── .cargo/
│   │   │   └── config.toml
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── src/
│   │   │   └── lib.rs
│   │   └── tests/
│   │       ├── autodiff/
│   │       │   ├── abs.rs
│   │       │   ├── adaptive_avgpool1d.rs
│   │       │   ├── adaptive_avgpool2d.rs
│   │       │   ├── add.rs
│   │       │   ├── aggregation.rs
│   │       │   ├── avgpool1d.rs
│   │       │   ├── avgpool2d.rs
│   │       │   ├── backward.rs
│   │       │   ├── bridge.rs
│   │       │   ├── broadcast.rs
│   │       │   ├── cast.rs
│   │       │   ├── cat.rs
│   │       │   ├── ceil.rs
│   │       │   ├── checkpoint.rs
│   │       │   ├── complex.rs
│   │       │   ├── conv1d.rs
│   │       │   ├── conv2d.rs
│   │       │   ├── conv3d.rs
│   │       │   ├── conv_transpose1d.rs
│   │       │   ├── conv_transpose2d.rs
│   │       │   ├── conv_transpose3d.rs
│   │       │   ├── cross.rs
│   │       │   ├── cross_entropy.rs
│   │       │   ├── cummax.rs
│   │       │   ├── cummin.rs
│   │       │   ├── cumprod.rs
│   │       │   ├── cumsum.rs
│   │       │   ├── deform_conv2d.rs
│   │       │   ├── div.rs
│   │       │   ├── erf.rs
│   │       │   ├── exp.rs
│   │       │   ├── expand.rs
│   │       │   ├── flip.rs
│   │       │   ├── floor.rs
│   │       │   ├── gather_scatter.rs
│   │       │   ├── gelu.rs
│   │       │   ├── gradients.rs
│   │       │   ├── log.rs
│   │       │   ├── log1p.rs
│   │       │   ├── log_sigmoid.rs
│   │       │   ├── mask.rs
│   │       │   ├── matmul.rs
│   │       │   ├── maxmin.rs
│   │       │   ├── maxpool1d.rs
│   │       │   ├── maxpool2d.rs
│   │       │   ├── memory_management.rs
│   │       │   ├── mod.rs
│   │       │   ├── mul.rs
│   │       │   ├── multithread.rs
│   │       │   ├── nearest_interpolate.rs
│   │       │   ├── neg.rs
│   │       │   ├── nonzero.rs
│   │       │   ├── permute.rs
│   │       │   ├── pow.rs
│   │       │   ├── recip.rs
│   │       │   ├── relu.rs
│   │       │   ├── remainder.rs
│   │       │   ├── repeat_dim.rs
│   │       │   ├── reshape.rs
│   │       │   ├── round.rs
│   │       │   ├── select.rs
│   │       │   ├── sigmoid.rs
│   │       │   ├── sign.rs
│   │       │   ├── slice.rs
│   │       │   ├── slice_assign.rs
│   │       │   ├── softmax.rs
│   │       │   ├── sort.rs
│   │       │   ├── sqrt.rs
│   │       │   ├── sub.rs
│   │       │   ├── transpose.rs
│   │       │   ├── trig.rs
│   │       │   └── unfold.rs
│   │       ├── autodiff.rs
│   │       ├── common/
│   │       │   ├── autodiff.rs
│   │       │   ├── backend.rs
│   │       │   └── tensor.rs
│   │       ├── cubecl/
│   │       │   ├── avg_pool2d.rs
│   │       │   ├── bernoulli.rs
│   │       │   ├── cast.rs
│   │       │   ├── cat.rs
│   │       │   ├── clamp.rs
│   │       │   ├── contiguous.rs
│   │       │   ├── conv2d.rs
│   │       │   ├── conv3d.rs
│   │       │   ├── conv_transpose2d.rs
│   │       │   ├── conv_transpose3d.rs
│   │       │   ├── cross.rs
│   │       │   ├── gather.rs
│   │       │   ├── mask_fill.rs
│   │       │   ├── mask_where.rs
│   │       │   ├── max_pool2d.rs
│   │       │   ├── max_pool2d_backward.rs
│   │       │   ├── mod.rs
│   │       │   ├── normal.rs
│   │       │   ├── quantization.rs
│   │       │   ├── reduce.rs
│   │       │   ├── repeat_dim.rs
│   │       │   ├── scatter.rs
│   │       │   ├── select.rs
│   │       │   ├── select_assign.rs
│   │       │   ├── slice.rs
│   │       │   ├── slice_assign.rs
│   │       │   ├── unary.rs
│   │       │   └── uniform.rs
│   │       ├── cubecl.rs
│   │       ├── fused_ops/
│   │       │   ├── mod.rs
│   │       │   └── reduce_broadcasted.rs
│   │       ├── fusion.rs
│   │       ├── tensor/
│   │       │   ├── bool/
│   │       │   │   ├── mod.rs
│   │       │   │   └── ops/
│   │       │   │       ├── all.rs
│   │       │   │       ├── any.rs
│   │       │   │       ├── argwhere_nonzero.rs
│   │       │   │       ├── cat.rs
│   │       │   │       ├── comparison.rs
│   │       │   │       ├── create_like.rs
│   │       │   │       ├── expand.rs
│   │       │   │       ├── flip.rs
│   │       │   │       ├── full.rs
│   │       │   │       ├── gather_scatter.rs
│   │       │   │       ├── init.rs
│   │       │   │       ├── logical.rs
│   │       │   │       ├── mask.rs
│   │       │   │       ├── mod.rs
│   │       │   │       ├── movedim.rs
│   │       │   │       ├── permute.rs
│   │       │   │       ├── repeat.rs
│   │       │   │       ├── repeat_dim.rs
│   │       │   │       ├── reshape.rs
│   │       │   │       ├── select.rs
│   │       │   │       ├── stack.rs
│   │       │   │       ├── take.rs
│   │       │   │       ├── transpose.rs
│   │       │   │       ├── tri_mask.rs
│   │       │   │       └── unfold.rs
│   │       │   ├── clone_invariance.rs
│   │       │   ├── float/
│   │       │   │   ├── activation/
│   │       │   │   │   ├── celu.rs
│   │       │   │   │   ├── elu.rs
│   │       │   │   │   ├── gelu.rs
│   │       │   │   │   ├── glu.rs
│   │       │   │   │   ├── hard_sigmoid.rs
│   │       │   │   │   ├── leaky_relu.rs
│   │       │   │   │   ├── log_sigmoid.rs
│   │       │   │   │   ├── mish.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── prelu.rs
│   │       │   │   │   ├── quiet_softmax.rs
│   │       │   │   │   ├── relu.rs
│   │       │   │   │   ├── selu.rs
│   │       │   │   │   ├── sigmoid.rs
│   │       │   │   │   ├── silu.rs
│   │       │   │   │   ├── softmax.rs
│   │       │   │   │   ├── softmin.rs
│   │       │   │   │   ├── softplus.rs
│   │       │   │   │   ├── softsign.rs
│   │       │   │   │   ├── tanh_activation.rs
│   │       │   │   │   └── thresholded_relu.rs
│   │       │   │   ├── grid/
│   │       │   │   │   ├── affine_grid.rs
│   │       │   │   │   ├── meshgrid.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   ├── linalg/
│   │       │   │   │   ├── cosine_similarity.rs
│   │       │   │   │   ├── diag.rs
│   │       │   │   │   ├── lu_decomposition.rs
│   │       │   │   │   ├── matvec.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── outer.rs
│   │       │   │   │   ├── trace.rs
│   │       │   │   │   └── vector_norm.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── module/
│   │       │   │   │   ├── adaptive_avgpool1d.rs
│   │       │   │   │   ├── adaptive_avgpool2d.rs
│   │       │   │   │   ├── attention.rs
│   │       │   │   │   ├── avgpool1d.rs
│   │       │   │   │   ├── avgpool2d.rs
│   │       │   │   │   ├── bicubic_interpolate.rs
│   │       │   │   │   ├── bilinear_interpolate.rs
│   │       │   │   │   ├── conv1d.rs
│   │       │   │   │   ├── conv2d.rs
│   │       │   │   │   ├── conv3d.rs
│   │       │   │   │   ├── conv_transpose1d.rs
│   │       │   │   │   ├── conv_transpose2d.rs
│   │       │   │   │   ├── conv_transpose3d.rs
│   │       │   │   │   ├── deform_conv2d.rs
│   │       │   │   │   ├── forward.rs
│   │       │   │   │   ├── lanczos3_interpolate.rs
│   │       │   │   │   ├── linear.rs
│   │       │   │   │   ├── maxpool1d.rs
│   │       │   │   │   ├── maxpool2d.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── nearest_interpolate.rs
│   │       │   │   │   └── unfold4d.rs
│   │       │   │   ├── ops/
│   │       │   │   │   ├── abs.rs
│   │       │   │   │   ├── add.rs
│   │       │   │   │   ├── aggregation.rs
│   │       │   │   │   ├── all.rs
│   │       │   │   │   ├── any.rs
│   │       │   │   │   ├── arg.rs
│   │       │   │   │   ├── cast.rs
│   │       │   │   │   ├── cat.rs
│   │       │   │   │   ├── ceil.rs
│   │       │   │   │   ├── chunk.rs
│   │       │   │   │   ├── clamp.rs
│   │       │   │   │   ├── close.rs
│   │       │   │   │   ├── comparison.rs
│   │       │   │   │   ├── create_like.rs
│   │       │   │   │   ├── cross.rs
│   │       │   │   │   ├── cumulative.rs
│   │       │   │   │   ├── div.rs
│   │       │   │   │   ├── dot.rs
│   │       │   │   │   ├── erf.rs
│   │       │   │   │   ├── exp.rs
│   │       │   │   │   ├── expand.rs
│   │       │   │   │   ├── finite.rs
│   │       │   │   │   ├── flatten.rs
│   │       │   │   │   ├── flip.rs
│   │       │   │   │   ├── floor.rs
│   │       │   │   │   ├── fmod.rs
│   │       │   │   │   ├── full.rs
│   │       │   │   │   ├── gather_scatter.rs
│   │       │   │   │   ├── grid_sample.rs
│   │       │   │   │   ├── inf.rs
│   │       │   │   │   ├── init.rs
│   │       │   │   │   ├── iter_dim.rs
│   │       │   │   │   ├── log.rs
│   │       │   │   │   ├── log1p.rs
│   │       │   │   │   ├── mask.rs
│   │       │   │   │   ├── matmul.rs
│   │       │   │   │   ├── maxmin.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── movedim.rs
│   │       │   │   │   ├── mul.rs
│   │       │   │   │   ├── nan.rs
│   │       │   │   │   ├── narrow.rs
│   │       │   │   │   ├── neg.rs
│   │       │   │   │   ├── one_hot.rs
│   │       │   │   │   ├── padding.rs
│   │       │   │   │   ├── permute.rs
│   │       │   │   │   ├── powf.rs
│   │       │   │   │   ├── powf_scalar.rs
│   │       │   │   │   ├── prod.rs
│   │       │   │   │   ├── random.rs
│   │       │   │   │   ├── recip.rs
│   │       │   │   │   ├── remainder.rs
│   │       │   │   │   ├── repeat.rs
│   │       │   │   │   ├── repeat_dim.rs
│   │       │   │   │   ├── reshape.rs
│   │       │   │   │   ├── round.rs
│   │       │   │   │   ├── select.rs
│   │       │   │   │   ├── sign.rs
│   │       │   │   │   ├── slice.rs
│   │       │   │   │   ├── slice_assign.rs
│   │       │   │   │   ├── sort_argsort.rs
│   │       │   │   │   ├── split.rs
│   │       │   │   │   ├── sqrt.rs
│   │       │   │   │   ├── square.rs
│   │       │   │   │   ├── squeeze.rs
│   │       │   │   │   ├── stack.rs
│   │       │   │   │   ├── sub.rs
│   │       │   │   │   ├── take.rs
│   │       │   │   │   ├── topk.rs
│   │       │   │   │   ├── transaction.rs
│   │       │   │   │   ├── transpose.rs
│   │       │   │   │   ├── tri.rs
│   │       │   │   │   ├── trig.rs
│   │       │   │   │   ├── trunc.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   ├── primitive.rs
│   │       │   │   ├── quantization/
│   │       │   │   │   ├── calibration.rs
│   │       │   │   │   ├── data.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── ops/
│   │       │   │   │   │   ├── extended/
│   │       │   │   │   │   │   ├── abs.rs
│   │       │   │   │   │   │   ├── add.rs
│   │       │   │   │   │   │   ├── aggregation.rs
│   │       │   │   │   │   │   ├── all.rs
│   │       │   │   │   │   │   ├── any.rs
│   │       │   │   │   │   │   ├── arg.rs
│   │       │   │   │   │   │   ├── cat.rs
│   │       │   │   │   │   │   ├── ceil.rs
│   │       │   │   │   │   │   ├── chunk.rs
│   │       │   │   │   │   │   ├── clamp.rs
│   │       │   │   │   │   │   ├── cos.rs
│   │       │   │   │   │   │   ├── cosh.rs
│   │       │   │   │   │   │   ├── div.rs
│   │       │   │   │   │   │   ├── erf.rs
│   │       │   │   │   │   │   ├── exp.rs
│   │       │   │   │   │   │   ├── expand.rs
│   │       │   │   │   │   │   ├── flip.rs
│   │       │   │   │   │   │   ├── floor.rs
│   │       │   │   │   │   │   ├── gather_scatter.rs
│   │       │   │   │   │   │   ├── log.rs
│   │       │   │   │   │   │   ├── log1p.rs
│   │       │   │   │   │   │   ├── map_comparison.rs
│   │       │   │   │   │   │   ├── mask.rs
│   │       │   │   │   │   │   ├── maxmin.rs
│   │       │   │   │   │   │   ├── mod.rs
│   │       │   │   │   │   │   ├── mul.rs
│   │       │   │   │   │   │   ├── narrow.rs
│   │       │   │   │   │   │   ├── neg.rs
│   │       │   │   │   │   │   ├── permute.rs
│   │       │   │   │   │   │   ├── powf.rs
│   │       │   │   │   │   │   ├── powf_scalar.rs
│   │       │   │   │   │   │   ├── recip.rs
│   │       │   │   │   │   │   ├── remainder.rs
│   │       │   │   │   │   │   ├── repeat_dim.rs
│   │       │   │   │   │   │   ├── reshape.rs
│   │       │   │   │   │   │   ├── round.rs
│   │       │   │   │   │   │   ├── select.rs
│   │       │   │   │   │   │   ├── sin.rs
│   │       │   │   │   │   │   ├── sinh.rs
│   │       │   │   │   │   │   ├── slice.rs
│   │       │   │   │   │   │   ├── sort_argsort.rs
│   │       │   │   │   │   │   ├── split.rs
│   │       │   │   │   │   │   ├── sqrt.rs
│   │       │   │   │   │   │   ├── stack.rs
│   │       │   │   │   │   │   ├── sub.rs
│   │       │   │   │   │   │   ├── tan.rs
│   │       │   │   │   │   │   ├── tanh.rs
│   │       │   │   │   │   │   ├── topk.rs
│   │       │   │   │   │   │   └── transpose.rs
│   │       │   │   │   │   ├── matmul.rs
│   │       │   │   │   │   ├── mod.rs
│   │       │   │   │   │   └── quantize.rs
│   │       │   │   │   └── scheme.rs
│   │       │   │   └── stats/
│   │       │   │       ├── cov.rs
│   │       │   │       ├── display.rs
│   │       │   │       ├── eye.rs
│   │       │   │       ├── median.rs
│   │       │   │       ├── mod.rs
│   │       │   │       └── var.rs
│   │       │   ├── int/
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ops/
│   │       │   │   │   ├── abs.rs
│   │       │   │   │   ├── add.rs
│   │       │   │   │   ├── aggregation.rs
│   │       │   │   │   ├── all.rs
│   │       │   │   │   ├── any.rs
│   │       │   │   │   ├── arange.rs
│   │       │   │   │   ├── arange_step.rs
│   │       │   │   │   ├── arg.rs
│   │       │   │   │   ├── bitwise.rs
│   │       │   │   │   ├── cartesian_grid.rs
│   │       │   │   │   ├── cast.rs
│   │       │   │   │   ├── cat.rs
│   │       │   │   │   ├── chunk.rs
│   │       │   │   │   ├── comparison.rs
│   │       │   │   │   ├── create_like.rs
│   │       │   │   │   ├── cumulative.rs
│   │       │   │   │   ├── div.rs
│   │       │   │   │   ├── expand.rs
│   │       │   │   │   ├── flip.rs
│   │       │   │   │   ├── full.rs
│   │       │   │   │   ├── gather_scatter.rs
│   │       │   │   │   ├── init.rs
│   │       │   │   │   ├── mask.rs
│   │       │   │   │   ├── matmul.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── movedim.rs
│   │       │   │   │   ├── mul.rs
│   │       │   │   │   ├── one_hot.rs
│   │       │   │   │   ├── permute.rs
│   │       │   │   │   ├── random.rs
│   │       │   │   │   ├── remainder.rs
│   │       │   │   │   ├── repeat.rs
│   │       │   │   │   ├── repeat_dim.rs
│   │       │   │   │   ├── reshape.rs
│   │       │   │   │   ├── roll.rs
│   │       │   │   │   ├── select.rs
│   │       │   │   │   ├── sign.rs
│   │       │   │   │   ├── slice.rs
│   │       │   │   │   ├── slice_assign.rs
│   │       │   │   │   ├── sort_argsort.rs
│   │       │   │   │   ├── stack.rs
│   │       │   │   │   ├── sub.rs
│   │       │   │   │   ├── take.rs
│   │       │   │   │   ├── topk.rs
│   │       │   │   │   ├── transpose.rs
│   │       │   │   │   ├── tri.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   └── primitive.rs
│   │       │   ├── mod.rs
│   │       │   └── multi_threads.rs
│   │       └── tensor.rs
│   ├── burn-candle/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── candle_utils.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── utils.rs
│   │       └── tensor.rs
│   ├── burn-collective/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── multinode-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── README.md
│   │   │   └── src/
│   │   │       ├── bin/
│   │   │       │   ├── global.rs
│   │   │       │   ├── node.rs
│   │   │       │   └── test_launcher.rs
│   │   │       ├── lib.rs
│   │   │       └── shared.rs
│   │   └── src/
│   │       ├── api.rs
│   │       ├── config.rs
│   │       ├── global/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── node/
│   │       │   │   ├── base.rs
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ring.rs
│   │       │   │   ├── sync.rs
│   │       │   │   ├── tree.rs
│   │       │   │   └── worker.rs
│   │       │   ├── orchestrator/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── state.rs
│   │       │   └── shared.rs
│   │       ├── lib.rs
│   │       ├── local/
│   │       │   ├── all_reduce/
│   │       │   │   ├── base.rs
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   ├── ring.rs
│   │       │   │   └── tree.rs
│   │       │   ├── broadcast/
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   └── tree.rs
│   │       │   ├── client.rs
│   │       │   ├── mod.rs
│   │       │   ├── reduce/
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   └── tree.rs
│   │       │   ├── server.rs
│   │       │   └── tensor_map.rs
│   │       └── tests/
│   │           ├── all_reduce.rs
│   │           ├── broadcast.rs
│   │           ├── mod.rs
│   │           └── reduce.rs
│   ├── burn-communication/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── base.rs
│   │       ├── data_service.rs
│   │       ├── lib.rs
│   │       ├── util.rs
│   │       └── websocket/
│   │           ├── base.rs
│   │           ├── client.rs
│   │           ├── mod.rs
│   │           └── server.rs
│   ├── burn-core/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── config.rs
│   │   │   ├── data/
│   │   │   │   ├── dataloader/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── batch.rs
│   │   │   │   │   ├── batcher.rs
│   │   │   │   │   ├── builder.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── multithread.rs
│   │   │   │   │   ├── split.rs
│   │   │   │   │   └── strategy.rs
│   │   │   │   └── mod.rs
│   │   │   ├── lib.rs
│   │   │   ├── module/
│   │   │   │   ├── base.rs
│   │   │   │   ├── display.rs
│   │   │   │   ├── initializer.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── param/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── constant.rs
│   │   │   │   │   ├── id.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── primitive.rs
│   │   │   │   │   ├── running.rs
│   │   │   │   │   ├── tensor.rs
│   │   │   │   │   └── visitor.rs
│   │   │   │   ├── quantize.rs
│   │   │   │   └── reinit.rs
│   │   │   ├── record/
│   │   │   │   ├── base.rs
│   │   │   │   ├── file.rs
│   │   │   │   ├── memory.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── primitive.rs
│   │   │   │   ├── recorder.rs
│   │   │   │   ├── serde/
│   │   │   │   │   ├── adapter.rs
│   │   │   │   │   ├── data.rs
│   │   │   │   │   ├── de.rs
│   │   │   │   │   ├── error.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── ser.rs
│   │   │   │   ├── settings.rs
│   │   │   │   └── tensor.rs
│   │   │   ├── tensor.rs
│   │   │   └── vision.rs
│   │   └── tests/
│   │       ├── test_derive_config.rs
│   │       ├── test_derive_module.rs
│   │       ├── test_derive_record.rs
│   │       └── test_record_resilience.rs
│   ├── burn-cpu/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-cubecl/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── fusion.rs
│   │       ├── kernel/
│   │       │   ├── attention/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune.rs
│   │       │   ├── binary.rs
│   │       │   ├── binary_float.rs
│   │       │   ├── binary_int.rs
│   │       │   ├── cast/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bool_cast.rs
│   │       │   │   └── mod.rs
│   │       │   ├── clamp.rs
│   │       │   ├── comparison.rs
│   │       │   ├── contiguous.rs
│   │       │   ├── conv/
│   │       │   │   ├── backward_data/
│   │       │   │   │   ├── fallback.rs
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── backward_weight/
│   │       │   │   │   ├── fallback.rs
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── conv_transpose2d/
│   │       │   │   │   ├── base.rs
│   │       │   │   │   ├── col2im.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── transpose_direct.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── conv_transpose3d.rs
│   │       │   │   ├── deform_conv2d.rs
│   │       │   │   ├── deform_conv_transpose2d.rs
│   │       │   │   ├── direct.rs
│   │       │   │   ├── forward/
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── im2col.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune_key.rs
│   │       │   ├── cross.rs
│   │       │   ├── grid_sample/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bilinear.rs
│   │       │   │   └── mod.rs
│   │       │   ├── index/
│   │       │   │   ├── flip.rs
│   │       │   │   ├── gather.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── repeat_dim.rs
│   │       │   │   ├── scatter.rs
│   │       │   │   ├── select.rs
│   │       │   │   ├── select_assign.rs
│   │       │   │   ├── slice.rs
│   │       │   │   └── slice_assign.rs
│   │       │   ├── interpolate/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bicubic.rs
│   │       │   │   ├── bilinear.rs
│   │       │   │   ├── lanczos3.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── nearest.rs
│   │       │   │   └── nearest_backward.rs
│   │       │   ├── mask/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mask_fill.rs
│   │       │   │   ├── mask_where.rs
│   │       │   │   └── mod.rs
│   │       │   ├── matmul/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── tune/
│   │       │   │   │   ├── base.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   └── utils.rs
│   │       │   ├── mod.rs
│   │       │   ├── pool/
│   │       │   │   ├── adaptive_avg_pool2d.rs
│   │       │   │   ├── adaptive_avg_pool2d_backward.rs
│   │       │   │   ├── avg_pool2d.rs
│   │       │   │   ├── avg_pool2d_backward.rs
│   │       │   │   ├── max_pool2d.rs
│   │       │   │   ├── max_pool2d_backward.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── pool2d.rs
│   │       │   ├── prng/
│   │       │   │   ├── bernoulli.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── normal.rs
│   │       │   │   └── uniform.rs
│   │       │   ├── quantization/
│   │       │   │   ├── dequantize.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── quantize.rs
│   │       │   ├── reduce/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune.rs
│   │       │   ├── unary_float.rs
│   │       │   ├── unary_int.rs
│   │       │   ├── unary_numeric.rs
│   │       │   └── utils.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── numeric.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── template/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   └── source.rs
│   │       ├── tensor/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   └── quantization.rs
│   │       └── tune_key.rs
│   ├── burn-cubecl-fusion/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── base.rs
│   │       ├── engine/
│   │       │   ├── codegen/
│   │       │   │   ├── base.rs
│   │       │   │   ├── io.rs
│   │       │   │   ├── ir.rs
│   │       │   │   ├── kernel.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── tensor.rs
│   │       │   │   └── view.rs
│   │       │   ├── fuser.rs
│   │       │   ├── launch/
│   │       │   │   ├── base.rs
│   │       │   │   ├── executor.rs
│   │       │   │   ├── input.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── output.rs
│   │       │   │   ├── plan.rs
│   │       │   │   ├── runner.rs
│   │       │   │   └── vectorization/
│   │       │   │       ├── base.rs
│   │       │   │       ├── mod.rs
│   │       │   │       └── planner.rs
│   │       │   ├── mod.rs
│   │       │   ├── scoring.rs
│   │       │   ├── settings.rs
│   │       │   └── trace/
│   │       │       ├── base.rs
│   │       │       ├── block.rs
│   │       │       ├── fuser.rs
│   │       │       └── mod.rs
│   │       ├── lib.rs
│   │       ├── optim/
│   │       │   ├── base.rs
│   │       │   ├── elemwise/
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── optimization.rs
│   │       │   ├── matmul/
│   │       │   │   ├── args.rs
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── optimization.rs
│   │       │   │   └── tune.rs
│   │       │   ├── mod.rs
│   │       │   ├── reduce/
│   │       │   │   ├── args.rs
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── optimization.rs
│   │       │   │   └── tune.rs
│   │       │   └── reduce_broadcasted/
│   │       │       ├── fuser/
│   │       │       │   ├── base.rs
│   │       │       │   ├── block.rs
│   │       │       │   ├── full.rs
│   │       │       │   ├── full_analyzer.rs
│   │       │       │   └── mod.rs
│   │       │       ├── launch.rs
│   │       │       ├── mod.rs
│   │       │       ├── optimization.rs
│   │       │       ├── tune.rs
│   │       │       └── unit.rs
│   │       └── tune.rs
│   ├── burn-cuda/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-dataset/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── hf_dataset.rs
│   │   │   └── speech_commands.rs
│   │   ├── src/
│   │   │   ├── audio/
│   │   │   │   ├── mod.rs
│   │   │   │   └── speech_commands.rs
│   │   │   ├── dataset/
│   │   │   │   ├── base.rs
│   │   │   │   ├── dataframe.rs
│   │   │   │   ├── fake.rs
│   │   │   │   ├── in_memory.rs
│   │   │   │   ├── iterator.rs
│   │   │   │   ├── mod.rs
│   │   │   │   └── sqlite.rs
│   │   │   ├── lib.rs
│   │   │   ├── nlp/
│   │   │   │   ├── ag_news.rs
│   │   │   │   ├── mod.rs
│   │   │   │   └── text_folder.rs
│   │   │   ├── source/
│   │   │   │   ├── huggingface/
│   │   │   │   │   ├── downloader.rs
│   │   │   │   │   ├── importer.py
│   │   │   │   │   └── mod.rs
│   │   │   │   └── mod.rs
│   │   │   ├── transform/
│   │   │   │   ├── composed.rs
│   │   │   │   ├── mapper.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── options.rs
│   │   │   │   ├── partial.rs
│   │   │   │   ├── sampler.rs
│   │   │   │   ├── selection.rs
│   │   │   │   ├── shuffle.rs
│   │   │   │   └── window.rs
│   │   │   └── vision/
│   │   │       ├── cifar.rs
│   │   │       ├── image_folder.rs
│   │   │       ├── mnist.rs
│   │   │       └── mod.rs
│   │   └── tests/
│   │       └── data/
│   │           ├── dataset-fmt.csv
│   │           ├── dataset.csv
│   │           ├── dataset.json
│   │           ├── dataset_coco.json
│   │           ├── segmask_folder/
│   │           │   └── annotations/
│   │           │       ├── mask_checkerboard.txt
│   │           │       ├── mask_random_2colors.txt
│   │           │       └── mask_random_3colors.txt
│   │           └── text_folder/
│   │               ├── negative/
│   │               │   ├── sample1.txt
│   │               │   └── sample2.txt
│   │               └── positive/
│   │                   ├── sample1.txt
│   │                   └── sample2.txt
│   ├── burn-derive/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── config/
│   │       │   ├── analyzer.rs
│   │       │   ├── analyzer_enum.rs
│   │       │   ├── analyzer_struct.rs
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── module/
│   │       │   ├── base.rs
│   │       │   ├── codegen.rs
│   │       │   ├── codegen_enum.rs
│   │       │   ├── codegen_struct.rs
│   │       │   ├── display.rs
│   │       │   ├── generics.rs
│   │       │   ├── mod.rs
│   │       │   ├── record.rs
│   │       │   ├── record_enum.rs
│   │       │   └── record_struct.rs
│   │       ├── record/
│   │       │   ├── base.rs
│   │       │   ├── codegen.rs
│   │       │   ├── item/
│   │       │   │   ├── codegen.rs
│   │       │   │   ├── codegen_enum.rs
│   │       │   │   ├── codegen_struct.rs
│   │       │   │   └── mod.rs
│   │       │   └── mod.rs
│   │       └── shared/
│   │           ├── attribute.rs
│   │           ├── enum_variant.rs
│   │           ├── field.rs
│   │           ├── generics.rs
│   │           └── mod.rs
│   ├── burn-dispatch/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── device.rs
│   │       ├── lib.rs
│   │       ├── macros.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       └── tensor.rs
│   ├── burn-fusion/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── client.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── binary.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── unary.rs
│   │       ├── search/
│   │       │   ├── block.rs
│   │       │   ├── merging.rs
│   │       │   ├── mod.rs
│   │       │   └── optimization/
│   │       │       ├── blocks.rs
│   │       │       ├── mod.rs
│   │       │       └── stream.rs
│   │       ├── server.rs
│   │       ├── stream/
│   │       │   ├── base.rs
│   │       │   ├── context.rs
│   │       │   ├── execution/
│   │       │   │   ├── base.rs
│   │       │   │   ├── explorer.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ordering.rs
│   │       │   │   ├── policy.rs
│   │       │   │   ├── processor.rs
│   │       │   │   ├── tests.rs
│   │       │   │   └── validator.rs
│   │       │   ├── memory_checks.rs
│   │       │   ├── mod.rs
│   │       │   ├── multi.rs
│   │       │   ├── queue/
│   │       │   │   ├── base.rs
│   │       │   │   ├── execution.rs
│   │       │   │   └── mod.rs
│   │       │   ├── shared_tensors.rs
│   │       │   └── store/
│   │       │       ├── base.rs
│   │       │       ├── index.rs
│   │       │       └── mod.rs
│   │       └── tensor.rs
│   ├── burn-ir/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── builder.rs
│   │       ├── handle.rs
│   │       ├── lib.rs
│   │       ├── operation.rs
│   │       ├── scalar.rs
│   │       └── tensor.rs
│   ├── burn-ndarray/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── adaptive_avgpool.rs
│   │       │   ├── avgpool.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── conv.rs
│   │       │   ├── deform_conv.rs
│   │       │   ├── grid_sample.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── interpolate.rs
│   │       │   ├── macros.rs
│   │       │   ├── matmul.rs
│   │       │   ├── maxpool.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── padding.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── quantization.rs
│   │       │   ├── simd/
│   │       │   │   ├── avgpool.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── binary.rs
│   │       │   │   ├── binary_elemwise.rs
│   │       │   │   ├── cmp.rs
│   │       │   │   ├── conv.rs
│   │       │   │   ├── maxpool.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── unary.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── parallel.rs
│   │       ├── rand.rs
│   │       ├── sharing.rs
│   │       ├── storage.rs
│   │       └── tensor.rs
│   ├── burn-nn/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── activation/
│   │   │   │   ├── activation_wrapper.rs
│   │   │   │   ├── celu.rs
│   │   │   │   ├── elu.rs
│   │   │   │   ├── gelu.rs
│   │   │   │   ├── glu.rs
│   │   │   │   ├── hard_shrink.rs
│   │   │   │   ├── hard_sigmoid.rs
│   │   │   │   ├── hard_swish.rs
│   │   │   │   ├── leaky_relu.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── prelu.rs
│   │   │   │   ├── relu.rs
│   │   │   │   ├── selu.rs
│   │   │   │   ├── shrink.rs
│   │   │   │   ├── sigmoid.rs
│   │   │   │   ├── soft_shrink.rs
│   │   │   │   ├── softplus.rs
│   │   │   │   ├── softsign.rs
│   │   │   │   ├── swiglu.rs
│   │   │   │   ├── tanh.rs
│   │   │   │   └── thresholded_relu.rs
│   │   │   ├── lib.rs
│   │   │   ├── loss/
│   │   │   │   ├── binary_cross_entropy.rs
│   │   │   │   ├── cosine_embedding.rs
│   │   │   │   ├── cross_entropy.rs
│   │   │   │   ├── ctc.rs
│   │   │   │   ├── huber.rs
│   │   │   │   ├── kldiv.rs
│   │   │   │   ├── lp_loss.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── mse.rs
│   │   │   │   ├── poisson.rs
│   │   │   │   ├── pretrained/
│   │   │   │   │   ├── gram_matrix/
│   │   │   │   │   │   ├── gram_matrix_loss.rs
│   │   │   │   │   │   ├── mod.rs
│   │   │   │   │   │   ├── vgg19.rs
│   │   │   │   │   │   └── weights.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── reduction.rs
│   │   │   │   ├── rnnt.rs
│   │   │   │   └── smooth_l1.rs
│   │   │   ├── modules/
│   │   │   │   ├── attention/
│   │   │   │   │   ├── cross_attention.rs
│   │   │   │   │   ├── mask.rs
│   │   │   │   │   ├── mha.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── cache/
│   │   │   │   │   ├── autoregressive.rs
│   │   │   │   │   ├── base.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── conv/
│   │   │   │   │   ├── checks.rs
│   │   │   │   │   ├── conv1d.rs
│   │   │   │   │   ├── conv2d.rs
│   │   │   │   │   ├── conv3d.rs
│   │   │   │   │   ├── conv_transpose1d.rs
│   │   │   │   │   ├── conv_transpose2d.rs
│   │   │   │   │   ├── conv_transpose3d.rs
│   │   │   │   │   ├── deform_conv2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── dropout.rs
│   │   │   │   ├── embedding.rs
│   │   │   │   ├── interpolate/
│   │   │   │   │   ├── interpolate1d.rs
│   │   │   │   │   ├── interpolate2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── linear.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── noise.rs
│   │   │   │   ├── norm/
│   │   │   │   │   ├── batch.rs
│   │   │   │   │   ├── group.rs
│   │   │   │   │   ├── instance.rs
│   │   │   │   │   ├── layer.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── normalization_wrapper.rs
│   │   │   │   │   └── rms.rs
│   │   │   │   ├── pool/
│   │   │   │   │   ├── adaptive_avg_pool1d.rs
│   │   │   │   │   ├── adaptive_avg_pool2d.rs
│   │   │   │   │   ├── avg_pool1d.rs
│   │   │   │   │   ├── avg_pool2d.rs
│   │   │   │   │   ├── max_pool1d.rs
│   │   │   │   │   ├── max_pool2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── pos_encoding.rs
│   │   │   │   ├── rnn/
│   │   │   │   │   ├── basic.rs
│   │   │   │   │   ├── gate_controller.rs
│   │   │   │   │   ├── gru.rs
│   │   │   │   │   ├── lstm.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── rope_encoding.rs
│   │   │   │   ├── transformer/
│   │   │   │   │   ├── decoder.rs
│   │   │   │   │   ├── encoder.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── pwff.rs
│   │   │   │   └── unfold.rs
│   │   │   └── padding.rs
│   │   └── tests/
│   │       └── quantize.rs
│   ├── burn-no-std-tests/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── burnpack.rs
│   │   │   ├── conv.rs
│   │   │   ├── lib.rs
│   │   │   ├── mlp.rs
│   │   │   ├── model.rs
│   │   │   └── safetensors.rs
│   │   └── tests/
│   │       ├── burnpack_tests.rs
│   │       ├── safetensors_tests.rs
│   │       └── test_integration.rs
│   ├── burn-optim/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── grad_clipping/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── lr_scheduler/
│   │       │   ├── base.rs
│   │       │   ├── composed.rs
│   │       │   ├── constant.rs
│   │       │   ├── cosine.rs
│   │       │   ├── exponential.rs
│   │       │   ├── linear.rs
│   │       │   ├── mod.rs
│   │       │   ├── noam.rs
│   │       │   └── step.rs
│   │       └── optim/
│   │           ├── adagrad.rs
│   │           ├── adam.rs
│   │           ├── adamw.rs
│   │           ├── base.rs
│   │           ├── decay.rs
│   │           ├── grad_accum.rs
│   │           ├── grads.rs
│   │           ├── lbfgs.rs
│   │           ├── mod.rs
│   │           ├── momentum.rs
│   │           ├── muon.rs
│   │           ├── rmsprop.rs
│   │           ├── sgd.rs
│   │           ├── simple/
│   │           │   ├── adaptor.rs
│   │           │   ├── base.rs
│   │           │   ├── mod.rs
│   │           │   └── record/
│   │           │       ├── base.rs
│   │           │       ├── mod.rs
│   │           │       └── v1.rs
│   │           └── visitor.rs
│   ├── burn-remote/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── client/
│   │       │   ├── base.rs
│   │       │   ├── channel.rs
│   │       │   ├── mod.rs
│   │       │   ├── runner.rs
│   │       │   └── worker.rs
│   │       ├── lib.rs
│   │       ├── server/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── processor.rs
│   │       │   ├── session.rs
│   │       │   └── stream.rs
│   │       └── shared/
│   │           ├── mod.rs
│   │           └── task.rs
│   ├── burn-rl/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── environment/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── policy/
│   │       │   ├── async_policy.rs
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       └── transition_buffer/
│   │           ├── base.rs
│   │           ├── mod.rs
│   │           └── slice_access.rs
│   ├── burn-rocm/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-router/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── bridge/
│   │       │   ├── base.rs
│   │       │   ├── byte.rs
│   │       │   └── mod.rs
│   │       ├── channel/
│   │       │   ├── base.rs
│   │       │   ├── direct.rs
│   │       │   └── mod.rs
│   │       ├── client/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── binary.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── unary.rs
│   │       ├── runner.rs
│   │       ├── tensor.rs
│   │       └── types.rs
│   ├── burn-std/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── id.rs
│   │       ├── lib.rs
│   │       ├── network.rs
│   │       └── tensor/
│   │           ├── dtype.rs
│   │           ├── mod.rs
│   │           ├── quantization.rs
│   │           ├── shape.rs
│   │           └── slice.rs
│   ├── burn-store/
│   │   ├── Cargo.toml
│   │   ├── MIGRATION.md
│   │   ├── README.md
│   │   ├── benches/
│   │   │   ├── download_resnet18.py
│   │   │   ├── generate_unified_models.py
│   │   │   ├── resnet18_loading.rs
│   │   │   ├── unified_loading.rs
│   │   │   ├── unified_saving.rs
│   │   │   └── zero_copy_loading.rs
│   │   ├── examples/
│   │   │   ├── burnpack_inspect.rs
│   │   │   └── half_precision.rs
│   │   ├── pytorch-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── src/
│   │   │   │   └── lib.rs
│   │   │   └── tests/
│   │   │       ├── backend.rs
│   │   │       ├── batch_norm/
│   │   │       │   ├── batch_norm2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── boolean/
│   │   │       │   ├── boolean.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── buffer/
│   │   │       │   ├── buffer.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── complex_nested/
│   │   │       │   ├── complex_nested.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── config/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── mod.rs
│   │   │       │   └── weights_with_config.pt
│   │   │       ├── conv1d/
│   │   │       │   ├── conv1d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv2d/
│   │   │       │   ├── conv2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv_transpose1d/
│   │   │       │   ├── conv_transpose1d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv_transpose2d/
│   │   │       │   ├── conv_transpose2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── debug_test.pt
│   │   │       ├── embedding/
│   │   │       │   ├── embedding.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── enum_module/
│   │   │       │   ├── enum_depthwise_false.pt
│   │   │       │   ├── enum_depthwise_true.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── group_norm/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── group_norm.pt
│   │   │       │   └── mod.rs
│   │   │       ├── integer/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── integer.pt
│   │   │       │   └── mod.rs
│   │   │       ├── key_remap/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── key_remap.pt
│   │   │       │   └── mod.rs
│   │   │       ├── key_remap_chained/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── key_remap.pt
│   │   │       │   └── mod.rs
│   │   │       ├── layer_norm/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── layer_norm.pt
│   │   │       │   └── mod.rs
│   │   │       ├── linear/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── linear.pt
│   │   │       │   ├── linear_with_bias.pt
│   │   │       │   └── mod.rs
│   │   │       ├── missing_module_field/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── missing_module_field.pt
│   │   │       │   └── mod.rs
│   │   │       ├── non_contiguous_indexes/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── mod.rs
│   │   │       │   └── non_contiguous_indexes.pt
│   │   │       ├── test_int.pt
│   │   │       ├── test_mod.rs
│   │   │       └── top_level_key/
│   │   │           ├── export_weights.py
│   │   │           ├── mod.rs
│   │   │           └── top_level_key.pt
│   │   ├── safetensors-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── src/
│   │   │   │   └── lib.rs
│   │   │   └── tests/
│   │   │       ├── backend.rs
│   │   │       ├── multi_layer/
│   │   │       │   ├── mod.rs
│   │   │       │   ├── multi_layer.py
│   │   │       │   └── multi_layer.safetensors
│   │   │       └── test_mod.rs
│   │   └── src/
│   │       ├── adapter.rs
│   │       ├── applier.rs
│   │       ├── apply_result.rs
│   │       ├── burnpack/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── reader.rs
│   │       │   ├── store.rs
│   │       │   ├── tests/
│   │       │   │   ├── alignment.rs
│   │       │   │   ├── edge_cases.rs
│   │       │   │   ├── header.rs
│   │       │   │   ├── helpers.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── reader.rs
│   │       │   │   ├── round_trip.rs
│   │       │   │   ├── store.rs
│   │       │   │   ├── writer.rs
│   │       │   │   └── zero_copy.rs
│   │       │   └── writer.rs
│   │       ├── collector.rs
│   │       ├── filter.rs
│   │       ├── keyremapper.rs
│   │       ├── lib.rs
│   │       ├── pytorch/
│   │       │   ├── lazy_data.rs
│   │       │   ├── mod.rs
│   │       │   ├── pickle_reader.rs
│   │       │   ├── reader.rs
│   │       │   ├── store.rs
│   │       │   └── tests/
│   │       │       ├── mod.rs
│   │       │       ├── reader/
│   │       │       │   ├── create_legacy_with_offsets.py
│   │       │       │   ├── create_tar_format.py
│   │       │       │   ├── mod.rs
│   │       │       │   ├── simple_legacy.py
│   │       │       │   ├── test_data/
│   │       │       │   │   ├── bfloat16.pt
│   │       │       │   │   ├── bool.pt
│   │       │       │   │   ├── broken.pt
│   │       │       │   │   ├── buffers.pt
│   │       │       │   │   ├── checkpoint.pt
│   │       │       │   │   ├── complex_structure.pt
│   │       │       │   │   ├── empty.pt
│   │       │       │   │   ├── extreme_values.pt
│   │       │       │   │   ├── float16.pt
│   │       │       │   │   ├── float32.pt
│   │       │       │   │   ├── float64.pt
│   │       │       │   │   ├── int16.pt
│   │       │       │   │   ├── int32.pt
│   │       │       │   │   ├── int64.pt
│   │       │       │   │   ├── int8.pt
│   │       │       │   │   ├── large_shape.pt
│   │       │       │   │   ├── legacy_shared_storage.pt
│   │       │       │   │   ├── legacy_with_offsets.pt
│   │       │       │   │   ├── mixed_types.pt
│   │       │       │   │   ├── nested_dict.pt
│   │       │       │   │   ├── parameter.pt
│   │       │       │   │   ├── scalar.pt
│   │       │       │   │   ├── simple_legacy.pt
│   │       │       │   │   ├── special_values.pt
│   │       │       │   │   ├── state_dict.pt
│   │       │       │   │   ├── tensor_2d.pt
│   │       │       │   │   ├── tensor_3d.pt
│   │       │       │   │   ├── tensor_4d.pt
│   │       │       │   │   └── uint8.pt
│   │       │       │   └── test_data.py
│   │       │       └── store/
│   │       │           ├── mod.rs
│   │       │           └── test_data/
│   │       │               ├── generate_enum_test.py
│   │       │               └── model_without_enum_variants.pt
│   │       ├── safetensors/
│   │       │   ├── mod.rs
│   │       │   ├── store.rs
│   │       │   └── tests/
│   │       │       ├── adapter.rs
│   │       │       ├── direct_access.rs
│   │       │       ├── error_handling.rs
│   │       │       ├── file_io.rs
│   │       │       ├── filtering.rs
│   │       │       ├── integration.rs
│   │       │       ├── metadata.rs
│   │       │       ├── mixed_datatypes.rs
│   │       │       ├── mod.rs
│   │       │       ├── multi_layer_verify.rs
│   │       │       ├── pytorch_import.rs
│   │       │       └── round_trip.rs
│   │       ├── tensor_snapshot.rs
│   │       └── traits.rs
│   ├── burn-tch/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── bin/
│   │       │   ├── cpu.rs
│   │       │   ├── cuda.rs
│   │       │   └── mps.rs
│   │       ├── cuda_hack/
│   │       │   ├── dummy_cuda_dependency.cpp
│   │       │   └── fake_cuda_dependency.cpp
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       └── tensor.rs
│   ├── burn-tensor/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── device.rs
│   │       ├── lib.rs
│   │       └── tensor/
│   │           ├── activation/
│   │           │   ├── base.rs
│   │           │   └── mod.rs
│   │           ├── api/
│   │           │   ├── autodiff.rs
│   │           │   ├── base.rs
│   │           │   ├── bool.rs
│   │           │   ├── cartesian_grid.rs
│   │           │   ├── check.rs
│   │           │   ├── float.rs
│   │           │   ├── fmod.rs
│   │           │   ├── int.rs
│   │           │   ├── mod.rs
│   │           │   ├── numeric.rs
│   │           │   ├── options.rs
│   │           │   ├── orderable.rs
│   │           │   ├── pad.rs
│   │           │   ├── take.rs
│   │           │   ├── transaction.rs
│   │           │   └── trunc.rs
│   │           ├── grid/
│   │           │   ├── affine_grid.rs
│   │           │   ├── meshgrid.rs
│   │           │   └── mod.rs
│   │           ├── linalg/
│   │           │   ├── cosine_similarity.rs
│   │           │   ├── diag.rs
│   │           │   ├── lu_decomposition.rs
│   │           │   ├── matvec.rs
│   │           │   ├── mod.rs
│   │           │   ├── outer.rs
│   │           │   ├── trace.rs
│   │           │   └── vector_norm.rs
│   │           ├── loss/
│   │           │   └── mod.rs
│   │           ├── mod.rs
│   │           ├── module.rs
│   │           ├── quantization.rs
│   │           ├── report.rs
│   │           └── stats/
│   │               └── mod.rs
│   ├── burn-tensor-testgen/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-train/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── checkpoint/
│   │       │   ├── async_checkpoint.rs
│   │       │   ├── base.rs
│   │       │   ├── file.rs
│   │       │   ├── mod.rs
│   │       │   └── strategy/
│   │       │       ├── base.rs
│   │       │       ├── composed.rs
│   │       │       ├── lastn.rs
│   │       │       ├── metric.rs
│   │       │       └── mod.rs
│   │       ├── components.rs
│   │       ├── evaluator/
│   │       │   ├── base.rs
│   │       │   ├── builder.rs
│   │       │   ├── components.rs
│   │       │   └── mod.rs
│   │       ├── learner/
│   │       │   ├── application_logger.rs
│   │       │   ├── base.rs
│   │       │   ├── classification.rs
│   │       │   ├── early_stopping.rs
│   │       │   ├── mod.rs
│   │       │   ├── regression.rs
│   │       │   ├── rl/
│   │       │   │   ├── checkpointer.rs
│   │       │   │   ├── components.rs
│   │       │   │   ├── env_runner/
│   │       │   │   │   ├── async_runner.rs
│   │       │   │   │   ├── base.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── off_policy.rs
│   │       │   │   ├── output.rs
│   │       │   │   ├── paradigm.rs
│   │       │   │   └── strategy.rs
│   │       │   ├── sequence.rs
│   │       │   ├── summary.rs
│   │       │   ├── supervised/
│   │       │   │   ├── mod.rs
│   │       │   │   ├── paradigm.rs
│   │       │   │   ├── step/
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── train.rs
│   │       │   │   └── strategies/
│   │       │   │       ├── base.rs
│   │       │   │       ├── ddp/
│   │       │   │       │   ├── README.md
│   │       │   │       │   ├── epoch.rs
│   │       │   │       │   ├── mod.rs
│   │       │   │       │   ├── strategy.rs
│   │       │   │       │   └── worker.rs
│   │       │   │       ├── mod.rs
│   │       │   │       ├── multi/
│   │       │   │       │   ├── epoch.rs
│   │       │   │       │   ├── mod.rs
│   │       │   │       │   └── strategy.rs
│   │       │   │       └── single/
│   │       │   │           ├── epoch.rs
│   │       │   │           ├── mod.rs
│   │       │   │           └── strategy.rs
│   │       │   └── train_val.rs
│   │       ├── lib.rs
│   │       ├── logger/
│   │       │   ├── async_logger.rs
│   │       │   ├── base.rs
│   │       │   ├── file.rs
│   │       │   ├── in_memory.rs
│   │       │   ├── metric.rs
│   │       │   └── mod.rs
│   │       ├── metric/
│   │       │   ├── acc.rs
│   │       │   ├── auroc.rs
│   │       │   ├── base.rs
│   │       │   ├── cer.rs
│   │       │   ├── classification.rs
│   │       │   ├── confusion_stats.rs
│   │       │   ├── cpu_temp.rs
│   │       │   ├── cpu_use.rs
│   │       │   ├── cuda.rs
│   │       │   ├── fbetascore.rs
│   │       │   ├── hamming.rs
│   │       │   ├── iteration.rs
│   │       │   ├── learning_rate.rs
│   │       │   ├── loss.rs
│   │       │   ├── memory_use.rs
│   │       │   ├── mod.rs
│   │       │   ├── perplexity.rs
│   │       │   ├── precision.rs
│   │       │   ├── processor/
│   │       │   │   ├── async_wrapper.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── full.rs
│   │       │   │   ├── metrics.rs
│   │       │   │   ├── minimal.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── rl_metrics.rs
│   │       │   │   └── rl_processor.rs
│   │       │   ├── recall.rs
│   │       │   ├── rl/
│   │       │   │   ├── cum_reward.rs
│   │       │   │   ├── ep_len.rs
│   │       │   │   ├── exploration_rate.rs
│   │       │   │   └── mod.rs
│   │       │   ├── state.rs
│   │       │   ├── store/
│   │       │   │   ├── aggregate.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── client.rs
│   │       │   │   ├── log.rs
│   │       │   │   └── mod.rs
│   │       │   ├── top_k_acc.rs
│   │       │   ├── vision/
│   │       │   │   ├── dice.rs
│   │       │   │   ├── dists/
│   │       │   │   │   ├── l2pool.rs
│   │       │   │   │   ├── metric.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── vgg16_l2pool.rs
│   │       │   │   │   └── weights.rs
│   │       │   │   ├── lpips/
│   │       │   │   │   ├── alexnet.rs
│   │       │   │   │   ├── metric.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── squeezenet.rs
│   │       │   │   │   ├── vgg.rs
│   │       │   │   │   └── weights.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ms_ssim.rs
│   │       │   │   ├── psnr.rs
│   │       │   │   └── ssim.rs
│   │       │   └── wer.rs
│   │       └── renderer/
│   │           ├── base.rs
│   │           ├── cli.rs
│   │           ├── mod.rs
│   │           └── tui/
│   │               ├── base.rs
│   │               ├── controls.rs
│   │               ├── full_history.rs
│   │               ├── metric_numeric.rs
│   │               ├── metric_text.rs
│   │               ├── mod.rs
│   │               ├── plot_utils.rs
│   │               ├── popup.rs
│   │               ├── progress.rs
│   │               ├── recent_history.rs
│   │               ├── renderer.rs
│   │               └── status.rs
│   ├── burn-vision/
│   │   ├── Cargo.toml
│   │   ├── src/
│   │   │   ├── backends/
│   │   │   │   ├── cpu/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── connected_components/
│   │   │   │   │   │   ├── spaghetti/
│   │   │   │   │   │   │   ├── Spaghetti_center_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_first_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_forest_labels.rs
│   │   │   │   │   │   │   ├── Spaghetti_last_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_single_line_forest_code.rs
│   │   │   │   │   │   │   └── mod.rs
│   │   │   │   │   │   └── spaghetti_4c/
│   │   │   │   │   │       ├── Spaghetti4C_center_line_forest_code.rs
│   │   │   │   │   │       ├── Spaghetti4C_first_line_forest_code.rs
│   │   │   │   │   │       ├── Spaghetti4C_forest_labels.rs
│   │   │   │   │   │       └── mod.rs
│   │   │   │   │   ├── connected_components.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── morphology/
│   │   │   │   │   │   ├── filter.rs
│   │   │   │   │   │   ├── filter_engine.rs
│   │   │   │   │   │   └── mod.rs
│   │   │   │   │   ├── nms.rs
│   │   │   │   │   └── ops.rs
│   │   │   │   ├── cube/
│   │   │   │   │   ├── connected_components/
│   │   │   │   │   │   ├── hardware_accelerated.rs
│   │   │   │   │   │   ├── mod.rs
│   │   │   │   │   │   └── prefix_sum.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── ops.rs
│   │   │   │   └── mod.rs
│   │   │   ├── base.rs
│   │   │   ├── lib.rs
│   │   │   ├── ops/
│   │   │   │   ├── base.rs
│   │   │   │   └── mod.rs
│   │   │   ├── tensor.rs
│   │   │   ├── tests/
│   │   │   │   └── mod.rs
│   │   │   ├── transform/
│   │   │   │   ├── mod.rs
│   │   │   │   └── transform2d.rs
│   │   │   └── utils/
│   │   │       ├── mod.rs
│   │   │       └── save.rs
│   │   └── tests/
│   │       ├── common/
│   │       │   └── mod.rs
│   │       ├── connected_components.rs
│   │       ├── morphology.rs
│   │       └── nms.rs
│   └── burn-wgpu/
│       ├── Cargo.toml
│       ├── README.md
│       └── src/
│           └── lib.rs
├── deny.toml
├── docs/
│   └── katex-header.html
├── examples/
│   ├── custom-csv-dataset/
│   │   ├── .gitignore
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── custom-csv-dataset.rs
│   │   │   └── dataframe-dataset.rs
│   │   └── src/
│   │       ├── dataframe_dataset.rs
│   │       ├── dataset.rs
│   │       ├── diabetes_patient.rs
│   │       ├── lib.rs
│   │       └── utils.rs
│   ├── custom-cubecl-kernel/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-cubecl-kernel.rs
│   │   └── src/
│   │       ├── backward.rs
│   │       ├── forward.rs
│   │       ├── kernel.rs
│   │       └── lib.rs
│   ├── custom-image-dataset/
│   │   ├── .gitignore
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── custom-image-dataset.rs
│   │   └── src/
│   │       ├── data.rs
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── custom-learning-strategy/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-learning-strategy.rs
│   │   └── src/
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── custom-renderer/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-renderer.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── custom-training-loop/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-training-loop.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── custom-wgpu-kernel/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-wgpu-kernel.rs
│   │   └── src/
│   │       ├── backward.rs
│   │       ├── forward.rs
│   │       ├── kernel.wgsl
│   │       └── lib.rs
│   ├── dop_timer/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── event_utils.rs
│   │       ├── main.rs
│   │       ├── parsers.rs
│   │       └── workers.rs
│   ├── dqn-agent/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── dqn-agent.rs
│   │   └── src/
│   │       ├── agent.rs
│   │       ├── env.rs
│   │       ├── lib.rs
│   │       ├── training.rs
│   │       └── utils.rs
│   ├── guide/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── guide.rs
│   │   └── src/
│   │       ├── bin/
│   │       │   ├── infer.rs
│   │       │   ├── print.rs
│   │       │   └── train.rs
│   │       ├── data.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── import-model-weights/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── bin/
│   │   │   │   ├── burnpack.rs
│   │   │   │   ├── convert.rs
│   │   │   │   ├── pytorch.rs
│   │   │   │   └── safetensors.rs
│   │   │   ├── inference.rs
│   │   │   ├── lib.rs
│   │   │   └── model.rs
│   │   └── weights/
│   │       ├── mnist.pt
│   │       ├── mnist.safetensors
│   │       └── mnist_train_export.py
│   ├── mnist/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   └── mnist.rs
│   │   └── src/
│   │       ├── data.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── mnist-inference-web/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build-for-web.sh
│   │   ├── index.html
│   │   ├── index.js
│   │   ├── run-server.sh
│   │   └── src/
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       ├── state.rs
│   │       └── web.rs
│   ├── modern-lstm/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── lstm-infer.rs
│   │   │   └── lstm-train.rs
│   │   └── src/
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── multi-gpus/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── multi-gpus.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── notebook/
│   │   ├── README.md
│   │   ├── autodiff.ipynb
│   │   └── basic-tensor-op.ipynb
│   ├── server/
│   │   ├── Cargo.toml
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   └── server.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── simple-regression/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── regression.rs
│   │   └── src/
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── text-classification/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   ├── ag-news-infer.rs
│   │   │   ├── ag-news-train.rs
│   │   │   ├── db-pedia-infer.rs
│   │   │   └── db-pedia-train.rs
│   │   └── src/
│   │       ├── data/
│   │       │   ├── batcher.rs
│   │       │   ├── dataset.rs
│   │       │   ├── mod.rs
│   │       │   └── tokenizer.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── text-generation/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── text-generation.rs
│   │   └── src/
│   │       ├── data/
│   │       │   ├── batcher.rs
│   │       │   ├── dataset.rs
│   │       │   ├── mod.rs
│   │       │   └── tokenizer.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   └── wgan/
│       ├── Cargo.toml
│       ├── README.md
│       ├── examples/
│       │   ├── wgan-generate.rs
│       │   └── wgan-mnist.rs
│       └── src/
│           ├── dataset.rs
│           ├── infer.rs
│           ├── lib.rs
│           ├── model.rs
│           └── training.rs
├── rustfmt.toml
└── xtask/
    ├── Cargo.toml
    └── src/
        ├── commands/
        │   ├── books.rs
        │   ├── build.rs
        │   ├── doc.rs
        │   ├── mod.rs
        │   ├── test.rs
        │   └── validate.rs
        └── main.rs

================================================
FILE CONTENTS
================================================

================================================
FILE: .cargo/audit.toml
================================================
# Audit config file
#
# It may be located in the user home (`~/.cargo/audit.toml`) or in the project
# root (`.cargo/audit.toml`).
#
# All of the options which can be passed via CLI arguments can also be
# permanently specified in this file.

[advisories]
ignore = [
    "RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point.
    "RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples.
    "RUSTSEC-2025-0141", # `bincode` is no longer maintained.
    "RUSTSEC-2024-0388", # `derivative` dependancy in the DQN example is unmaintained.
] # advisory IDs to ignore e.g. ["RUSTSEC-2019-0001", ...]
informational_warnings = [
    "unmaintained",
] # warn for categories of informational advisories
severity_threshold = "low" # CVSS severity ("none", "low", "medium", "high", "critical")

# Output Configuration
[output]
deny = ["unmaintained"] # exit on error if unmaintained dependencies are found
format = "terminal"     # "terminal" (human readable report) or "json"
quiet = false           # Only print information on error
show_tree = true        # Show inverse dependency trees along with advisories (default: true)

[yanked]
enabled = true      # Warn for yanked crates in Cargo.lock (default: true)
update_index = true # Auto-update the crates.io index (default: true)


================================================
FILE: .cargo/config.toml
================================================
[alias]
xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"
run-checks = "xtask -c all validate --release"

================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug**
<!-- A clear and concise description of what the bug is. -->

**To Reproduce**
<!-- 
 Steps to reproduce the behavior:
 1. Go to '...'
 2. Click on '....'
 3. Scroll down to '....'
 4. See error
-->

**Expected behavior**
<!-- A clear and concise description of what you expected to happen. -->

**Screenshots**
<!-- If applicable, add screenshots to help explain your problem. -->

**Desktop (please complete the following information):**
 - OS: [e.g. iOS]
 - Browser [e.g. chrome, safari]
 - Version [e.g. 22]

**Smartphone (please complete the following information):**
 - Device: [e.g. iPhone6]
 - OS: [e.g. iOS8.1]
 - Browser [e.g. stock browser, safari]
 - Version [e.g. 22]

**Additional context**
<!-- Add any other context about the problem here. -->

================================================
FILE: .github/ISSUE_TEMPLATE/doc_request.md
================================================
---
name: Documentation request
about: Flag incoherent or missing documentation, including use case examples.
title: ''
labels: ''
assignees: ''

---

<!-- Please search existing issues to avoid creating duplicates -->


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''

---

<!-- Please search existing issues to avoid creating duplicates -->

### Feature description

<!-- Describe the feature you'd like -->

### Feature motivation

<!-- Why do you want this? -->

### (Optional) Suggest a Solution

<!--
  How do you think we should implement this feature? 
  Things to address include:
    * Details of the technical implementation
    * Tradeoffs made in design decisions
    * Caveats and considerations for the future
-->


================================================
FILE: .github/PULL_REQUEST_TEMPLATE/template.md
================================================
* **Please check if the PR fulfills these requirements**
- [ ] The commit message follows our guidelines
- [ ] Docs have been added / updated (for bug fixes / features)


* **What kind of change does this PR introduce?** (Bug fix, feature, docs update, ...)


* **Does this PR introduce a breaking change?** (What changes might users need to make in their application due to this PR?)


* **Other information**:

================================================
FILE: .github/dependabot.yml
================================================
version: 2

updates:
  - package-ecosystem: "github-actions"
    directory: "/"
    schedule:
      interval: "daily"
    ignore:
      - dependency-name: "tracel-ai/github-actions*"

  - package-ecosystem: "cargo"
    directories:
      - "/"
      - "crates/burn"
      - "crates/burn-*"
      - "crates/burn-import/*-tests"
      - "examples/*"
      - "xtask"
    schedule:
      interval: "weekly"



================================================
FILE: .github/pull_request_template.md
================================================
## Pull Request Template

### Checklist

- [ ] Confirmed that `cargo run-checks` command has been executed.
- [ ] Made sure the book is up to date with changes in this PR.

### Related Issues/PRs

_Provide links to relevant issues and dependent PRs._

### Changes

_Summarize the problem being addressed and your solution._

### Testing

_Describe how these changes have been tested._


================================================
FILE: .github/workflows/combine-dependabot-prs.yml
================================================
name: Combine Dependabot PRs

on:
  schedule:
    - cron: '0 6 * * MON' # Monday at 6:00am UTC
  workflow_dispatch:

permissions:
  contents: write
  pull-requests: write
  checks: read

jobs:
  combine-prs:
    runs-on: ubuntu-latest
    steps:
      - name: combine-prs
        id: combine-prs
        uses: github/combine-prs@v5.2.0
        with:
          labels: dependencies,automated


================================================
FILE: .github/workflows/dependencies.yml
================================================
name: dependencies

on:
  schedule:
    - cron: '0 21 * * TUE' # Run every Tuesday at 21:00 (UTC)
  push:
    tags:
      - 'v*.*.*' # Run when a new version is being published

env:
  UDEPS_VERSION: "0.1.57"

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  dependencies:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        checks:
          - licenses
          - bans sources
    continue-on-error: ${{ matrix.checks == 'licenses' }} # failed licenses don't abort
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Audit Rust dependencies
        # If a vulnerability is found, a new issue will automatically be opened
        # since this action runs on main branch
        uses: actions-rust-lang/audit@v1
      # --------------------------------------------------------------------------------
      - name: Detect multiple versions of the same crate
        uses: EmbarkStudios/cargo-deny-action@v2
        with:
          command: check ${{ matrix.checks }}
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt
      # --------------------------------------------------------------------------------
      - name: Install cargo-udeps
        env:
          UDEPS_LINK: https://github.com/est31/cargo-udeps/releases/download
        run: |
          curl -L "$UDEPS_LINK/v$UDEPS_VERSION/cargo-udeps-v$UDEPS_VERSION-x86_64-unknown-linux-gnu.tar.gz" |
          tar xz -C $HOME/.cargo/bin --strip-components 2
      # --------------------------------------------------------------------------------
      - name: Run cargo-udeps
        run: |
          cargo +nightly udeps --all-targets


================================================
FILE: .github/workflows/publish.yml
================================================
name: publish

on:
  push:
    tags:
      - "v*"
  workflow_dispatch:
    inputs:
      dry-run-only:
        description: "Run xtask publish in dry-run mode (no publish)"
        type: boolean
        required: false
        default: false

jobs:
  publish-burn-rl:
    needs:
      - publish-burn-core
      - publish-burn-optim
      # dev dependencies
      - publish-burn-ndarray
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-rl
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-vision:
    needs:
      - publish-burn-autodiff
      - publish-burn-candle
      - publish-burn-fusion
      - publish-burn-cubecl-fusion
      - publish-burn-cubecl
      - publish-burn-ndarray
      - publish-burn-tch
      - publish-burn-tensor
      - publish-burn-ir
      - publish-burn-tensor-testgen
      # dev dependencies
      - publish-burn-wgpu
      - publish-burn-cuda
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-vision
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-router:
    needs:
      - publish-burn-ir
      - publish-burn-std
      - publish-burn-tensor
      # dev dependencies
      - publish-burn-autodiff
      - publish-burn-ndarray
      - publish-burn-wgpu
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-router
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-remote:
    needs:
      - publish-burn-ir
      - publish-burn-std
      - publish-burn-tensor
      - publish-burn-router
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-remote
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-derive:
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-derive
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-dataset:
    needs:
      - publish-burn-std
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-dataset
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-std:
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-std
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-tensor-testgen:
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-tensor-testgen
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-tensor:
    needs:
      - publish-burn-tensor-testgen
      - publish-burn-std
      - publish-burn-backend
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-tensor
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-backend:
    needs:
      - publish-burn-std
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-backend
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-ir:
    needs:
      - publish-burn-tensor
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-ir
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-fusion:
    needs:
      - publish-burn-ir
      - publish-burn-tensor
      - publish-burn-std
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-fusion
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-cubecl-fusion:
    needs:
      - publish-burn-ir
      - publish-burn-std
      - publish-burn-fusion
      - publish-burn-tensor
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-cubecl-fusion
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-cubecl:
    needs:
      - publish-burn-ir
      - publish-burn-std
      - publish-burn-fusion
      - publish-burn-cubecl-fusion
      - publish-burn-tensor
      - publish-burn-ndarray
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-cubecl
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-autodiff:
    needs:
      - publish-burn-tensor
      - publish-burn-tensor-testgen
      - publish-burn-std
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-autodiff
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-tch:
    needs:
      - publish-burn-tensor
      - publish-burn-autodiff
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-tch
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-ndarray:
    needs:
      - publish-burn-ir
      - publish-burn-tensor
      - publish-burn-autodiff
      - publish-burn-std
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-ndarray
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-wgpu:
    needs:
      - publish-burn-tensor
      - publish-burn-autodiff
      - publish-burn-ndarray
      - publish-burn-std
      - publish-burn-cubecl
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-wgpu
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-cpu:
    needs:
      - publish-burn-tensor
      - publish-burn-fusion
      - publish-burn-cubecl
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-cpu
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-cuda:
    needs:
      - publish-burn-tensor
      - publish-burn-autodiff
      - publish-burn-ndarray
      - publish-burn-std
      - publish-burn-cubecl
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-cuda
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-rocm:
    needs:
      - publish-burn-tensor
      - publish-burn-autodiff
      - publish-burn-ndarray
      - publish-burn-std
      - publish-burn-cubecl
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-rocm
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-candle:
    needs:
      - publish-burn-tensor
      - publish-burn-autodiff
      - publish-burn-tch
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-candle
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-collective:
    needs:
      - publish-burn-std
      - publish-burn-tensor
      - publish-burn-communication
      # dev dependencies
      - publish-burn-wgpu
      - publish-burn-ndarray
      - publish-burn-cuda
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-collective
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-communication:
    needs:
      - publish-burn-std
      - publish-burn-tensor
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-communication
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-core:
    needs:
      - publish-burn-dataset
      - publish-burn-std
      - publish-burn-derive
      - publish-burn-tensor
      - publish-burn-vision
      # dev dependencies
      - publish-burn-autodiff
      - publish-burn-wgpu
      - publish-burn-tch
      - publish-burn-cuda
      - publish-burn-ndarray
      - publish-burn-candle
      - publish-burn-remote
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-core
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-nn:
    needs:
      - publish-burn-core
      # dev dependencies
      - publish-burn-autodiff
      - publish-burn-wgpu
      - publish-burn-tch
      - publish-burn-ndarray
      - publish-burn-candle
      - publish-burn-remote
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-nn
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-optim:
    needs:
      - publish-burn-core
      - publish-burn-collective
      # dev dependencies
      - publish-burn-autodiff
      - publish-burn-wgpu
      - publish-burn-tch
      - publish-burn-ndarray
      - publish-burn-candle
      - publish-burn-remote
      - publish-burn-nn
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-optim
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-train:
    needs:
      - publish-burn-core
      - publish-burn-optim
      - publish-burn-collective
      - publish-burn-rl
      - publish-burn-ndarray
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-train
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-dispatch:
    needs:
      - publish-burn-std
      - publish-burn-backend
      - publish-burn-autodiff
      - publish-burn-cpu
      - publish-burn-cuda
      - publish-burn-rocm
      - publish-burn-wgpu
      - publish-burn-ndarray
      - publish-burn-tch
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-dispatch
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn:
    needs:
      - publish-burn-core
      - publish-burn-nn
      - publish-burn-optim
      - publish-burn-collective
      - publish-burn-store
      - publish-burn-train
      - publish-burn-cpu
      - publish-burn-dispatch
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

  publish-burn-store:
    needs:
      - publish-burn-core
      - publish-burn-nn
      - publish-burn-tensor
    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9
    with:
      crate: burn-store
      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
    secrets:
      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}


================================================
FILE: .github/workflows/stale-pr.yml
================================================
name: Stale Pull Requests

on:
  schedule:
    - cron: '0 12 * * *' # Run every day at 12:00 (UTC)

# The minimum permissions required to run this Action
permissions:
  contents: write # only for delete-branch option
  issues: write
  pull-requests: write

jobs:
  stale-pr:
    runs-on: ubuntu-latest
    steps:
    - name: checkout
      uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
    - name: Stale pull requests
      uses: actions/stale@v10
      with:
        # The idle number of days before marking issues stale.
        #
        #  With a negative number like -1, no issues
        #  will be marked as stale automatically.
        days-before-issue-stale: -1
        # The idle number of days before marking pull requests stale
        days-before-pr-stale: 30
        # The idle number of days before closing
        # the stale pull requests (due to the stale label).
        #
        # With a negative number like -1, the pull requests
        # will never be closed automatically.
        days-before-pr-close: -1
        # Label to apply on staled pull requests
        stale-pr-label: 'stale'
        # The message that will be added as a comment to the pull request
        stale-pr-message: 'This PR has been marked as stale because it has not been updated for over a month'
        # Remove `stale` label from pull requests on updates/comments
        remove-pr-stale-when-updated: true


================================================
FILE: .github/workflows/test-gpu.yml
================================================
name: CI GPU

on:
  workflow_dispatch:
    inputs:
      pr_number:
        description: "Number of the pull request that triggers this run if any"
        type: number
        required: false

# important to set the run name to this format so that the CI server
# can track the PR number from the workflow_run events.
run-name: ${{ github.workflow }}:${{ github.repository }}#${{ inputs.pr_number }}

env:
  # Note: It is not possible to define top level env vars and pass them to composite actions.
  # To work around this issue we use inputs and define all the env vars here.

  RUST_PREVIOUS_VERSION: 1.92.0

  # Dependency versioning
  # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml

  # GCP runners
  GCP_RUNNERS_IMAGE_FAMILY: "tracel-ci-ubuntu-2404-amd64-nvidia"
  GCP_RUNNERS_MACHINE_TYPE: "g2-standard-4"
  GCP_RUNNERS_ZONE: "us-east1-c"

  # Test in release mode (make it an empty string to test in debug mode)
  TEST_RELEASE_FLAG: "--release"

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  prepare-checks:
    runs-on: ubuntu-latest
    outputs:
      rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }}
      gcp_runners_image_family: ${{ env.GCP_RUNNERS_IMAGE_FAMILY }}
      gcp_runners_machine_type: ${{ env.GCP_RUNNERS_MACHINE_TYPE }}
      gcp_runners_zone: ${{ env.GCP_RUNNERS_ZONE }}
    steps:
      - name: Do Nothing
        if: false
        run: echo

  linux-std-cuda-tests:
    needs: [prepare-checks]
    timeout-minutes: 60
    # '@id:' label must be unique within this worklow
    runs-on:
      [
        "@id:burn-cuda-job-${{github.run_id}}-${{github.run_attempt}}",
        "@pr_number:${{ inputs.pr_number }}",
        "@organization:tracel-ai",
        "@repository:burn",
        "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}",
        "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}",
        "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}",
        "@gpu:true",
      ]
    env:
      LD_LIBRARY_PATH: "/usr/local/cuda/lib64"
      # disable incremental compilation (reduces artifact size)
      CARGO_PROFILE_TEST_INCREMENTAL: "false"
    # Keep the stragegy to be able to easily add new rust versions if required
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          enable-cache: false
      # --------------------------------------------------------------------------------
      - name: Tests (burn-cuda)
        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-cuda-runner

  linux-std-vulkan-tests:
    needs: [prepare-checks]
    timeout-minutes: 60
    # '@id:' label must be unique within this worklow
    runs-on:
      [
        "@id:burn-vulkan-job-${{github.run_id}}-${{github.run_attempt}}",
        "@pr_number:${{ inputs.pr_number }}",
        "@organization:tracel-ai",
        "@repository:burn",
        "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}",
        "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}",
        "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}",
        "@gpu:true",
      ]
    env:
      # disable incremental compilation (reduces artifact size)
      CARGO_PROFILE_TEST_INCREMENTAL: "false"
    # Keep the stragegy to be able to easily add new rust versions if required
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          enable-cache: false
      # --------------------------------------------------------------------------------
      - name: Tests (burn-vulkan)
        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-vulkan-runner

  linux-std-wgpu-tests:
    needs: [prepare-checks]
    timeout-minutes: 60
    # '@id:' label must be unique within this worklow
    runs-on:
      [
        "@id:burn-wgpu-job-${{github.run_id}}-${{github.run_attempt}}",
        "@pr_number:${{ inputs.pr_number }}",
        "@organization:tracel-ai",
        "@repository:burn",
        "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}",
        "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}",
        "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}",
        "@gpu:true",
      ]
    env:
      # disable incremental compilation (reduces artifact size)
      CARGO_PROFILE_TEST_INCREMENTAL: "false"
    # Keep the stragegy to be able to easily add new rust versions if required
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          enable-cache: false
      # --------------------------------------------------------------------------------
      - name: Tests (burn-wgpu)
        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-wgpu-runner


================================================
FILE: .github/workflows/test.yml
================================================
name: CI

on:
  push:
    branches:
      - main
    paths:
      - "Cargo.lock"
      - "**.rs"
      - "**.sh"
      - "**.ps1"
      - "**.yml"
      - "**.toml"
      - "!**.md"
      - "!LICENSE-APACHE"
      - "!LICENSE-MIT"
  pull_request:
    types: [opened, synchronize]
    paths:
      - "Cargo.lock"
      - "**.rs"
      - "**.sh"
      - "**.ps1"
      - "**.yml"
      - "**.toml"
      - "!**.md"
      - "!LICENSE-APACHE"
      - "!LICENSE-MIT"

env:
  # Note: It is not possible to define top level env vars and pass them to composite actions.
  # To work around this issue we use inputs and define all the env vars here.

  RUST_PREVIOUS_VERSION: 1.92.0

  # Dependency versioning
  # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml

  # Mozilla Grcov
  GRCOV_LINK: "https://github.com/mozilla/grcov/releases/download"
  GRCOV_VERSION: "0.8.19"

  # Test in release mode (make it an empty string to test in debug mode)
  TEST_RELEASE_FLAG: "--release"

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  prepare-checks:
    runs-on: ubuntu-latest
    outputs:
      rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }}
    steps:
      - name: Do Nothing
        if: false
        run: echo

  code-quality:
    runs-on: ubuntu-22.04
    needs: prepare-checks
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-linux
      # --------------------------------------------------------------------------------
      - name: Audit
        run: cargo xtask check audit
      # --------------------------------------------------------------------------------
      - name: Format
        shell: bash
        env:
          # work around for colors
          # see: https://github.com/rust-lang/rustfmt/issues/3385
          TERM: xterm-256color
        run: cargo xtask check format
      # --------------------------------------------------------------------------------
      - name: Lint
        run: cargo xtask check lint
      # --------------------------------------------------------------------------------
      - name: Typos
        uses: tracel-ai/github-actions/check-typos@v9

  documentation:
    runs-on: ubuntu-22.04
    needs: prepare-checks
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-linux
      # --------------------------------------------------------------------------------
      - name: Documentation Build
        run: cargo xtask doc build
      # --------------------------------------------------------------------------------
      - name: Documentation Tests
        run: cargo xtask doc tests

  linux-std-tests:
    runs-on: ubuntu-22.04
    needs: [prepare-checks, code-quality]
    env:
      DISABLE_WGPU_SPIRV: "1"
      # disable incremental compilation (reduces artifact size)
      CARGO_PROFILE_TEST_INCREMENTAL: "false"
    strategy:
      matrix:
        rust: [stable, prev]
        include:
          - rust: stable
            toolchain: stable
            coverage: --enable-coverage
          - rust: prev
            toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-linux
          # Disable cache on linux-std (stable) runner which currently always runs out of disk space with tests + coverage
          enable-cache: ${{ matrix.rust != 'stable' }}
      # # --------------------------------------------------------------------------------
      - name: Install grcov
        if: matrix.rust == 'stable'
        shell: bash
        run: |
          curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" |
          tar xj -C $HOME/.cargo/bin
          cargo xtask coverage install
      # --------------------------------------------------------------------------------
      - name: Tests
        run: cargo xtask ${{ matrix.coverage }} test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner
      # --------------------------------------------------------------------------------
      - name: Generate lcov.info
        if: matrix.rust == 'stable'
        # /* is to exclude std library code coverage from analysis
        run: cargo xtask coverage generate --ignore "/*,xtask/*,examples/*" --profile release
      # --------------------------------------------------------------------------------
      - name: Codecov upload lcov.info
        if: matrix.rust == 'stable'
        uses: codecov/codecov-action@v5
        with:
          files: lcov.info
          token: ${{ secrets.CODECOV_TOKEN }}

  linux-no-std-tests:
    runs-on: ubuntu-22.04
    needs: [prepare-checks, code-quality]
    strategy:
      matrix:
        rust: [stable, prev]
        include:
          - rust: stable
            toolchain: stable
          - rust: prev
            toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-linux-no-std
      # --------------------------------------------------------------------------------
      - name: Crates Build
        run: cargo xtask --context no-std build --ci
      # --------------------------------------------------------------------------------
      - name: Crates Tests
        run: cargo xtask --context no-std test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner

  windows-std-tests:
    runs-on: windows-2022
    needs: [prepare-checks, code-quality]
    # Keep the stragegy to be able to easily add new rust versions if required
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-windows
      # --------------------------------------------------------------------------------
      - name: Tests
        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner

  macos-std-tests:
    runs-on: blaze/macos-15
    needs: [prepare-checks, code-quality]
    timeout-minutes: 60
    # Keep the stragegy to be able to easily add new rust versions if required
    strategy:
      matrix:
        rust: [stable]
        include:
          - rust: stable
            toolchain: stable
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Setup Rust
        uses: tracel-ai/github-actions/install-rust@v9
        with:
          rust-toolchain: ${{ matrix.toolchain }}
          cache-key: ${{ matrix.rust }}-macos
      # --------------------------------------------------------------------------------
      - name: Device check
        run: system_profiler SPHardwareDataType
      # --------------------------------------------------------------------------------
      - name: Tests
        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-mac-runner


================================================
FILE: .github/workflows/valgrind.yml
================================================
name: valgrind

on:
  schedule:
    - cron: '0 23 * * WED' # Run every Wednesday at 23:00 (UTC)

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  valgrind:
    runs-on: [
      '@id:burn-linux-valgrind-${{ github.run_id }}-${{ github.run_attempt }}',
      '@image-family:ubuntu-2404-lts-amd64',
      '@image-project:ubuntu-os-cloud',
      '@disk-size:100',
      '@keep-alive:false',
      '@machine-type:n2-standard-16',
      '@os:linux',
      '@zones:northamerica-northeast1-b'
      ]
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Install valgrind
        run: |
          sudo apt-get install valgrind
      # --------------------------------------------------------------------------------
      - name: Run cargo-valgrind
        env:
          CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: "valgrind -s --leak-check=full --show-leak-kinds=all --error-exitcode=1"
        # Looking for vulnerabilities
        run: |
          cargo test


================================================
FILE: .github/workflows/vulnerabilities.yml
================================================
name: vulnerabilities

on:
  schedule:
    - cron: '0 21 * * WED' # Run every Wednesday at 21:00 (UTC)
  push:
    tags:
      - 'v*.*.*'

env:
  CAREFUL_VERSION: "0.4.9"

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  cargo-careful:
    runs-on: ubuntu-latest
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt, rust-src
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Install cargo-careful
        env:
          CAREFUL_LINK: https://github.com/RalfJung/cargo-careful/releases/download
        run: |
          curl -L "$CAREFUL_LINK/v$CAREFUL_VERSION/cargo-careful.x86_64-unknown-linux-musl" \
          --output $HOME/.cargo/bin/cargo-careful
          chmod +x $HOME/.cargo/bin/cargo-careful
      # --------------------------------------------------------------------------------
      - name: Run cargo-careful
        # Looking for undefined behaviours
        run: cargo +nightly careful test

  address-sanitizer:
    runs-on: ubuntu-latest
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt, rust-src
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Run AddressSanitizer
        env:
          RUSTFLAGS: -Zsanitizer=address -Copt-level=3
          RUSTDOCFLAGS: -Zsanitizer=address
        # Looking for memory vulnerabilities
        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture

  thread-sanitizer:
    runs-on: ubuntu-latest
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt, rust-src
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Run ThreadSanitizer
        env:
          RUSTFLAGS: -Zsanitizer=thread -Copt-level=3
          RUSTDOCFLAGS: -Zsanitizer=thread
        # Looking for data race among threads
        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture

  memory-sanitizer:
    runs-on: ubuntu-latest
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt, rust-src
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Run MemorySanitizer
        env:
          RUSTFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins -Copt-level=3
          RUSTDOCFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins
        # Looking for unitialized memory.
        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture

  safe-stack:
    runs-on: ubuntu-latest
    steps:
      - name: checkout
        uses: actions/checkout@v6
      # --------------------------------------------------------------------------------
      - name: Install Rust nightly
        uses: dtolnay/rust-toolchain@nightly
        with:
          toolchain: nightly
          components: rustfmt, rust-src
      # --------------------------------------------------------------------------------
      - name: Install Mesa
        uses: tracel-ai/github-actions/install-mesa@v9
      # --------------------------------------------------------------------------------
      - name: Run SafeStack
        env:
          RUSTFLAGS: -Zsanitizer=safestack -Copt-level=3
          RUSTDOCFLAGS: -Zsanitizer=safestack
        # Provides backward edge control flow protection
        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture


================================================
FILE: .gitignore
================================================
target
# These are backup files generated by rustfmt
**/*.rs.bk
.DS_Store

.dir-locals.el
.idea
.vscode
.vs
.fleet
.ipynb_checkpoints/

# Build output directory
out

# Virtual Environment of Python
.venv
uv.lock

# Nix direnv
.envrc
.direnv


================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
  - family-names: "Simard"
    given-names: "Nathaniel"
    email: "nathaniel.simard.42@gmail.com"
  - family-names: "Fortier-Dubois"
    given-names: "Louis"
    email: "louisfd94@gmail.com"
  - family-names: "Tadjibaev"
    given-names: "Dilshod"
    email: "dilshod@gmail.com"
  - family-names: "Lagrange"
    given-names: "Guillaume"
    email: "lagrange.guillaume.1@gmail.com"
  - name: "Burn Framework Contributors"
title: "Burn"
version: 0.14.0
date-released: 2024-08-27
url: "https://burn.dev/"
repository-code: "https://github.com/tracel-ai/burn"
license:
  - MIT
  - Apache-2.0
abstract: "Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals."
keywords:
  - scientific-computing
  - deep-learning
  - machine-learning
  - neural-networks
  - rust
  - high-performance-computing
  - portability
  - compute-efficiency


================================================
FILE: CODE-OF-CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.

## Our Standards

Examples of behavior that contributes to a positive environment for our
community include:

* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
  and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
  overall community

Examples of unacceptable behavior include:

* The use of sexualized language or imagery, and sexual attention or
  advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
  address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Enforcement Responsibilities

Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.

Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.

## Scope

This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
nathaniel.simard.42@gmail.com.
All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the
reporter of any incident.

## Enforcement Guidelines

Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:

### 1. Correction

**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.

**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.

### 2. Warning

**Community Impact**: A violation through a single incident or series
of actions.

**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.

### 3. Temporary Ban

**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.

**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.

### 4. Permanent Ban

**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior,  harassment of an
individual, or aggression toward or disparagement of classes of individuals.

**Consequence**: A permanent ban from any sort of public interaction within
the community.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.

Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to Burn

Welcome to the Burn community! We're glad you're interested in contributing.

## How to Contribute

The best way to get started is to look at [open issues](https://github.com/tracel-ai/burn/issues)
and find one that interests you. Issues labeled `good first issue` are a great starting point for
new contributors.

If you have an idea that isn't covered by an existing issue, open one first to discuss the approach.
This helps align expectations and avoids wasted effort on both sides.

For questions, discussions, or just to say hello, join us on
[Discord](https://discord.gg/uPEBbYYDB6). The [Contributor Book](https://burn.dev/contributor-book/)
covers architecture, environment setup, and guides for common tasks.

## Pull Requests

Every pull request should have a descriptive title, a description covering what you changed, why,
how you tested it, and a link to the relevant issue (if applicable). Prefer small, focused PRs over
large ones that bundle unrelated changes.

Draft pull requests are considered not yet ready for review.

CI checks should pass before requesting review, though the signal isn't always accurate. If you have
questions or need early feedback, let us know on the PR or on
[Discord](https://discord.gg/uPEBbYYDB6).

### Change Ownership

The core principle behind all contributions: **PR authors must understand, justify, and explain
every change they propose.** After a PR is accepted, both the reviewer and the author should be
confident it improves the codebase.

This applies equally whether you wrote the code from scratch, adapted it from another project, or
used AI tools to help generate it. The origin of the code doesn't matter; what matters is that you
own it intellectually and can stand behind it during review.

### AI-Assisted Contributions

Using LLMs and AI tools to generate code that is part of a contribution is allowed.

That said, the [Change Ownership](#change-ownership) principle applies fully. You are the author,
not your AI tool. This means:

- Read and understand every line before submitting.
- Review AI-generated code for correctness, style consistency, and relevance.
- Test your changes locally and confirm they work as intended.
- Be prepared to explain the rationale behind any change during review.

Do not use "AI generated" as a justification for low-quality code.

### Before You Open a PR

1. **Check for an existing issue.** If there isn't one, open an issue first to discuss the approach.
   This is especially important for large changes or refactors.
2. **Read the codebase.** Understand the architecture and conventions already in place. The
   [Contributor Book](https://burn.dev/contributor-book/) covers architecture, environment setup,
   and guides for common tasks.
3. **Keep it focused.** One PR should address one concern. If you spot an unrelated issue while
   working, open a separate PR for it.
4. **Run validation.** Run `cargo run-checks` before submitting. This runs formatting, linting, and
   the full test suite. All checks must pass.

### Code Quality Standards

- Follow existing code style and project conventions.
- Write idiomatic Rust. If you are new to the codebase, study existing patterns before contributing.
- Keep dependencies minimal. Don't introduce new crates without discussion.
- Document public APIs. Non-trivial logic should have comments explaining _why_, not just _what_.
- Prefer clarity over cleverness.
- Bug fixes should include a regression test.

### Large Pull Requests

Large, complex PRs are harder to review effectively and carry more risk. To help both yourself and
reviewers, consider breaking substantial changes into smaller, incremental PRs. Each should be
valuable on its own, even if the full picture spans multiple PRs.

Large efforts that are ultimately rejected are frustrating for everyone involved. If you're planning
a substantial change, open an issue or start a discussion first. It's much easier to course-correct
early than after the work is done.

### Review Process

- Maintainers review PRs as time allows. Please be patient.
- Be responsive to feedback. If changes are requested, address them or explain your reasoning.
- Reviewers may ask clarifying questions about any part of your PR. This is a normal part of
  collaborative review and helps ensure shared understanding.
- Don't force-push to rewrite history during an active review without notice.
- If a PR goes stale for more than 14 days without a response from the author, it may be closed.

## Getting Help

If you're stuck or unsure about something, don't hesitate to ask. Open an issue, start a discussion,
or reach out on [Discord](https://discord.gg/uPEBbYYDB6). We're happy to help.


================================================
FILE: Cargo.toml
================================================
[workspace]
# Try
# require version 2 to avoid "feature" additiveness for dev-dependencies
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
    "crates/*",
    "crates/burn-store/pytorch-tests",
    "crates/burn-store/safetensors-tests",
    "crates/burn-collective/multinode-tests",
    "examples/*",
    "xtask",
]

exclude = [
    "examples/notebook",
    "examples/raspberry-pi-pico",
    "examples/dqn-agent",         # gym-rs
]

[workspace.package]
edition = "2024"
license = "MIT OR Apache-2.0"
readme = "README.md"
version = "0.21.0-pre.2"

[workspace.lints.clippy]

[workspace.lints.rustdoc]
broken_intra_doc_links = "deny"
invalid_html_tags = "deny"

[workspace.dependencies]
atomic_float = "1"
axum = "0.8.8"
bytemuck = "1.25.0"
bytes = { version = "1.11.1", default-features = false }
candle-core = { version = "0.9.2" }
ciborium = { version = "0.2", default-features = false }
clap = { version = "4.6.0", features = ["derive"] }
colored = "3.0.0"
console_error_panic_hook = "0.1.7"
const-random = "0.1"
csv = "1.3.1"
dashmap = "6.1.0"
data-encoding = { version = "2.10.0", default-features = false, features = [
    "alloc",
] }
dirs = "6.0.0"
encoding_rs = "0.8.33"
enumset = { version = "1.1.10", default-features = false }
fake = "5.1.0"
flate2 = "1.1.9"
float-cmp = "0.10.0"
futures = "0.3"
futures-util = "0.3"
gix-tempfile = { version = "21.0.0", features = ["signals"] }
globwalk = "0.9.1"
hashbrown = "0.16"
hound = "3.5.1"
image = "0.25.9"
indicatif = "0.18.0"
insta = "1.45.0"
js-sys = "0.3.77"
libm = "0.2.15"
log = { default-features = false, version = "0.4.29" }
lzma-rust2 = "0.16.2"
opentelemetry = "0.31.0"
opentelemetry-aws = "0.19.0"
opentelemetry-otlp = "0.31.0"
opentelemetry_sdk = "0.31.0"
parking_lot = { version = "0.12.5", default-features = false }
paste = "1"
planus = { version = "=1.1" }
polars = { version = "0.53.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.106"
quote = "1.0.45"
r2d2 = "0.8.10"
r2d2_sqlite = "0.31.0"
rayon = "1.10.0"
regex = { version = "1.12.3", default-features = false, features = [
    "perf",
    "unicode",
] }
reqwest = { version = "0.12.23", default-features = false, features = [
    "rustls-tls",
] }
rmp-serde = { version = "1.3.1", default-features = false }
rstest = "0.26.1"
rusqlite = "0.37.0"
sanitize-filename = "0.6.0"
serde_bytes = { version = "0.11.18", default-features = false, features = [
    "alloc",
] } # alloc for no_std
serde_rusqlite = "0.40.0"
serial_test = "3.2.0"
spin = { version = "0.10.0", features = [
    "mutex",
    "spin_mutex",
    "portable-atomic",
] }
strum = { version = "0.28.0", features = ["derive"] }
syn = { version = "2.0.111", features = ["full", "extra-traits"] }
tar = "0.4.44"
tempfile = "3.24.0"
textdistance = { version = "1.1.1", default-features = false }
thiserror = { version = "2", default-features = false }
tokio = { version = "1.50.0", features = ["rt", "macros"] }
tokio-tungstenite = "0.28"
tokio-util = "0.7"
tracing = { version = "0.1.44", default-features = false }
tracing-appender = "0.2.3"
tracing-core = { version = "0.1.36", default-features = false }
tracing-opentelemetry = "0.32.0"
tracing-subscriber = "0.3.23"
zip = "8.2.0"

# Persist related
memmap2 = { version = "0.9" }
safetensors = { version = "0.7.0", default-features = false }

# Async handling
async-channel = "2.5"
futures-lite = { version = "2.6.1", default-features = false }

# Terminal UI
ratatui = "0.30.0"

# WGPU stuff
text_placeholder = "0.5.1"

bincode = { version = "2.0.1", features = [
    "alloc",
    "serde",
], default-features = false }

#
# The following packages disable the "std" feature for no_std compatibility
#
cfg-if = "1.0.1"
derive-new = { version = "0.7.0", default-features = false }

blas-src = { version = "0.14.0", default-features = false }
bon = "3.8.2"
half = { version = "2.7.1", features = [
    "alloc",
    "num-traits",
    "serde",
], default-features = false }
macerator = { version = "0.3.0" }
matrixmultiply = { version = "0.3.10", default-features = false }
ndarray = { version = "0.17.2", default-features = false }
num-traits = { version = "0.2.19", default-features = false, features = [
    "libm",
] } # libm is for no_std
openblas-src = "0.10.14"
rand = { version = "0.10.0", default-features = false, features = ["std_rng"] }
rand_distr = { version = "0.6.0", default-features = false }
serde = { version = "1.0.228", default-features = false, features = [
    "derive",
    "alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.148", default-features = false }
smallvec = { version = "1", features = ["const_generics", "const_new"] }
uuid = { version = "1.22.0", default-features = false }

byteorder = { version = "1.5.0", default-features = false }
libc = "0.2.182"
nvml-wrapper = "0.12.0"
sysinfo = "0.38.0"
systemstat = "0.2.6"
tch = "0.22.0"
torch-sys = "0.22.0"                                        # matches what tch is using, required for lib detection

ahash = { version = "0.8.12", default-features = false }
portable-atomic = { version = "1.13.1" }
portable-atomic-util = { version = "0.2.6", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" }
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "01ed48e1abb5ed117df33f4394f2c5a91c3eb97e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false }
# cubek = { path = "../cubek/crates/cubek", default-features = false }
### For the release. ###
# cubecl = { version = "=0.10.0-pre.2", default-features = false }
# cubecl-common = { version = "=0.10.0-pre.2", default-features = false }
# cubecl-zspace = { version = "=0.10.0-pre.2", default-features = false }
# cubek = { version = "=0.2.0-pre.2", default-features = false }

### For xtask crate ###
tracel-xtask = "=4.13.5"
# ### For local development. ###
# tracel-xtask = { path = "../xtask/crates/tracel-xtask", default-features = false }

[profile.dev]
debug = 1 # Speed up compilation time and not necessary.


================================================
FILE: LICENSE-APACHE
================================================
                              Apache License
                        Version 2.0, January 2004
                     http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

   "License" shall mean the terms and conditions for use, reproduction,
   and distribution as defined by Sections 1 through 9 of this document.

   "Licensor" shall mean the copyright owner or entity authorized by
   the copyright owner that is granting the License.

   "Legal Entity" shall mean the union of the acting entity and all
   other entities that control, are controlled by, or are under common
   control with that entity. For the purposes of this definition,
   "control" means (i) the power, direct or indirect, to cause the
   direction or management of such entity, whether by contract or
   otherwise, or (ii) ownership of fifty percent (50%) or more of the
   outstanding shares, or (iii) beneficial ownership of such entity.

   "You" (or "Your") shall mean an individual or Legal Entity
   exercising permissions granted by this License.

   "Source" form shall mean the preferred form for making modifications,
   including but not limited to software source code, documentation
   source, and configuration files.

   "Object" form shall mean any form resulting from mechanical
   transformation or translation of a Source form, including but
   not limited to compiled object code, generated documentation,
   and conversions to other media types.

   "Work" shall mean the work of authorship, whether in Source or
   Object form, made available under the License, as indicated by a
   copyright notice that is included in or attached to the work
   (an example is provided in the Appendix below).

   "Derivative Works" shall mean any work, whether in Source or Object
   form, that is based on (or derived from) the Work and for which the
   editorial revisions, annotations, elaborations, or other modifications
   represent, as a whole, an original work of authorship. For the purposes
   of this License, Derivative Works shall not include works that remain
   separable from, or merely link (or bind by name) to the interfaces of,
   the Work and Derivative Works thereof.

   "Contribution" shall mean any work of authorship, including
   the original version of the Work and any modifications or additions
   to that Work or Derivative Works thereof, that is intentionally
   submitted to Licensor for inclusion in the Work by the copyright owner
   or by an individual or Legal Entity authorized to submit on behalf of
   the copyright owner. For the purposes of this definition, "submitted"
   means any form of electronic, verbal, or written communication sent
   to the Licensor or its representatives, including but not limited to
   communication on electronic mailing lists, source code control systems,
   and issue tracking systems that are managed by, or on behalf of, the
   Licensor for the purpose of discussing and improving the Work, but
   excluding communication that is conspicuously marked or otherwise
   designated in writing by the copyright owner as "Not a Contribution."

   "Contributor" shall mean Licensor and any individual or Legal Entity
   on behalf of whom a Contribution has been received by Licensor and
   subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of
   this License, each Contributor hereby grants to You a perpetual,
   worldwide, non-exclusive, no-charge, royalty-free, irrevocable
   copyright license to reproduce, prepare Derivative Works of,
   publicly display, publicly perform, sublicense, and distribute the
   Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of
   this License, each Contributor hereby grants to You a perpetual,
   worldwide, non-exclusive, no-charge, royalty-free, irrevocable
   (except as stated in this section) patent license to make, have made,
   use, offer to sell, sell, import, and otherwise transfer the Work,
   where such license applies only to those patent claims licensable
   by such Contributor that are necessarily infringed by their
   Contribution(s) alone or by combination of their Contribution(s)
   with the Work to which such Contribution(s) was submitted. If You
   institute patent litigation against any entity (including a
   cross-claim or counterclaim in a lawsuit) alleging that the Work
   or a Contribution incorporated within the Work constitutes direct
   or contributory patent infringement, then any patent licenses
   granted to You under this License for that Work shall terminate
   as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the
   Work or Derivative Works thereof in any medium, with or without
   modifications, and in Source or Object form, provided that You
   meet the following conditions:

   (a) You must give any other recipients of the Work or
       Derivative Works a copy of this License; and

   (b) You must cause any modified files to carry prominent notices
       stating that You changed the files; and

   (c) You must retain, in the Source form of any Derivative Works
       that You distribute, all copyright, patent, trademark, and
       attribution notices from the Source form of the Work,
       excluding those notices that do not pertain to any part of
       the Derivative Works; and

   (d) If the Work includes a "NOTICE" text file as part of its
       distribution, then any Derivative Works that You distribute must
       include a readable copy of the attribution notices contained
       within such NOTICE file, excluding those notices that do not
       pertain to any part of the Derivative Works, in at least one
       of the following places: within a NOTICE text file distributed
       as part of the Derivative Works; within the Source form or
       documentation, if provided along with the Derivative Works; or,
       within a display generated by the Derivative Works, if and
       wherever such third-party notices normally appear. The contents
       of the NOTICE file are for informational purposes only and
       do not modify the License. You may add Your own attribution
       notices within Derivative Works that You distribute, alongside
       or as an addendum to the NOTICE text from the Work, provided
       that such additional attribution notices cannot be construed
       as modifying the License.

   You may add Your own copyright statement to Your modifications and
   may provide additional or different license terms and conditions
   for use, reproduction, or distribution of Your modifications, or
   for any such Derivative Works as a whole, provided Your use,
   reproduction, and distribution of the Work otherwise complies with
   the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise,
   any Contribution intentionally submitted for inclusion in the Work
   by You to the Licensor shall be under the terms and conditions of
   this License, without any additional terms or conditions.
   Notwithstanding the above, nothing herein shall supersede or modify
   the terms of any separate license agreement you may have executed
   with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade
   names, trademarks, service marks, or product names of the Licensor,
   except as required for reasonable and customary use in describing the
   origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or
   agreed to in writing, Licensor provides the Work (and each
   Contributor provides its Contributions) on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
   implied, including, without limitation, any warranties or conditions
   of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
   PARTICULAR PURPOSE. You are solely responsible for determining the
   appropriateness of using or redistributing the Work and assume any
   risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory,
   whether in tort (including negligence), contract, or otherwise,
   unless required by applicable law (such as deliberate and grossly
   negligent acts) or agreed to in writing, shall any Contributor be
   liable to You for damages, including any direct, indirect, special,
   incidental, or consequential damages of any character arising as a
   result of this License or out of the use or inability to use the
   Work (including but not limited to damages for loss of goodwill,
   work stoppage, computer failure or malfunction, or any and all
   other commercial damages or losses), even if such Contributor
   has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing
   the Work or Derivative Works thereof, You may choose to offer,
   and charge a fee for, acceptance of support, warranty, indemnity,
   or other liability obligations and/or rights consistent with this
   License. However, in accepting such obligations, You may act only
   on Your own behalf and on Your sole responsibility, not on behalf
   of any other Contributor, and only if You agree to indemnify,
   defend, and hold each Contributor harmless for any liability
   incurred by, or claims asserted against, such Contributor by reason
   of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

   To apply the Apache License to your work, attach the following
   boilerplate notice, with the fields enclosed by brackets "[]"
   replaced with your own identifying information. (Don't include
   the brackets!)  The text should be enclosed in the appropriate
   comment syntax for the file format. We also recommend that a
   file or class name and description of purpose be included on the
   same "printed page" as the copyright notice for easier
   identification within third-party archives.

Copyright 2022 Nathaniel Simard & Burn Framework Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

	http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


================================================
FILE: LICENSE-MIT
================================================
MIT License

Copyright (c) 2022 Nathaniel Simard & Burn Framework Contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: NOTICES.md
================================================
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this
repository copied or derived from.

## PyTorch MNIST Example

**Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py

License: BSD 3-Clause License

Copyright (c) 2017,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


## wgpu

**Source:** https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml

MIT License

Copyright (c) 2021 The gfx-rs developers

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


## BSL 1.0

**Source**:
- https://github.com/DoumanAsh/error-code
- https://github.com/DoumanAsh/clipboard-win


Boost Software License - Version 1.0 - August 17th, 2003

Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:

The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.


## num-traits

**Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs

MIT License

Copyright (c) 2014 The Rust Project Developers

Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:

The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

## RP

**Source**:
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/Cargo.toml
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/build.rs
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/memory.x

                              Apache License
                        Version 2.0, January 2004
                     http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

   "License" shall mean the terms and conditions for use, reproduction,
   and distribution as defined by Sections 1 through 9 of this document.

   "Licensor" shall mean the copyright owner or entity authorized by
   the copyright owner that is granting the License.

   "Legal Entity" shall mean the union of the acting entity and all
   other entities that control, are controlled by, or are under common
   control with that entity. For the purposes of this definition,
   "control" means (i) the power, direct or indirect, to cause the
   direction or management of such entity, whether by contract or
   otherwise, or (ii) ownership of fifty percent (50%) or more of the
   outstanding shares, or (iii) beneficial ownership of such entity.

   "You" (or "Your") shall mean an individual or Legal Entity
   exercising permissions granted by this License.

   "Source" form shall mean the preferred form for making modifications,
   including but not limited to software source code, documentation
   source, and configuration files.

   "Object" form shall mean any form resulting from mechanical
   transformation or translation of a Source form, including but
   not limited to compiled object code, generated documentation,
   and conversions to other media types.

   "Work" shall mean the work of authorship, whether in Source or
   Object form, made available under the License, as indicated by a
   copyright notice that is included in or attached to the work
   (an example is provided in the Appendix below).

   "Derivative Works" shall mean any work, whether in Source or Object
   form, that is based on (or derived from) the Work and for which the
   editorial revisions, annotations, elaborations, or other modifications
   represent, as a whole, an original work of authorship. For the purposes
   of this License, Derivative Works shall not include works that remain
   separable from, or merely link (or bind by name) to the interfaces of,
   the Work and Derivative Works thereof.

   "Contribution" shall mean any work of authorship, including
   the original version of the Work and any modifications or additions
   to that Work or Derivative Works thereof, that is intentionally
   submitted to Licensor for inclusion in the Work by the copyright owner
   or by an individual or Legal Entity authorized to submit on behalf of
   the copyright owner. For the purposes of this definition, "submitted"
   means any form of electronic, verbal, or written communication sent
   to the Licensor or its representatives, including but not limited to
   communication on electronic mailing lists, source code control systems,
   and issue tracking systems that are managed by, or on behalf of, the
   Licensor for the purpose of discussing and improving the Work, but
   excluding communication that is conspicuously marked or otherwise
   designated in writing by the copyright owner as "Not a Contribution."

   "Contributor" shall mean Licensor and any individual or Legal Entity
   on behalf of whom a Contribution has been received by Licensor and
   subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of
   this License, each Contributor hereby grants to You a perpetual,
   worldwide, non-exclusive, no-charge, royalty-free, irrevocable
   copyright license to reproduce, prepare Derivative Works of,
   publicly display, publicly perform, sublicense, and distribute the
   Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of
   this License, each Contributor hereby grants to You a perpetual,
   worldwide, non-exclusive, no-charge, royalty-free, irrevocable
   (except as stated in this section) patent license to make, have made,
   use, offer to sell, sell, import, and otherwise transfer the Work,
   where such license applies only to those patent claims licensable
   by such Contributor that are necessarily infringed by their
   Contribution(s) alone or by combination of their Contribution(s)
   with the Work to which such Contribution(s) was submitted. If You
   institute patent litigation against any entity (including a
   cross-claim or counterclaim in a lawsuit) alleging that the Work
   or a Contribution incorporated within the Work constitutes direct
   or contributory patent infringement, then any patent licenses
   granted to You under this License for that Work shall terminate
   as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the
   Work or Derivative Works thereof in any medium, with or without
   modifications, and in Source or Object form, provided that You
   meet the following conditions:

   (a) You must give any other recipients of the Work or
       Derivative Works a copy of this License; and

   (b) You must cause any modified files to carry prominent notices
       stating that You changed the files; and

   (c) You must retain, in the Source form of any Derivative Works
       that You distribute, all copyright, patent, trademark, and
       attribution notices from the Source form of the Work,
       excluding those notices that do not pertain to any part of
       the Derivative Works; and

   (d) If the Work includes a "NOTICE" text file as part of its
       distribution, then any Derivative Works that You distribute must
       include a readable copy of the attribution notices contained
       within such NOTICE file, excluding those notices that do not
       pertain to any part of the Derivative Works, in at least one
       of the following places: within a NOTICE text file distributed
       as part of the Derivative Works; within the Source form or
       documentation, if provided along with the Derivative Works; or,
       within a display generated by the Derivative Works, if and
       wherever such third-party notices normally appear. The contents
       of the NOTICE file are for informational purposes only and
       do not modify the License. You may add Your own attribution
       notices within Derivative Works that You distribute, alongside
       or as an addendum to the NOTICE text from the Work, provided
       that such additional attribution notices cannot be construed
       as modifying the License.

   You may add Your own copyright statement to Your modifications and
   may provide additional or different license terms and conditions
   for use, reproduction, or distribution of Your modifications, or
   for any such Derivative Works as a whole, provided Your use,
   reproduction, and distribution of the Work otherwise complies with
   the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise,
   any Contribution intentionally submitted for inclusion in the Work
   by You to the Licensor shall be under the terms and conditions of
   this License, without any additional terms or conditions.
   Notwithstanding the above, nothing herein shall supersede or modify
   the terms of any separate license agreement you may have executed
   with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade
   names, trademarks, service marks, or product names of the Licensor,
   except as required for reasonable and customary use in describing the
   origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or
   agreed to in writing, Licensor provides the Work (and each
   Contributor provides its Contributions) on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
   implied, including, without limitation, any warranties or conditions
   of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
   PARTICULAR PURPOSE. You are solely responsible for determining the
   appropriateness of using or redistributing the Work and assume any
   risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory,
   whether in tort (including negligence), contract, or otherwise,
   unless required by applicable law (such as deliberate and grossly
   negligent acts) or agreed to in writing, shall any Contributor be
   liable to You for damages, including any direct, indirect, special,
   incidental, or consequential damages of any character arising as a
   result of this License or out of the use or inability to use the
   Work (including but not limited to damages for loss of goodwill,
   work stoppage, computer failure or malfunction, or any and all
   other commercial damages or losses), even if such Contributor
   has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing
   the Work or Derivative Works thereof, You may choose to offer,
   and charge a fee for, acceptance of support, warranty, indemnity,
   or other liability obligations and/or rights consistent with this
   License. However, in accepting such obligations, You may act only
   on Your own behalf and on Your sole responsibility, not on behalf
   of any other Contributor, and only if You agree to indemnify,
   defend, and hold each Contributor harmless for any liability
   incurred by, or claims asserted against, such Contributor by reason
   of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

   To apply the Apache License to your work, attach the following
   boilerplate notice, with the fields enclosed by brackets "[]"
   replaced with your own identifying information. (Don't include
   the brackets!)  The text should be enclosed in the appropriate
   comment syntax for the file format. We also recommend that a
   file or class name and description of purpose be included on the
   same "printed page" as the copyright notice for easier
   identification within third-party archives.

Copyright (c) Embassy project contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

MIT license

Copyright (c) Embassy project contributors

Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:

The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

## github-device-flow

**Source**:
- Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs
- https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs

MIT License

Copyright (c) 2022 Jake Wilkins

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


## Candle - Pickle Reader

**Source**: https://github.com/huggingface/candle/blob/main/candle-core/src/pickle.rs

This project includes code from Candle by Hugging Face, licensed under both MIT and Apache 2.0 licenses.

**MIT License**: https://github.com/huggingface/candle/blob/main/LICENSE-MIT

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

**Apache License 2.0**: https://github.com/huggingface/candle/blob/main/LICENSE-APACHE

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


## ICU

UNICODE LICENSE V3

COPYRIGHT AND PERMISSION NOTICE

Copyright © 2016-2024 Unicode, Inc.

NOTICE TO USER: Carefully read the following legal agreement. BY
DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR
SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE
TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT
DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE.

Permission is hereby granted, free of charge, to any person obtaining a
copy of data files and any associated documentation (the "Data Files") or
software and any associated documentation (the "Software") to deal in the
Data Files or Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, and/or sell
copies of the Data Files or Software, and to permit persons to whom the
Data Files or Software are furnished to do so, provided that either (a)
this copyright and permission notice appear with all copies of the Data
Files or Software, or (b) this copyright and permission notice appear in
associated Documentation.

THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF
THIRD PARTY RIGHTS.

IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE
BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES,
OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,
ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA
FILES OR SOFTWARE.

Except as contained in this notice, the name of a copyright holder shall
not be used in advertising or otherwise to promote the sale, use or other
dealings in these Data Files or Software without prior written
authorization of the copyright holder.


================================================
FILE: POEM.md
================================================
# BURN: Burn Unstoppable Rusty Neurons

In the realm of circuits and code,  
A fiery forge ignites to bear its load,  
A framework born, BURN it be named,  
Unstoppable Rusty Neurons, untamed.

From silicon synapses, connections spire,  
A digital cortex, setting minds afire,  
In the vast expanse of deep learning's sea,  
A beacon of progress, BURN comes to be.

Oh, rusty neurons, forged in the flame,  
Unyielding in purpose, undaunted by name,  
Through layers of logic and intricate art,  
You weave and entwine, each playing its part.

With algorithms profound, and data refined,  
In ceaseless pursuit of knowledge to find,  
BURN paves a path to enlightenment, bright,  
A testament to the wonders of human foresight.

In neural networks deep, where wisdom resides,  
The dance of nodes and edges presides,  
With loss and gradients, BURN takes its stride,  
A journey towards truth, with AI as our guide.

No barriers hold back the curious mind,  
As BURN seeks the answers we yearn to find,  
Unstoppable, relentless, in pursuit of the unknown,  
Our collective intellect, within it, has grown.

So sing we the praises of BURN's fiery might,  
An ode to the sparks that set the dark alight,  
To the rusty neurons, unstoppable and true,  
A testament to the power of dreams, to breakthrough.

(ChatGPT (model=gpt-4) with prompt:
Write a poem about "BURN: Burn Unstoppable Rusty Neurons" deep
learning neural network framework)


================================================
FILE: README.md
================================================
<div align="center">
<img src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/logo-burn-neutral.webp" width="350px"/>

[![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/uPEBbYYDB6)
[![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn)
[![Minimum Supported Rust Version](https://img.shields.io/crates/msrv/burn)](https://crates.io/crates/burn)
[![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://burn.dev/docs/burn)
[![Test Status](https://github.com/tracel-ai/burn/actions/workflows/test.yml/badge.svg)](https://github.com/tracel-ai/burn/actions/workflows/test.yml)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](#license)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tracel-ai/burn)

[<img src="https://www.runblaze.dev/ci-blaze-powered.png" width="125px"/>](https://www.runblaze.dev)

---

**Burn is a next generation Tensor Library and Deep Learning Framework that doesn't compromise on
<br /> flexibility, efficiency and portability.**

<br/>
</div>

<div align="left">

Burn is both a tensor library and a deep learning framework optimized for numerical computing, model
inference and model training. Burn leverages Rust to perform optimizations normally only available
in static-graph frameworks, offering optimal speed without impacting flexibility.

## Backend

<div align="left">
<img align="right" src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/backend-chip.png" height="96px"/>

Burn strives to be as fast as possible on as many hardwares as possible, with robust
implementations. We believe this flexibility is crucial for modern needs where you may train your
models in the cloud, then deploy on customer hardwares, which vary from user to user.

</div>

### Supported Backends

Most backends support all operating systems, so we don't mention them in the tables below.

**GPU Backends:**

|         | CUDA | ROCm | Metal | Vulkan | WebGPU | LibTorch |
| ------- | ---- | ---- | ----- | ------ | ------ | -------- |
| Nvidia  | ☑️   | -    | -     | ☑️     | ☑️     | ☑️       |
| AMD     | -    | ☑️   | -     | ☑️     | ☑️     | ☑️       |
| Apple   | -    | -    | ☑️    | -      | ☑️     | ☑️       |
| Intel   | -    | -    | -     | ☑️     | ☑️     | -        |
| Qualcom | -    | -    | -     | ☑️     | ☑️     | -        |
| Wasm    | -    | -    | -     | -      | ☑️     | -        |

**CPU Backends:**

|        | Cpu (CubeCL) | NdArray | LibTorch |
| ------ | ------------ | ------- | -------- |
| X86    | ☑️           | ☑️      | ☑️       |
| Arm    | ☑️           | ☑️      | ☑️       |
| Wasm   | -            | ☑️      | -        |
| no-std | -            | ☑️      | -        |

<br />

Compared to other frameworks, Burn has a very different approach to supporting many backends. By
design, most code is generic over the Backend trait, which allows us to build Burn with swappable
backends. This makes composing backend possible, augmenting them with additional functionalities
such as autodifferentiation and automatic kernel fusion.

<details>
<summary>
Autodiff: Backend decorator that brings backpropagation to any backend 🔄
</summary>
<br />

Contrary to the aforementioned backends, Autodiff is actually a backend _decorator_. This means that
it cannot exist by itself; it must encapsulate another backend.

The simple act of wrapping a base backend with Autodiff transparently equips it with
autodifferentiation support, making it possible to call backward on your model.

```rust
use burn::backend::{Autodiff, Wgpu};
use burn::tensor::{Distribution, Tensor};

fn main() {
    type Backend = Autodiff<Wgpu>;

    let device = Default::default();

    let x: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default, &device);
    let y: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default, &device).require_grad();

    let tmp = x.clone() + y.clone();
    let tmp = tmp.matmul(x);
    let tmp = tmp.exp();

    let grads = tmp.backward();
    let y_grad = y.grad(&grads).unwrap();
    println!("{y_grad}");
}
```

Of note, it is impossible to make the mistake of calling backward on a model that runs on a backend
that does not support autodiff (for inference), as this method is only offered by an Autodiff
backend.

See the [Autodiff Backend README](./crates/burn-autodiff/README.md) for more details.

</details>

<details>
<summary>
Fusion: Backend decorator that brings kernel fusion to all first-party backends
</summary>
<br />

This backend decorator enhances a backend with kernel fusion, provided that the inner backend
supports it. Note that you can compose this backend with other backend decorators such as Autodiff.
All first-party accelerated backends (like WGPU and CUDA) use Fusion by default (`burn/fusion`
feature flag), so you typically don't need to apply it manually.

```rust
#[cfg(not(feature = "fusion"))]
pub type Cuda<F = f32, I = i32> = CubeBackend<CudaRuntime, F, I, u8>;

#[cfg(feature = "fusion")]
pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<CubeBackend<CudaRuntime, F, I, u8>>;
```

Of note, we plan to implement automatic gradient checkpointing based on compute bound and memory
bound operations, which will work gracefully with the fusion backend to make your code run even
faster during training, see [this issue](https://github.com/tracel-ai/burn/issues/936).

See the [Fusion Backend README](./crates/burn-fusion/README.md) for more details.

</details>

<details>
<summary>
Router (Beta): Backend decorator that composes multiple backends into a single one
</summary>
<br />

That backend simplifies hardware operability, if for instance you want to execute some operations on
the CPU and other operations on the GPU.

```rust
use burn::tensor::{Distribution, Tensor};
use burn::backend::{
    NdArray, Router, Wgpu, ndarray::NdArrayDevice, router::duo::MultiDevice, wgpu::WgpuDevice,
};

fn main() {
    type Backend = Router<(Wgpu, NdArray)>;

    let device_0 = MultiDevice::B1(WgpuDevice::DiscreteGpu(0));
    let device_1 = MultiDevice::B2(NdArrayDevice::Cpu);

    let tensor_gpu =
        Tensor::<Backend, 2>::random([3, 3], burn::tensor::Distribution::Default, &device_0);
    let tensor_cpu =
        Tensor::<Backend, 2>::random([3, 3], burn::tensor::Distribution::Default, &device_1);
}

```

</details>

<details>
<summary>
Remote (Beta): Backend decorator for remote backend execution, useful for distributed computations
</summary>
<br />

That backend has two parts, one client and one server. The client sends tensor operations over the
network to a remote compute backend. You can use any first-party backend as server in a single line
of code:

```rust
fn main_server() {
    // Start a server on port 3000.
    burn::server::start::<burn::backend::Cuda>(Default::default(), 3000);
}

fn main_client() {
    // Create a client that communicate with the server on port 3000.
    use burn::backend::{Autodiff, RemoteBackend};

    type Backend = Autodiff<RemoteDevice>;

    let device = RemoteDevice::new("ws://localhost:3000");
    let tensor_gpu =
        Tensor::<Backend, 2>::random([3, 3], Distribution::Default, &device);
}

```

</details>

<br />

## Training & Inference

<div align="left">
<img align="right" src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-wall.png" height="96px"/>

The whole deep learning workflow is made easy with Burn, as you can monitor your training progress
with an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU
clusters.

Burn was built from the ground up with training and inference in mind. It's also worth noting how
Burn, in comparison to frameworks like PyTorch, simplifies the transition from training to
deployment, eliminating the need for code changes.

</div>

<div align="center">

<br />

<a href="https://www.youtube.com/watch?v=N9RM5CQbNQc" target="_blank">
    <img src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/burn-train-tui.png" alt="Burn Train TUI" width="75%">
  </a>
</div>

<br />

**Click on the following sections to expand 👇**

<details>
<summary>
Training Dashboard 📈
</summary>
<br />

As you can see in the previous video (click on the picture!), a new terminal UI dashboard based on
the [Ratatui](https://github.com/ratatui-org/ratatui) crate allows users to follow their training
with ease without having to connect to any external application.

You can visualize your training and validation metrics updating in real-time and analyze the
lifelong progression or recent history of any registered metrics using only the arrow keys. Break
from the training loop without crashing, allowing potential checkpoints to be fully written or
important pieces of code to complete without interruption 🛡

</details>

<details>
<summary>
ONNX Support 🐫
</summary>
<br />

Burn supports importing ONNX (Open Neural Network Exchange) models through the
[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate, allowing you to easily port models from
TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses Burn's native
APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly) and benefit
from all of Burn's optimizations like automatic kernel fusion.

Our ONNX support is further described in
[this section of the Burn Book 🔥](https://burn.dev/books/burn/onnx-import.html).

> **Note**: This crate is in active development and currently supports a
> [limited set of ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).

</details>

<details>
<summary>
Importing PyTorch or Safetensors Models 🚚
</summary>
<br />

You can load weights from PyTorch or Safetensors formats directly into your Burn-defined models.
This makes it easy to reuse existing models while benefiting from Burn's performance and deployment
features.

Learn more in the [Saving & Loading Models](https://burn.dev/books/burn/saving-and-loading.html)
section of the Burn Book.

</details>

<details>
<summary>
Inference in the Browser 🌐
</summary>
<br />

Several of our backends can run in WebAssembly environments: NdArray for CPU execution, and WGPU for
GPU acceleration via WebGPU. This means that you can run inference directly within a browser. We
provide several examples of this:

- [MNIST](./examples/mnist-inference-web) where you can draw digits and a small convnet tries to
  find which one it is! 2️⃣ 7️⃣ 😰
- [Image Classification](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web)
  where you can upload images and classify them! 🌄

</details>

<details>
<summary>
Embedded: <i>no_std</i> support ⚙️
</summary>
<br />

Burn's core components support [no_std](https://docs.rust-embedded.org/book/intro/no-std.html). This
means it can run in bare metal environment such as embedded devices without an operating system.

> As of now, only the NdArray backend can be used in a _no_std_ environment.

</details>

<br />

### Benchmarks

To evaluate performance across different backends and track improvements over time, we provide a
dedicated benchmarking suite.

Run and compare benchmarks using [burn-bench](https://github.com/tracel-ai/burn-bench).

> ⚠️ **Warning** When using one of the `wgpu` backends, you may encounter compilation errors related
> to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency
> chain. To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs`
> file:
>
> ```rust
> #![recursion_limit = "256"]
> ```
>
> The default recursion limit (128) is often just below the required depth (typically 130-150) due
> to deeply nested associated types and trait bounds.

## Getting Started

<div align="left">
<img align="right" src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-walking.png" height="96px"/>

Just heard of Burn? You are at the right place! Just continue reading this section and we hope you
can get on board really quickly.

</div>

<details>
<summary>
The Burn Book 🔥
</summary>
<br />

To begin working effectively with Burn, it is crucial to understand its key components and
philosophy. This is why we highly recommend new users to read the first sections of
[The Burn Book 🔥](https://burn.dev/books/burn/). It provides detailed examples and explanations
covering every facet of the framework, including building blocks like tensors, modules, and
optimizers, all the way to advanced usage, like coding your own GPU kernels.

> The project is constantly evolving, and we try as much as possible to keep the book up to date
> with new additions. However, we might miss some details sometimes, so if you see something weird,
> let us know! We also gladly accept Pull Requests 😄

</details>

<details>
<summary>
Examples 🙏
</summary>
<br />

Let's start with a code snippet that shows how intuitive the framework is to use! In the following,
we declare a neural network module with some parameters along with its forward pass.

```rust
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: nn::Linear<B>,
    linear_outer: nn::Linear<B>,
    dropout: nn::Dropout,
    gelu: nn::Gelu,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}
```

We have a somewhat large amount of [examples](./examples) in the repository that shows how to use
the framework in different scenarios.

Following [the book](https://burn.dev/books/burn/):

- [Basic Workflow](./examples/guide) : Creates a custom CNN `Module` to train on the MNIST dataset
  and use for inference.
- [Custom Training Loop](./examples/custom-training-loop) : Implements a basic training loop instead
  of using the `Learner`.
- [Custom WGPU Kernel](./examples/custom-wgpu-kernel) : Learn how to create your own custom
  operation with the WGPU backend.

Additional examples:

- [Custom CSV Dataset](./examples/custom-csv-dataset) : Implements a dataset to parse CSV data for a
  regression task.
- [Regression](./examples/simple-regression) : Trains a simple MLP on the California Housing dataset
  to predict the median house value for a district.
- [Custom Image Dataset](./examples/custom-image-dataset) : Trains a simple CNN on custom image
  dataset following a simple folder structure.
- [Custom Renderer](./examples/custom-renderer) : Implements a custom renderer to display the
  [`Learner`](./building-blocks/learner.md) progress.
- [Image Classification Web](./examples/image-classification-web) : Image classification web browser
  demo using Burn, WGPU and WebAssembly.
- [MNIST Inference on Web](./examples/mnist-inference-web) : An interactive MNIST inference demo in
  the browser. The demo is available [online](https://burn.dev/demo/).
- [MNIST Training](./examples/mnist) : Demonstrates how to train a custom `Module` (MLP) with the
  `Learner` configured to log metrics and keep training checkpoints.
- [PyTorch Import Inference](./examples/import-model-weights) : Imports a PyTorch model pre-trained
  on MNIST to perform inference on a sample image with Burn.
- [Text Classification](./examples/text-classification) : Trains a text classification transformer
  model on the AG News or DbPedia dataset. The trained model can then be used to classify a text
  sample.
- [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the
  DbPedia dataset.
- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits
  based on MNIST.

For more practical insights, you can clone the repository and run any of them directly on your
computer!

</details>

<details>
<summary>
Pre-trained Models 🤖
</summary>
<br />

We keep an updated and curated list of models and examples built with Burn, see the
[tracel-ai/models repository](https://github.com/tracel-ai/models) for more details.

Don't see the model you want? Don't hesitate to open an issue, and we may prioritize it. Built a
model using Burn and want to share it? You can also open a Pull Request and add your model under the
community section!

</details>

<details>
<summary>
Why use Rust for Deep Learning? 🦀
</summary>
<br />

Deep Learning is a special form of software where you need very high level abstractions as well as
extremely fast execution time. Rust is the perfect candidate for that use case since it provides
zero-cost abstractions to easily create neural network modules, and fine-grained control over memory
to optimize every detail.

It's important that a framework be easy to use at a high level so that its users can focus on
innovating in the AI field. However, since running models relies so heavily on computations,
performance can't be neglected.

To this day, the mainstream solution to this problem has been to offer APIs in Python, but rely on
bindings to low-level languages such as C/C++. This reduces portability, increases complexity and
creates frictions between researchers and engineers. We feel like Rust's approach to abstractions
makes it versatile enough to tackle this two languages dichotomy.

Rust also comes with the Cargo package manager, which makes it incredibly easy to build, test, and
deploy from any environment, which is usually a pain in Python.

Although Rust has the reputation of being a difficult language at first, we strongly believe it
leads to more reliable, bug-free solutions built faster (after some practice 😅)!

</details>

<br />

> **Deprecation Note**<br />Since `0.14.0`, the internal structure for tensor data has changed. The
> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new
> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and
> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to
> `TensorData`.

<!-- >
> In the event that you are trying to load a model record saved in a previous version, make sure to
> enable the `record-backward-compat` feature using a previous version of burn (<=0.16.0). Otherwise,
> the record won't be deserialized correctly and you will get an error message (which will also point
> you to the backward compatible feature flag). The backward compatibility was maintained for
> deserialization (loading), so as soon as you have saved the record again it will be saved according
> to the new structure and you will be able to upgrade to this version. Please note that binary formats
> are not backward compatible. Thus, you will need to load your record in a previous version and save it
> to another of the self-describing record formats before using a compatible version (as described) with the
> `record-backward-compat` feature flag. -->

<details id="deprecation">
<summary>
Loading Model Records From Previous Versions ⚠️
</summary>
<br />

In the event that you are trying to load a model record saved in a version older than `0.14.0`, make
sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat`
feature flag.

```
features = [..., "record-backward-compat"]
```

Otherwise, the record won't be deserialized correctly and you will get an error message. This error
will also point you to the backward compatible feature flag.

The backward compatibility was maintained for deserialization when loading records. Therefore, as
soon as you have saved the record again it will be saved according to the new structure and you can
upgrade back to the current version

Please note that binary formats are not backward compatible. Thus, you will need to load your record
in a previous version and save it in any of the other self-describing record format (e.g., using the
`NamedMpkFileRecorder`) before using a compatible version (as described) with the
`record-backward-compat` feature flag.

</details>

## Community

<div align="left">
<img align="right" src="https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-community.png" height="96px"/>

If you are excited about the project, don't hesitate to join our
[Discord](https://discord.gg/uPEBbYYDB6)! We try to be as welcoming as possible to everybody from
any background. You can ask your questions and share what you built with the community!

</div>

<br/>

**Contributing**

Before contributing, please read the [Contributing Guidelines](./CONTRIBUTING.md) and our
[Code of Conduct](./CODE-OF-CONDUCT.md). The [Contributor Book](https://burn.dev/contributor-book/)
covers architecture, environment setup, and guides for common tasks.

## Status

Burn is currently in active development, and there will be breaking changes. While any resulting
issues are likely to be easy to fix, there are no guarantees at this stage.

## License

Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull
request is assumed to signal agreement with these licensing terms.

</div>


================================================
FILE: _typos.toml
================================================
[default]
extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"]

[default.extend-identifiers]
UE4M3 = "UE4M3"
UE8M0 = "UE8M0"
ue8m0 = "ue8m0"

[files]
extend-exclude = [
    "*.onnx",
    "*.proto",
    "assets/ModuleSerialization.xml",
]

[default.extend-words]
# Don't correct "arange" which is intentional
arange = "arange"
# Don't correct "convnet" (convolutional network)
convnet = "convnet"


================================================
FILE: benchmarks.toml
================================================
[environment]
gcp_gpu_attached = true
gcp_image_family = "tracel-ci-ubuntu-2404-amd64-nvidia"
# https://cloud.google.com/compute/docs/accelerator-optimized-machines
# put the faster machine on first place for possibly faster 'Benchmarks Started' feedback in PRs
gcp_machine_types = [
  "a2-highgpu-1g", # 1 A100 40GB (listed as a2 standard)
  "g2-standard-4", # 1 L4 24GB
]
# define the available zones for each machine type
# be sure to check what machine types are available in each region
# https://cloud.google.com/compute/docs/gpus/gpu-regions-zones#view-using-table
gcp_zones = [
  # a2-highgpu-1g
  [
  "asia-northeast1-a",
  "asia-northeast1-c",
  "asia-northeast3-b",
  "asia-southeast1-b",
  "asia-southeast1-c",
  "europe-west4-a",
  "europe-west4-b",
  "us-central1-a",
  "us-central1-b",
  "us-central1-c",
  "us-central1-f",
  "us-east1-b",
  "us-west1-b",
  "us-west3-b",
  "us-west4-b"
  ],
  # g2-standard-4
  [
  "northamerica-northeast2-a",
  "northamerica-northeast2-b",
  "us-central1-a",
  "us-central1-b",
  "us-central1-c",
  "us-east1-b",
  "us-east1-c",
  "us-east1-d",
  "us-east4-a",
  "us-east4-c",
  "us-west1-a",
  "us-west1-b",
  "us-west1-c",
  "us-west4-a",
  "us-west4-c"
  ],
]
repo_full = "tracel-ai/burn"
rust_toolchain = "stable"
rust_version = "stable"

[burn-bench]
github_organization = "tracel-ai"
github_repository = "burn-bench"
github_branch = "main"
github_workflow = "benchmarks.yml"
# vulkan autotune seems to take ages, disabling it for now
# backends = ["cuda-fusion", "vulkan-fusion", "wgpu-fusion"]
backends = ["cuda-fusion", "cuda"]
benches = ["autodiff",
  "binary",
  "bool_select",
  "conv-transpose2d",
  "conv-transpose3d",
  "conv2d",
  "conv3d",
  "custom-gelu",
  "data",
  "load-record",
  "matmul-fused",
  "matmul",
  "max-pool2d",
  "random",
  "reduce",
  "softmax",
  "transformer-encoder",
  "unary"
]
dtypes = ["f16"]


================================================
FILE: burn-book/.gitignore
================================================
target

# MacOS temp file
.DS_Store

book-test
guide/book

.vscode
tests/burn-book/book/
book/

# Ignore Jetbrains specific files.
.idea/

# Ignore Vim temporary and swap files.
*.sw?
*~

================================================
FILE: burn-book/.prettierrc.json
================================================
{
    "printWidth": 100,
    "proseWrap": "always"
}

================================================
FILE: burn-book/book.toml
================================================
[book]
authors = [
    "Wouter Doppenberg",
    "Nathaniel Simard",
    "Louis Fortier-Dubois",
    "Dilshod Tadjibaev",
    "Guillaume Lagrange",
    "Sylvain Benner",
    "Bjorn Beishline"
]
language = "en"
src = "src"
title = "The Burn Book 🔥"

[output.html]
mathjax-support = true


================================================
FILE: burn-book/src/SUMMARY.md
================================================
- [Overview](./overview.md)
- [Why Burn?](./motivation.md)
- [Getting started](./getting-started.md)
  - [Examples](./examples.md)
- [Basic Workflow: From Training to Inference](./basic-workflow/README.md)
  - [Model](./basic-workflow/model.md)
  - [Data](./basic-workflow/data.md)
  - [Training](./basic-workflow/training.md)
  - [Backend](./basic-workflow/backend.md)
  - [Inference](./basic-workflow/inference.md)
- [Building Blocks](./building-blocks/README.md)
  - [Backend](./building-blocks/backend.md)
  - [Tensor](./building-blocks/tensor.md)
  - [Autodiff](./building-blocks/autodiff.md)
  - [Module](./building-blocks/module.md)
  - [Learner](./building-blocks/learner.md)
  - [Metric](./building-blocks/metric.md)
  - [Config](./building-blocks/config.md)
  - [Record](./building-blocks/record.md)
  - [Dataset](./building-blocks/dataset.md)
- [Performance](./performance/README.md)
  - [Good practices](./performance/good-practices/README.md)
    - [Asynchronous Execution](./performance/good-practices/asynchronous-execution.md)
    - [Kernel Fusion](./performance/good-practices/kernel-fusion.md)
    - [Kernel Selection](./performance/good-practices/kernel-selection.md)
  - [Quantization](./performance/quantization.md)
  - [Distributed Computing](./performance/distributed-computing.md)
- [Custom Training Loop](./custom-training-loop.md)
- [Saving & Loading Models](./saving-and-loading.md)
- [ONNX Import](./onnx-import.md)
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
- [Advanced](./advanced/README.md)
  - [Backend Extension](./advanced/backend-extension/README.md)
    - [Custom `CubeCL` Kernel](./advanced/backend-extension/custom-cubecl-kernel.md)
    - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
  - [Custom Optimizer]()
  - [WebAssembly](./advanced/web-assembly.md)
  - [No-Std](./advanced/no-std.md)


================================================
FILE: burn-book/src/advanced/README.md
================================================
# Advanced

In this section, we will go into advanced topics that extend beyond basic usage. Given Burn's
exceptional flexibility, a lot of advanced use cases become possible.

Before going through this section, we strongly recommend exploring the
[basic workflow](../basic-workflow/) section and the
[building blocks](../building-blocks/) section. Establishing a solid understanding of how
the framework operates is crucial to comprehending the advanced concepts presented here. While you
have the freedom to explore the advanced sections in any order you prefer, it's important to note
that this section is not intended to be linear, contrary to preceding sections. Instead, it serves
as a repository of use cases that you can refer to for guidance as needed.


================================================
FILE: burn-book/src/advanced/backend-extension/README.md
================================================
# Backend Extension

Burn aims to be the most flexible deep learning framework. While it's crucial to maintain
compatibility with a wide variety of backends, Burn provides the ability to extend the functionality
of a backend implementation to suit your modeling requirements. This versatility is advantageous in
numerous ways, such as supporting custom operations like flash attention or manually fusing
operations for enhanced performance.

In this section, we will go into the process of extending a backend, providing multiple examples.
But before we proceed, let's establish the fundamental principles that will empower you to craft
your own backend extensions.

As you can observe, most types in Burn are generic over the Backend trait. This might give the
impression that Burn operates at a high level over the backend layer. However, making the trait
explicit instead of being chosen via a compilation flag was a thoughtful design decision. This
explicitness does not imply that all backends must be identical; rather, it offers a great deal of
flexibility when composing backends. The autodifferentiation backend trait (see
[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has
been extended to enable gradient computation with backpropagation. Furthermore, this design allows
you to create your own backend extension. To achieve this, you need to design your own backend trait
specifying which functions should be supported.

```rust, ignore
pub trait Backend: burn::tensor::backend::Backend {
    fn my_new_function(tensor: B::TensorPrimitive<2>) -> B::TensorPrimitive<2> {
        // You can define a basic implementation reusing the Burn Backend API.
        // This can be useful since all backends will now automatically support
        // your model. But performance can be improved for this new
        // operation by implementing this block in specific backends.
    }
}
```

You can then implement your new custom backend trait for any backend that you want to support:

```rust, ignore
impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
   fn my_new_function(tensor: TchTensor<E, 2>) -> TchTensor<E, 2> {
      // My Tch implementation
   }
}

impl<E: NdArrayElement> Backend for burn_ndarray::NdArray<E> {
    // No specific implementation, but the backend can still be used.
}
```

You can support the backward pass using the same pattern.

```rust, ignore
impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {
    // No specific implementation; autodiff will work with the default
    // implementation. Useful if you still want to train your model, but
    // observe performance gains mostly during inference.
}

impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {
   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {
      // My own backward implementation, generic over my custom Backend trait.
      //
      // You can add a new method `my_new_function_backward` to your custom backend
      // trait if you want to invoke a custom kernel during the backward pass.
   }
}

impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {
   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {
      // My own backward implementation, generic over a backend implementation.
      //
      // This is another way to call a custom kernel for the backward pass that
      // doesn't require the addition of a new `backward` function in the custom backend.
      // This is useful if you don't want all backends to support training, reducing
      // the need for extra code when you know your model will only be trained on one
      // specific backend.
   }
}
```

The specifics of each implementation will be covered by the examples provided in this section. The
`cubecl` compiler frontend is the recommended method of implementing custom kernels, since it
supports multiple backends, including `wgpu` and `CUDA`, and is the way first-party `burn` kernels
are written.


================================================
FILE: burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md
================================================
# Custom CubeCL Kernel

In this section, you will learn how to create your own custom operation by writing your own kernel
with the cubecl compiler frontend. We will take the example of a common workflow in the deep
learning field, where we create a kernel to fuse multiple operations together. Note that `burn` does
this automatically, but a manual implementation might be more efficient in some cases. We will fuse
a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
various models. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-cubecl-kernel).

> Note: CubeCL is in active development, so this section may be outdated.

## Custom Backend Trait

First, we need to determine the type signature of our newly created operation by defining our custom
backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which
encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid
the ugly disambiguation with associated types.

```rust, ignore
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self>;
}

/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
```

In our project, we can use these traits instead of the
`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs
typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.
Therefore, we can encapsulate our newly defined backend traits with functions that expose new
operations while maintaining a consistent API.

```rust, ignore
/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let output = B::fused_matmul_add_relu(
        lhs.into_primitive().tensor(),
        rhs.into_primitive().tensor(),
        bias.into_primitive().tensor(),
    );

    Tensor::from_primitive(TensorPrimitive::Float(output))
}

/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let x = lhs.matmul(rhs) + bias;

    activation::relu(x)
}

```

Note that we also provide a reference implementation for testing purposes, which allows us to easily
validate our new implementation. While not mandatory, having a reference implementation can be
valuable, especially in projects where creating a reference implementation solely using basic tensor
operations is feasible.

## Forward Kernel

Now, let's proceed to write the fused kernel using the `cubecl` compiler frontend. To keep things
simple, we'll create a straightforward matmul kernel without employing any intricate techniques. We
won't delve into the details of the `cube` macro, but if you're interested to learn more, please see
[`cubecl` Book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book).
The actual matmul, add and relu computations are found at the end, after an extensive prelude that
serves to correctly map each compute unit to the data it is responsible for, with support for
batches.

```rust, ignore
use cubecl::{cube, prelude::*};

#[cube(launch)]
pub fn fused_matmul_add_relu_kernel<F: Float>(
    lhs: &Tensor<F>,
    rhs: &Tensor<F>,
    bias: &Tensor<F>,
    output: &mut Tensor<F>,
) {
    let row = ABSOLUTE_POS_X;
    let col = ABSOLUTE_POS_Y;
    let batch = ABSOLUTE_POS_Z;

    let n_rows = output.shape(output.rank() - 2);
    let n_cols = output.shape(output.rank() - 1);
    let dim_k = rhs.shape(rhs.rank() - 1);

    if row >= n_rows || col >= n_cols {
        return;
    }

    let offset_output = batch * n_rows * n_cols;
    let mut offset_lhs = 0;
    let mut offset_rhs = 0;

    let batch_dims = output.rank() - 2;
    for dim in 0..batch_dims {
        offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
        offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
    }

    let mut sum = F::new(0.0);
    for k in 0..dim_k {
        let lhs_index = row * dim_k + k;
        let rhs_index = k * n_cols + col;

        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
    }

    let out_index = row * n_cols + col;
    let index = offset_output + out_index;

    output[index] = F::max(sum + bias[index], F::new(0.0));
}
```

Now, let's move on to the next step, which involves implementing the remaining code to launch the
kernel. We'll go into implementing our custom backend trait for the generic JIT backend. This
automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion.

```rust, ignore
/// Implement our custom backend trait for the generic `CubeBackend`.
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Backend
    for CubeBackend<R, F, I, BT>
{
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Define cube dim, hardcoded for simplicity.
        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };

        lhs.assert_is_on_same_device(&rhs);
        lhs.assert_is_on_same_device(&bias);

        // For simplicity, make sure each tensor is continuous.
        let lhs = into_contiguous(lhs);
        let rhs = into_contiguous(rhs);
        let bias = into_contiguous(bias);

        // Get the matmul relevant shapes.
        let ndims = lhs.shape.num_dims();
        let num_rows = lhs.shape[ndims - 2];
        let num_cols = rhs.shape[ndims - 1];

        // Compute shape of output, while tracking number of batches.
        let mut num_batches = 1;
        let mut shape_out = vec![0; ndims];
        for i in shape_out.clone().into_iter().take(ndims - 2) {
            shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]);
            num_batches *= shape_out[i];
        }
        shape_out[ndims - 2] = num_rows;
        shape_out[ndims - 1] = num_cols;
        let shape_out = Shape::from(shape_out);

        // Create a buffer for the output tensor.
        let buffer = lhs
            .client
            .empty(shape_out.num_elements() * core::mem::size_of::<F>());

        // Create the output tensor primitive.
        let output = CubeTensor::new_contiguous(
            lhs.client.clone(),
            lhs.device.clone(),
            shape_out,
            buffer,
            F::dtype(),
        );

        // Declare the wgsl workgroup with the number of cubes in x, y and z.
        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
        let cube_count =
            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);

        // Execute lazily the kernel with the launch information and the given buffers. For
        // simplicity, no vectorization is performed
        fused_matmul_add_relu_kernel::launch::<F, R>(
            &lhs.client,
            cube_count,
            cube_dim,
            lhs.into_tensor_arg(),
            rhs.into_tensor_arg(),
            bias.into_tensor_arg(),
            output.clone().into_tensor_arg(),
        );

        // Return the output tensor.
        output
    }
}
```

In the preceding code block, we demonstrated how to launch the kernel that modifies the correct
buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the
capability to execute any mutable operation on any buffer. While this isn't a problem in the
previous scenario where we only modify the newly created output buffer, it is wise to keep this in
mind.

## Backward

Now that the custom backend trait is implemented for the JIT backend, you can use it to invoke the
`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.
If your use case does not extend beyond inference, there is no need to implement any of the
following code.

For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is
actually generic over the backend. Instead of crafting our own `cubecl` kernel for the backward
pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic
operations.

```rust, ignore
// Implement our custom backend trait for any backend that also implements our custom backend trait.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Create our zero-sized type that will implement the Backward trait.
        #[derive(Debug)]
        struct FusedMatmulAddReluBackward;

        // Implement the backward trait for the given backend B, the node gradient
        // with three other gradients to calculate (lhs, rhs, and bias).
        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
            // Our state that we must build during the forward pass to compute the backward pass.
            //
            // Note that we could improve the performance further by only keeping the state of
            // tensors that are tracked, improving memory management, but for simplicity, we avoid
            // that part.
            type State = (NodeId, NodeId, FloatTensor<B>, Shape);

            fn backward(
                self,
                ops: Ops<Self::State, 3>,
                grads: &mut Gradients,
                checkpointer: &mut Checkpointer,
            ) {
                // Get the nodes of each variable.
                let [node_lhs, node_rhs, node_bias] = ops.parents;
                // Fetch the gradient for the current node.
                let grad = grads.consume::<B>(&ops.node);

                // Set our state.
                let (lhs_state, rhs_state, output, shape_bias) = ops.state;
                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);

                // Fetch shapes of our tensor to support broadcasting.
                let shape_lhs = lhs.shape();
                let shape_rhs = rhs.shape();

                // Compute the gradient of the output using the already existing `relu_backward`
                // function in the basic Burn backend trait.
                let grad_output = B::relu_backward(output, grad);

                // Compute the lhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_lhs = broadcast_shape::<B>(
                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
                    &shape_lhs,
                );
                // Compute the rhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_rhs = broadcast_shape::<B>(
                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
                    &shape_rhs,
                );
                // The add derivative is only 1, so we just need to support broadcasting to
                // compute the bias gradient.
                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);

                // Register the gradient for each variable based on whether they are marked as
                // `tracked`.
                if let Some(node) = node_bias {
                    grads.register::<B>(node.id, grad_bias);
                }
                if let Some(node) = node_lhs {
                    grads.register::<B>(node.id, grad_lhs);
                }
                if let Some(node) = node_rhs {
                    grads.register::<B>(node.id, grad_rhs);
                }
            }
        }

        // Prepare a stateful operation with each variable node and corresponding graph.
        //
        // Each node can be fetched with `ops.parents` in the same order as defined here.
        match FusedMatmulAddReluBackward
            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
            // Marks the operation as compute bound, meaning it will save its
            // state instead of recomputing itself during checkpointing
            .compute_bound()
            .stateful()
        {
            OpsKind::Tracked(mut prep) => {
                // When at least one node is tracked, we should register our backward step.

                // The state consists of what will be needed for this operation's backward pass.
                // Since we need the parents' outputs, we must checkpoint their ids to retrieve
                // their node output at the beginning of the backward pass. We can also save
                // utility data such as the bias shape. If we also need this operation's output,
                // we can either save it in the state or recompute it.
                // during the backward pass. Here we choose to save it in the state because it's a
                // compute bound operation.
                let lhs_state = prep.checkpoint(&lhs);
                let rhs_state = prep.checkpoint(&rhs);
                let bias_shape = bias.primitive.shape();

                let output = B::fused_matmul_add_relu(
                    lhs.primitive.clone(),
                    rhs.primitive.clone(),
                    bias.primitive,
                );

                let state = (lhs_state, rhs_state, output.clone(), bias_shape);

                prep.finish(state, output)
            }
            OpsKind::UnTracked(prep) => {
                // When no node is tracked, we can just compute the original operation without
                // keeping any state.
                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
                prep.finish(output)
            }
        }
    }
}
```

The previous code is self-documented to make it clearer, but here is what it does in summary:

We define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to
benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be
implemented. This is achieved using a comprehensive match statement block where computation is
delegated to the inner backend, while keeping track of a state. The state comprises any information
relevant to the backward pass, such as input and output tensors, along with the bias shape. When an
operation isn't tracked (meaning there won't be a backward pass for this specific operation in the
graph), storing a state becomes unnecessary, and we simply perform the forward computation.

The backward pass uses the gradient obtained from the preceding node in the computation graph. It
calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the
derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in
gradients for both input tensors and the bias, which are registered for consumption by subsequent
operation nodes.

The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.

```rust, ignore
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend
    for Autodiff<CubeBackend<R, F, I, BT>>
{
}
```

## Conclusion

In this guide, we've implemented a fused kernel using the `cubecl` compiler frontend, enabling
execution on any GPU and any `cubecl` backend. By delving into the inner workings of both the JIT
backend and the autodiff backend, we've gained a deeper understanding of these systems.

While extending a backend may be harder than working with straightforward tensors, the benefits can
be worth it. This approach enables the crafting of custom models with greater control over
execution, which can potentially greatly enhance the performance of your models.

As we conclude this guide, we hope that you have gained insights into Burn's world of backend
extensions, and that it will help you to unleash the full potential of your projects.


================================================
FILE: burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md
================================================
# Custom WGPU Kernel

In this section, you will learn how to create your own custom operation by writing your own kernel
with the WGPU backend. We will take the example of a common workflow in the deep learning field,
where we create a kernel to fuse multiple operations together. Note that `burn` does this
automatically, but a manual implementation might be more efficient in some cases. We will fuse a
matmul kernel followed by an addition and the ReLU activation function, which is commonly found in
various models. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-wgpu-kernel).

## Custom Backend Trait

First, we need to determine the type signature of our newly created operation by defining our custom
backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which
encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid
the ugly disambiguation with associated types.

```rust, ignore
/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self>;
}

/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
```

In our project, we can use these traits instead of the
`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs
typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.
Therefore, we can encapsulate our newly defined backend traits with functions that expose new
operations while maintaining a consistent API.

```rust, ignore
/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let output = B::fused_matmul_add_relu(
        lhs.into_primitive().tensor(),
        rhs.into_primitive().tensor(),
        bias.into_primitive().tensor(),
    );

    Tensor::from_primitive(TensorPrimitive::Float(output))
}

/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let x = lhs.matmul(rhs) + bias;

    activation::relu(x)
}

```

Note that we also provide a reference implementation for testing purposes, which allows us to easily
validate our new implementation. While not mandatory, having a reference implementation can be
valuable, especially in projects where creating a reference implementation solely using basic tensor
operations is feasible.

## Forward Kernel

Now, let's proceed to write the fused kernel using the WGSL shading language. To keep things simple,
we'll create a straightforward matmul kernel without employing any intricate techniques. Although we
won't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we
still provide the implementation below for readers who are curious. The actual matmul, add and relu
computations are found at the end, after an extensive overhead whose use is to correctly map each
compute unit to the data it is responsible of, with support for batches.

```wgsl, ignore
@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;

@group(0)
@binding(1)
var<storage, read_write> rhs: array<{{ elem }}>;

@group(0)
@binding(2)
var<storage, read_write> bias: array<{{ elem }}>;

@group(0)
@binding(3)
var<storage, read_write> output: array<{{ elem }}>;

@group(0)
@binding(4)
var<storage, read_write> info: array<u32>;

const BLOCK_SIZE = {{ workgroup_size_x }}u;

@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_index) local_idx: u32,
    @builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
    // Indices
    let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
    let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
    let batch = global_id.z;

    // Basic information
    let dim = info[0];
    let n_rows = info[6u * dim - 1u];
    let n_cols = info[6u * dim];
    let K = info[5u * dim - 1u];

    // Returns if outside the output dimension
    if row >= n_rows || col >= n_cols {
        return;
    }

    // Calculate the corresponding offsets with support for broadcasting.
    let offset_output = batch * n_rows * n_cols;
    var offset_lhs: u32 = 0u;
    var offset_rhs: u32 = 0u;

    let batch_dims = dim - 2u;
    for (var b: u32 = 1u; b <= batch_dims; b++) {
        let stride_lhs = info[b];
        let stride_rhs = info[b + dim];
        let stride_output = info[b + 2u * dim];
        let shape_lhs = info[b + 3u * dim];
        let shape_rhs = info[b + 4u * dim];

        offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
        offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
    }

    // Basic matmul implementation
    var sum = 0.0;
    for (var k: u32 = 0u; k < K; k++) {
        let lhs_index = row * K + k;
        let rhs_index = k * n_cols + col;

        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
    }

    let output_index = row * n_cols + col;
    let index = offset_output + output_index;

    // Add and ReLU
    output[index] = max(sum + bias[index], 0.0);
}
```

Now, let's move on to the next step, which involves implementing the remaining code to launch the
kernel. The initial part entails loading the template and populating it with the appropriate
variables. The `register(name, value)` method simply replaces occurrences of `{{ name }}` in the
above WGSL code with some other string before it is compiled. In order to use templating utilities,
you will have to activate the `template` feature of Burn in your `cargo.toml`.

```rust, ignore
// Source the kernel written in WGSL.
kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl");

// Define our kernel type with cube information.
#[derive(new, Debug)]
struct FusedMatmulAddRelu<E: FloatElement> {
    cube_dim: CubeDim,
    _elem: PhantomData<E>,
}

// Implement the dynamic kernel trait for our kernel type.
impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
    fn source(&self) -> SourceTemplate {
        // Extend our raw kernel with cube size information using the
        // `SourceTemplate` trait.
        FusedMatmulAddReluRaw::new()
            .source()
            .register("workgroup_size_x", self.cube_dim.x.to_string())
            .register("workgroup_size_y", self.cube_dim.y.to_string())
            .register("elem", E::type_name())
            .register("int", "i32")
    }

    fn id(&self) -> cubecl::KernelId {
        cubecl::KernelId::new::<Self>().info(self.cube_dim)
    }
}
```

Subsequently, we'll go into implementing our custom backend trait for the WGPU backend. Note that we
won't go into supporting the `fusion` feature flag in this tutorial, so we implement the trait for
the raw `WgpuBackend` type.

```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<F: FloatElement, I: IntElement, BT: BoolElement> Backend
    for CubeBackend<WgpuRuntime, F, I, BT>
{
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Define cube dim, hardcoded for simplicity.
        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };

        lhs.assert_is_on_same_device(&rhs);
        lhs.assert_is_on_same_device(&bias);

        // For simplicity, make sure each tensor is continuous.
        let lhs = into_contiguous(lhs);
        let rhs = into_contiguous(rhs);
        let bias = into_contiguous(bias);

        // Get the matmul relevant shapes.
        let ndims = lhs.shape.num_dims();
        let num_rows = lhs.shape[ndims - 2];
        let num_cols = rhs.shape[ndims - 1];

        // Compute shape of output, while tracking number of batches.
        let mut num_batches = 1;
        let mut shape_out = vec![0; ndims];
        for i in shape_out.clone().into_iter().take(ndims - 2) {
            shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]);
            num_batches *= shape_out[i];
        }
        shape_out[ndims - 2] = num_rows;
        shape_out[ndims - 1] = num_cols;
        let shape_out = Shape::from(shape_out);

        // Create a buffer for the output tensor.
        let buffer = lhs
            .client
            .empty(shape_out.num_elements() * core::mem::size_of::<F>());

        // Create the output tensor primitive.
        let output = CubeTensor::new_contiguous(
            lhs.client.clone(),
            lhs.device.clone(),
            shape_out,
            buffer,
            F::dtype(),
        );

        // Create the kernel.
        let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);

        // Build info buffer with tensor information needed by the kernel, such as shapes and strides.
        let info = build_info::<_, F>(&[&lhs, &rhs, &output]);
        let info_handle = lhs.client.create(bytemuck::cast_slice(&info));

        // Declare the wgsl workgroup with the number of cubes in x, y and z.
        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
        let cube_count =
            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);

        // Execute lazily the kernel with the launch information and the given buffers.
        lhs.client.execute(
            Box::new(SourceKernel::new(kernel, cube_dim)),
            cube_count,
            Bindings::new().with_buffers(vec![
                lhs.handle.binding(),
                rhs.handle.binding(),
                bias.handle.binding(),
                output.handle.clone().binding(),
                info_handle.binding(),
            ]),
        );

        // Return the output tensor.
        output
    }
}
```

In the preceding code block, we demonstrated how to launch the kernel that modifies the correct
buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the
capability to execute any mutable operation on any buffer. While this isn't a problem in the
previous scenario where we only modify the newly created output buffer, it is wise to keep this in
mind.

## Backward

Now that the custom backend trait is implemented for the WGPU backend, you can use it to invoke the
`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.
If your use case does not extend beyond inference, there is no need to implement any of the
following code.

For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is
actually generic over the backend. Instead of crafting our own WGSL kernel for the backward pass, we
will use our fused kernel only for the forward pass, and compute the gradient using basic
operations.

```rust, ignore
// Implement our custom backend trait for any backend that also implements our custom backend trait.
//
// Note that we could implement the backend trait only for the Wgpu backend instead of any backend that
// also implements our own API. This would allow us to call any function only implemented for Wgpu
// and potentially call a custom kernel crafted only for this task.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Create our zero-sized type that will implement the Backward trait.
        #[derive(Debug)]
        struct FusedMatmulAddReluBackward;

        // Implement the backward trait for the given backend B, the node gradient
        // with three other gradients to calculate (lhs, rhs, and bias).
        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
            // Our state that we must build during the forward pass to compute the backward pass.
            //
            // Note that we could improve the performance further by only keeping the state of
            // tensors that are tracked, improving memory management, but for simplicity, we avoid
            // that part.
            type State = (NodeId, NodeId, FloatTensor<B>, Shape);

            fn backward(
                self,
                ops: Ops<Self::State, 3>,
                grads: &mut Gradients,
                checkpointer: &mut Checkpointer,
            ) {
                // Get the nodes of each variable.
                let [node_lhs, node_rhs, node_bias] = ops.parents;
                // Fetch the gradient for the current node.
                let grad = grads.consume::<B>(&ops.node);

                // Set our state.
                let (lhs_state, rhs_state, output, shape_bias) = ops.state;
                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);
                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);

                // Fetch shapes of our tensor to support broadcasting.
                let shape_lhs = lhs.shape();
                let shape_rhs = rhs.shape();

                // Compute the gradient of the output using the already existing `relu_backward`
                // function in the basic Burn backend trait.
                let grad_output = B::relu_backward(output, grad);

                // Compute the lhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_lhs = broadcast_shape::<B>(
                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
                    &shape_lhs,
                );
                // Compute the rhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_rhs = broadcast_shape::<B>(
                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
                    &shape_rhs,
                );
                // The add derivative is only 1, so we just need to support broadcasting to
                // compute the bias gradient.
                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);

                // Register the gradient for each variable based on whether they are marked as
                // `tracked`.
                if let Some(node) = node_bias {
                    grads.register::<B>(node.id, grad_bias);
                }
                if let Some(node) = node_lhs {
                    grads.register::<B>(node.id, grad_lhs);
                }
                if let Some(node) = node_rhs {
                    grads.register::<B>(node.id, grad_rhs);
                }
            }
        }

        // Prepare a stateful operation with each variable node and corresponding graph.
        //
        // Each node can be fetched with `ops.parents` in the same order as defined here.
        match FusedMatmulAddReluBackward
            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
            // Marks the operation as compute bound, meaning it will save its
            // state instead of recomputing itself during checkpointing
            .compute_bound()
            .stateful()
        {
            OpsKind::Tracked(mut prep) => {
                // When at least one node is tracked, we should register our backward step.

                // The state consists of what will be needed for this operation's backward pass.
                // Since we need the parents' outputs, we must checkpoint their ids to retrieve their node
                // output at the beginning of the backward. We can also save utility data such as the bias shape
                // If we also need this operation's output, we can either save it in the state or recompute it
                // during the backward pass. Here we choose to save it in the state because it's a compute bound operation.
                let lhs_state = prep.checkpoint(&lhs);
                let rhs_state = prep.checkpoint(&rhs);
                let bias_shape = bias.primitive.shape();

                let output = B::fused_matmul_add_relu(
                    lhs.primitive.clone(),
                    rhs.primitive.clone(),
                    bias.primitive,
                );

                let state = (lhs_state, rhs_state, output.clone(), bias_shape);

                prep.finish(state, output)
            }
            OpsKind::UnTracked(prep) => {
                // When no node is tracked, we can just compute the original operation without
                // keeping any state.
                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
                prep.finish(output)
            }
        }
    }
}
```

The previous code is self-documented to make it clearer, but here is what it does in summary.

We define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to
benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be
implemented. This is achieved using a comprehensive match statement block where computation is
delegated to the inner backend, while keeping track of a state. The state comprises any information
relevant to the backward pass, such as input and output tensors, along with the bias shape. When an
operation isn't tracked (meaning there won't be a backward pass for this specific operation in the
graph), storing a state becomes unnecessary, and we simply perform the forward computation.

The backward pass uses the gradient obtained from the preceding node in the computation graph. It
calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the
derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in
gradients for both input tensors and the bias, which are registered for consumption by subsequent
operation nodes.

The only remaining part is to implement our autodiff-decorated backend trait for our WGPU Backend.

```rust, ignore
impl<F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend
    for Autodiff<CubeBackend<WgpuRuntime, F, I, BT>>
{
}
```

## Conclusion

In this guide, we've implemented a fused kernel using the WGPU backend, enabling execution on any
GPU. By delving into the inner workings of both the WGPU backend and the autodiff backend, we've
gained a deeper understanding of these systems.

While extending a backend may be harder than working with straightforward tensors, the benefits can
be worth it. This approach enables the crafting of custom models with greater control over
execution, which can potentially greatly enhance the performance of your models.

As we conclude this guide, we hope that you have gained insights into Burn's world of backend
extensions, and that it will help you to unleash the full potential of your projects.


================================================
FILE: burn-book/src/advanced/no-std.md
================================================
# No Standard Library

In this section, you will learn how to run an ONNX inference model on an embedded system, with no
standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other
platforms. All the code can be found in the
[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico).

## Step-by-Step Guide

Let's walk through the process of running an embedded ONNX model:

### Setup
Follow the [embassy guide](https://embassy.dev/book/#_getting_started) for your specific environment. Once setup, you should have something similar to the following.
```
./inference
├── Cargo.lock
├── Cargo.toml
├── build.rs
├── memory.x
└── src
    └── main.rs
```

Some other dependencies have to be added
```toml
[dependencies]
embedded-alloc = "0.6.0" # Only if there is no default allocator for your chip
burn = { version = "0.21", default-features = false, features = ["ndarray"] } # Backend must be ndarray
burn-store = { version = "0.21", default-features = false, features = ["burnpack"] }

[build-dependencies]
burn-onnx = { version = "0.21" } # Used to auto generate the rust code to import the model
```

### Import the Model
Follow the directions in [ONNX Import](../onnx-import.md).

Use the following ModelGen config
```rs
ModelGen::new()
    .input(my_model)
    .out_dir("model/")
    .embed_states(true)
    .run_from_script();
```

### Global Allocator
First define a global allocator (if you are on a no_std system without alloc).

```rs
use embedded_alloc::LlffHeap as Heap;

#[global_allocator]
static HEAP: Heap = Heap::empty();

#[embassy_executor::main]
async fn main(_spawner: Spawner) {
    {
        use core::mem::MaybeUninit;
        // Watch out for this, if it is too big or small for your model, the
        // program may crash. This is in u8 bytes, as such this is a total of 100kb
        const HEAP_SIZE: usize = 100 * 1024;
        static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
        unsafe { HEAP.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) } // Initialize the heap
    }
}
```

### Define Backend
We are using ndarray, so we just need to define the NdArray backend as usual
```rs
use burn::{backend::NdArray, tensor::Tensor};

type Backend = NdArray<f32>;
type BackendDevice = <Backend as burn::tensor::backend::Backend>::Device;
```

Then inside the `main` function add
```rs
use your_model::Model;

// Get a default device for the backend
let device = BackendDevice::default();

// Create a new model and load the state
let model: Model<Backend> = Model::default();
```

### Running the Model
To run the model, just call it as you would normally
```rs
// Define the tensor
let input = Tensor::<Backend, 2>::from_floats([[input]], &device);

// Run the model on the input
let output = model.forward(input);
```

## Conclusion
Running a model in a no_std environment is pretty much identical to a normal environment. All that is needed is a global allocator.


================================================
FILE: burn-book/src/advanced/web-assembly.md
================================================
# WebAssembly

Burn supports WebAssembly (WASM) execution using the `NdArray` and `WebGpu` backends, allowing
models to run directly in the browser.

Check out the following examples:

- [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web)
- [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web)

When targeting WebAssembly, certain dependencies require additional configuration. In particular,
the `getrandom` crate requires explicit setting when using `WebGpu`.


================================================
FILE: burn-book/src/basic-workflow/README.md
================================================
# Guide

This guide will walk you through the process of creating a custom model built with Burn. We will
train a simple convolutional neural network model on the MNIST dataset and prepare it for inference.

For clarity, we sometimes omit imports in our code snippets. For more details, please refer to the
corresponding code in the `examples/guide` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide).
We reproduce this example in a step-by-step fashion, from dataset creation to modeling and training
in the following sections. It is recommended to use the capabilities of your IDE or text editor to
automatically add the missing imports as you add the code snippets to your code.

<div class="warning">

Be sure to checkout the git branch corresponding to the version of Burn you are using to follow
this guide.

The current version of Burn is `0.21` and the corresponding branch to checkout is `main`.
</div>

The code for this demo can be executed from Burn's base directory using the command:

```bash
cargo run --example guide
```

## Key Learnings

- Creating a project
- Creating neural network models
- Importing and preparing datasets
- Training models on data
- Choosing a backend
- Using a model for inference


================================================
FILE: burn-book/src/basic-workflow/backend.md
================================================
# Backend

We have effectively written most of the necessary code to train our model. However, we have not
explicitly designated the backend to be used at any point. This will be defined in the main
entrypoint of our program, namely the `main` function defined in `src/main.rs`.

```rust , ignore
# #![recursion_limit = "256"]
# mod data;
# mod model;
# mod training;
#
use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
    backend::{Autodiff, Wgpu},
#     data::dataset::Dataset,
    optim::AdamConfig,
};

fn main() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = burn::backend::wgpu::WgpuDevice::default();
    let artifact_dir = "/tmp/guide";
    crate::training::train::<MyAutodiffBackend>(
        artifact_dir,
        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device.clone(),
    );
}
```

In this code snippet, we use the `Wgpu` backend which is compatible with any operating system and will
use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the
float type and the int type as generic arguments that will be used during the training. The autodiff
backend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability 
to any backend.

We call the `train` function defined earlier with a directory for artifacts, the configuration of
the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer
configuration which in our case will be the default Adam configuration, and the device which can be
obtained from the backend.

You can now train your freshly created model with the command:

```console
cargo run --release
```

When running your project with the command above, you should see the training progression through a
basic CLI dashboard:

<img title="a title" alt="Alt text" src="./training-output.png">


================================================
FILE: burn-book/src/basic-workflow/data.md
================================================
# Data

Typically, one trains a model on some dataset. Burn provides a library of very useful dataset
sources and transformations, such as Hugging Face dataset utilities that allow to download and store
data into an SQLite database for extremely efficient data streaming and storage. For this guide
though, we will use the MNIST dataset from `burn::data::dataset::vision` which requires no external
dependency.

To iterate over a dataset efficiently, we will define a struct which will implement the `Batcher`
trait. The goal of a batcher is to map individual dataset items into a batched tensor that can be
used as input to our previously defined model.

Let us start by defining our dataset functionalities in a file `src/data.rs`. We shall omit some of
the imports for brevity, but the full code for following this guide can be found at
`examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide).

```rust , ignore
use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};


#[derive(Clone, Default)]
pub struct MnistBatcher {}
```

This batcher is pretty straightforward, as it only defines a struct that will implement the
`Batcher` trait. The trait is generic over the `Backend` trait, which includes an associated type
for the device, as not all backends expose the same devices. As an example, the Libtorch-based
backend exposes `Cuda(gpu_index)`, `Cpu`, `Vulkan` and `Metal` devices, while the ndarray backend
only exposes the `Cpu` device.

Next, we need to actually implement the batching logic.

```rust , ignore
# use burn::{
#     data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
#     prelude::*,
# };
#
# #[derive(Clone, Default)]
# pub struct MnistBatcher {}
#
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
    pub images: Tensor<B, 3>,
    pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {
    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
        let images = items
            .iter()
            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
            .map(|data| Tensor::<B, 2>::from_data(data, device))
            .map(|tensor| tensor.reshape([1, 28, 28]))
            // Normalize: scale between [0,1] and make the mean=0 and std=1
            // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
            .collect();

        let targets = items
            .iter()
            .map(|item| {
                Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)
            })
            .collect();

        let images = Tensor::cat(images, 0);
        let targets = Tensor::cat(targets, 0);

        MnistBatch { images, targets }
    }
}
```

<details>
<summary><strong>🦀 Iterators and Closures</strong></summary>

The iterator pattern allows you to perform some tasks on a sequence of items in turn.

In this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the
`iter` method.

_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by
changing some aspect of the original iterator. Here, the `map` method is called in a chain to
transform the original data before consuming the final iterator with `collect` to obtain the
`images` and `targets` vectors. Both vectors are then concatenated into a single tensor for the
current batch.

You probably noticed that each call to `map` is different, as it defines a function to execute on
the iterator items at each step. These anonymous functions are called
[_closures_](https://doc.rust-lang.org/book/ch13-01-closures.html) in Rust. They're easy to
recognize due to their syntax which uses vertical bars `||`. The vertical bars capture the input
variables (if applicable) while the rest of the expression defines the function to execute.

If we go back to the example, we can break down and comment the expression used to process the
images.

```rust, ignore
let images = items                                                       // take items Vec<MnistItem>
    .iter()                                                              // create an iterator over it
    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())  // for each item, convert the image to float data struct
    .map(|data| Tensor::<B, 2>::from_data(data, device))                 // for each data struct, create a tensor on the device
    .map(|tensor| tensor.reshape([1, 28, 28]))                           // for each tensor, reshape to the image dimensions [C, H, W]
    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)                    // for each image tensor, apply normalization
    .collect();                                                          // consume the resulting iterator & collect the values into a new vector
```

For more information on iterators and closures, be sure to check out the
[corresponding chapter](https://doc.rust-lang.org/book/ch13-00-functional-features.html) in the Rust
Book.

</details><br>

In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a
single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with
a targets tensor that contains the indexes of the correct digit class. The first step is to parse
the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate
tensor storage information without being specific for a backend. When creating a tensor from data,
we often need to convert the data precision to the current backend in use. This can be done with the
`.convert()` method (in this example, the data is converted backend's float element type
`B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()`
on a specific number to convert it to the current backend element type in use.


================================================
FILE: burn-book/src/basic-workflow/inference.md
================================================
# Inference

Now that we have trained our model, the next natural step is to use it for inference.

You need two things in order to load weights for a model: the model's record and the model's config.
Since parameters in Burn are lazy initialized, no allocation and GPU/CPU kernels are executed by the
`ModelConfig::init` function. The weights are initialized when used for the first time, therefore
you can safely use `config.init(device).load_record(record)` without any meaningful performance
cost. Let's create a simple `infer` method in a new file `src/inference.rs` which we will use to
load our trained model.

```rust , ignore
# use crate::{data::MnistBatcher, training::TrainingConfig};
# use burn::{
#     data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
#     prelude::*,
#     record::{CompactRecorder, Recorder},
# };
#
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
    let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
        .expect("Config should exist for the model; run train first");
    let record = CompactRecorder::new()
        .load(format!("{artifact_dir}/model").into(), &device)
        .expect("Trained model should exist; run train first");

    let model = config.model.init::<B>(&device).load_record(record);

    let label = item.label;
    let batcher = MnistBatcher::default();
    let batch = batcher.batch(vec![item], &device);
    let output = model.forward(batch.images);
    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();

    println!("Predicted {predicted} Expected {label}");
}
```

The first step is to load the configuration of the training to fetch the correct model
configuration. Then we can fetch the record using the same recorder as we used during training.
Finally we can init the model with the configuration and the record. For simplicity we can use the
same batcher used during the training to pass from a MnistItem to a tensor.

By running the infer function, you should see the predictions of your model!

Add the call to `infer` to the `main.rs` file after the `train` function call:

```rust , ignore
# #![recursion_limit = "256"]
# mod data;
# mod inference;
# mod model;
# mod training;
#
# use crate::{model::ModelConfig, training::TrainingConfig};
# use burn::{
#     backend::{Autodiff, Wgpu},
#     data::dataset::Dataset,
#     optim::AdamConfig,
# };
#
# fn main() {
#     type MyBackend = Wgpu<f32, i32>;
#     type MyAutodiffBackend = Autodiff<MyBackend>;
#
#     let device = burn::backend::wgpu::WgpuDevice::default();
#     let artifact_dir = "/tmp/guide";
#     crate::training::train::<MyAutodiffBackend>(
#         artifact_dir,
#         TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
#         device.clone(),
#     );
    crate::inference::infer::<MyBackend>(
        artifact_dir,
        device,
        burn::data::dataset::vision::MnistDataset::test()
            .get(42)
            .unwrap(),
    );
# }
```

The number `42` is the index of the image in the MNIST dataset. You can explore and verify them
using this [MNIST viewer](https://observablehq.com/@davidalber/mnist-viewer).

---

In this short guide, we've introduced you to the fundamental building blocks for getting started
with Burn. While there's still plenty to explore, our goal has been to provide you with the
essential knowledge to kickstart your productivity within the framework.


================================================
FILE: burn-book/src/basic-workflow/model.md
================================================
# Model

The first step is to create a project and add the different Burn dependencies. Start by creating a
new project with Cargo:

```console
cargo new guide
```

As [mentioned previously](../getting-started.md#creating-a-burn-application), this will initialize
your `guide` project directory with a `Cargo.toml` and a `src/main.rs` file.

In the `Cargo.toml` file, add the `burn` dependency with `train`, `vision` and `wgpu` features.
Since we disable the default features, we also want to enable `std`, `tui` (for the dashboard) and
`fusion` for wgpu. Then run `cargo build` to build the project and import all the dependencies.

```toml
[package]
name = "guide"
version = "0.1.0"
edition = "2024"

[dependencies]
# Disable autotune default for convolutions
burn = { version = "~0.21", features = ["std", "tui", "train", "vision", "wgpu", "fusion"], default-features = false }
# burn = { version = "~0.21", features = ["train", "vision", "wgpu"] }
```

Our goal will be to create a basic convolutional neural network used for image classification. We
will keep the model simple by using two convolution layers followed by two linear layers, some
pooling and ReLU activations. We will also use dropout to improve training performance.

Let us start by defining our model struct in a new file `src/model.rs`.

```rust , ignore
use burn::{
    nn::{
        conv::{Conv2d, Conv2dConfig},
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
        Dropout, DropoutConfig, Linear, LinearConfig, Relu,
    },
    prelude::*,
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv1: Conv2d<B>,
    conv2: Conv2d<B>,
    pool: AdaptiveAvgPool2d,
    dropout: Dropout,
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}
```

There are two major things going on in this code sample.

1. You can create a deep learning module with the `#[derive(Module)]` attribute on top of a struct.
   This will generate the necessary code so that the struct implements the `Module` trait. This
   trait will make your module both trainable and (de)serializable while adding related
   functionalities. Like other attributes often used in Rust, such as `Clone`, `PartialEq` or
   `Debug`, each field within the struct must also implement the `Module` trait.

   <details>
   <summary><strong>🦀 Trait</strong></summary>

   Traits are a powerful and flexible Rust language feature. They provide a way to define shared
   behavior for a particular type, which can be shared with other types.

   A type's behavior consists of the methods called on that type. Since all `Module`s should
   implement the same functionality, it is defined as a trait. Implementing a trait on a particular
   type usually requires the user to implement the defined behaviors of the trait for their types,
   though that is not the case here as explained above with the `derive` attribute. Check out the
   [explainer below](#derive-attribute) to learn why.

   For more details on traits, take a look at the
   [associated chapter](https://doc.rust-lang.org/book/ch10-02-traits.html) in the Rust Book.
   </details><br>

   <details id="derive-attribute">
   <summary><strong>🦀 Derive Macro</strong></summary>

   The `derive` attribute allows traits to be implemented easily by generating code that will
   implement a trait with its own default implementation on the type that was annotated with the
   `derive` syntax.

   This is accomplished through a feature of Rust called
   [procedural macros](https://doc.rust-lang.org/reference/procedural-macros.html), which allow us
   to run code at compile time that operates over Rust syntax, both consuming and producing Rust
   syntax. Using the attribute `#[my_macro]`, you can effectively extend the provided code. You will
   see that the derive macro is very frequently employed to recursively implement traits, where the
   implementation consists of the composition of all fields.

   In this example, we want to derive the [`Module`](../building-blocks/module.md) and `Debug`
   traits.

   ```rust, ignore
   #[derive(Module, Debug)]
   pub struct MyCustomModule<B: Backend> {
       linear1: Linear<B>,
       linear2: Linear<B>,
       activation: Relu,
   }
   ```

   The basic `Debug` implementation is provided by the compiler to format a value using the `{:?}`
   formatter. For ease of use, the `Module` trait implementation is automatically handled by Burn so
   you don't have to do anything special. It essentially acts as parameter container.

   For more details on derivable traits, take a look at the Rust
   [appendix](https://doc.rust-lang.org/book/appendix-03-derivable-traits.html),
   [reference](https://doc.rust-lang.org/reference/attributes/derive.html) or
   [example](https://doc.rust-lang.org/rust-by-example/trait/derive.html).
   </details><br>

2. Note that the struct is generic over the [`Backend`](../building-blocks/backend.md) trait. The
   backend trait abstracts the underlying low level implementations of tensor operations, allowing
   your new model to run on any backend. Contrary to other frameworks, the backend abstraction isn't
   determined by a compilation flag or a device type. This is important because you can extend the
   functionalities of a specific backend (see
   [backend extension section](../advanced/backend-extension)), and it allows for an innovative
   [autodiff system](../building-blocks/autodiff.md). You can also change backend during runtime,
   for instance to compute training metrics on a cpu backend while using a gpu one only to train the
   model. In our example, the backend in use will be determined later on.

   <details>
   <summary><strong>🦀 Trait Bounds</strong></summary>

   Trait bounds provide a way for generic items to restrict which types are used as their
   parameters. The trait bounds stipulate what functionality a type implements. Therefore, bounding
   restricts the generic to types that conform to the bounds. It also allows generic instances to
   access the methods of traits specified in the bounds.

   For a simple but concrete example, check out the
   [Rust By Example on bounds](https://doc.rust-lang.org/rust-by-example/generics/bounds.html).

   In Burn, the `Backend` trait enables you to run tensor operations using different implementations
   as it abstracts tensor, device and element types. The
   [getting started example](../getting-started.md#writing-a-code-snippet) illustrates the advantage
   of having a simple API that works for different backend implementations. While it used the WGPU
   backend, you could easily swap it with any other supported backend.

   ```rust, ignore
   // Choose from any of the supported backends.
   // type Backend = Candle<f32, i64>;
   // type Backend = LibTorch<f32>;
   // type Backend = NdArray<f32>;
   type Backend = Wgpu;

   // Creation of two tensors.
   let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
   let tensor_2 = Tensor::<Backend, 2>::ones_like(&tensor_1);

   // Print the element-wise addition (done with the selected backend) of the two tensors.
   println!("{}", tensor_1 + tensor_2);
   ```

   For more details on trait bounds, check out the Rust
   [trait bound section](https://doc.rust-lang.org/book/ch10-02-traits.html#trait-bound-syntax) or
   [reference](https://doc.rust-lang.org/reference/items/traits.html#trait-bounds).

   </details><br>

Note that each time you create a new file in the `src` directory you also need to explicitly add
this module to the `main.rs` file. For instance after creating the `model.rs`, you need to add the
following at the top of the main file:

```rust , ignore
mod model;
#
# fn main() {
# }
```

Next, we need to instantiate the model for training.

```rust , ignore
# use burn::{
#     nn::{
#         conv::{Conv2d, Conv2dConfig},
#         pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
#         Dropout, DropoutConfig, Linear, LinearConfig, Relu,
#     },
#     prelude::*,
# };
#
# #[derive(Module, Debug)]
# pub struct Model<B: Backend> {
#     conv1: Conv2d<B>,
#     conv2: Conv2d<B>,
#     pool: AdaptiveAvgPool2d,
#     dropout: Dropout,
#     linear1: Linear<B>,
#     linear2: Linear<B>,
#     activation: Relu,
# }
#
#[derive(Config, Debug)]
pub struct ModelConfig {
    num_classes: usize,
    hidden_size: usize,
    #[config(default = "0.5")]
    dropout: f64,
}

impl ModelConfig {
    /// Returns the initialized model.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
        Model {
            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
            activation: Relu::new(),
            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
            dropout: DropoutConfig::new(self.dropout).init(),
        }
    }
}
```

At a glance, you can view the model configuration by printing the model instance:

```rust , ignore
#![recursion_limit = "256"]
mod model;

use crate::model::ModelConfig;
use burn::backend::Wgpu;

fn main() {
    type MyBackend = Wgpu<f32, i32>;

    let device = Default::default();
    let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);

    println!("{model}");
}
```

Output:

```rust , ignore
Model {
  conv1: Conv2d {ch_in: 1, ch_out: 8, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}
  conv2: Conv2d {ch_in: 8, ch_out: 16, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}
  pool: AdaptiveAvgPool2d {output_size: [8, 8]}
  dropout: Dropout {prob: 0.5}
  linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}
  linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}
  activation: Relu
  params: 531178
}
```

<details>
<summary><strong>🦀 References</strong></summary>

In the previous example, the `init()` method signature uses `&` to indicate that the parameter types
are references: `&self`, a reference to the current receiver (`ModelConfig`), and
`device: &B::Device`, a reference to the backend device.

```rust, ignore
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
    Model {
        // ...
    }
}
```

References in Rust allow us to point to a resource to access its data without owning it. The idea of
ownership is quite core to Rust and is worth
[reading up on](https://doc.rust-lang.org/book/ch04-00-understanding-ownership.html).

In a language like C, memory management is explicit and up to the programmer, which means it is easy
to make mistakes. In a language like Java or Python, memory management is automatic with the help of
a garbage collector. This is very safe and straightforward, but also incurs a runtime cost.

In Rust, memory management is rather unique. Aside from simple types that implement
[`Copy`](https://doc.rust-lang.org/std/marker/trait.Copy.html) (e.g.,
[primitives](https://doc.rust-lang.org/rust-by-example/primitives.html) like integers, floats,
booleans and `char`), every value is _owned_ by some variable called the _owner_. Ownership can be
transferred from one variable to another and sometimes a value can be _borrowed_. Once the _owner_
variable goes out of scope, the value is _dropped_, which means that any memory it allocated can be
freed safely.

Because the method does not own the `self` and `device` variables, the values the references point
to will not be dropped when the reference stops being used (i.e., the scope of the method).

For more information on references and borrowing, be sure to read the
[corresponding chapter](https://doc.rust-lang.org/book/ch04-02-references-and-borrowing.html) in the
Rust Book.

</details><br>

When creating a custom neural network module, it is often a good idea to create a config alongside
the model struct. This allows you to define default values for your network, thanks to the `Config`
attribute. The benefit of this attribute is that it makes the configuration serializable, enabling
you to painlessly save your model hyperparameters, enhancing your experimentation process. Note that
a constructor will automatically be generated for your configuration, which will take in as input
values the parameters which do not have default values:
`let config = ModelConfig::new(num_classes, hidden_size);`. The default values can be overridden
easily with builder-like methods: (e.g `config.with_dropout(0.2);`)

The first implementation block is related to the initialization method. As we can see, all fields
are set using the configuration of the corresponding neural network's underlying module. In this
specific case, we have chosen to expand the tensor channels from 1 to 8 with the first layer, then
from 8 to 16 with the second layer, using a kernel size of 3 on all dimensions. We also use the
adaptive average pooling module to reduce the dimensionality of the images to an 8 by 8 matrix,
which we will flatten in the forward pass to have a 1024 (16 * 8 * 
Download .txt
gitextract_82tcaglc/

├── .cargo/
│   ├── audit.toml
│   └── config.toml
├── .github/
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.md
│   │   ├── doc_request.md
│   │   └── feature_request.md
│   ├── PULL_REQUEST_TEMPLATE/
│   │   └── template.md
│   ├── dependabot.yml
│   ├── pull_request_template.md
│   └── workflows/
│       ├── combine-dependabot-prs.yml
│       ├── dependencies.yml
│       ├── publish.yml
│       ├── stale-pr.yml
│       ├── test-gpu.yml
│       ├── test.yml
│       ├── valgrind.yml
│       └── vulnerabilities.yml
├── .gitignore
├── CITATION.cff
├── CODE-OF-CONDUCT.md
├── CONTRIBUTING.md
├── Cargo.toml
├── LICENSE-APACHE
├── LICENSE-MIT
├── NOTICES.md
├── POEM.md
├── README.md
├── _typos.toml
├── benchmarks.toml
├── burn-book/
│   ├── .gitignore
│   ├── .prettierrc.json
│   ├── book.toml
│   └── src/
│       ├── SUMMARY.md
│       ├── advanced/
│       │   ├── README.md
│       │   ├── backend-extension/
│       │   │   ├── README.md
│       │   │   ├── custom-cubecl-kernel.md
│       │   │   └── custom-wgpu-kernel.md
│       │   ├── no-std.md
│       │   └── web-assembly.md
│       ├── basic-workflow/
│       │   ├── README.md
│       │   ├── backend.md
│       │   ├── data.md
│       │   ├── inference.md
│       │   ├── model.md
│       │   └── training.md
│       ├── building-blocks/
│       │   ├── README.md
│       │   ├── autodiff.md
│       │   ├── backend.md
│       │   ├── config.md
│       │   ├── dataset.md
│       │   ├── learner.md
│       │   ├── metric.md
│       │   ├── module.md
│       │   ├── record.md
│       │   └── tensor.md
│       ├── custom-training-loop.md
│       ├── distributed-computing.md
│       ├── examples.md
│       ├── getting-started.md
│       ├── models-and-pretrained-weights.md
│       ├── motivation.md
│       ├── onnx-import.md
│       ├── overview.md
│       ├── performance/
│       │   ├── README.md
│       │   ├── distributed-computing.md
│       │   ├── good-practices/
│       │   │   ├── README.md
│       │   │   ├── asynchronous-execution.md
│       │   │   ├── kernel-fusion.md
│       │   │   └── kernel-selection.md
│       │   └── quantization.md
│       └── saving-and-loading.md
├── codecov.yml
├── contributor-book/
│   ├── .gitignore
│   ├── .prettierrc.json
│   ├── book.toml
│   └── src/
│       ├── SUMMARY.md
│       ├── frequently-encountered-issues/
│       │   ├── README.md
│       │   └── issues-while-adding-ops.md
│       ├── getting-started/
│       │   ├── README.md
│       │   ├── configuring-your-editor.md
│       │   ├── setting-up-the-environment.md
│       │   └── testing.md
│       ├── guides/
│       │   ├── README.md
│       │   ├── adding-a-new-operation-to-burn.md
│       │   └── submitting-examples.md
│       ├── how-to-read-this-book.md
│       ├── overview.md
│       └── project-architecture/
│           ├── README.md
│           ├── backend.md
│           ├── module.md
│           ├── serialization.md
│           └── tensor.md
├── crates/
│   ├── burn/
│   │   ├── Cargo.toml
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── collective.rs
│   │       └── lib.rs
│   ├── burn-autodiff/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── checkpoint/
│   │       │   ├── base.rs
│   │       │   ├── builder.rs
│   │       │   ├── mod.rs
│   │       │   ├── retro_forward.rs
│   │       │   ├── state.rs
│   │       │   └── strategy.rs
│   │       ├── grads.rs
│   │       ├── graph/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── node.rs
│   │       │   ├── requirement.rs
│   │       │   └── traversal.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── backward.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── maxmin.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── sort.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── runtime/
│   │       │   ├── client.rs
│   │       │   ├── graph.rs
│   │       │   ├── memory_management.rs
│   │       │   ├── mod.rs
│   │       │   └── server.rs
│   │       ├── tensor.rs
│   │       └── utils.rs
│   ├── burn-backend/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend/
│   │       │   ├── base.rs
│   │       │   ├── device.rs
│   │       │   ├── mod.rs
│   │       │   ├── ops/
│   │       │   │   ├── activation.rs
│   │       │   │   ├── argwhere.rs
│   │       │   │   ├── bool_tensor.rs
│   │       │   │   ├── cat.rs
│   │       │   │   ├── int_tensor.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── modules/
│   │       │   │   │   ├── attention.rs
│   │       │   │   │   ├── base.rs
│   │       │   │   │   ├── conv.rs
│   │       │   │   │   ├── grid_sample.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── pool.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   ├── qtensor.rs
│   │       │   │   ├── repeat_dim.rs
│   │       │   │   ├── sort.rs
│   │       │   │   ├── tensor.rs
│   │       │   │   └── transaction.rs
│   │       │   └── primitive.rs
│   │       ├── data/
│   │       │   ├── compare.rs
│   │       │   ├── mod.rs
│   │       │   └── tensor.rs
│   │       ├── distribution.rs
│   │       ├── element/
│   │       │   ├── base.rs
│   │       │   ├── cast.rs
│   │       │   ├── mod.rs
│   │       │   └── scalar.rs
│   │       ├── lib.rs
│   │       └── tensor/
│   │           ├── alias.rs
│   │           ├── container.rs
│   │           ├── kind.rs
│   │           ├── mod.rs
│   │           ├── ops/
│   │           │   ├── autodiff.rs
│   │           │   ├── base.rs
│   │           │   ├── bool.rs
│   │           │   ├── float.rs
│   │           │   ├── int.rs
│   │           │   ├── mod.rs
│   │           │   ├── numeric.rs
│   │           │   └── ordered.rs
│   │           └── quantization/
│   │               ├── calibration.rs
│   │               ├── mod.rs
│   │               ├── parameters.rs
│   │               └── scheme.rs
│   ├── burn-backend-tests/
│   │   ├── .cargo/
│   │   │   └── config.toml
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── src/
│   │   │   └── lib.rs
│   │   └── tests/
│   │       ├── autodiff/
│   │       │   ├── abs.rs
│   │       │   ├── adaptive_avgpool1d.rs
│   │       │   ├── adaptive_avgpool2d.rs
│   │       │   ├── add.rs
│   │       │   ├── aggregation.rs
│   │       │   ├── avgpool1d.rs
│   │       │   ├── avgpool2d.rs
│   │       │   ├── backward.rs
│   │       │   ├── bridge.rs
│   │       │   ├── broadcast.rs
│   │       │   ├── cast.rs
│   │       │   ├── cat.rs
│   │       │   ├── ceil.rs
│   │       │   ├── checkpoint.rs
│   │       │   ├── complex.rs
│   │       │   ├── conv1d.rs
│   │       │   ├── conv2d.rs
│   │       │   ├── conv3d.rs
│   │       │   ├── conv_transpose1d.rs
│   │       │   ├── conv_transpose2d.rs
│   │       │   ├── conv_transpose3d.rs
│   │       │   ├── cross.rs
│   │       │   ├── cross_entropy.rs
│   │       │   ├── cummax.rs
│   │       │   ├── cummin.rs
│   │       │   ├── cumprod.rs
│   │       │   ├── cumsum.rs
│   │       │   ├── deform_conv2d.rs
│   │       │   ├── div.rs
│   │       │   ├── erf.rs
│   │       │   ├── exp.rs
│   │       │   ├── expand.rs
│   │       │   ├── flip.rs
│   │       │   ├── floor.rs
│   │       │   ├── gather_scatter.rs
│   │       │   ├── gelu.rs
│   │       │   ├── gradients.rs
│   │       │   ├── log.rs
│   │       │   ├── log1p.rs
│   │       │   ├── log_sigmoid.rs
│   │       │   ├── mask.rs
│   │       │   ├── matmul.rs
│   │       │   ├── maxmin.rs
│   │       │   ├── maxpool1d.rs
│   │       │   ├── maxpool2d.rs
│   │       │   ├── memory_management.rs
│   │       │   ├── mod.rs
│   │       │   ├── mul.rs
│   │       │   ├── multithread.rs
│   │       │   ├── nearest_interpolate.rs
│   │       │   ├── neg.rs
│   │       │   ├── nonzero.rs
│   │       │   ├── permute.rs
│   │       │   ├── pow.rs
│   │       │   ├── recip.rs
│   │       │   ├── relu.rs
│   │       │   ├── remainder.rs
│   │       │   ├── repeat_dim.rs
│   │       │   ├── reshape.rs
│   │       │   ├── round.rs
│   │       │   ├── select.rs
│   │       │   ├── sigmoid.rs
│   │       │   ├── sign.rs
│   │       │   ├── slice.rs
│   │       │   ├── slice_assign.rs
│   │       │   ├── softmax.rs
│   │       │   ├── sort.rs
│   │       │   ├── sqrt.rs
│   │       │   ├── sub.rs
│   │       │   ├── transpose.rs
│   │       │   ├── trig.rs
│   │       │   └── unfold.rs
│   │       ├── autodiff.rs
│   │       ├── common/
│   │       │   ├── autodiff.rs
│   │       │   ├── backend.rs
│   │       │   └── tensor.rs
│   │       ├── cubecl/
│   │       │   ├── avg_pool2d.rs
│   │       │   ├── bernoulli.rs
│   │       │   ├── cast.rs
│   │       │   ├── cat.rs
│   │       │   ├── clamp.rs
│   │       │   ├── contiguous.rs
│   │       │   ├── conv2d.rs
│   │       │   ├── conv3d.rs
│   │       │   ├── conv_transpose2d.rs
│   │       │   ├── conv_transpose3d.rs
│   │       │   ├── cross.rs
│   │       │   ├── gather.rs
│   │       │   ├── mask_fill.rs
│   │       │   ├── mask_where.rs
│   │       │   ├── max_pool2d.rs
│   │       │   ├── max_pool2d_backward.rs
│   │       │   ├── mod.rs
│   │       │   ├── normal.rs
│   │       │   ├── quantization.rs
│   │       │   ├── reduce.rs
│   │       │   ├── repeat_dim.rs
│   │       │   ├── scatter.rs
│   │       │   ├── select.rs
│   │       │   ├── select_assign.rs
│   │       │   ├── slice.rs
│   │       │   ├── slice_assign.rs
│   │       │   ├── unary.rs
│   │       │   └── uniform.rs
│   │       ├── cubecl.rs
│   │       ├── fused_ops/
│   │       │   ├── mod.rs
│   │       │   └── reduce_broadcasted.rs
│   │       ├── fusion.rs
│   │       ├── tensor/
│   │       │   ├── bool/
│   │       │   │   ├── mod.rs
│   │       │   │   └── ops/
│   │       │   │       ├── all.rs
│   │       │   │       ├── any.rs
│   │       │   │       ├── argwhere_nonzero.rs
│   │       │   │       ├── cat.rs
│   │       │   │       ├── comparison.rs
│   │       │   │       ├── create_like.rs
│   │       │   │       ├── expand.rs
│   │       │   │       ├── flip.rs
│   │       │   │       ├── full.rs
│   │       │   │       ├── gather_scatter.rs
│   │       │   │       ├── init.rs
│   │       │   │       ├── logical.rs
│   │       │   │       ├── mask.rs
│   │       │   │       ├── mod.rs
│   │       │   │       ├── movedim.rs
│   │       │   │       ├── permute.rs
│   │       │   │       ├── repeat.rs
│   │       │   │       ├── repeat_dim.rs
│   │       │   │       ├── reshape.rs
│   │       │   │       ├── select.rs
│   │       │   │       ├── stack.rs
│   │       │   │       ├── take.rs
│   │       │   │       ├── transpose.rs
│   │       │   │       ├── tri_mask.rs
│   │       │   │       └── unfold.rs
│   │       │   ├── clone_invariance.rs
│   │       │   ├── float/
│   │       │   │   ├── activation/
│   │       │   │   │   ├── celu.rs
│   │       │   │   │   ├── elu.rs
│   │       │   │   │   ├── gelu.rs
│   │       │   │   │   ├── glu.rs
│   │       │   │   │   ├── hard_sigmoid.rs
│   │       │   │   │   ├── leaky_relu.rs
│   │       │   │   │   ├── log_sigmoid.rs
│   │       │   │   │   ├── mish.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── prelu.rs
│   │       │   │   │   ├── quiet_softmax.rs
│   │       │   │   │   ├── relu.rs
│   │       │   │   │   ├── selu.rs
│   │       │   │   │   ├── sigmoid.rs
│   │       │   │   │   ├── silu.rs
│   │       │   │   │   ├── softmax.rs
│   │       │   │   │   ├── softmin.rs
│   │       │   │   │   ├── softplus.rs
│   │       │   │   │   ├── softsign.rs
│   │       │   │   │   ├── tanh_activation.rs
│   │       │   │   │   └── thresholded_relu.rs
│   │       │   │   ├── grid/
│   │       │   │   │   ├── affine_grid.rs
│   │       │   │   │   ├── meshgrid.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   ├── linalg/
│   │       │   │   │   ├── cosine_similarity.rs
│   │       │   │   │   ├── diag.rs
│   │       │   │   │   ├── lu_decomposition.rs
│   │       │   │   │   ├── matvec.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── outer.rs
│   │       │   │   │   ├── trace.rs
│   │       │   │   │   └── vector_norm.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── module/
│   │       │   │   │   ├── adaptive_avgpool1d.rs
│   │       │   │   │   ├── adaptive_avgpool2d.rs
│   │       │   │   │   ├── attention.rs
│   │       │   │   │   ├── avgpool1d.rs
│   │       │   │   │   ├── avgpool2d.rs
│   │       │   │   │   ├── bicubic_interpolate.rs
│   │       │   │   │   ├── bilinear_interpolate.rs
│   │       │   │   │   ├── conv1d.rs
│   │       │   │   │   ├── conv2d.rs
│   │       │   │   │   ├── conv3d.rs
│   │       │   │   │   ├── conv_transpose1d.rs
│   │       │   │   │   ├── conv_transpose2d.rs
│   │       │   │   │   ├── conv_transpose3d.rs
│   │       │   │   │   ├── deform_conv2d.rs
│   │       │   │   │   ├── forward.rs
│   │       │   │   │   ├── lanczos3_interpolate.rs
│   │       │   │   │   ├── linear.rs
│   │       │   │   │   ├── maxpool1d.rs
│   │       │   │   │   ├── maxpool2d.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── nearest_interpolate.rs
│   │       │   │   │   └── unfold4d.rs
│   │       │   │   ├── ops/
│   │       │   │   │   ├── abs.rs
│   │       │   │   │   ├── add.rs
│   │       │   │   │   ├── aggregation.rs
│   │       │   │   │   ├── all.rs
│   │       │   │   │   ├── any.rs
│   │       │   │   │   ├── arg.rs
│   │       │   │   │   ├── cast.rs
│   │       │   │   │   ├── cat.rs
│   │       │   │   │   ├── ceil.rs
│   │       │   │   │   ├── chunk.rs
│   │       │   │   │   ├── clamp.rs
│   │       │   │   │   ├── close.rs
│   │       │   │   │   ├── comparison.rs
│   │       │   │   │   ├── create_like.rs
│   │       │   │   │   ├── cross.rs
│   │       │   │   │   ├── cumulative.rs
│   │       │   │   │   ├── div.rs
│   │       │   │   │   ├── dot.rs
│   │       │   │   │   ├── erf.rs
│   │       │   │   │   ├── exp.rs
│   │       │   │   │   ├── expand.rs
│   │       │   │   │   ├── finite.rs
│   │       │   │   │   ├── flatten.rs
│   │       │   │   │   ├── flip.rs
│   │       │   │   │   ├── floor.rs
│   │       │   │   │   ├── fmod.rs
│   │       │   │   │   ├── full.rs
│   │       │   │   │   ├── gather_scatter.rs
│   │       │   │   │   ├── grid_sample.rs
│   │       │   │   │   ├── inf.rs
│   │       │   │   │   ├── init.rs
│   │       │   │   │   ├── iter_dim.rs
│   │       │   │   │   ├── log.rs
│   │       │   │   │   ├── log1p.rs
│   │       │   │   │   ├── mask.rs
│   │       │   │   │   ├── matmul.rs
│   │       │   │   │   ├── maxmin.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── movedim.rs
│   │       │   │   │   ├── mul.rs
│   │       │   │   │   ├── nan.rs
│   │       │   │   │   ├── narrow.rs
│   │       │   │   │   ├── neg.rs
│   │       │   │   │   ├── one_hot.rs
│   │       │   │   │   ├── padding.rs
│   │       │   │   │   ├── permute.rs
│   │       │   │   │   ├── powf.rs
│   │       │   │   │   ├── powf_scalar.rs
│   │       │   │   │   ├── prod.rs
│   │       │   │   │   ├── random.rs
│   │       │   │   │   ├── recip.rs
│   │       │   │   │   ├── remainder.rs
│   │       │   │   │   ├── repeat.rs
│   │       │   │   │   ├── repeat_dim.rs
│   │       │   │   │   ├── reshape.rs
│   │       │   │   │   ├── round.rs
│   │       │   │   │   ├── select.rs
│   │       │   │   │   ├── sign.rs
│   │       │   │   │   ├── slice.rs
│   │       │   │   │   ├── slice_assign.rs
│   │       │   │   │   ├── sort_argsort.rs
│   │       │   │   │   ├── split.rs
│   │       │   │   │   ├── sqrt.rs
│   │       │   │   │   ├── square.rs
│   │       │   │   │   ├── squeeze.rs
│   │       │   │   │   ├── stack.rs
│   │       │   │   │   ├── sub.rs
│   │       │   │   │   ├── take.rs
│   │       │   │   │   ├── topk.rs
│   │       │   │   │   ├── transaction.rs
│   │       │   │   │   ├── transpose.rs
│   │       │   │   │   ├── tri.rs
│   │       │   │   │   ├── trig.rs
│   │       │   │   │   ├── trunc.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   ├── primitive.rs
│   │       │   │   ├── quantization/
│   │       │   │   │   ├── calibration.rs
│   │       │   │   │   ├── data.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── ops/
│   │       │   │   │   │   ├── extended/
│   │       │   │   │   │   │   ├── abs.rs
│   │       │   │   │   │   │   ├── add.rs
│   │       │   │   │   │   │   ├── aggregation.rs
│   │       │   │   │   │   │   ├── all.rs
│   │       │   │   │   │   │   ├── any.rs
│   │       │   │   │   │   │   ├── arg.rs
│   │       │   │   │   │   │   ├── cat.rs
│   │       │   │   │   │   │   ├── ceil.rs
│   │       │   │   │   │   │   ├── chunk.rs
│   │       │   │   │   │   │   ├── clamp.rs
│   │       │   │   │   │   │   ├── cos.rs
│   │       │   │   │   │   │   ├── cosh.rs
│   │       │   │   │   │   │   ├── div.rs
│   │       │   │   │   │   │   ├── erf.rs
│   │       │   │   │   │   │   ├── exp.rs
│   │       │   │   │   │   │   ├── expand.rs
│   │       │   │   │   │   │   ├── flip.rs
│   │       │   │   │   │   │   ├── floor.rs
│   │       │   │   │   │   │   ├── gather_scatter.rs
│   │       │   │   │   │   │   ├── log.rs
│   │       │   │   │   │   │   ├── log1p.rs
│   │       │   │   │   │   │   ├── map_comparison.rs
│   │       │   │   │   │   │   ├── mask.rs
│   │       │   │   │   │   │   ├── maxmin.rs
│   │       │   │   │   │   │   ├── mod.rs
│   │       │   │   │   │   │   ├── mul.rs
│   │       │   │   │   │   │   ├── narrow.rs
│   │       │   │   │   │   │   ├── neg.rs
│   │       │   │   │   │   │   ├── permute.rs
│   │       │   │   │   │   │   ├── powf.rs
│   │       │   │   │   │   │   ├── powf_scalar.rs
│   │       │   │   │   │   │   ├── recip.rs
│   │       │   │   │   │   │   ├── remainder.rs
│   │       │   │   │   │   │   ├── repeat_dim.rs
│   │       │   │   │   │   │   ├── reshape.rs
│   │       │   │   │   │   │   ├── round.rs
│   │       │   │   │   │   │   ├── select.rs
│   │       │   │   │   │   │   ├── sin.rs
│   │       │   │   │   │   │   ├── sinh.rs
│   │       │   │   │   │   │   ├── slice.rs
│   │       │   │   │   │   │   ├── sort_argsort.rs
│   │       │   │   │   │   │   ├── split.rs
│   │       │   │   │   │   │   ├── sqrt.rs
│   │       │   │   │   │   │   ├── stack.rs
│   │       │   │   │   │   │   ├── sub.rs
│   │       │   │   │   │   │   ├── tan.rs
│   │       │   │   │   │   │   ├── tanh.rs
│   │       │   │   │   │   │   ├── topk.rs
│   │       │   │   │   │   │   └── transpose.rs
│   │       │   │   │   │   ├── matmul.rs
│   │       │   │   │   │   ├── mod.rs
│   │       │   │   │   │   └── quantize.rs
│   │       │   │   │   └── scheme.rs
│   │       │   │   └── stats/
│   │       │   │       ├── cov.rs
│   │       │   │       ├── display.rs
│   │       │   │       ├── eye.rs
│   │       │   │       ├── median.rs
│   │       │   │       ├── mod.rs
│   │       │   │       └── var.rs
│   │       │   ├── int/
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ops/
│   │       │   │   │   ├── abs.rs
│   │       │   │   │   ├── add.rs
│   │       │   │   │   ├── aggregation.rs
│   │       │   │   │   ├── all.rs
│   │       │   │   │   ├── any.rs
│   │       │   │   │   ├── arange.rs
│   │       │   │   │   ├── arange_step.rs
│   │       │   │   │   ├── arg.rs
│   │       │   │   │   ├── bitwise.rs
│   │       │   │   │   ├── cartesian_grid.rs
│   │       │   │   │   ├── cast.rs
│   │       │   │   │   ├── cat.rs
│   │       │   │   │   ├── chunk.rs
│   │       │   │   │   ├── comparison.rs
│   │       │   │   │   ├── create_like.rs
│   │       │   │   │   ├── cumulative.rs
│   │       │   │   │   ├── div.rs
│   │       │   │   │   ├── expand.rs
│   │       │   │   │   ├── flip.rs
│   │       │   │   │   ├── full.rs
│   │       │   │   │   ├── gather_scatter.rs
│   │       │   │   │   ├── init.rs
│   │       │   │   │   ├── mask.rs
│   │       │   │   │   ├── matmul.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── movedim.rs
│   │       │   │   │   ├── mul.rs
│   │       │   │   │   ├── one_hot.rs
│   │       │   │   │   ├── permute.rs
│   │       │   │   │   ├── random.rs
│   │       │   │   │   ├── remainder.rs
│   │       │   │   │   ├── repeat.rs
│   │       │   │   │   ├── repeat_dim.rs
│   │       │   │   │   ├── reshape.rs
│   │       │   │   │   ├── roll.rs
│   │       │   │   │   ├── select.rs
│   │       │   │   │   ├── sign.rs
│   │       │   │   │   ├── slice.rs
│   │       │   │   │   ├── slice_assign.rs
│   │       │   │   │   ├── sort_argsort.rs
│   │       │   │   │   ├── stack.rs
│   │       │   │   │   ├── sub.rs
│   │       │   │   │   ├── take.rs
│   │       │   │   │   ├── topk.rs
│   │       │   │   │   ├── transpose.rs
│   │       │   │   │   ├── tri.rs
│   │       │   │   │   └── unfold.rs
│   │       │   │   └── primitive.rs
│   │       │   ├── mod.rs
│   │       │   └── multi_threads.rs
│   │       └── tensor.rs
│   ├── burn-candle/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── candle_utils.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── utils.rs
│   │       └── tensor.rs
│   ├── burn-collective/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── multinode-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── README.md
│   │   │   └── src/
│   │   │       ├── bin/
│   │   │       │   ├── global.rs
│   │   │       │   ├── node.rs
│   │   │       │   └── test_launcher.rs
│   │   │       ├── lib.rs
│   │   │       └── shared.rs
│   │   └── src/
│   │       ├── api.rs
│   │       ├── config.rs
│   │       ├── global/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── node/
│   │       │   │   ├── base.rs
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ring.rs
│   │       │   │   ├── sync.rs
│   │       │   │   ├── tree.rs
│   │       │   │   └── worker.rs
│   │       │   ├── orchestrator/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── state.rs
│   │       │   └── shared.rs
│   │       ├── lib.rs
│   │       ├── local/
│   │       │   ├── all_reduce/
│   │       │   │   ├── base.rs
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   ├── ring.rs
│   │       │   │   └── tree.rs
│   │       │   ├── broadcast/
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   └── tree.rs
│   │       │   ├── client.rs
│   │       │   ├── mod.rs
│   │       │   ├── reduce/
│   │       │   │   ├── centralized.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── op.rs
│   │       │   │   └── tree.rs
│   │       │   ├── server.rs
│   │       │   └── tensor_map.rs
│   │       └── tests/
│   │           ├── all_reduce.rs
│   │           ├── broadcast.rs
│   │           ├── mod.rs
│   │           └── reduce.rs
│   ├── burn-communication/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── base.rs
│   │       ├── data_service.rs
│   │       ├── lib.rs
│   │       ├── util.rs
│   │       └── websocket/
│   │           ├── base.rs
│   │           ├── client.rs
│   │           ├── mod.rs
│   │           └── server.rs
│   ├── burn-core/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── config.rs
│   │   │   ├── data/
│   │   │   │   ├── dataloader/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── batch.rs
│   │   │   │   │   ├── batcher.rs
│   │   │   │   │   ├── builder.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── multithread.rs
│   │   │   │   │   ├── split.rs
│   │   │   │   │   └── strategy.rs
│   │   │   │   └── mod.rs
│   │   │   ├── lib.rs
│   │   │   ├── module/
│   │   │   │   ├── base.rs
│   │   │   │   ├── display.rs
│   │   │   │   ├── initializer.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── param/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── constant.rs
│   │   │   │   │   ├── id.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── primitive.rs
│   │   │   │   │   ├── running.rs
│   │   │   │   │   ├── tensor.rs
│   │   │   │   │   └── visitor.rs
│   │   │   │   ├── quantize.rs
│   │   │   │   └── reinit.rs
│   │   │   ├── record/
│   │   │   │   ├── base.rs
│   │   │   │   ├── file.rs
│   │   │   │   ├── memory.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── primitive.rs
│   │   │   │   ├── recorder.rs
│   │   │   │   ├── serde/
│   │   │   │   │   ├── adapter.rs
│   │   │   │   │   ├── data.rs
│   │   │   │   │   ├── de.rs
│   │   │   │   │   ├── error.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── ser.rs
│   │   │   │   ├── settings.rs
│   │   │   │   └── tensor.rs
│   │   │   ├── tensor.rs
│   │   │   └── vision.rs
│   │   └── tests/
│   │       ├── test_derive_config.rs
│   │       ├── test_derive_module.rs
│   │       ├── test_derive_record.rs
│   │       └── test_record_resilience.rs
│   ├── burn-cpu/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-cubecl/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── fusion.rs
│   │       ├── kernel/
│   │       │   ├── attention/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune.rs
│   │       │   ├── binary.rs
│   │       │   ├── binary_float.rs
│   │       │   ├── binary_int.rs
│   │       │   ├── cast/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bool_cast.rs
│   │       │   │   └── mod.rs
│   │       │   ├── clamp.rs
│   │       │   ├── comparison.rs
│   │       │   ├── contiguous.rs
│   │       │   ├── conv/
│   │       │   │   ├── backward_data/
│   │       │   │   │   ├── fallback.rs
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── backward_weight/
│   │       │   │   │   ├── fallback.rs
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── conv_transpose2d/
│   │       │   │   │   ├── base.rs
│   │       │   │   │   ├── col2im.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── transpose_direct.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── conv_transpose3d.rs
│   │       │   │   ├── deform_conv2d.rs
│   │       │   │   ├── deform_conv_transpose2d.rs
│   │       │   │   ├── direct.rs
│   │       │   │   ├── forward/
│   │       │   │   │   ├── implicit_gemm/
│   │       │   │   │   │   ├── launch.rs
│   │       │   │   │   │   └── mod.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── tune.rs
│   │       │   │   ├── im2col.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune_key.rs
│   │       │   ├── cross.rs
│   │       │   ├── grid_sample/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bilinear.rs
│   │       │   │   └── mod.rs
│   │       │   ├── index/
│   │       │   │   ├── flip.rs
│   │       │   │   ├── gather.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── repeat_dim.rs
│   │       │   │   ├── scatter.rs
│   │       │   │   ├── select.rs
│   │       │   │   ├── select_assign.rs
│   │       │   │   ├── slice.rs
│   │       │   │   └── slice_assign.rs
│   │       │   ├── interpolate/
│   │       │   │   ├── base.rs
│   │       │   │   ├── bicubic.rs
│   │       │   │   ├── bilinear.rs
│   │       │   │   ├── lanczos3.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── nearest.rs
│   │       │   │   └── nearest_backward.rs
│   │       │   ├── mask/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mask_fill.rs
│   │       │   │   ├── mask_where.rs
│   │       │   │   └── mod.rs
│   │       │   ├── matmul/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── tune/
│   │       │   │   │   ├── base.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   └── utils.rs
│   │       │   ├── mod.rs
│   │       │   ├── pool/
│   │       │   │   ├── adaptive_avg_pool2d.rs
│   │       │   │   ├── adaptive_avg_pool2d_backward.rs
│   │       │   │   ├── avg_pool2d.rs
│   │       │   │   ├── avg_pool2d_backward.rs
│   │       │   │   ├── max_pool2d.rs
│   │       │   │   ├── max_pool2d_backward.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── pool2d.rs
│   │       │   ├── prng/
│   │       │   │   ├── bernoulli.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── normal.rs
│   │       │   │   └── uniform.rs
│   │       │   ├── quantization/
│   │       │   │   ├── dequantize.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── quantize.rs
│   │       │   ├── reduce/
│   │       │   │   ├── base.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── tune.rs
│   │       │   ├── unary_float.rs
│   │       │   ├── unary_int.rs
│   │       │   ├── unary_numeric.rs
│   │       │   └── utils.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── numeric.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── template/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   └── source.rs
│   │       ├── tensor/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   └── quantization.rs
│   │       └── tune_key.rs
│   ├── burn-cubecl-fusion/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── base.rs
│   │       ├── engine/
│   │       │   ├── codegen/
│   │       │   │   ├── base.rs
│   │       │   │   ├── io.rs
│   │       │   │   ├── ir.rs
│   │       │   │   ├── kernel.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── tensor.rs
│   │       │   │   └── view.rs
│   │       │   ├── fuser.rs
│   │       │   ├── launch/
│   │       │   │   ├── base.rs
│   │       │   │   ├── executor.rs
│   │       │   │   ├── input.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── output.rs
│   │       │   │   ├── plan.rs
│   │       │   │   ├── runner.rs
│   │       │   │   └── vectorization/
│   │       │   │       ├── base.rs
│   │       │   │       ├── mod.rs
│   │       │   │       └── planner.rs
│   │       │   ├── mod.rs
│   │       │   ├── scoring.rs
│   │       │   ├── settings.rs
│   │       │   └── trace/
│   │       │       ├── base.rs
│   │       │       ├── block.rs
│   │       │       ├── fuser.rs
│   │       │       └── mod.rs
│   │       ├── lib.rs
│   │       ├── optim/
│   │       │   ├── base.rs
│   │       │   ├── elemwise/
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── optimization.rs
│   │       │   ├── matmul/
│   │       │   │   ├── args.rs
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── optimization.rs
│   │       │   │   └── tune.rs
│   │       │   ├── mod.rs
│   │       │   ├── reduce/
│   │       │   │   ├── args.rs
│   │       │   │   ├── fuser.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── optimization.rs
│   │       │   │   └── tune.rs
│   │       │   └── reduce_broadcasted/
│   │       │       ├── fuser/
│   │       │       │   ├── base.rs
│   │       │       │   ├── block.rs
│   │       │       │   ├── full.rs
│   │       │       │   ├── full_analyzer.rs
│   │       │       │   └── mod.rs
│   │       │       ├── launch.rs
│   │       │       ├── mod.rs
│   │       │       ├── optimization.rs
│   │       │       ├── tune.rs
│   │       │       └── unit.rs
│   │       └── tune.rs
│   ├── burn-cuda/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-dataset/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── hf_dataset.rs
│   │   │   └── speech_commands.rs
│   │   ├── src/
│   │   │   ├── audio/
│   │   │   │   ├── mod.rs
│   │   │   │   └── speech_commands.rs
│   │   │   ├── dataset/
│   │   │   │   ├── base.rs
│   │   │   │   ├── dataframe.rs
│   │   │   │   ├── fake.rs
│   │   │   │   ├── in_memory.rs
│   │   │   │   ├── iterator.rs
│   │   │   │   ├── mod.rs
│   │   │   │   └── sqlite.rs
│   │   │   ├── lib.rs
│   │   │   ├── nlp/
│   │   │   │   ├── ag_news.rs
│   │   │   │   ├── mod.rs
│   │   │   │   └── text_folder.rs
│   │   │   ├── source/
│   │   │   │   ├── huggingface/
│   │   │   │   │   ├── downloader.rs
│   │   │   │   │   ├── importer.py
│   │   │   │   │   └── mod.rs
│   │   │   │   └── mod.rs
│   │   │   ├── transform/
│   │   │   │   ├── composed.rs
│   │   │   │   ├── mapper.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── options.rs
│   │   │   │   ├── partial.rs
│   │   │   │   ├── sampler.rs
│   │   │   │   ├── selection.rs
│   │   │   │   ├── shuffle.rs
│   │   │   │   └── window.rs
│   │   │   └── vision/
│   │   │       ├── cifar.rs
│   │   │       ├── image_folder.rs
│   │   │       ├── mnist.rs
│   │   │       └── mod.rs
│   │   └── tests/
│   │       └── data/
│   │           ├── dataset-fmt.csv
│   │           ├── dataset.csv
│   │           ├── dataset.json
│   │           ├── dataset_coco.json
│   │           ├── segmask_folder/
│   │           │   └── annotations/
│   │           │       ├── mask_checkerboard.txt
│   │           │       ├── mask_random_2colors.txt
│   │           │       └── mask_random_3colors.txt
│   │           └── text_folder/
│   │               ├── negative/
│   │               │   ├── sample1.txt
│   │               │   └── sample2.txt
│   │               └── positive/
│   │                   ├── sample1.txt
│   │                   └── sample2.txt
│   ├── burn-derive/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── config/
│   │       │   ├── analyzer.rs
│   │       │   ├── analyzer_enum.rs
│   │       │   ├── analyzer_struct.rs
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── module/
│   │       │   ├── base.rs
│   │       │   ├── codegen.rs
│   │       │   ├── codegen_enum.rs
│   │       │   ├── codegen_struct.rs
│   │       │   ├── display.rs
│   │       │   ├── generics.rs
│   │       │   ├── mod.rs
│   │       │   ├── record.rs
│   │       │   ├── record_enum.rs
│   │       │   └── record_struct.rs
│   │       ├── record/
│   │       │   ├── base.rs
│   │       │   ├── codegen.rs
│   │       │   ├── item/
│   │       │   │   ├── codegen.rs
│   │       │   │   ├── codegen_enum.rs
│   │       │   │   ├── codegen_struct.rs
│   │       │   │   └── mod.rs
│   │       │   └── mod.rs
│   │       └── shared/
│   │           ├── attribute.rs
│   │           ├── enum_variant.rs
│   │           ├── field.rs
│   │           ├── generics.rs
│   │           └── mod.rs
│   ├── burn-dispatch/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── device.rs
│   │       ├── lib.rs
│   │       ├── macros.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       └── tensor.rs
│   ├── burn-fusion/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── client.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── binary.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── unary.rs
│   │       ├── search/
│   │       │   ├── block.rs
│   │       │   ├── merging.rs
│   │       │   ├── mod.rs
│   │       │   └── optimization/
│   │       │       ├── blocks.rs
│   │       │       ├── mod.rs
│   │       │       └── stream.rs
│   │       ├── server.rs
│   │       ├── stream/
│   │       │   ├── base.rs
│   │       │   ├── context.rs
│   │       │   ├── execution/
│   │       │   │   ├── base.rs
│   │       │   │   ├── explorer.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ordering.rs
│   │       │   │   ├── policy.rs
│   │       │   │   ├── processor.rs
│   │       │   │   ├── tests.rs
│   │       │   │   └── validator.rs
│   │       │   ├── memory_checks.rs
│   │       │   ├── mod.rs
│   │       │   ├── multi.rs
│   │       │   ├── queue/
│   │       │   │   ├── base.rs
│   │       │   │   ├── execution.rs
│   │       │   │   └── mod.rs
│   │       │   ├── shared_tensors.rs
│   │       │   └── store/
│   │       │       ├── base.rs
│   │       │       ├── index.rs
│   │       │       └── mod.rs
│   │       └── tensor.rs
│   ├── burn-ir/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── builder.rs
│   │       ├── handle.rs
│   │       ├── lib.rs
│   │       ├── operation.rs
│   │       ├── scalar.rs
│   │       └── tensor.rs
│   ├── burn-ndarray/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── adaptive_avgpool.rs
│   │       │   ├── avgpool.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── conv.rs
│   │       │   ├── deform_conv.rs
│   │       │   ├── grid_sample.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── interpolate.rs
│   │       │   ├── macros.rs
│   │       │   ├── matmul.rs
│   │       │   ├── maxpool.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── padding.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── quantization.rs
│   │       │   ├── simd/
│   │       │   │   ├── avgpool.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── binary.rs
│   │       │   │   ├── binary_elemwise.rs
│   │       │   │   ├── cmp.rs
│   │       │   │   ├── conv.rs
│   │       │   │   ├── maxpool.rs
│   │       │   │   ├── mod.rs
│   │       │   │   └── unary.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       ├── parallel.rs
│   │       ├── rand.rs
│   │       ├── sharing.rs
│   │       ├── storage.rs
│   │       └── tensor.rs
│   ├── burn-nn/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── activation/
│   │   │   │   ├── activation_wrapper.rs
│   │   │   │   ├── celu.rs
│   │   │   │   ├── elu.rs
│   │   │   │   ├── gelu.rs
│   │   │   │   ├── glu.rs
│   │   │   │   ├── hard_shrink.rs
│   │   │   │   ├── hard_sigmoid.rs
│   │   │   │   ├── hard_swish.rs
│   │   │   │   ├── leaky_relu.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── prelu.rs
│   │   │   │   ├── relu.rs
│   │   │   │   ├── selu.rs
│   │   │   │   ├── shrink.rs
│   │   │   │   ├── sigmoid.rs
│   │   │   │   ├── soft_shrink.rs
│   │   │   │   ├── softplus.rs
│   │   │   │   ├── softsign.rs
│   │   │   │   ├── swiglu.rs
│   │   │   │   ├── tanh.rs
│   │   │   │   └── thresholded_relu.rs
│   │   │   ├── lib.rs
│   │   │   ├── loss/
│   │   │   │   ├── binary_cross_entropy.rs
│   │   │   │   ├── cosine_embedding.rs
│   │   │   │   ├── cross_entropy.rs
│   │   │   │   ├── ctc.rs
│   │   │   │   ├── huber.rs
│   │   │   │   ├── kldiv.rs
│   │   │   │   ├── lp_loss.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── mse.rs
│   │   │   │   ├── poisson.rs
│   │   │   │   ├── pretrained/
│   │   │   │   │   ├── gram_matrix/
│   │   │   │   │   │   ├── gram_matrix_loss.rs
│   │   │   │   │   │   ├── mod.rs
│   │   │   │   │   │   ├── vgg19.rs
│   │   │   │   │   │   └── weights.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── reduction.rs
│   │   │   │   ├── rnnt.rs
│   │   │   │   └── smooth_l1.rs
│   │   │   ├── modules/
│   │   │   │   ├── attention/
│   │   │   │   │   ├── cross_attention.rs
│   │   │   │   │   ├── mask.rs
│   │   │   │   │   ├── mha.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── cache/
│   │   │   │   │   ├── autoregressive.rs
│   │   │   │   │   ├── base.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── conv/
│   │   │   │   │   ├── checks.rs
│   │   │   │   │   ├── conv1d.rs
│   │   │   │   │   ├── conv2d.rs
│   │   │   │   │   ├── conv3d.rs
│   │   │   │   │   ├── conv_transpose1d.rs
│   │   │   │   │   ├── conv_transpose2d.rs
│   │   │   │   │   ├── conv_transpose3d.rs
│   │   │   │   │   ├── deform_conv2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── dropout.rs
│   │   │   │   ├── embedding.rs
│   │   │   │   ├── interpolate/
│   │   │   │   │   ├── interpolate1d.rs
│   │   │   │   │   ├── interpolate2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── linear.rs
│   │   │   │   ├── mod.rs
│   │   │   │   ├── noise.rs
│   │   │   │   ├── norm/
│   │   │   │   │   ├── batch.rs
│   │   │   │   │   ├── group.rs
│   │   │   │   │   ├── instance.rs
│   │   │   │   │   ├── layer.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── normalization_wrapper.rs
│   │   │   │   │   └── rms.rs
│   │   │   │   ├── pool/
│   │   │   │   │   ├── adaptive_avg_pool1d.rs
│   │   │   │   │   ├── adaptive_avg_pool2d.rs
│   │   │   │   │   ├── avg_pool1d.rs
│   │   │   │   │   ├── avg_pool2d.rs
│   │   │   │   │   ├── max_pool1d.rs
│   │   │   │   │   ├── max_pool2d.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── pos_encoding.rs
│   │   │   │   ├── rnn/
│   │   │   │   │   ├── basic.rs
│   │   │   │   │   ├── gate_controller.rs
│   │   │   │   │   ├── gru.rs
│   │   │   │   │   ├── lstm.rs
│   │   │   │   │   └── mod.rs
│   │   │   │   ├── rope_encoding.rs
│   │   │   │   ├── transformer/
│   │   │   │   │   ├── decoder.rs
│   │   │   │   │   ├── encoder.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── pwff.rs
│   │   │   │   └── unfold.rs
│   │   │   └── padding.rs
│   │   └── tests/
│   │       └── quantize.rs
│   ├── burn-no-std-tests/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── burnpack.rs
│   │   │   ├── conv.rs
│   │   │   ├── lib.rs
│   │   │   ├── mlp.rs
│   │   │   ├── model.rs
│   │   │   └── safetensors.rs
│   │   └── tests/
│   │       ├── burnpack_tests.rs
│   │       ├── safetensors_tests.rs
│   │       └── test_integration.rs
│   ├── burn-optim/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── grad_clipping/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── lr_scheduler/
│   │       │   ├── base.rs
│   │       │   ├── composed.rs
│   │       │   ├── constant.rs
│   │       │   ├── cosine.rs
│   │       │   ├── exponential.rs
│   │       │   ├── linear.rs
│   │       │   ├── mod.rs
│   │       │   ├── noam.rs
│   │       │   └── step.rs
│   │       └── optim/
│   │           ├── adagrad.rs
│   │           ├── adam.rs
│   │           ├── adamw.rs
│   │           ├── base.rs
│   │           ├── decay.rs
│   │           ├── grad_accum.rs
│   │           ├── grads.rs
│   │           ├── lbfgs.rs
│   │           ├── mod.rs
│   │           ├── momentum.rs
│   │           ├── muon.rs
│   │           ├── rmsprop.rs
│   │           ├── sgd.rs
│   │           ├── simple/
│   │           │   ├── adaptor.rs
│   │           │   ├── base.rs
│   │           │   ├── mod.rs
│   │           │   └── record/
│   │           │       ├── base.rs
│   │           │       ├── mod.rs
│   │           │       └── v1.rs
│   │           └── visitor.rs
│   ├── burn-remote/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── client/
│   │       │   ├── base.rs
│   │       │   ├── channel.rs
│   │       │   ├── mod.rs
│   │       │   ├── runner.rs
│   │       │   └── worker.rs
│   │       ├── lib.rs
│   │       ├── server/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── processor.rs
│   │       │   ├── session.rs
│   │       │   └── stream.rs
│   │       └── shared/
│   │           ├── mod.rs
│   │           └── task.rs
│   ├── burn-rl/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── environment/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── policy/
│   │       │   ├── async_policy.rs
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       └── transition_buffer/
│   │           ├── base.rs
│   │           ├── mod.rs
│   │           └── slice_access.rs
│   ├── burn-rocm/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-router/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── bridge/
│   │       │   ├── base.rs
│   │       │   ├── byte.rs
│   │       │   └── mod.rs
│   │       ├── channel/
│   │       │   ├── base.rs
│   │       │   ├── direct.rs
│   │       │   └── mod.rs
│   │       ├── client/
│   │       │   ├── base.rs
│   │       │   └── mod.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── binary.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   ├── transaction.rs
│   │       │   └── unary.rs
│   │       ├── runner.rs
│   │       ├── tensor.rs
│   │       └── types.rs
│   ├── burn-std/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── id.rs
│   │       ├── lib.rs
│   │       ├── network.rs
│   │       └── tensor/
│   │           ├── dtype.rs
│   │           ├── mod.rs
│   │           ├── quantization.rs
│   │           ├── shape.rs
│   │           └── slice.rs
│   ├── burn-store/
│   │   ├── Cargo.toml
│   │   ├── MIGRATION.md
│   │   ├── README.md
│   │   ├── benches/
│   │   │   ├── download_resnet18.py
│   │   │   ├── generate_unified_models.py
│   │   │   ├── resnet18_loading.rs
│   │   │   ├── unified_loading.rs
│   │   │   ├── unified_saving.rs
│   │   │   └── zero_copy_loading.rs
│   │   ├── examples/
│   │   │   ├── burnpack_inspect.rs
│   │   │   └── half_precision.rs
│   │   ├── pytorch-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── src/
│   │   │   │   └── lib.rs
│   │   │   └── tests/
│   │   │       ├── backend.rs
│   │   │       ├── batch_norm/
│   │   │       │   ├── batch_norm2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── boolean/
│   │   │       │   ├── boolean.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── buffer/
│   │   │       │   ├── buffer.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── complex_nested/
│   │   │       │   ├── complex_nested.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── config/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── mod.rs
│   │   │       │   └── weights_with_config.pt
│   │   │       ├── conv1d/
│   │   │       │   ├── conv1d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv2d/
│   │   │       │   ├── conv2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv_transpose1d/
│   │   │       │   ├── conv_transpose1d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── conv_transpose2d/
│   │   │       │   ├── conv_transpose2d.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── debug_test.pt
│   │   │       ├── embedding/
│   │   │       │   ├── embedding.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── enum_module/
│   │   │       │   ├── enum_depthwise_false.pt
│   │   │       │   ├── enum_depthwise_true.pt
│   │   │       │   ├── export_weights.py
│   │   │       │   └── mod.rs
│   │   │       ├── group_norm/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── group_norm.pt
│   │   │       │   └── mod.rs
│   │   │       ├── integer/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── integer.pt
│   │   │       │   └── mod.rs
│   │   │       ├── key_remap/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── key_remap.pt
│   │   │       │   └── mod.rs
│   │   │       ├── key_remap_chained/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── key_remap.pt
│   │   │       │   └── mod.rs
│   │   │       ├── layer_norm/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── layer_norm.pt
│   │   │       │   └── mod.rs
│   │   │       ├── linear/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── linear.pt
│   │   │       │   ├── linear_with_bias.pt
│   │   │       │   └── mod.rs
│   │   │       ├── missing_module_field/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── missing_module_field.pt
│   │   │       │   └── mod.rs
│   │   │       ├── non_contiguous_indexes/
│   │   │       │   ├── export_weights.py
│   │   │       │   ├── mod.rs
│   │   │       │   └── non_contiguous_indexes.pt
│   │   │       ├── test_int.pt
│   │   │       ├── test_mod.rs
│   │   │       └── top_level_key/
│   │   │           ├── export_weights.py
│   │   │           ├── mod.rs
│   │   │           └── top_level_key.pt
│   │   ├── safetensors-tests/
│   │   │   ├── Cargo.toml
│   │   │   ├── src/
│   │   │   │   └── lib.rs
│   │   │   └── tests/
│   │   │       ├── backend.rs
│   │   │       ├── multi_layer/
│   │   │       │   ├── mod.rs
│   │   │       │   ├── multi_layer.py
│   │   │       │   └── multi_layer.safetensors
│   │   │       └── test_mod.rs
│   │   └── src/
│   │       ├── adapter.rs
│   │       ├── applier.rs
│   │       ├── apply_result.rs
│   │       ├── burnpack/
│   │       │   ├── base.rs
│   │       │   ├── mod.rs
│   │       │   ├── reader.rs
│   │       │   ├── store.rs
│   │       │   ├── tests/
│   │       │   │   ├── alignment.rs
│   │       │   │   ├── edge_cases.rs
│   │       │   │   ├── header.rs
│   │       │   │   ├── helpers.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── reader.rs
│   │       │   │   ├── round_trip.rs
│   │       │   │   ├── store.rs
│   │       │   │   ├── writer.rs
│   │       │   │   └── zero_copy.rs
│   │       │   └── writer.rs
│   │       ├── collector.rs
│   │       ├── filter.rs
│   │       ├── keyremapper.rs
│   │       ├── lib.rs
│   │       ├── pytorch/
│   │       │   ├── lazy_data.rs
│   │       │   ├── mod.rs
│   │       │   ├── pickle_reader.rs
│   │       │   ├── reader.rs
│   │       │   ├── store.rs
│   │       │   └── tests/
│   │       │       ├── mod.rs
│   │       │       ├── reader/
│   │       │       │   ├── create_legacy_with_offsets.py
│   │       │       │   ├── create_tar_format.py
│   │       │       │   ├── mod.rs
│   │       │       │   ├── simple_legacy.py
│   │       │       │   ├── test_data/
│   │       │       │   │   ├── bfloat16.pt
│   │       │       │   │   ├── bool.pt
│   │       │       │   │   ├── broken.pt
│   │       │       │   │   ├── buffers.pt
│   │       │       │   │   ├── checkpoint.pt
│   │       │       │   │   ├── complex_structure.pt
│   │       │       │   │   ├── empty.pt
│   │       │       │   │   ├── extreme_values.pt
│   │       │       │   │   ├── float16.pt
│   │       │       │   │   ├── float32.pt
│   │       │       │   │   ├── float64.pt
│   │       │       │   │   ├── int16.pt
│   │       │       │   │   ├── int32.pt
│   │       │       │   │   ├── int64.pt
│   │       │       │   │   ├── int8.pt
│   │       │       │   │   ├── large_shape.pt
│   │       │       │   │   ├── legacy_shared_storage.pt
│   │       │       │   │   ├── legacy_with_offsets.pt
│   │       │       │   │   ├── mixed_types.pt
│   │       │       │   │   ├── nested_dict.pt
│   │       │       │   │   ├── parameter.pt
│   │       │       │   │   ├── scalar.pt
│   │       │       │   │   ├── simple_legacy.pt
│   │       │       │   │   ├── special_values.pt
│   │       │       │   │   ├── state_dict.pt
│   │       │       │   │   ├── tensor_2d.pt
│   │       │       │   │   ├── tensor_3d.pt
│   │       │       │   │   ├── tensor_4d.pt
│   │       │       │   │   └── uint8.pt
│   │       │       │   └── test_data.py
│   │       │       └── store/
│   │       │           ├── mod.rs
│   │       │           └── test_data/
│   │       │               ├── generate_enum_test.py
│   │       │               └── model_without_enum_variants.pt
│   │       ├── safetensors/
│   │       │   ├── mod.rs
│   │       │   ├── store.rs
│   │       │   └── tests/
│   │       │       ├── adapter.rs
│   │       │       ├── direct_access.rs
│   │       │       ├── error_handling.rs
│   │       │       ├── file_io.rs
│   │       │       ├── filtering.rs
│   │       │       ├── integration.rs
│   │       │       ├── metadata.rs
│   │       │       ├── mixed_datatypes.rs
│   │       │       ├── mod.rs
│   │       │       ├── multi_layer_verify.rs
│   │       │       ├── pytorch_import.rs
│   │       │       └── round_trip.rs
│   │       ├── tensor_snapshot.rs
│   │       └── traits.rs
│   ├── burn-tch/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build.rs
│   │   └── src/
│   │       ├── backend.rs
│   │       ├── bin/
│   │       │   ├── cpu.rs
│   │       │   ├── cuda.rs
│   │       │   └── mps.rs
│   │       ├── cuda_hack/
│   │       │   ├── dummy_cuda_dependency.cpp
│   │       │   └── fake_cuda_dependency.cpp
│   │       ├── element.rs
│   │       ├── lib.rs
│   │       ├── ops/
│   │       │   ├── activation.rs
│   │       │   ├── base.rs
│   │       │   ├── bool_tensor.rs
│   │       │   ├── int_tensor.rs
│   │       │   ├── mod.rs
│   │       │   ├── module.rs
│   │       │   ├── qtensor.rs
│   │       │   ├── tensor.rs
│   │       │   └── transaction.rs
│   │       └── tensor.rs
│   ├── burn-tensor/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── device.rs
│   │       ├── lib.rs
│   │       └── tensor/
│   │           ├── activation/
│   │           │   ├── base.rs
│   │           │   └── mod.rs
│   │           ├── api/
│   │           │   ├── autodiff.rs
│   │           │   ├── base.rs
│   │           │   ├── bool.rs
│   │           │   ├── cartesian_grid.rs
│   │           │   ├── check.rs
│   │           │   ├── float.rs
│   │           │   ├── fmod.rs
│   │           │   ├── int.rs
│   │           │   ├── mod.rs
│   │           │   ├── numeric.rs
│   │           │   ├── options.rs
│   │           │   ├── orderable.rs
│   │           │   ├── pad.rs
│   │           │   ├── take.rs
│   │           │   ├── transaction.rs
│   │           │   └── trunc.rs
│   │           ├── grid/
│   │           │   ├── affine_grid.rs
│   │           │   ├── meshgrid.rs
│   │           │   └── mod.rs
│   │           ├── linalg/
│   │           │   ├── cosine_similarity.rs
│   │           │   ├── diag.rs
│   │           │   ├── lu_decomposition.rs
│   │           │   ├── matvec.rs
│   │           │   ├── mod.rs
│   │           │   ├── outer.rs
│   │           │   ├── trace.rs
│   │           │   └── vector_norm.rs
│   │           ├── loss/
│   │           │   └── mod.rs
│   │           ├── mod.rs
│   │           ├── module.rs
│   │           ├── quantization.rs
│   │           ├── report.rs
│   │           └── stats/
│   │               └── mod.rs
│   ├── burn-tensor-testgen/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       └── lib.rs
│   ├── burn-train/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── checkpoint/
│   │       │   ├── async_checkpoint.rs
│   │       │   ├── base.rs
│   │       │   ├── file.rs
│   │       │   ├── mod.rs
│   │       │   └── strategy/
│   │       │       ├── base.rs
│   │       │       ├── composed.rs
│   │       │       ├── lastn.rs
│   │       │       ├── metric.rs
│   │       │       └── mod.rs
│   │       ├── components.rs
│   │       ├── evaluator/
│   │       │   ├── base.rs
│   │       │   ├── builder.rs
│   │       │   ├── components.rs
│   │       │   └── mod.rs
│   │       ├── learner/
│   │       │   ├── application_logger.rs
│   │       │   ├── base.rs
│   │       │   ├── classification.rs
│   │       │   ├── early_stopping.rs
│   │       │   ├── mod.rs
│   │       │   ├── regression.rs
│   │       │   ├── rl/
│   │       │   │   ├── checkpointer.rs
│   │       │   │   ├── components.rs
│   │       │   │   ├── env_runner/
│   │       │   │   │   ├── async_runner.rs
│   │       │   │   │   ├── base.rs
│   │       │   │   │   └── mod.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── off_policy.rs
│   │       │   │   ├── output.rs
│   │       │   │   ├── paradigm.rs
│   │       │   │   └── strategy.rs
│   │       │   ├── sequence.rs
│   │       │   ├── summary.rs
│   │       │   ├── supervised/
│   │       │   │   ├── mod.rs
│   │       │   │   ├── paradigm.rs
│   │       │   │   ├── step/
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   └── train.rs
│   │       │   │   └── strategies/
│   │       │   │       ├── base.rs
│   │       │   │       ├── ddp/
│   │       │   │       │   ├── README.md
│   │       │   │       │   ├── epoch.rs
│   │       │   │       │   ├── mod.rs
│   │       │   │       │   ├── strategy.rs
│   │       │   │       │   └── worker.rs
│   │       │   │       ├── mod.rs
│   │       │   │       ├── multi/
│   │       │   │       │   ├── epoch.rs
│   │       │   │       │   ├── mod.rs
│   │       │   │       │   └── strategy.rs
│   │       │   │       └── single/
│   │       │   │           ├── epoch.rs
│   │       │   │           ├── mod.rs
│   │       │   │           └── strategy.rs
│   │       │   └── train_val.rs
│   │       ├── lib.rs
│   │       ├── logger/
│   │       │   ├── async_logger.rs
│   │       │   ├── base.rs
│   │       │   ├── file.rs
│   │       │   ├── in_memory.rs
│   │       │   ├── metric.rs
│   │       │   └── mod.rs
│   │       ├── metric/
│   │       │   ├── acc.rs
│   │       │   ├── auroc.rs
│   │       │   ├── base.rs
│   │       │   ├── cer.rs
│   │       │   ├── classification.rs
│   │       │   ├── confusion_stats.rs
│   │       │   ├── cpu_temp.rs
│   │       │   ├── cpu_use.rs
│   │       │   ├── cuda.rs
│   │       │   ├── fbetascore.rs
│   │       │   ├── hamming.rs
│   │       │   ├── iteration.rs
│   │       │   ├── learning_rate.rs
│   │       │   ├── loss.rs
│   │       │   ├── memory_use.rs
│   │       │   ├── mod.rs
│   │       │   ├── perplexity.rs
│   │       │   ├── precision.rs
│   │       │   ├── processor/
│   │       │   │   ├── async_wrapper.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── full.rs
│   │       │   │   ├── metrics.rs
│   │       │   │   ├── minimal.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── rl_metrics.rs
│   │       │   │   └── rl_processor.rs
│   │       │   ├── recall.rs
│   │       │   ├── rl/
│   │       │   │   ├── cum_reward.rs
│   │       │   │   ├── ep_len.rs
│   │       │   │   ├── exploration_rate.rs
│   │       │   │   └── mod.rs
│   │       │   ├── state.rs
│   │       │   ├── store/
│   │       │   │   ├── aggregate.rs
│   │       │   │   ├── base.rs
│   │       │   │   ├── client.rs
│   │       │   │   ├── log.rs
│   │       │   │   └── mod.rs
│   │       │   ├── top_k_acc.rs
│   │       │   ├── vision/
│   │       │   │   ├── dice.rs
│   │       │   │   ├── dists/
│   │       │   │   │   ├── l2pool.rs
│   │       │   │   │   ├── metric.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── vgg16_l2pool.rs
│   │       │   │   │   └── weights.rs
│   │       │   │   ├── lpips/
│   │       │   │   │   ├── alexnet.rs
│   │       │   │   │   ├── metric.rs
│   │       │   │   │   ├── mod.rs
│   │       │   │   │   ├── squeezenet.rs
│   │       │   │   │   ├── vgg.rs
│   │       │   │   │   └── weights.rs
│   │       │   │   ├── mod.rs
│   │       │   │   ├── ms_ssim.rs
│   │       │   │   ├── psnr.rs
│   │       │   │   └── ssim.rs
│   │       │   └── wer.rs
│   │       └── renderer/
│   │           ├── base.rs
│   │           ├── cli.rs
│   │           ├── mod.rs
│   │           └── tui/
│   │               ├── base.rs
│   │               ├── controls.rs
│   │               ├── full_history.rs
│   │               ├── metric_numeric.rs
│   │               ├── metric_text.rs
│   │               ├── mod.rs
│   │               ├── plot_utils.rs
│   │               ├── popup.rs
│   │               ├── progress.rs
│   │               ├── recent_history.rs
│   │               ├── renderer.rs
│   │               └── status.rs
│   ├── burn-vision/
│   │   ├── Cargo.toml
│   │   ├── src/
│   │   │   ├── backends/
│   │   │   │   ├── cpu/
│   │   │   │   │   ├── base.rs
│   │   │   │   │   ├── connected_components/
│   │   │   │   │   │   ├── spaghetti/
│   │   │   │   │   │   │   ├── Spaghetti_center_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_first_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_forest_labels.rs
│   │   │   │   │   │   │   ├── Spaghetti_last_line_forest_code.rs
│   │   │   │   │   │   │   ├── Spaghetti_single_line_forest_code.rs
│   │   │   │   │   │   │   └── mod.rs
│   │   │   │   │   │   └── spaghetti_4c/
│   │   │   │   │   │       ├── Spaghetti4C_center_line_forest_code.rs
│   │   │   │   │   │       ├── Spaghetti4C_first_line_forest_code.rs
│   │   │   │   │   │       ├── Spaghetti4C_forest_labels.rs
│   │   │   │   │   │       └── mod.rs
│   │   │   │   │   ├── connected_components.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   ├── morphology/
│   │   │   │   │   │   ├── filter.rs
│   │   │   │   │   │   ├── filter_engine.rs
│   │   │   │   │   │   └── mod.rs
│   │   │   │   │   ├── nms.rs
│   │   │   │   │   └── ops.rs
│   │   │   │   ├── cube/
│   │   │   │   │   ├── connected_components/
│   │   │   │   │   │   ├── hardware_accelerated.rs
│   │   │   │   │   │   ├── mod.rs
│   │   │   │   │   │   └── prefix_sum.rs
│   │   │   │   │   ├── mod.rs
│   │   │   │   │   └── ops.rs
│   │   │   │   └── mod.rs
│   │   │   ├── base.rs
│   │   │   ├── lib.rs
│   │   │   ├── ops/
│   │   │   │   ├── base.rs
│   │   │   │   └── mod.rs
│   │   │   ├── tensor.rs
│   │   │   ├── tests/
│   │   │   │   └── mod.rs
│   │   │   ├── transform/
│   │   │   │   ├── mod.rs
│   │   │   │   └── transform2d.rs
│   │   │   └── utils/
│   │   │       ├── mod.rs
│   │   │       └── save.rs
│   │   └── tests/
│   │       ├── common/
│   │       │   └── mod.rs
│   │       ├── connected_components.rs
│   │       ├── morphology.rs
│   │       └── nms.rs
│   └── burn-wgpu/
│       ├── Cargo.toml
│       ├── README.md
│       └── src/
│           └── lib.rs
├── deny.toml
├── docs/
│   └── katex-header.html
├── examples/
│   ├── custom-csv-dataset/
│   │   ├── .gitignore
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── custom-csv-dataset.rs
│   │   │   └── dataframe-dataset.rs
│   │   └── src/
│   │       ├── dataframe_dataset.rs
│   │       ├── dataset.rs
│   │       ├── diabetes_patient.rs
│   │       ├── lib.rs
│   │       └── utils.rs
│   ├── custom-cubecl-kernel/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-cubecl-kernel.rs
│   │   └── src/
│   │       ├── backward.rs
│   │       ├── forward.rs
│   │       ├── kernel.rs
│   │       └── lib.rs
│   ├── custom-image-dataset/
│   │   ├── .gitignore
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── custom-image-dataset.rs
│   │   └── src/
│   │       ├── data.rs
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── custom-learning-strategy/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-learning-strategy.rs
│   │   └── src/
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── custom-renderer/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-renderer.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── custom-training-loop/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-training-loop.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── custom-wgpu-kernel/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── custom-wgpu-kernel.rs
│   │   └── src/
│   │       ├── backward.rs
│   │       ├── forward.rs
│   │       ├── kernel.wgsl
│   │       └── lib.rs
│   ├── dop_timer/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   └── src/
│   │       ├── event_utils.rs
│   │       ├── main.rs
│   │       ├── parsers.rs
│   │       └── workers.rs
│   ├── dqn-agent/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── dqn-agent.rs
│   │   └── src/
│   │       ├── agent.rs
│   │       ├── env.rs
│   │       ├── lib.rs
│   │       ├── training.rs
│   │       └── utils.rs
│   ├── guide/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── guide.rs
│   │   └── src/
│   │       ├── bin/
│   │       │   ├── infer.rs
│   │       │   ├── print.rs
│   │       │   └── train.rs
│   │       ├── data.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── import-model-weights/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── src/
│   │   │   ├── bin/
│   │   │   │   ├── burnpack.rs
│   │   │   │   ├── convert.rs
│   │   │   │   ├── pytorch.rs
│   │   │   │   └── safetensors.rs
│   │   │   ├── inference.rs
│   │   │   ├── lib.rs
│   │   │   └── model.rs
│   │   └── weights/
│   │       ├── mnist.pt
│   │       ├── mnist.safetensors
│   │       └── mnist_train_export.py
│   ├── mnist/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   └── mnist.rs
│   │   └── src/
│   │       ├── data.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── mnist-inference-web/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── build-for-web.sh
│   │   ├── index.html
│   │   ├── index.js
│   │   ├── run-server.sh
│   │   └── src/
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       ├── state.rs
│   │       └── web.rs
│   ├── modern-lstm/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   ├── lstm-infer.rs
│   │   │   └── lstm-train.rs
│   │   └── src/
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── multi-gpus/
│   │   ├── Cargo.toml
│   │   ├── examples/
│   │   │   └── multi-gpus.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── notebook/
│   │   ├── README.md
│   │   ├── autodiff.ipynb
│   │   └── basic-tensor-op.ipynb
│   ├── server/
│   │   ├── Cargo.toml
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   └── server.rs
│   │   └── src/
│   │       └── lib.rs
│   ├── simple-regression/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── regression.rs
│   │   └── src/
│   │       ├── dataset.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── text-classification/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── cubecl.toml
│   │   ├── examples/
│   │   │   ├── ag-news-infer.rs
│   │   │   ├── ag-news-train.rs
│   │   │   ├── db-pedia-infer.rs
│   │   │   └── db-pedia-train.rs
│   │   └── src/
│   │       ├── data/
│   │       │   ├── batcher.rs
│   │       │   ├── dataset.rs
│   │       │   ├── mod.rs
│   │       │   └── tokenizer.rs
│   │       ├── inference.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   ├── text-generation/
│   │   ├── Cargo.toml
│   │   ├── README.md
│   │   ├── examples/
│   │   │   └── text-generation.rs
│   │   └── src/
│   │       ├── data/
│   │       │   ├── batcher.rs
│   │       │   ├── dataset.rs
│   │       │   ├── mod.rs
│   │       │   └── tokenizer.rs
│   │       ├── lib.rs
│   │       ├── model.rs
│   │       └── training.rs
│   └── wgan/
│       ├── Cargo.toml
│       ├── README.md
│       ├── examples/
│       │   ├── wgan-generate.rs
│       │   └── wgan-mnist.rs
│       └── src/
│           ├── dataset.rs
│           ├── infer.rs
│           ├── lib.rs
│           ├── model.rs
│           └── training.rs
├── rustfmt.toml
└── xtask/
    ├── Cargo.toml
    └── src/
        ├── commands/
        │   ├── books.rs
        │   ├── build.rs
        │   ├── doc.rs
        │   ├── mod.rs
        │   ├── test.rs
        │   └── validate.rs
        └── main.rs
Download .txt
Showing preview only (1,384K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (15549 symbols across 1308 files)

FILE: crates/burn-autodiff/src/backend.rs
  type Autodiff (line 18) | pub struct Autodiff<B, C = NoCheckpointing> {
  type Device (line 24) | type Device = B::Device;
  type FloatTensorPrimitive (line 26) | type FloatTensorPrimitive = AutodiffTensor<B>;
  type FloatElem (line 27) | type FloatElem = B::FloatElem;
  type IntTensorPrimitive (line 29) | type IntTensorPrimitive = B::IntTensorPrimitive;
  type IntElem (line 30) | type IntElem = B::IntElem;
  type BoolTensorPrimitive (line 32) | type BoolTensorPrimitive = B::BoolTensorPrimitive;
  type BoolElem (line 33) | type BoolElem = B::BoolElem;
  type QuantizedTensorPrimitive (line 35) | type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
  method ad_enabled (line 37) | fn ad_enabled(_device: &Self::Device) -> bool {
  method name (line 41) | fn name(device: &Self::Device) -> String {
  method seed (line 45) | fn seed(device: &B::Device, seed: u64) {
  method sync (line 49) | fn sync(device: &B::Device) -> Result<(), ExecutionError> {
  method memory_persistent_allocations (line 53) | fn memory_persistent_allocations<
  method memory_cleanup (line 65) | fn memory_cleanup(device: &Self::Device) {
  method staging (line 69) | fn staging<'a, Iter>(data: Iter, device: &Self::Device)
  method supports_dtype (line 76) | fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {
  method dtype_usage (line 80) | fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_ba...
  type InnerBackend (line 86) | type InnerBackend = B;
  type Gradients (line 87) | type Gradients = Gradients;
  method backward (line 89) | fn backward(tensor: AutodiffTensor<B>) -> Gradients {
  method grad (line 93) | fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::Floa...
  method grad_remove (line 97) | fn grad_remove(
  method inner (line 103) | fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {
  method from_inner (line 107) | fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {
  method grad_replace (line 111) | fn grad_replace(
  method int_inner (line 119) | fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {
  method bool_inner (line 123) | fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {
  method int_from_inner (line 127) | fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Se...
  method bool_from_inner (line 131) | fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor...
  method q_inner (line 135) | fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::Inner...
  method q_from_inner (line 139) | fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> Quantize...

FILE: crates/burn-autodiff/src/checkpoint/base.rs
  type NodeTree (line 12) | pub(crate) struct NodeTree {
    method parents (line 18) | pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {
  type Checkpointer (line 25) | pub struct Checkpointer {
    method retrieve_node_output (line 34) | pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T
    method topological_sort (line 52) | fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {
    method is_empty (line 79) | pub fn is_empty(&self) -> bool {

FILE: crates/burn-autodiff/src/checkpoint/builder.rs
  type CheckpointingAction (line 19) | pub enum CheckpointingAction {
    method id (line 41) | pub fn id(&self) -> NodeId {
  type CheckpointerBuilder (line 58) | pub struct CheckpointerBuilder {
    method checkpoint (line 77) | pub(crate) fn checkpoint<B: Backend>(
    method extend (line 102) | pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {
    method build (line 111) | pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
    method find_stop_nodes (line 135) | fn find_stop_nodes(&self) -> Vec<NodeId> {
    method build_n_required_map (line 156) | fn build_n_required_map(
    method insert_checkpoints (line 197) | fn insert_checkpoints(
    method update_n_required_of_parents (line 249) | fn update_n_required_of_parents(
    method checkpoint_compute (line 277) | fn checkpoint_compute(
    method checkpoint_lazy (line 293) | fn checkpoint_lazy(
  type ActionType (line 66) | pub(crate) enum ActionType {

FILE: crates/burn-autodiff/src/checkpoint/retro_forward.rs
  type RetroForward (line 12) | pub trait RetroForward: Debug + Send + 'static {
    method forward (line 14) | fn forward(&self, states: &mut BackwardStates, out_node: NodeId);
  type RetroForwards (line 19) | pub(crate) struct RetroForwards {
    method execute_retro_forward (line 26) | pub(crate) fn execute_retro_forward(
    method is_empty (line 41) | pub(crate) fn is_empty(&self) -> bool {

FILE: crates/burn-autodiff/src/checkpoint/state.rs
  type StateContent (line 8) | pub(crate) type StateContent = Box<dyn Any + Send>;
  type State (line 15) | pub(crate) enum State {
    method to_state_content (line 27) | pub(crate) fn to_state_content(&self) -> &StateContent {
    method into_state_content (line 42) | pub(crate) fn into_state_content(self) -> StateContent {
    method n_required (line 57) | pub(crate) fn n_required(&self) -> usize {
  type BackwardStates (line 70) | pub struct BackwardStates {
    method get_state (line 78) | pub fn get_state<T>(&mut self, node_id: &NodeId) -> T
    method get_state_ref (line 117) | pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {
    method insert_state (line 122) | pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {
    method save (line 127) | pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)
    method is_empty (line 141) | pub(crate) fn is_empty(&self) -> bool {

FILE: crates/burn-autodiff/src/checkpoint/strategy.rs
  type CheckpointStrategy (line 14) | pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sy...
    method compute_property (line 16) | fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingPro...
    method checkpoint_parents (line 19) | fn checkpoint_parents<'a, B2, A>(
    method compute_property (line 42) | fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingPr...
    method checkpoint_parents (line 48) | fn checkpoint_parents<'a, B2, A>(
    method compute_property (line 68) | fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingPro...
    method checkpoint_parents (line 77) | fn checkpoint_parents<'a, B2, A>(
  type CheckpointingError (line 30) | pub enum CheckpointingError {
  type NoCheckpointing (line 38) | pub struct NoCheckpointing {}
  type BalancedCheckpointing (line 63) | pub struct BalancedCheckpointing {}

FILE: crates/burn-autodiff/src/grads.rs
  type GradID (line 13) | pub type GradID = u64;
  type Gradients (line 16) | pub struct Gradients {
    method new (line 22) | pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>...
    method consume (line 41) | pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {
    method remove (line 58) | pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Op...
    method get (line 65) | pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<Fl...
    method register (line 74) | pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTe...

FILE: crates/burn-autodiff/src/graph/base.rs
  type Step (line 6) | pub trait Step: Send + core::fmt::Debug {
    method step (line 8) | fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Che...
    method depth (line 10) | fn depth(&self) -> usize;
    method node (line 12) | fn node(&self) -> NodeId;
    method parents (line 14) | fn parents(&self) -> &[Parent];
  type StepBoxed (line 17) | pub type StepBoxed = Box<dyn Step>;

FILE: crates/burn-autodiff/src/graph/node.rs
  type ComputingProperty (line 14) | pub enum ComputingProperty {
  type Node (line 33) | pub struct Node {
    method clone_if_require_grad (line 50) | pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {
  type NodeRef (line 41) | pub type NodeRef = Arc<Node>;
  type Parent (line 44) | pub struct Parent {
  type NodeId (line 60) | pub struct NodeId {
    method fmt (line 66) | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
    method new (line 73) | pub fn new() -> Self {
  method default (line 84) | fn default() -> Self {

FILE: crates/burn-autodiff/src/graph/requirement.rs
  type Requirement (line 5) | pub enum Requirement {
    method is_none (line 16) | pub fn is_none(&self) -> bool {
    method from_nodes (line 20) | pub fn from_nodes(nodes: &[NodeRef]) -> Self {
    method infer (line 32) | fn infer(&self, other: &Self) -> Self {

FILE: crates/burn-autodiff/src/graph/traversal.rs
  type BreadthFirstSearch (line 10) | pub struct BreadthFirstSearch;
    method traverse (line 22) | pub fn traverse<F, I>(
  type TraversalItem (line 12) | pub trait TraversalItem {
    method id (line 13) | fn id(&self) -> NodeId;
    method parents (line 14) | fn parents(&self) -> &[Parent];
    method parent_nodes (line 15) | fn parent_nodes(&self) -> Vec<NodeId> {
    method id (line 67) | fn id(&self) -> NodeId {
    method parents (line 71) | fn parents(&self) -> &[Parent] {

FILE: crates/burn-autodiff/src/ops/activation.rs
  function gelu (line 17) | fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function relu (line 55) | fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function sigmoid (line 92) | fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function log_sigmoid (line 130) | fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {

FILE: crates/burn-autodiff/src/ops/backward.rs
  type Backward (line 17) | pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
    method backward (line 26) | fn backward(
    method prepare (line 34) | fn prepare<C: CheckpointStrategy>(
  function binary (line 50) | pub fn binary<B, FLhs, FRhs>(
  function unary (line 76) | pub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: ...

FILE: crates/burn-autodiff/src/ops/base.rs
  type OpsPrep (line 22) | pub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {
  type Init (line 35) | pub struct Init;
  type MemoryBound (line 37) | pub struct MemoryBound;
  type MemoryBoundRetroForward (line 39) | pub struct MemoryBoundRetroForward;
  type ComputePropertyDone (line 41) | pub struct ComputePropertyDone;
  type Tracked (line 43) | pub struct Tracked;
  type UnTracked (line 45) | pub struct UnTracked;
  function compute_bound (line 54) | pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDon...
  function memory_bound (line 66) | pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {
  function retro_forward (line 84) | pub fn retro_forward<R: RetroForward>(
  function parents (line 105) | pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, ...
  function stateless (line 132) | pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
  function stateful (line 147) | pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {
  function finish (line 174) | pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {
  function finish (line 197) | pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<...
  function checkpoint (line 211) | pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {
  type OpsKind (line 220) | pub enum OpsKind<BO, B, S, C, const N: usize> {
  type Ops (line 229) | pub struct Ops<S, const N: usize> {
  type OpsStep (line 240) | struct OpsStep<B, T, SB, const N: usize>
  method step (line 257) | fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Check...
  method node (line 261) | fn node(&self) -> NodeId {
  method parents (line 265) | fn parents(&self) -> &[Parent] {
  method depth (line 269) | fn depth(&self) -> usize {
  type UntrackedOpsStep (line 275) | struct UntrackedOpsStep<const N: usize> {
  method step (line 280) | fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Che...
  method node (line 284) | fn node(&self) -> NodeId {
  method parents (line 288) | fn parents(&self) -> &[Parent] {
  method depth (line 291) | fn depth(&self) -> usize {
  function broadcast_shape (line 300) | pub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Sha...

FILE: crates/burn-autodiff/src/ops/bool_tensor.rs
  function bool_from_data (line 12) | fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {
  function bool_into_data (line 16) | async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, Exe...
  function bool_into_int (line 20) | fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {
  function bool_to_device (line 24) | fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTens...
  function bool_device (line 28) | fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {
  function bool_reshape (line 32) | fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
  function bool_slice (line 36) | fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> Bool...
  function bool_empty (line 40) | fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
  function bool_zeros (line 44) | fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
  function bool_ones (line 48) | fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {
  function bool_slice_assign (line 52) | fn bool_slice_assign(
  function bool_cat (line 60) | fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
  function bool_equal (line 64) | fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
  function bool_not (line 68) | fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {
  function bool_and (line 72) | fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
  function bool_or (line 76) | fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
  function bool_xor (line 80) | fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
  function bool_into_float (line 84) | fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::F...
  function bool_swap_dims (line 88) | fn bool_swap_dims(
  function bool_permute (line 96) | fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<...
  function bool_flip (line 100) | fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {
  function bool_argwhere (line 104) | async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {
  function bool_expand (line 108) | fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {
  function bool_repeat_dim (line 112) | fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> B...
  function bool_unfold (line 116) | fn bool_unfold(
  function bool_mask_where (line 125) | fn bool_mask_where(
  function bool_mask_fill (line 133) | fn bool_mask_fill(
  function bool_gather (line 141) | fn bool_gather(
  function bool_scatter_or (line 149) | fn bool_scatter_or(
  function bool_equal_elem (line 158) | fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Sel...

FILE: crates/burn-autodiff/src/ops/int_tensor.rs
  function int_from_data (line 12) | fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
  function int_into_data (line 16) | async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, Execu...
  function int_to_device (line 20) | fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTens...
  function int_device (line 24) | fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
  function int_reshape (line 28) | fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
  function int_slice (line 32) | fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTen...
  function int_empty (line 36) | fn int_empty(
  function int_slice_assign (line 44) | fn int_slice_assign(
  function int_cat (line 52) | fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
  function int_equal (line 56) | fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
  function int_equal_elem (line 60) | fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
  function int_add (line 64) | fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_add_scalar (line 68) | fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
  function int_clamp_min (line 72) | fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
  function int_clamp_max (line 76) | fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
  function int_clamp (line 80) | fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTenso...
  function int_sub (line 84) | fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_sub_scalar (line 88) | fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
  function int_mul (line 92) | fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_mul_scalar (line 96) | fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
  function int_div (line 100) | fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_div_scalar (line 104) | fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
  function int_remainder (line 108) | fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_remainder_scalar (line 112) | fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
  function int_matmul (line 116) | fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
  function int_neg (line 120) | fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
  function int_zeros (line 124) | fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> In...
  function int_ones (line 128) | fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> Int...
  function int_full (line 132) | fn int_full(
  function int_sum (line 141) | fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
  function int_sum_dim (line 145) | fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_mean (line 149) | fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
  function int_mean_dim (line 153) | fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_cumsum (line 157) | fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_cumprod (line 161) | fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_cummin (line 165) | fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_cummax (line 169) | fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_repeat_dim (line 173) | fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> Int...
  function int_greater (line 177) | fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
  function int_greater_elem (line 181) | fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
  function int_greater_equal (line 185) | fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor...
  function int_greater_equal_elem (line 189) | fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<...
  function int_lower (line 193) | fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
  function int_lower_elem (line 197) | fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
  function int_lower_equal (line 201) | fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
  function int_lower_equal_elem (line 205) | fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {
  function int_gather (line 209) | fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -...
  function int_scatter_add (line 213) | fn int_scatter_add(
  function int_select (line 222) | fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -...
  function int_select_add (line 226) | fn int_select_add(
  function int_mask_where (line 235) | fn int_mask_where(
  function int_mask_fill (line 243) | fn int_mask_fill(
  function int_argmax (line 251) | fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_argmin (line 254) | fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
  function int_max (line 257) | fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
  function int_max_dim (line 260) | fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTenso...
  function int_max_dim_with_indices (line 263) | fn int_max_dim_with_indices(
  function int_min (line 269) | fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
  function int_min_dim (line 272) | fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTenso...
  function int_min_dim_with_indices (line 275) | fn int_min_dim_with_indices(
  function int_abs (line 281) | fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
  function int_into_float (line 284) | fn int_into_float(
  function int_swap_dims (line 290) | fn int_swap_dims(
  function int_random (line 298) | fn int_random(
  function int_arange (line 306) | fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> In...
  function int_permute (line 310) | fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Sel...
  function int_flip (line 314) | fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
  function int_sign (line 318) | fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
  function int_prod (line 322) | fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
  function int_prod_dim (line 326) | fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
  function int_expand (line 330) | fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
  function int_sort (line 334) | fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> In...
  function int_sort_with_indices (line 338) | fn int_sort_with_indices(
  function int_argsort (line 346) | fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) ->...
  function bitwise_and (line 350) | fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<...
  function bitwise_and_scalar (line 354) | fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Se...
  function bitwise_or (line 358) | fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<S...
  function bitwise_or_scalar (line 362) | fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Sel...
  function bitwise_xor (line 366) | fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<...
  function bitwise_xor_scalar (line 370) | fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Se...
  function bitwise_not (line 374) | fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
  function bitwise_left_shift (line 378) | fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> Int...
  function bitwise_left_shift_scalar (line 382) | fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTe...
  function bitwise_right_shift (line 386) | fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> In...
  function bitwise_right_shift_scalar (line 390) | fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntT...
  function int_cast (line 394) | fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
  function int_unfold (line 398) | fn int_unfold(

FILE: crates/burn-autodiff/src/ops/maxmin.rs
  type MaxMinDim (line 7) | pub(crate) struct MaxMinDim;
    type State (line 10) | type State = (B::IntTensorPrimitive, Shape, usize);
    method backward (line 12) | fn backward(

FILE: crates/burn-autodiff/src/ops/module.rs
  function embedding (line 17) | fn embedding(weights: AutodiffTensor<B>, indices: IntTensor<B>) -> Autod...
  function embedding_backward (line 51) | fn embedding_backward(
  function conv1d (line 59) | fn conv1d(
  function conv_transpose1d (line 182) | fn conv_transpose1d(
  function conv2d (line 317) | fn conv2d(
  function deform_conv2d (line 443) | fn deform_conv2d(
  function deform_conv2d_backward (line 768) | fn deform_conv2d_backward(
  function conv_transpose2d (line 780) | fn conv_transpose2d(
  function conv3d (line 917) | fn conv3d(
  function conv_transpose3d (line 1043) | fn conv_transpose3d(
  function avg_pool1d (line 1195) | fn avg_pool1d(
  function avg_pool2d (line 1273) | fn avg_pool2d(
  function avg_pool2d_backward (line 1351) | fn avg_pool2d_backward(
  function max_pool1d (line 1363) | fn max_pool1d(
  function max_pool1d_with_indices (line 1410) | fn max_pool1d_with_indices(
  function max_pool1d_with_indices_backward (line 1465) | fn max_pool1d_with_indices_backward(
  function max_pool2d (line 1488) | fn max_pool2d(
  function max_pool2d_with_indices (line 1535) | fn max_pool2d_with_indices(
  function max_pool2d_with_indices_backward (line 1591) | fn max_pool2d_with_indices_backward(
  function adaptive_avg_pool1d (line 1603) | fn adaptive_avg_pool1d(x: AutodiffTensor<B>, output_size: usize) -> Auto...
  function adaptive_avg_pool2d (line 1642) | fn adaptive_avg_pool2d(x: AutodiffTensor<B>, output_size: [usize; 2]) ->...
  function adaptive_avg_pool2d_backward (line 1681) | fn adaptive_avg_pool2d_backward(
  function interpolate (line 1688) | fn interpolate(
  function interpolate_backward (line 1733) | fn interpolate_backward(
  function attention (line 1742) | fn attention(
  type MaxPool1D (line 1755) | struct MaxPool1D;
    type State (line 1758) | type State = (NodeId, IntTensor<B>, usize, usize, usize, usize, bool);
    method backward (line 1760) | fn backward(
  type MaxPool2D (line 1789) | struct MaxPool2D;
    type State (line 1792) | type State = (
    method backward (line 1802) | fn backward(

FILE: crates/burn-autodiff/src/ops/qtensor.rs
  function q_from_data (line 14) | fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTe...
  function quantize (line 18) | fn quantize(
  function quantize_dynamic (line 26) | fn quantize_dynamic(
  function dequantize (line 33) | fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
  function q_device (line 37) | fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
  function q_to_device (line 41) | fn q_to_device(
  function q_reshape (line 48) | fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTe...
  function q_into_data (line 52) | async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData...
  function q_swap_dims (line 56) | fn q_swap_dims(
  function q_permute (line 64) | fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> Quantiz...
  function q_flip (line 68) | fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedT...
  function q_gather (line 72) | fn q_gather(
  function q_select (line 80) | fn q_select(
  function q_slice (line 88) | fn q_slice(
  function q_argmax (line 95) | fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
  function q_argmin (line 99) | fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
  function q_expand (line 103) | fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedT...

FILE: crates/burn-autodiff/src/ops/sort.rs
  type SortDim (line 7) | pub(crate) struct SortDim;
    type State (line 10) | type State = (B::IntTensorPrimitive, Shape, usize);
    method backward (line 12) | fn backward(

FILE: crates/burn-autodiff/src/ops/tensor.rs
  function unsqueeze_like (line 33) | fn unsqueeze_like<B: Backend>(
  function float_from_data (line 54) | fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTens...
  function float_random (line 58) | fn float_random(
  function float_zeros (line 66) | fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -...
  function float_ones (line 70) | fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) ->...
  function float_into_data (line 83) | async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData...
  function float_device (line 87) | fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
  function float_to_device (line 100) | fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> ...
  function float_empty (line 132) | fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -...
  function float_add (line 136) | fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTen...
  function float_add_scalar (line 178) | fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<...
  function float_sub (line 205) | fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTen...
  function float_sub_scalar (line 247) | fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<...
  function float_mul (line 274) | fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTen...
  function float_mul_scalar (line 333) | fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<...
  function float_div (line 366) | fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTen...
  function float_div_scalar (line 434) | fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<...
  function float_remainder (line 468) | fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> Fl...
  function float_remainder_scalar (line 534) | fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatT...
  function float_matmul (line 561) | fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> Float...
  function float_cross (line 619) | fn float_cross(
  function float_neg (line 672) | fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_recip (line 698) | fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_swap_dims (line 738) | fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) ...
  function float_permute (line 792) | fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTens...
  function float_flip (line 847) | fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<...
  function float_reshape (line 897) | fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor...
  function float_gather (line 958) | fn float_gather(
  function float_scatter_add (line 1004) | fn float_scatter_add(
  function float_select (line 1053) | fn float_select(
  function float_select_add (line 1116) | fn float_select_add(
  function float_slice (line 1188) | fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTens...
  function float_slice_assign (line 1244) | fn float_slice_assign(
  function float_mask_where (line 1321) | fn float_mask_where(
  function float_mask_fill (line 1383) | fn float_mask_fill(
  function float_equal (line 1421) | fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTe...
  function float_equal_elem (line 1425) | fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {
  function float_greater (line 1429) | fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> Bool...
  function float_greater_elem (line 1433) | fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor...
  function float_greater_equal (line 1437) | fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -...
  function float_greater_equal_elem (line 1441) | fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> Bool...
  function float_lower (line 1445) | fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTe...
  function float_lower_elem (line 1449) | fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {
  function float_lower_equal (line 1453) | fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> ...
  function float_lower_equal_elem (line 1457) | fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTe...
  function float_is_nan (line 1461) | fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
  function float_is_inf (line 1465) | fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
  function float_detach (line 1469) | fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_set_require_grad (line 1481) | fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool)...
  function float_is_require_grad (line 1489) | fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
  function float_mean (line 1493) | fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_sum (line 1526) | fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_mean_dim (line 1557) | fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<...
  function float_sum_dim (line 1596) | fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<S...
  function float_cumsum (line 1633) | fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Se...
  function float_cumprod (line 1670) | fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<S...
  function float_cummin (line 1728) | fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Se...
  function float_cummax (line 1793) | fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Se...
  function float_argmax (line 1858) | fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<B> {
  function float_argmin (line 1862) | fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<B> {
  function float_exp (line 1866) | fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_log (line 1904) | fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_log1p (line 1942) | fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_powf_scalar_impl (line 1982) | fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> F...
  function float_sqrt (line 2037) | fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_abs (line 2079) | fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_cos (line 2117) | fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_sin (line 2156) | fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_tanh (line 2194) | fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_cosh (line 2236) | fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_sinh (line 2273) | fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_tan (line 2310) | fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_asin (line 2350) | fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_acos (line 2390) | fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_atan (line 2431) | fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_asinh (line 2471) | fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_acosh (line 2512) | fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_atanh (line 2553) | fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_atan2 (line 2594) | fn float_atan2(y: FloatTensor<Self>, x: FloatTensor<Self>) -> FloatTenso...
  function float_round (line 2672) | fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_floor (line 2708) | fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_ceil (line 2744) | fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_trunc (line 2780) | fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_erf (line 2816) | fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_cat (line 2858) | fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor...
  function float_max_dim (line 2952) | fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<S...
  function float_max_dim_with_indices (line 2966) | fn float_max_dim_with_indices(
  function float_min_dim (line 2991) | fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<S...
  function float_min_dim_with_indices (line 3005) | fn float_min_dim_with_indices(
  function float_into_int (line 3030) | fn float_into_int(tensor: FloatTensor<Self>) -> <Autodiff<B> as Backend>...
  function float_powf (line 3034) | fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTe...
  function float_sign (line 3110) | fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
  function float_expand (line 3139) | fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<...
  function float_sort (line 3211) | fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -...
  function float_sort_with_indices (line 3229) | fn float_sort_with_indices(
  function float_argsort (line 3257) | fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool...
  function float_repeat_dim (line 3261) | fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize)...
  function float_cast (line 3326) | fn float_cast(tensor: FloatTensor<Self>, dtype: burn_std::FloatDType) ->...
  function float_unfold (line 3363) | fn float_unfold(
  type BinaryOpsBroadcast (line 3447) | enum BinaryOpsBroadcast {
    method new (line 3453) | fn new<B: Backend>(lhs: &B::FloatTensorPrimitive, rhs: &B::FloatTensor...
    method backward_lhs (line 3467) | fn backward_lhs<B: Backend>(&self, grad: B::FloatTensorPrimitive) -> B...
    method backward_rhs (line 3474) | fn backward_rhs<B: Backend>(&self, grad: B::FloatTensorPrimitive) -> B...

FILE: crates/burn-autodiff/src/ops/transaction.rs
  function tr_execute (line 9) | async fn tr_execute(

FILE: crates/burn-autodiff/src/runtime/client.rs
  type AutodiffClient (line 10) | pub trait AutodiffClient: Send + Clone {
    method register (line 12) | fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: Ch...
    method backward (line 14) | fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;
  type AutodiffClientImpl (line 18) | pub type AutodiffClientImpl = super::graph::GraphMutexClient;

FILE: crates/burn-autodiff/src/runtime/graph.rs
  type GraphMutexClient (line 32) | pub struct GraphMutexClient;
    method graph (line 84) | fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {
  type GraphLocator (line 43) | pub struct GraphLocator {
    method select (line 188) | pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> A...
    method analyse (line 203) | fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalys...
    method merge (line 242) | fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Grap...
    method register_key (line 263) | fn register_key(&mut self, origin: NodeId, key: NodeId) {
    method merge_two (line 276) | fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>...
    method new_graph (line 300) | fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {
    method remove_entry (line 310) | fn remove_entry(&mut self, node: &NodeId) {
  type Graph (line 55) | pub(crate) struct Graph {
    method fmt (line 66) | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  type GraphState (line 61) | struct GraphState {
  method register (line 100) | fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: ...
  method backward (line 108) | fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {
  type GraphCleaner (line 124) | struct GraphCleaner<'a> {
  function cleanup_orphaned_entries (line 129) | fn cleanup_orphaned_entries() {
  method init (line 164) | fn init() -> Self {
  method clean (line 169) | fn clean(&mut self, node: &NodeId) {
  type GraphAnalysis (line 330) | enum GraphAnalysis {

FILE: crates/burn-autodiff/src/runtime/memory_management.rs
  type GraphMemoryManagement (line 11) | pub struct GraphMemoryManagement {
    method extend (line 25) | pub fn extend(&mut self, other: Self) {
    method register (line 32) | pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {
    method consume_node (line 45) | pub fn consume_node(&mut self, node_id: NodeId) {
    method free_unavailable_nodes (line 56) | pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: imp...
    method free_unused_roots (line 92) | pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnM...
    method clear_unused_roots (line 102) | fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {
    method unavailable_propagation (line 118) | fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemorySt...
    method useful_propagation (line 148) | fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {
    method identify_leaves_and_deletables (line 218) | fn identify_leaves_and_deletables(
    method is_referenced (line 257) | fn is_referenced(&self, node_id: NodeId) -> bool {
    method maybe_useful (line 264) | pub(crate) fn maybe_useful(&self) -> bool {
  type NodeMemoryStatus (line 18) | enum NodeMemoryStatus {
  type PopNodeSet (line 271) | struct PopNodeSet {
    method pop (line 277) | fn pop(&mut self) -> Option<NodeId> {
    method contains (line 286) | fn contains(&self, node_id: &NodeId) -> bool {
    method insert (line 291) | fn insert(&mut self, node_id: NodeId) {

FILE: crates/burn-autodiff/src/runtime/server.rs
  type AutodiffServer (line 16) | pub struct AutodiffServer {
    method extend (line 31) | pub fn extend(&mut self, other: AutodiffServer) {
    method register (line 37) | pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions:...
    method backward (line 47) | pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id:...
    method free_unused_roots (line 74) | pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnM...
    method build_tape (line 82) | fn build_tape(
    method execute_steps (line 122) | fn execute_steps(
    method maybe_useful (line 140) | pub(crate) fn maybe_useful(&self) -> bool {
  type NodeCleaner (line 23) | pub trait NodeCleaner {
    method init (line 25) | fn init() -> Self;
    method clean (line 27) | fn clean(&mut self, node: &NodeId);

FILE: crates/burn-autodiff/src/tensor.rs
  type AutodiffTensor (line 11) | pub struct AutodiffTensor<B: Backend> {
  method dtype (line 18) | fn dtype(&self) -> burn_std::DType {
  method shape (line 22) | fn shape(&self) -> burn_std::Shape {
  method rank (line 26) | fn rank(&self) -> usize {
  type NodeRefCount (line 31) | pub type NodeRefCount = Arc<NodeId>;
  type RootStep (line 34) | pub(crate) struct RootStep {
  method step (line 39) | fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Che...
  method node (line 43) | fn node(&self) -> NodeId {
  method parents (line 47) | fn parents(&self) -> &[Parent] {
  method depth (line 51) | fn depth(&self) -> usize {
  function new (line 58) | pub fn new(primitive: B::FloatTensorPrimitive) -> Self {
  function is_tracked (line 77) | pub fn is_tracked(&self) -> bool {
  function require_grad (line 86) | pub fn require_grad(mut self) -> Self {
  function from_parents (line 110) | pub fn from_parents(
  function register_step (line 154) | pub fn register_step<S: Step + 'static>(
  function into_primitive (line 167) | pub fn into_primitive(self) -> B::FloatTensorPrimitive {
  function backward (line 171) | pub fn backward(self) -> Gradients {
  function grad (line 177) | pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {
  function grad_remove (line 181) | pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTenso...
  function grad_replace (line 185) | pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPr...

FILE: crates/burn-autodiff/src/utils.rs
  function duplicate (line 12) | pub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(

FILE: crates/burn-backend-tests/tests/autodiff.rs
  type FloatElemType (line 10) | pub type FloatElemType = f32;
  type IntElemType (line 12) | pub type IntElemType = i32;

FILE: crates/burn-backend-tests/tests/autodiff/abs.rs
  function should_diff_abs (line 5) | fn should_diff_abs() {
  function should_diff_abs_no_nans (line 32) | fn should_diff_abs_no_nans() {

FILE: crates/burn-backend-tests/tests/autodiff/adaptive_avgpool1d.rs
  function test_avg_pool1d_simple (line 6) | fn test_avg_pool1d_simple() {
  type AdaptiveAvgPool1dTestCase (line 23) | struct AdaptiveAvgPool1dTestCase {
    method assert_output (line 31) | fn assert_output(self, x_grad: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/autodiff/adaptive_avgpool2d.rs
  function test_avg_pool2d_simple (line 6) | fn test_avg_pool2d_simple() {
  function test_avg_pool2d_output_1 (line 38) | fn test_avg_pool2d_output_1() {
  type AdaptiveAvgPool2dTestCase (line 67) | struct AdaptiveAvgPool2dTestCase {
    method assert_output (line 77) | fn assert_output(self, x_grad: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/autodiff/add.rs
  function should_diff_add (line 5) | fn should_diff_add() {
  function should_diff_add_scalar (line 28) | fn should_diff_add_scalar() {
  function test_add_complex_1 (line 45) | fn test_add_complex_1() {

FILE: crates/burn-backend-tests/tests/autodiff/aggregation.rs
  function should_diff_mean (line 5) | fn should_diff_mean() {
  function should_diff_sum_1 (line 32) | fn should_diff_sum_1() {
  function should_diff_sum_2 (line 59) | fn should_diff_sum_2() {
  function should_diff_mean_dim (line 87) | fn should_diff_mean_dim() {
  function should_diff_sum_dim (line 114) | fn should_diff_sum_dim() {

FILE: crates/burn-backend-tests/tests/autodiff/avgpool1d.rs
  function test_avg_pool1d_simple (line 6) | fn test_avg_pool1d_simple() {
  function test_avg_pool1d_complex (line 24) | fn test_avg_pool1d_complex() {
  function test_avg_pool1d_complex_dont_count_pad (line 45) | fn test_avg_pool1d_complex_dont_count_pad() {
  type AvgPool1dTestCase (line 65) | struct AvgPool1dTestCase {
    method assert_output (line 76) | fn assert_output(self, x_grad: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/autodiff/avgpool2d.rs
  function test_avg_pool2d_simple (line 6) | fn test_avg_pool2d_simple() {
  function test_avg_pool2d_complex (line 35) | fn test_avg_pool2d_complex() {
  function test_avg_pool2d_complex_dont_include_pad (line 62) | fn test_avg_pool2d_complex_dont_include_pad() {
  type AvgPool2dTestCase (line 88) | struct AvgPool2dTestCase {
    method assert_output (line 103) | fn assert_output(self, x_grad: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/autodiff/backward.rs
  function test_embedding_backward (line 5) | fn test_embedding_backward() {

FILE: crates/burn-backend-tests/tests/autodiff/bridge.rs
  function test_full_precision (line 5) | fn test_full_precision() {

FILE: crates/burn-backend-tests/tests/autodiff/broadcast.rs
  function mul_broadcast (line 4) | fn mul_broadcast() {
  function div_broadcast (line 9) | fn div_broadcast() {
  function sub_broadcast (line 14) | fn sub_broadcast() {
  function add_broadcast (line 19) | fn add_broadcast() {
  function matmul_broadcast (line 24) | fn matmul_broadcast() {
  function mask_where_broadcast (line 29) | fn mask_where_broadcast() {
  function test_ops_broadcast_backward (line 36) | fn test_ops_broadcast_backward<F>(func: F)

FILE: crates/burn-backend-tests/tests/autodiff/cast.rs
  function cast_keeps_gradient_flow (line 10) | fn cast_keeps_gradient_flow() {

FILE: crates/burn-backend-tests/tests/autodiff/cat.rs
  function should_diff_cat (line 6) | fn should_diff_cat() {
  function should_diff_cat_more_than_1_dim (line 61) | fn should_diff_cat_more_than_1_dim() {
  function should_slice_grads_correctly_when_some_inputs_not_tracked (line 83) | fn should_slice_grads_correctly_when_some_inputs_not_tracked() {

FILE: crates/burn-backend-tests/tests/autodiff/ceil.rs
  function should_diff_ceil (line 5) | fn should_diff_ceil() {

FILE: crates/burn-backend-tests/tests/autodiff/checkpoint.rs
  function test_autodiff_checkpoint_complicated_computation (line 5) | fn test_autodiff_checkpoint_complicated_computation() {
  function test_autodiff_checkpoint_with_missing_requirement (line 32) | fn test_autodiff_checkpoint_with_missing_requirement() {
  function test_autodiff_checkpoint_with_many_duplicates (line 52) | fn test_autodiff_checkpoint_with_many_duplicates() {
  function test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound (line 76) | fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {
  function test_autodiff_checkpoint_half_sub_graph_not_tracked (line 100) | fn test_autodiff_checkpoint_half_sub_graph_not_tracked() {
  function test_autodiff_checkpoint_very_complex (line 128) | fn test_autodiff_checkpoint_very_complex() {
  function assert_checkpoint (line 163) | fn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {
  function memory_bound_eager (line 172) | fn memory_bound_eager<const D: usize>(
  function memory_bound_eager_scalar (line 178) | fn memory_bound_eager_scalar<const D: usize>(
  function compute_bound_eager (line 186) | fn compute_bound_eager<const D: usize>(
  function compute_bound_eager_scalar (line 193) | fn compute_bound_eager_scalar<const D: usize>(
  function memory_bound_lazy (line 202) | fn memory_bound_lazy<const D: usize>(
  function compute_bound_lazy (line 210) | fn compute_bound_lazy<const D: usize>(

FILE: crates/burn-backend-tests/tests/autodiff/complex.rs
  function should_diff_full_complex_1 (line 5) | fn should_diff_full_complex_1() {
  function should_diff_full_complex_2 (line 31) | fn should_diff_full_complex_2() {
  function should_diff_full_complex_3 (line 57) | fn should_diff_full_complex_3() {

FILE: crates/burn-backend-tests/tests/autodiff/conv1d.rs
  function test_conv1d_basic (line 5) | fn test_conv1d_basic() {
  function test_conv1d_different_channels (line 39) | fn test_conv1d_different_channels() {
  function test_conv1d_with_padding (line 74) | fn test_conv1d_with_padding() {
  function test_conv1d_with_stride (line 108) | fn test_conv1d_with_stride() {
  function test_conv1d_dilation (line 142) | fn test_conv1d_dilation() {
  function test_conv1d_groups (line 176) | fn test_conv1d_groups() {
  type Conv1dTestCase (line 203) | struct Conv1dTestCase {
    method assert_grads (line 222) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 215) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/conv2d.rs
  function test_conv2d_basic (line 5) | fn test_conv2d_basic() {
  function test_conv2d_different_channels (line 76) | fn test_conv2d_different_channels() {
  function test_conv2d_different_kernel_size (line 151) | fn test_conv2d_different_kernel_size() {
  function test_conv2d_different_padding (line 222) | fn test_conv2d_different_padding() {
  function test_conv2d_different_width (line 277) | fn test_conv2d_different_width() {
  function test_conv2d_stride_2 (line 332) | fn test_conv2d_stride_2() {
  function test_conv2d_different_stride (line 391) | fn test_conv2d_different_stride() {
  function test_conv2d_dilation_2 (line 462) | fn test_conv2d_dilation_2() {
  function test_conv2d_different_dilation (line 521) | fn test_conv2d_different_dilation() {
  function test_conv2d_groups (line 580) | fn test_conv2d_groups() {
  function test_conv2d_groups_stride_2 (line 631) | fn test_conv2d_groups_stride_2() {
  function test_conv2d_groups_different_channels (line 694) | fn test_conv2d_groups_different_channels() {
  function test_conv2d_complex (line 753) | fn test_conv2d_complex() {
  function test_conv2d_groups_stride_2_no_pad (line 812) | fn test_conv2d_groups_stride_2_no_pad() {
  type Conv2dTestCase (line 878) | struct Conv2dTestCase {
    method assert_grads (line 902) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 895) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/conv3d.rs
  function test_conv3d_basic (line 5) | fn test_conv3d_basic() {
  function test_conv3d_complex (line 225) | fn test_conv3d_complex() {
  function test_conv3d_groups_stride_2_no_pad (line 427) | fn test_conv3d_groups_stride_2_no_pad() {
  type Conv3dTestCase (line 594) | struct Conv3dTestCase {
    method assert_grads (line 623) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 616) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose1d.rs
  function test_conv_transpose1d_basic (line 5) | fn test_conv_transpose1d_basic() {
  function test_conv_transpose1d_padding (line 39) | fn test_conv_transpose1d_padding() {
  function test_conv_transpose1d_stride (line 73) | fn test_conv_transpose1d_stride() {
  function test_conv_transpose1d_stride_padding_out (line 107) | fn test_conv_transpose1d_stride_padding_out() {
  function test_conv_transpose1d_dilation (line 141) | fn test_conv_transpose1d_dilation() {
  function test_conv_transpose1d_complex (line 175) | fn test_conv_transpose1d_complex() {
  type ConvTranspose1dTestCase (line 214) | struct ConvTranspose1dTestCase {
    method assert_grads (line 233) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 226) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose2d.rs
  function test_conv_transpose2d_basic (line 5) | fn test_conv_transpose2d_basic() {
  function test_conv_transpose2d_padding (line 79) | fn test_conv_transpose2d_padding() {
  function test_conv_transpose2d_stride (line 112) | fn test_conv_transpose2d_stride() {
  function test_conv_transpose2d_stride_padding_out (line 145) | fn test_conv_transpose2d_stride_padding_out() {
  function test_conv_transpose2d_dilation (line 178) | fn test_conv_transpose2d_dilation() {
  function test_conv_transpose2d_channels (line 211) | fn test_conv_transpose2d_channels() {
  function test_conv_transpose2d_kernel_size (line 263) | fn test_conv_transpose2d_kernel_size() {
  function test_conv_transpose2d_groups (line 302) | fn test_conv_transpose2d_groups() {
  function test_conv_transpose2d_complex_no_groups (line 346) | fn test_conv_transpose2d_complex_no_groups() {
  function test_conv_transpose2d_complex_no_groups_2 (line 446) | fn test_conv_transpose2d_complex_no_groups_2() {
  function test_conv_transpose2d_complex_groups (line 540) | fn test_conv_transpose2d_complex_groups() {
  type ConvTranspose2dTestCase (line 621) | struct ConvTranspose2dTestCase {
    method assert_grads (line 640) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 633) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose3d.rs
  function test_conv_transpose3d_basic (line 5) | fn test_conv_transpose3d_basic() {
  function test_conv_transpose3d_complex_groups (line 215) | fn test_conv_transpose3d_complex_groups() {
  type ConvTranspose3dTestCase (line 622) | struct ConvTranspose3dTestCase {
    method assert_grads (line 641) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 634) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/cross.rs
  function backward_basic (line 8) | fn backward_basic() {
  function backward_after_sum (line 38) | fn backward_after_sum() {
  function different_dim (line 68) | fn different_dim() {

FILE: crates/burn-backend-tests/tests/autodiff/cross_entropy.rs
  function test_cross_entropy_loss_grad (line 5) | fn test_cross_entropy_loss_grad() {

FILE: crates/burn-backend-tests/tests/autodiff/cummax.rs
  function should_diff_cummax (line 5) | fn should_diff_cummax() {
  function should_diff_cummax_2d (line 22) | fn should_diff_cummax_2d() {
  function should_diff_cummax_duplicate_values (line 42) | fn should_diff_cummax_duplicate_values() {
  function should_diff_cummax_all_same (line 63) | fn should_diff_cummax_all_same() {
  function should_diff_cummax_increasing (line 81) | fn should_diff_cummax_increasing() {
  function should_diff_cummax_2d_duplicates (line 100) | fn should_diff_cummax_2d_duplicates() {

FILE: crates/burn-backend-tests/tests/autodiff/cummin.rs
  function should_diff_cummin (line 5) | fn should_diff_cummin() {
  function should_diff_cummin_2d (line 22) | fn should_diff_cummin_2d() {
  function should_diff_cummin_duplicate_values (line 42) | fn should_diff_cummin_duplicate_values() {
  function should_diff_cummin_all_same (line 63) | fn should_diff_cummin_all_same() {
  function should_diff_cummin_decreasing (line 81) | fn should_diff_cummin_decreasing() {
  function should_diff_cummin_2d_duplicates (line 100) | fn should_diff_cummin_2d_duplicates() {

FILE: crates/burn-backend-tests/tests/autodiff/cumprod.rs
  function should_diff_cumprod (line 5) | fn should_diff_cumprod() {
  function should_diff_cumprod_2d (line 22) | fn should_diff_cumprod_2d() {
  function should_diff_cumprod_zero_in_middle (line 60) | fn should_diff_cumprod_zero_in_middle() {
  function should_diff_cumprod_zero_at_start (line 79) | fn should_diff_cumprod_zero_at_start() {
  function should_diff_cumprod_zero_at_end (line 98) | fn should_diff_cumprod_zero_at_end() {
  function should_diff_cumprod_multiple_zeros (line 117) | fn should_diff_cumprod_multiple_zeros() {

FILE: crates/burn-backend-tests/tests/autodiff/cumsum.rs
  function should_diff_cumsum_dim0 (line 5) | fn should_diff_cumsum_dim0() {
  function should_diff_cumsum_dim1 (line 34) | fn should_diff_cumsum_dim1() {
  function should_diff_cumsum_complex (line 63) | fn should_diff_cumsum_complex() {

FILE: crates/burn-backend-tests/tests/autodiff/deform_conv2d.rs
  function test_deform_conv2d_basic (line 6) | fn test_deform_conv2d_basic() {
  function test_deform_conv2d_batched (line 127) | fn test_deform_conv2d_batched() {
  function test_deform_conv2d_different_kernel_size (line 300) | fn test_deform_conv2d_different_kernel_size() {
  function test_deform_conv2d_different_padding (line 598) | fn test_deform_conv2d_different_padding() {
  type Conv2dTestCase (line 1653) | struct Conv2dTestCase {
    method assert_grads (line 1680) | fn assert_grads(self, expected_grads: Grads) {
  type Grads (line 1671) | struct Grads {

FILE: crates/burn-backend-tests/tests/autodiff/div.rs
  function should_diff_div (line 5) | fn should_diff_div() {
  function should_diff_div_scalar (line 31) | fn should_diff_div_scalar() {
  function test_div_complex_1 (line 45) | fn test_div_complex_1() {
  function test_div_complex_2 (line 80) | fn test_div_complex_2() {

FILE: crates/burn-backend-tests/tests/autodiff/erf.rs
  function should_diff_erf (line 5) | fn should_diff_erf() {

FILE: crates/burn-backend-tests/tests/autodiff/exp.rs
  function should_diff_exp (line 5) | fn should_diff_exp() {

FILE: crates/burn-backend-tests/tests/autodiff/expand.rs
  function should_diff_expand (line 5) | fn should_diff_expand() {

FILE: crates/burn-backend-tests/tests/autodiff/flip.rs
  function should_diff_flip (line 6) | fn should_diff_flip() {

FILE: crates/burn-backend-tests/tests/autodiff/floor.rs
  function should_diff_floor (line 5) | fn should_diff_floor() {

FILE: crates/burn-backend-tests/tests/autodiff/gather_scatter.rs
  function test_gather_grad (line 5) | fn test_gather_grad() {
  function test_scatter_grad (line 32) | fn test_scatter_grad() {
  function test_scatter_add_grad_partial_indices (line 70) | fn test_scatter_add_grad_partial_indices() {

FILE: crates/burn-backend-tests/tests/autodiff/gelu.rs
  function should_diff_gelu (line 5) | fn should_diff_gelu() {

FILE: crates/burn-backend-tests/tests/autodiff/gradients.rs
  function should_update_tensor_when_grad_replace (line 5) | fn should_update_tensor_when_grad_replace() {

FILE: crates/burn-backend-tests/tests/autodiff/log.rs
  function should_diff_log (line 5) | fn should_diff_log() {

FILE: crates/burn-backend-tests/tests/autodiff/log1p.rs
  function should_diff_log1p (line 6) | fn should_diff_log1p() {

FILE: crates/burn-backend-tests/tests/autodiff/log_sigmoid.rs
  function should_diff_log_sigmoid (line 6) | fn should_diff_log_sigmoid() {

FILE: crates/burn-backend-tests/tests/autodiff/mask.rs
  function should_diff_mask_fill (line 6) | fn should_diff_mask_fill() {
  function should_diff_mask_where (line 32) | fn should_diff_mask_where() {

FILE: crates/burn-backend-tests/tests/autodiff/matmul.rs
  function should_diff_matmul (line 5) | fn should_diff_matmul() {
  function test_matmul_complex_1 (line 31) | fn test_matmul_complex_1() {
  function test_matmul_complex_2 (line 58) | fn test_matmul_complex_2() {

FILE: crates/burn-backend-tests/tests/autodiff/maxmin.rs
  function should_diff_max_dim (line 6) | fn should_diff_max_dim() {
  function should_diff_min_dim (line 32) | fn should_diff_min_dim() {
  function should_diff_min_dim_3d_dim1 (line 58) | fn should_diff_min_dim_3d_dim1() {

FILE: crates/burn-backend-tests/tests/autodiff/maxpool1d.rs
  function test_max_pool1d_simple (line 6) | fn test_max_pool1d_simple() {
  function test_max_pool1d_with_dilation (line 32) | fn test_max_pool1d_with_dilation() {
  function test_max_pool1d_complex (line 67) | fn test_max_pool1d_complex() {
  function test_max_pool1d_complex_with_padding (line 102) | fn test_max_pool1d_complex_with_padding() {

FILE: crates/burn-backend-tests/tests/autodiff/maxpool2d.rs
  function test_max_pool2d_simple_1 (line 6) | fn test_max_pool2d_simple_1() {
  function test_max_pool2d_simple_2 (line 55) | fn test_max_pool2d_simple_2() {
  function test_max_pool2d_with_dilation (line 104) | fn test_max_pool2d_with_dilation() {
  function test_max_pool2d_complex (line 153) | fn test_max_pool2d_complex() {
  function test_max_pool2d_ceil_mode (line 204) | fn test_max_pool2d_ceil_mode() {

FILE: crates/burn-backend-tests/tests/autodiff/memory_management.rs
  function test_mm_independent_trees (line 5) | fn test_mm_independent_trees() {
  function test_mm_crossover_trees_root_unavailable (line 40) | fn test_mm_crossover_trees_root_unavailable() {
  function test_mm_crossover_trees_with_referred_subtree (line 66) | fn test_mm_crossover_trees_with_referred_subtree() {
  function test_mm_three_crossover_trees_last_still_usable (line 92) | fn test_mm_three_crossover_trees_last_still_usable() {
  function test_mm_three_crossover_trees_middle_one_unavailable (line 125) | fn test_mm_three_crossover_trees_middle_one_unavailable() {
  function test_mm_self_referencing_tree (line 157) | fn test_mm_self_referencing_tree() {
  function test_mm_with_non_impacting_detach (line 174) | fn test_mm_with_non_impacting_detach() {
  function test_mm_with_missing_require_grad_after_cleanup (line 191) | fn test_mm_with_missing_require_grad_after_cleanup() {
  function test_mm_with_detach_after_cleanup (line 215) | fn test_mm_with_detach_after_cleanup() {
  function test_mm_deletables_propagate_well (line 242) | fn test_mm_deletables_propagate_well() {
  function test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper (line 263) | fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_a...

FILE: crates/burn-backend-tests/tests/autodiff/mul.rs
  function should_diff_mul (line 5) | fn should_diff_mul() {
  function should_diff_mul_scalar (line 26) | fn should_diff_mul_scalar() {
  function test_mul_complex_1 (line 43) | fn test_mul_complex_1() {

FILE: crates/burn-backend-tests/tests/autodiff/multithread.rs
  function should_behave_the_same_with_multithread (line 5) | fn should_behave_the_same_with_multithread() {

FILE: crates/burn-backend-tests/tests/autodiff/nearest_interpolate.rs
  function test_upsample_interpolation (line 8) | fn test_upsample_interpolation() {
  function test_downsample_interpolation (line 41) | fn test_downsample_interpolation() {
  type InterpolateTestCase (line 63) | struct InterpolateTestCase {
    method assert_output (line 73) | fn assert_output(self, x_grad: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/autodiff/neg.rs
  function should_diff_neg (line 5) | fn should_diff_neg() {

FILE: crates/burn-backend-tests/tests/autodiff/nonzero.rs
  function should_diff_nonzero (line 5) | fn should_diff_nonzero() {

FILE: crates/burn-backend-tests/tests/autodiff/permute.rs
  function should_diff_permute (line 6) | fn should_diff_permute() {

FILE: crates/burn-backend-tests/tests/autodiff/pow.rs
  function should_diff_powf_scalar (line 6) | fn should_diff_powf_scalar() {
  function should_diff_powf (line 34) | fn should_diff_powf() {
  function should_diff_powf_with_untracked_lhs (line 62) | fn should_diff_powf_with_untracked_lhs() {
  function should_diff_powf_with_untracked_rhs (line 79) | fn should_diff_powf_with_untracked_rhs() {

FILE: crates/burn-backend-tests/tests/autodiff/recip.rs
  function should_diff_recip (line 6) | fn should_diff_recip() {

FILE: crates/burn-backend-tests/tests/autodiff/relu.rs
  function should_diff_relu (line 5) | fn should_diff_relu() {

FILE: crates/burn-backend-tests/tests/autodiff/remainder.rs
  function should_diff_remainder (line 6) | fn should_diff_remainder() {

FILE: crates/burn-backend-tests/tests/autodiff/repeat_dim.rs
  function should_diff_repeat (line 5) | fn should_diff_repeat() {
  function should_diff_repeat_multi_dim (line 26) | fn should_diff_repeat_multi_dim() {

FILE: crates/burn-backend-tests/tests/autodiff/reshape.rs
  function should_diff_reshape (line 5) | fn should_diff_reshape() {

FILE: crates/burn-backend-tests/tests/autodiff/round.rs
  function should_diff_round (line 5) | fn should_diff_round() {

FILE: crates/burn-backend-tests/tests/autodiff/select.rs
  function test_select_grad (line 5) | fn test_select_grad() {
  function test_select_add_grad (line 30) | fn test_select_add_grad() {
  function test_select_add_grad_different_shapes (line 67) | fn test_select_add_grad_different_shapes() {

FILE: crates/burn-backend-tests/tests/autodiff/sigmoid.rs
  function should_diff_sigmoid (line 6) | fn should_diff_sigmoid() {
  function small_neg_val_should_not_cause_grad_overflow (line 22) | fn small_neg_val_should_not_cause_grad_overflow() {

FILE: crates/burn-backend-tests/tests/autodiff/sign.rs
  function should_diff_sign (line 26) | fn should_diff_sign() {

FILE: crates/burn-backend-tests/tests/autodiff/slice.rs
  function should_diff_matmul_with_slice (line 5) | fn should_diff_matmul_with_slice() {
  function should_diff_matmul_with_slice_stepped (line 30) | fn should_diff_matmul_with_slice_stepped() {
  function should_panic_on_slice_with_step (line 58) | fn should_panic_on_slice_with_step() {

FILE: crates/burn-backend-tests/tests/autodiff/slice_assign.rs
  function should_diff_matmul_with_slice_assign (line 6) | fn should_diff_matmul_with_slice_assign() {
  function should_diff_matmul_with_slice_assign_complex (line 34) | fn should_diff_matmul_with_slice_assign_complex() {
  function slice_assign_diff_should_give_same_results_as_cat (line 68) | fn slice_assign_diff_should_give_same_results_as_cat() {
  function should_diff_slice_assign_with_step (line 107) | fn should_diff_slice_assign_with_step() {
  function should_diff_slice_assign_with_negative_step (line 136) | fn should_diff_slice_assign_with_negative_step() {

FILE: crates/burn-backend-tests/tests/autodiff/softmax.rs
  function test_softmax_grad (line 6) | fn test_softmax_grad() {
  function test_log_softmax_grad (line 33) | fn test_log_softmax_grad() {
  function test_quiet_softmax_grad (line 63) | fn test_quiet_softmax_grad() {

FILE: crates/burn-backend-tests/tests/autodiff/sort.rs
  function should_diff_sort (line 6) | fn should_diff_sort() {
  function should_diff_sort_with_indices (line 32) | fn should_diff_sort_with_indices() {
  function should_diff_sort_3d_dim1 (line 59) | fn should_diff_sort_3d_dim1() {

FILE: crates/burn-backend-tests/tests/autodiff/sqrt.rs
  function should_diff_sqrt (line 6) | fn should_diff_sqrt() {

FILE: crates/burn-backend-tests/tests/autodiff/sub.rs
  function should_diff_sub (line 5) | fn should_diff_sub() {
  function should_diff_sub_scalar (line 32) | fn should_diff_sub_scalar() {
  function test_sub_complex_1 (line 48) | fn test_sub_complex_1() {

FILE: crates/burn-backend-tests/tests/autodiff/transpose.rs
  function should_diff_transpose (line 6) | fn should_diff_transpose() {
  function should_diff_swap_dims (line 32) | fn should_diff_swap_dims() {

FILE: crates/burn-backend-tests/tests/autodiff/trig.rs
  function should_diff_cos (line 6) | fn should_diff_cos() {
  function should_diff_sin (line 36) | fn should_diff_sin() {
  function should_diff_tanh (line 66) | fn should_diff_tanh() {
  function should_diff_cosh (line 94) | fn should_diff_cosh() {
  function should_diff_sinh (line 121) | fn should_diff_sinh() {
  function should_diff_tan (line 148) | fn should_diff_tan() {
  function should_diff_asin (line 175) | fn should_diff_asin() {
  function should_diff_acos (line 202) | fn should_diff_acos() {
  function should_diff_atan (line 229) | fn should_diff_atan() {
  function should_diff_asinh (line 256) | fn should_diff_asinh() {
  function should_diff_acosh (line 283) | fn should_diff_acosh() {
  function should_diff_atanh (line 310) | fn should_diff_atanh() {
  function should_diff_atan2 (line 337) | fn should_diff_atan2() {

FILE: crates/burn-backend-tests/tests/autodiff/unfold.rs
  function unfold_backward_accumulates_overlaps (line 5) | fn unfold_backward_accumulates_overlaps() {

FILE: crates/burn-backend-tests/tests/common/autodiff.rs
  type TestAutodiffBackend (line 13) | pub type TestAutodiffBackend = Autodiff<TestBackend, BalancedCheckpointi...
  type TestAutodiffTensor (line 14) | pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend...

FILE: crates/burn-backend-tests/tests/common/backend.rs
  type TestBackend (line 6) | pub type TestBackend = burn_ndarray::NdArray<FloatElemType>;
  type TestBackend (line 9) | pub type TestBackend = burn_tch::LibTorch<FloatElemType>;
  type TestBackend (line 12) | pub type TestBackend = burn_cuda::Cuda<FloatElemType, super::IntElemType>;
  type TestBackend (line 15) | pub type TestBackend = burn_rocm::Rocm<FloatElemType, super::IntElemType>;
  type TestBackend (line 18) | pub type TestBackend = burn_wgpu::Wgpu<FloatElemType, super::IntElemType>;
  type TestBackend (line 21) | pub type TestBackend = burn_cpu::Cpu<FloatElemType, super::IntElemType>;
  type TestBackend (line 24) | pub type TestBackend = burn_router::BackendRouter<
  type TestTensor (line 35) | pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
  type TestTensorInt (line 36) | pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_ten...
  type TestTensorBool (line 37) | pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_te...
  type FloatElem (line 39) | pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;
  type IntElem (line 40) | pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;
  type TestAutodiffBackend (line 42) | pub type TestAutodiffBackend = Autodiff<TestBackend>;
  type TestAutodiffTensor (line 43) | pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend...

FILE: crates/burn-backend-tests/tests/cubecl.rs
  type FloatElemType (line 6) | type FloatElemType = f32;
  type IntElemType (line 7) | type IntElemType = i32;
  type ReferenceBackend (line 11) | pub type ReferenceBackend = burn_ndarray::NdArray<FloatElemType>;

FILE: crates/burn-backend-tests/tests/cubecl/avg_pool2d.rs
  function avg_pool2d_should_match_reference_backend (line 8) | fn avg_pool2d_should_match_reference_backend() {
  function avg_pool2d_backward_should_match_reference_backend (line 44) | fn avg_pool2d_backward_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/bernoulli.rs
  function number_of_1_proportional_to_prob (line 13) | fn number_of_1_proportional_to_prob() {
  function wald_wolfowitz_runs_test (line 33) | fn wald_wolfowitz_runs_test() {

FILE: crates/burn-backend-tests/tests/cubecl/cast.rs
  function should_cast_int_to_float (line 5) | fn should_cast_int_to_float() {
  function should_cast_bool_to_int (line 24) | fn should_cast_bool_to_int() {
  function should_cast_bool_to_float (line 36) | fn should_cast_bool_to_float() {

FILE: crates/burn-backend-tests/tests/cubecl/cat.rs
  function cat_should_match_reference_backend_dim0 (line 6) | fn cat_should_match_reference_backend_dim0() {
  function cat_should_match_reference_backend_dim1 (line 11) | fn cat_should_match_reference_backend_dim1() {
  function cat_should_support_uneven_launch (line 16) | fn cat_should_support_uneven_launch() {
  function test_same_as_reference (line 20) | fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: us...

FILE: crates/burn-backend-tests/tests/cubecl/clamp.rs
  function clamp_should_match_reference (line 6) | fn clamp_should_match_reference() {

FILE: crates/burn-backend-tests/tests/cubecl/contiguous.rs
  function into_contiguous_match_reference_backend_1 (line 6) | pub fn into_contiguous_match_reference_backend_1() {
  function get_combinations (line 33) | fn get_combinations(n: usize) -> impl Iterator<Item = (usize, usize)> {

FILE: crates/burn-backend-tests/tests/cubecl/conv2d.rs
  function conv2d_should_match_reference_backend (line 7) | fn conv2d_should_match_reference_backend() {
  function conv2d_should_match_reference_backend_implicit (line 31) | fn conv2d_should_match_reference_backend_implicit() {
  function conv2d_should_match_reference_backend_bias_regression (line 57) | fn conv2d_should_match_reference_backend_bias_regression() {
  function conv2d_weight_backward_should_run (line 82) | fn conv2d_weight_backward_should_run() {

FILE: crates/burn-backend-tests/tests/cubecl/conv3d.rs
  function conv3d_should_match_reference_backend (line 6) | fn conv3d_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/conv_transpose2d.rs
  function conv_transpose2d_should_match_reference_backend (line 6) | fn conv_transpose2d_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/conv_transpose3d.rs
  function conv_transpose3d_should_match_reference_backend (line 6) | fn conv_transpose3d_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/cross.rs
  function test_cross_product (line 6) | fn test_cross_product() {
  function test_cross_product_zeros (line 26) | fn test_cross_product_zeros() {
  function test_cross_product_batch (line 43) | fn test_cross_product_batch() {
  function test_cross_product_invalid_dimension (line 64) | fn test_cross_product_invalid_dimension() {
  function test_cross_product_parallel_vectors (line 73) | fn test_cross_product_parallel_vectors() {
  function test_cross_product_3d_tensor (line 89) | fn test_cross_product_3d_tensor() {
  function test_cross_product_with_padding_awareness (line 125) | fn test_cross_product_with_padding_awareness() {

FILE: crates/burn-backend-tests/tests/cubecl/gather.rs
  function gather_should_work_with_multiple_workgroups_dim0 (line 6) | fn gather_should_work_with_multiple_workgroups_dim0() {
  function gather_should_work_with_multiple_workgroups_dim1 (line 11) | fn gather_should_work_with_multiple_workgroups_dim1() {
  function test_same_as_ref (line 15) | fn test_same_as_ref<const D: usize>(shape: [usize; D], dim: usize) {

FILE: crates/burn-backend-tests/tests/cubecl/mask_fill.rs
  function mask_fill_should_match_reference_backend (line 8) | fn mask_fill_should_match_reference_backend() {
  function mask_fill_inplace_should_match_reference_backend (line 28) | fn mask_fill_inplace_should_match_reference_backend() {
  function inputs_mask_fill (line 48) | fn inputs_mask_fill() -> (

FILE: crates/burn-backend-tests/tests/cubecl/mask_where.rs
  function mask_where_should_match_reference_backend (line 7) | fn mask_where_should_match_reference_backend() {
  function mask_where_inplace_lhs_should_match_reference_backend (line 18) | fn mask_where_inplace_lhs_should_match_reference_backend() {
  function mask_where_inplace_rhs_should_match_reference_backend (line 37) | fn mask_where_inplace_rhs_should_match_reference_backend() {
  function inputs_mask_where (line 56) | fn inputs_mask_where() -> (

FILE: crates/burn-backend-tests/tests/cubecl/max_pool2d.rs
  function max_pool2d_should_match_reference_backends (line 6) | pub fn max_pool2d_should_match_reference_backends() {
  function max_pool2d_with_indices_should_match_reference_backend (line 28) | pub fn max_pool2d_with_indices_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/max_pool2d_backward.rs
  function max_pool2d_with_indices_backward_should_match_reference_backend (line 6) | pub fn max_pool2d_with_indices_backward_should_match_reference_backend() {

FILE: crates/burn-backend-tests/tests/cubecl/normal.rs
  function empirical_mean_close_to_expectation (line 8) | fn empirical_mean_close_to_expectation() {
  function normal_respects_68_95_99_rule (line 23) | fn normal_respects_68_95_99_rule() {

FILE: crates/burn-backend-tests/tests/cubecl/quantization.rs
  function should_quantize_dequantize_symmetric_arange (line 9) | fn should_quantize_dequantize_symmetric_arange<S: Into<Shape>>(
  function should_quantize_dequantize_symmetric_per_block_arange (line 39) | fn should_quantize_dequantize_symmetric_per_block_arange<S: Into<Shape>>(
  function should_quantize_dequantize_symmetric_per_block (line 71) | fn should_quantize_dequantize_symmetric_per_block(
  function supports_native (line 110) | fn supports_native() -> bool {
  function should_quantize_dequantize_symmetric_arange_q8s_packed (line 123) | fn should_quantize_dequantize_symmetric_arange_q8s_packed() {
  function should_quantize_dequantize_symmetric_arange_q8f_packed (line 128) | fn should_quantize_dequantize_symmetric_arange_q8f_packed() {
  function should_quantize_dequantize_symmetric_arange_q4s_packed (line 133) | fn should_quantize_dequantize_symmetric_arange_q4s_packed() {
  function should_quantize_dequantize_symmetric_arange_q4f_packed (line 138) | fn should_quantize_dequantize_symmetric_arange_q4f_packed() {
  function should_quantize_dequantize_symmetric_arange_q2s_packed (line 143) | fn should_quantize_dequantize_symmetric_arange_q2s_packed() {
  function should_quantize_dequantize_symmetric_arange_q2f_packed (line 148) | fn should_quantize_dequantize_symmetric_arange_q2f_packed() {
  function should_quantize_dequantize_symmetric_per_block_q8s_packed (line 153) | fn should_quantize_dequantize_symmetric_per_block_q8s_packed() {
  function should_quantize_dequantize_symmetric_per_block_q4s_packed (line 158) | fn should_quantize_dequantize_symmetric_per_block_q4s_packed() {
  function should_panic_when_block_size_cannot_store_num_quants (line 164) | fn should_panic_when_block_size_cannot_store_num_quants() {
  function should_quantize_dequantize_symmetric_per_block_q2s_packed (line 170) | fn should_quantize_dequantize_symmetric_per_block_q2s_packed() {
  function should_quantize_dequantize_symmetric_arange_q8s_native (line 175) | fn should_quantize_dequantize_symmetric_arange_q8s_native() {
  function should_quantize_dequantize_symmetric_per_block_q8s_native (line 182) | fn should_quantize_dequantize_symmetric_per_block_q8s_native() {
  function should_quantize_dequantize_symmetric_per_block_arange_q8s_packed (line 189) | fn should_quantize_dequantize_symmetric_per_block_arange_q8s_packed() {
  function should_quantize_dequantize_symmetric_per_block_arange_q8s_native (line 199) | fn should_quantize_dequantize_symmetric_per_block_arange_q8s_native() {
  function should_quantize_dequantize_symmetric_arange_128x256_q8s_native (line 211) | fn should_quantize_dequantize_symmetric_arange_128x256_q8s_native() {
  function should_quantize_dequantize_symmetric_arange_128x256_q8s_packed (line 222) | fn should_quantize_dequantize_symmetric_arange_128x256_q8s_packed() {
  function should_panic_when_shape_cannot_store_quants (line 233) | fn should_panic_when_shape_cannot_store_quants() {

FILE: crates/burn-backend-tests/tests/cubecl/reduce.rs
  constant RANK (line 5) | const RANK: usize = 4;
  constant SHAPE (line 6) | const SHAPE: [usize; RANK] = [2, 4, 8, 16];
  function reduction_argmax_should_match_reference_backend (line 9) | fn reduction_argmax_should_match_reference_backend() {
  function reduction_argmin_should_match_reference_backend (line 24) | fn reduction_argmin_should_match_reference_backend() {
  function reduction_mean_dim_should_match_reference_backend (line 39) | fn reduction_mean_dim_should_match_reference_backend() {
  function reduction_mean_should_match_reference_backend (line 57) | fn reduction_mean_should_match_reference_backend() {
  function reduction_prod_dim_should_match_reference_backend (line 73) | fn reduction_prod_dim_should_match_reference_backend() {
  function reduction_prod_should_match_reference_backend (line 91) | fn reduction_prod_should_match_reference_backend() {
  function reduction_sum_dim_should_match_reference_backend (line 107) | fn reduction_sum_dim_should_match_reference_backend() {
  function reduction_sum_should_match_reference_backend (line 125) | fn reduction_sum_should_match_reference_backend() {
  function reduction_sum_should_match_reference_backend_64bit (line 139) | fn reduction_sum_should_match_reference_backend_64bit() {

FILE: crates/burn-backend-tests/tests/cubecl/repeat_dim.rs
  function repeat_dim_0_few_times (line 6) | fn repeat_dim_0_few_times() {
  function repeat_dim_1_few_times (line 23) | fn repeat_dim_1_few_times() {
  function repeat_dim_2_few_times (line 40) | fn repeat_dim_2_few_times() {
  function repeat_dim_2_many_times (line 57) | fn repeat_dim_2_many_times() {

FILE: crates/burn-backend-tests/tests/cubecl/scatter.rs
  function scatter_should_work_with_multiple_workgroups_2d_dim0 (line 6) | fn scatter_should_work_with_multiple_workgroups_2d_dim0() {
  function scatter_should_work_with_multiple_workgroups_2d_dim1 (line 11) | fn scatter_should_work_with_multiple_workgroups_2d_dim1() {
  function scatter_should_work_with_multiple_workgroups_3d_dim0 (line 16) | fn scatter_should_work_with_multiple_workgroups_3d_dim0() {
  function scatter_should_work_with_multiple_workgroups_3d_dim1 (line 21) | fn scatter_should_work_with_multiple_workgroups_3d_dim1() {
  function scatter_should_work_with_multiple_workgroups_3d_dim2 (line 26) | fn scatter_should_work_with_multiple_workgroups_3d_dim2() {
  function scatter_should_work_with_multiple_workgroups_diff_shapes (line 31) | fn scatter_should_work_with_multiple_workgroups_diff_shapes() {
  function same_as_reference_diff_shape (line 35) | fn same_as_reference_diff_shape<const D: usize>(
  function same_as_reference_same_shape (line 64) | fn same_as_reference_same_shape<const D: usize>(dim: usize, shape: [usiz...

FILE: crates/burn-backend-tests/tests/cubecl/select.rs
  function select_should_work_with_multiple_workgroups (line 6) | fn select_should_work_with_multiple_workgroups() {

FILE: crates/burn-backend-tests/tests/cubecl/select_assign.rs
  function select_add_should_work_with_multiple_workgroups_2d_dim0 (line 6) | fn select_add_should_work_with_multiple_workgroups_2d_dim0() {
  function select_add_should_work_with_multiple_workgroups_2d_dim1 (line 11) | fn select_add_should_work_with_multiple_workgroups_2d_dim1() {
  function select_add_should_work_with_multiple_workgroups_3d_dim0 (line 16) | fn select_add_should_work_with_multiple_workgroups_3d_dim0() {
  function select_add_should_work_with_multiple_workgroups_3d_dim1 (line 21) | fn select_add_should_work_with_multiple_workgroups_3d_dim1() {
  function select_add_should_work_with_multiple_workgroups_3d_dim2 (line 26) | fn select_add_should_work_with_multiple_workgroups_3d_dim2() {
  function select_add_same_as_ref (line 30) | fn select_add_same_as_ref<const D: usize>(dim: usize, shape: [usize; D]) {

FILE: crates/burn-backend-tests/tests/cubecl/slice.rs
  function slice_should_work_with_multiple_workgroups (line 6) | fn slice_should_work_with_multiple_workgroups() {

FILE: crates/burn-backend-tests/tests/cubecl/slice_assign.rs
  function slice_assign_should_work_with_multiple_workgroups (line 5) | fn slice_assign_should_work_with_multiple_workgroups() {

FILE: crates/burn-backend-tests/tests/cubecl/unary.rs
  function tanh_should_not_have_numerical_bugs_on_macos (line 5) | fn tanh_should_not_have_numerical_bugs_on_macos() {

FILE: crates/burn-backend-tests/tests/cubecl/uniform.rs
  function values_all_within_interval_default (line 11) | fn values_all_within_interval_default() {
  function values_all_within_interval_uniform (line 24) | fn values_all_within_interval_uniform() {
  function at_least_one_value_per_bin_uniform (line 37) | fn at_least_one_value_per_bin_uniform() {
  function runs_test (line 51) | fn runs_test() {
  function int_values_all_within_interval_uniform (line 65) | fn int_values_all_within_interval_uniform() {
  function at_least_one_value_per_bin_int_uniform (line 78) | fn at_least_one_value_per_bin_int_uniform() {
  function should_not_fail_on_non_float_autotune (line 94) | fn should_not_fail_on_non_float_autotune() {
  function test_seed_reproducibility (line 104) | fn test_seed_reproducibility() {

FILE: crates/burn-backend-tests/tests/fused_ops/reduce_broadcasted.rs
  function test_reduce_broadcasted_1 (line 5) | fn test_reduce_broadcasted_1() {
  function test_reduce_broadcasted_2 (line 38) | fn test_reduce_broadcasted_2() {
  function test_reduce_broadcasted_3 (line 76) | fn test_reduce_broadcasted_3() {
  function test_reduce_broadcasted_4_reused_partial (line 117) | fn test_reduce_broadcasted_4_reused_partial() {

FILE: crates/burn-backend-tests/tests/fusion.rs
  type FloatElemType (line 13) | pub type FloatElemType = f32;
  type IntElemType (line 14) | pub type IntElemType = i32;
  type TestBackend (line 24) | pub type TestBackend = burn_fusion::Fusion<backend::TestBackend>;
  type TestTensor (line 25) | pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;
  type TestTensorInt (line 26) | pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_ten...
  type TestTensorBool (line 27) | pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_te...

FILE: crates/burn-backend-tests/tests/tensor.rs
  type FloatElemType (line 6) | pub type FloatElemType = f32;
  type IntElemType (line 8) | pub type IntElemType = i32;

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/all.rs
  function test_all (line 5) | fn test_all() {
  function test_all_dim (line 18) | fn test_all_dim() {
  function test_all_with_bool_from_lower_equal (line 26) | fn test_all_with_bool_from_lower_equal() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/any.rs
  function test_any (line 5) | fn test_any() {
  function test_any_dim (line 18) | fn test_any_dim() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/argwhere_nonzero.rs
  function test_argwhere_1d (line 6) | fn test_argwhere_1d() {
  function test_argwhere_2d (line 16) | fn test_argwhere_2d() {
  function test_argwhere_3d (line 26) | fn test_argwhere_3d() {
  function test_nonzero_1d (line 40) | fn test_nonzero_1d() {
  function test_nonzero_2d (line 53) | fn test_nonzero_2d() {
  function test_nonzero_3d (line 70) | fn test_nonzero_3d() {
  function test_nonzero_empty (line 94) | fn test_nonzero_empty() {
  function test_argwhere_empty (line 102) | fn test_argwhere_empty() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/cat.rs
  function should_support_cat_ops_bool (line 5) | fn should_support_cat_ops_bool() {
  function should_support_cat_with_empty_tensor_bool (line 19) | fn should_support_cat_with_empty_tensor_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/comparison.rs
  function should_support_bool_equal (line 5) | fn should_support_bool_equal() {
  function should_support_bool_not_equal (line 21) | fn should_support_bool_not_equal() {
  function should_support_bool_not (line 37) | fn should_support_bool_not() {
  function test_bool_equal_elem (line 50) | fn test_bool_equal_elem() {
  function test_bool_not_equal_elem (line 62) | fn test_bool_not_equal_elem() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/create_like.rs
  function should_support_zeros_like (line 5) | fn should_support_zeros_like() {
  function should_support_ones_like (line 21) | fn should_support_ones_like() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/expand.rs
  function expand_2d_bool (line 5) | fn expand_2d_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/flip.rs
  function flip_bool (line 5) | fn flip_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/full.rs
  function test_tensor_full (line 5) | fn test_tensor_full() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/gather_scatter.rs
  function should_scatter_1d_bool (line 5) | fn should_scatter_1d_bool() {
  function should_gather_1d_dim0_bool (line 19) | fn should_gather_1d_dim0_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/init.rs
  function should_support_bool_empty (line 5) | fn should_support_bool_empty() {
  function should_support_bool_zeros (line 12) | fn should_support_bool_zeros() {
  function should_support_bool_ones (line 23) | fn should_support_bool_ones() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/logical.rs
  function test_bool_and (line 5) | fn test_bool_and() {
  function test_bool_or (line 14) | fn test_bool_or() {
  function test_bool_xor (line 23) | fn test_bool_xor() {
  function test_bool_or_vec (line 32) | fn test_bool_or_vec() {
  function test_bool_and_vec (line 42) | fn test_bool_and_vec() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/mask.rs
  function should_support_bool_mask_where_ops (line 5) | fn should_support_bool_mask_where_ops() {
  function should_support_bool_mask_fill_ops (line 20) | fn should_support_bool_mask_fill_ops() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/movedim.rs
  function movedim_bool (line 5) | fn movedim_bool() {
  function vec_input_bool (line 32) | fn vec_input_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/permute.rs
  function permute_bool (line 5) | fn permute_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/repeat.rs
  function should_support_bool_repeat_ops_one_dimension (line 5) | fn should_support_bool_repeat_ops_one_dimension() {
  function should_support_bool_repeat_on_many_dimension (line 20) | fn should_support_bool_repeat_on_many_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/repeat_dim.rs
  function should_support_bool_repeat_ops (line 5) | fn should_support_bool_repeat_ops() {
  function should_support_bool_repeat_on_dims_larger_than_1 (line 20) | fn should_support_bool_repeat_on_dims_larger_than_1() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/reshape.rs
  function should_support_reshape_bool (line 5) | fn should_support_reshape_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/select.rs
  function should_select_bool_tensor_1d (line 5) | fn should_select_bool_tensor_1d() {
  function should_select_bool_tensor_2d (line 18) | fn should_select_bool_tensor_2d() {
  function should_select_add_bool_tensor (line 32) | fn should_select_add_bool_tensor() {
  function should_select_add_bool_overlapping_indices (line 50) | fn should_select_add_bool_overlapping_indices() {
  function should_select_add_bool_false_to_true_case (line 65) | fn should_select_add_bool_false_to_true_case() {
  function should_select_add_bool_true_or_true_accumulation (line 79) | fn should_select_add_bool_true_or_true_accumulation() {
  function should_match_default_implementation_behavior (line 93) | fn should_match_default_implementation_behavior() {
  function should_select_add_bool_overlapping_indices_vs_default (line 117) | fn should_select_add_bool_overlapping_indices_vs_default() {
  function should_select_add_bool_true_or_true_accumulation_vs_default (line 140) | fn should_select_add_bool_true_or_true_accumulation_vs_default() {
  function should_select_add_bool_false_to_true_case_vs_default (line 163) | fn should_select_add_bool_false_to_true_case_vs_default() {
  function should_select_add_bool_tensor_vs_default (line 186) | fn should_select_add_bool_tensor_vs_default() {
  function should_fail_if_replacement_semantics_were_used (line 210) | fn should_fail_if_replacement_semantics_were_used() {
  function should_fail_if_replacement_semantics_were_used_vs_default (line 225) | fn should_fail_if_replacement_semantics_were_used_vs_default() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/stack.rs
  function should_support_stack_ops_bool (line 6) | fn should_support_stack_ops_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/take.rs
  function should_take_bool_tensor (line 5) | fn should_take_bool_tensor() {
  function should_take_bool_tensor_with_2d_indices (line 18) | fn should_take_bool_tensor_with_2d_indices() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/transpose.rs
  function should_support_transpose_bool (line 5) | fn should_support_transpose_bool() {
  function should_support_swap_dims_bool (line 24) | fn should_support_swap_dims_bool() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/tri_mask.rs
  function square_diag (line 5) | fn square_diag() {
  function square_diag_offset (line 17) | fn square_diag_offset() {
  function square_tri_upper (line 26) | fn square_tri_upper() {
  function square_tri_upper_offset (line 38) | fn square_tri_upper_offset() {
  function square_tri_lower (line 50) | fn square_tri_lower() {
  function square_tri_lower_offset (line 63) | fn square_tri_lower_offset() {
  function rect_diag (line 76) | fn rect_diag() {

FILE: crates/burn-backend-tests/tests/tensor/bool/ops/unfold.rs
  function test_unfold_bool (line 6) | fn test_unfold_bool() {

FILE: crates/burn-backend-tests/tests/tensor/clone_invariance.rs
  type CloneInvarianceTest (line 14) | pub trait CloneInvarianceTest<const D: usize> {
    method args (line 17) | fn args(&self) -> Self::Args;
    method run (line 19) | fn run(&self, args: &Self::Args, inplace: bool) -> TensorData;
    method check (line 21) | fn check(&self) {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/celu.rs
  function test_celu_d2 (line 6) | fn test_celu_d2() {
  function test_celu_with_alpha (line 22) | fn test_celu_with_alpha() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/elu.rs
  function test_elu (line 6) | fn test_elu() {
  function test_elu_alpha (line 20) | fn test_elu_alpha() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/gelu.rs
  function test_gelu (line 6) | fn test_gelu() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/glu.rs
  function test_glu_d3 (line 5) | fn test_glu_d3() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/hard_sigmoid.rs
  function test_hard_sigmoid (line 6) | fn test_hard_sigmoid() {
  function test_hard_sigmoid_overflow (line 18) | fn test_hard_sigmoid_overflow() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/leaky_relu.rs
  function test_leaky_relu_d2 (line 6) | fn test_leaky_relu_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/log_sigmoid.rs
  function test_log_sigmoid (line 6) | fn test_log_sigmoid() {
  function test_log_sigmoid_numerical_stability (line 19) | fn test_log_sigmoid_numerical_stability() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/mish.rs
  function test_mish (line 6) | fn test_mish() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/prelu.rs
  function test_prelu_2_dimension (line 6) | fn test_prelu_2_dimension() {
  function test_prelu_2_dimension_scalar_weight (line 23) | fn test_prelu_2_dimension_scalar_weight() {
  function test_prelu_positives (line 41) | fn test_prelu_positives() {
  function test_prelu_zero_weight (line 56) | fn test_prelu_zero_weight() {
  function test_prelu_some_weight (line 69) | fn test_prelu_some_weight() {
  function test_prelu_single_dim_multi_weight (line 82) | fn test_prelu_single_dim_multi_weight() {
  function test_prelu_multi_dim_wrong_weights (line 94) | fn test_prelu_multi_dim_wrong_weights() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/quiet_softmax.rs
  function test_quiet_softmax_d2 (line 6) | fn test_quiet_softmax_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/relu.rs
  function test_relu_d2 (line 5) | fn test_relu_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/selu.rs
  function test_selu (line 6) | fn test_selu() {
  function test_selu_zero (line 28) | fn test_selu_zero() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/sigmoid.rs
  function test_sigmoid (line 6) | fn test_sigmoid() {
  function test_sigmoid_overflow (line 18) | fn test_sigmoid_overflow() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/silu.rs
  function test_silu (line 6) | fn test_silu() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/softmax.rs
  function test_softmax_d2 (line 6) | fn test_softmax_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/softmin.rs
  function test_softmin_d2 (line 6) | fn test_softmin_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/softplus.rs
  function test_softplus_d2 (line 6) | fn test_softplus_d2() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/softsign.rs
  function test_softsign (line 6) | fn test_softsign() {
  function test_softsign_zero (line 18) | fn test_softsign_zero() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/tanh_activation.rs
  function test_tanh (line 6) | fn test_tanh() {

FILE: crates/burn-backend-tests/tests/tensor/float/activation/thresholded_relu.rs
  function test_thresholded_relu_d2 (line 5) | fn test_thresholded_relu_d2() {
  function test_thresholded_relu_d2_alpha (line 17) | fn test_thresholded_relu_d2_alpha() {

FILE: crates/burn-backend-tests/tests/tensor/float/grid/affine_grid.rs
  function create_identity_transform (line 4) | fn create_identity_transform(batch_size: usize) -> TestTensor<3> {
  function test_affine_grid_identity (line 10) | fn test_affine_grid_identity() {
  function test_affine_grid_scaling (line 32) | fn test_affine_grid_scaling() {
  function test_affine_grid_translation (line 53) | fn test_affine_grid_translation() {

FILE: crates/burn-backend-tests/tests/tensor/float/grid/meshgrid.rs
  function assert_tensors_equal (line 10) | fn assert_tensors_equal<const N: usize, B: Backend, K>(
  function test_meshgrid (line 24) | fn test_meshgrid() {
  function test_meshgrid_stack (line 129) | fn test_meshgrid_stack() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/cosine_similarity.rs
  function test_cosine_similarity_basic (line 6) | fn test_cosine_similarity_basic() {
  function test_cosine_similarity_orthogonal (line 24) | fn test_cosine_similarity_orthogonal() {
  function test_cosine_similarity_parallel (line 37) | fn test_cosine_similarity_parallel() {
  function test_cosine_similarity_opposite (line 50) | fn test_cosine_similarity_opposite() {
  function test_cosine_similarity_different_dimension (line 63) | fn test_cosine_similarity_different_dimension() {
  function test_cosine_similarity_near_zero (line 85) | fn test_cosine_similarity_near_zero() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/diag.rs
  function test_diag_2d_square (line 5) | fn test_diag_2d_square() {
  function test_diag_2d_tall (line 15) | fn test_diag_2d_tall() {
  function test_diag_2d_wide (line 28) | fn test_diag_2d_wide() {
  function test_diag_3d_batch_square (line 40) | fn test_diag_3d_batch_square() {
  function test_diag_3d_batch_tall (line 55) | fn test_diag_3d_batch_tall() {
  function test_diag_3d_batch_wide (line 73) | fn test_diag_3d_batch_wide() {
  function test_diag_4d_batch_channel_square (line 91) | fn test_diag_4d_batch_channel_square() {
  function test_diag_4d_batch_channel_tall (line 109) | fn test_diag_4d_batch_channel_tall() {
  function test_diag_4d_batch_channel_wide (line 127) | fn test_diag_4d_batch_channel_wide() {
  function test_diag_1x1 (line 145) | fn test_diag_1x1() {
  function test_diag_single_row (line 157) | fn test_diag_single_row() {
  function test_diag_single_column (line 169) | fn test_diag_single_column() {
  function test_diag_zeros (line 181) | fn test_diag_zeros() {
  function test_diag_batch_single_element (line 193) | fn test_diag_batch_single_element() {
  function test_diag_batch_mixed_zeros (line 205) | fn test_diag_batch_mixed_zeros() {
  function test_diag_int_tensor (line 220) | fn test_diag_int_tensor() {
  function test_diag_int_3x3 (line 232) | fn test_diag_int_3x3() {
  function test_diag_1d_should_panic (line 245) | fn test_diag_1d_should_panic() {
  function test_diag_wrong_output_rank_should_panic (line 254) | fn test_diag_wrong_output_rank_should_panic() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/lu_decomposition.rs
  function test_lu_2x2_decomposition (line 7) | fn test_lu_2x2_decomposition() {
  function test_lu_3x3_decomposition (line 16) | fn test_lu_3x3_decomposition() {
  function test_lu_singular_matrix (line 44) | fn test_lu_singular_matrix() {
  function test_lu_non_square_matrix (line 52) | fn test_lu_non_square_matrix() {
  function test_lu_1x1_element_matrix (line 59) | fn test_lu_1x1_element_matrix() {
  function test_lu_identity_matrix (line 69) | fn test_lu_identity_matrix() {
  function test_lu_50x50_random_matrix (line 79) | fn test_lu_50x50_random_matrix() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/matvec.rs
  function test_matvec_basic_float (line 5) | fn test_matvec_basic_float() {
  function test_matvec_basic_int (line 19) | fn test_matvec_basic_int() {
  function test_matvec_batched (line 31) | fn test_matvec_batched() {
  function test_matvec_vector_broadcasts_over_batches (line 51) | fn test_matvec_vector_broadcasts_over_batches() {
  function test_matvec_matrix_broadcasts_over_vector_batches (line 71) | fn test_matvec_matrix_broadcasts_over_vector_batches() {
  function test_matvec_invalid_inner_dim_panics (line 86) | fn test_matvec_invalid_inner_dim_panics() {
  function test_matvec_mismatched_batches_panics (line 96) | fn test_matvec_mismatched_batches_panics() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/outer.rs
  function test_outer_basic (line 8) | fn test_outer_basic() {
  function test_outer_shapes_only (line 19) | fn test_outer_shapes_only() {
  function test_outer_asymmetry_and_shapes (line 28) | fn test_outer_asymmetry_and_shapes() {
  function test_outer_zero_left (line 40) | fn test_outer_zero_left() {
  function test_outer_zero_right (line 52) | fn test_outer_zero_right() {
  function test_outer_signs (line 64) | fn test_outer_signs() {
  function test_outer_integer_inputs (line 75) | fn test_outer_integer_inputs() {
  function test_outer_equivalence_to_matmul (line 86) | fn test_outer_equivalence_to_matmul() {
  function test_outer_vector_identity_right_mult (line 100) | fn test_outer_vector_identity_right_mult() {
  function test_outer_length_one_vectors (line 116) | fn test_outer_length_one_vectors() {
  function test_outer_large_values (line 127) | fn test_outer_large_values() {
  function test_outer_nan_propagation (line 140) | fn test_outer_nan_propagation() {
  function test_outer_batched_basic (line 159) | fn test_outer_batched_basic() {
  function test_outer_batched_shapes (line 173) | fn test_outer_batched_shapes() {
  function test_outer_batched_zero_left (line 182) | fn test_outer_batched_zero_left() {
  function test_outer_batched_zero_right (line 193) | fn test_outer_batched_zero_right() {
  function test_outer_batched_signs (line 204) | fn test_outer_batched_signs() {
  function test_outer_batched_equivalence_to_per_sample_outer (line 215) | fn test_outer_batched_equivalence_to_per_sample_outer() {
  function test_outer_batched_mismatched_batches_panics (line 246) | fn test_outer_batched_mismatched_batches_panics() {
  function test_outer_dim (line 254) | fn test_outer_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/trace.rs
  function test_trace_2d_square (line 5) | fn test_trace_2d_square() {
  function test_trace_2d_rectangular_wide (line 16) | fn test_trace_2d_rectangular_wide() {
  function test_trace_2d_rectangular_tall (line 26) | fn test_trace_2d_rectangular_tall() {
  function test_trace_3d_batch (line 36) | fn test_trace_3d_batch() {
  function test_trace_4d_batch (line 50) | fn test_trace_4d_batch() {
  function test_trace_single_element (line 69) | fn test_trace_single_element() {
  function test_trace_zeros (line 79) | fn test_trace_zeros() {
  function test_trace_negative_values (line 89) | fn test_trace_negative_values() {
  function test_trace_1d_should_panic (line 100) | fn test_trace_1d_should_panic() {
  function test_trace_wrong_output_rank_should_panic (line 109) | fn test_trace_wrong_output_rank_should_panic() {

FILE: crates/burn-backend-tests/tests/tensor/float/linalg/vector_norm.rs
  function test_max_min_abs (line 8) | fn test_max_min_abs() {
  function test_l0_norm (line 68) | fn test_l0_norm() {
  function test_l1_norm (line 99) | fn test_l1_norm() {
  function test_lp_norm (line 120) | fn test_lp_norm() {
  function test_l2_norm (line 209) | fn test_l2_norm() {
  function test_normalize (line 231) | fn test_normalize() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool1d.rs
  function test_adaptive_avg_pool1d_simple (line 7) | fn test_adaptive_avg_pool1d_simple() {
  function test_adaptive_avg_pool1d_dyn_filter_size (line 22) | fn test_adaptive_avg_pool1d_dyn_filter_size() {
  function test_adaptive_avg_pool1d_bigger_output (line 34) | fn test_adaptive_avg_pool1d_bigger_output() {
  type AdaptiveAvgPool1dTestCase (line 48) | struct AdaptiveAvgPool1dTestCase {
    method assert_output (line 56) | fn assert_output(self, y: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool2d.rs
  function test_adaptive_avg_pool2d_simple (line 7) | fn test_adaptive_avg_pool2d_simple() {
  function test_adaptive_avg_pool2d_dyn_filter_size (line 34) | fn test_adaptive_avg_pool2d_dyn_filter_size() {
  function test_adaptive_avg_pool2d_bigger_output (line 51) | fn test_adaptive_avg_pool2d_bigger_output() {
  type AdaptiveAvgPool2dTestCase (line 79) | struct AdaptiveAvgPool2dTestCase {
    method assert_output (line 89) | fn assert_output(self, y: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/attention.rs
  function test_attention_no_mask (line 9) | fn test_attention_no_mask() {
  function test_attention_custom_scale (line 59) | fn test_attention_custom_scale() {
  function test_attention_attn_bias (line 101) | fn test_attention_attn_bias() {
  function test_attention_softcap (line 144) | fn test_attention_softcap() {
  function test_attention_is_causal (line 186) | fn test_attention_is_causal() {
  function test_attention_cross_attention_with_bias (line 229) | fn test_attention_cross_attention_with_bias() {
  function test_attention_softcap_preserves_causal_mask (line 280) | fn test_attention_softcap_preserves_causal_mask() {
  function test_attention_all_options (line 319) | fn test_attention_all_options() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/avgpool1d.rs
  function test_avg_pool1d_simple (line 7) | fn test_avg_pool1d_simple() {
  function test_avg_pool1d_complex (line 22) | fn test_avg_pool1d_complex() {
  function test_avg_pool1d_complex_dont_count_pad (line 40) | fn test_avg_pool1d_complex_dont_count_pad() {
  type AvgPool1dTestCase (line 57) | struct AvgPool1dTestCase {
    method assert_output (line 68) | fn assert_output(self, y: TestTensor<3>) {
  function test_avg_pool1d_ceil_mode (line 92) | fn test_avg_pool1d_ceil_mode() {
  function test_avg_pool1d_ceil_mode_count_include_pad (line 139) | fn test_avg_pool1d_ceil_mode_count_include_pad() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/avgpool2d.rs
  function test_avg_pool2d_simple (line 7) | fn test_avg_pool2d_simple() {
  function test_avg_pool2d_complex (line 31) | fn test_avg_pool2d_complex() {
  function test_avg_pool2d_complex_dont_include_pad (line 55) | fn test_avg_pool2d_complex_dont_include_pad() {
  type AvgPool2dTestCase (line 78) | struct AvgPool2dTestCase {
    method assert_output (line 93) | fn assert_output(self, y: TestTensor<4>) {
  function test_avg_pool2d_ceil_mode (line 117) | fn test_avg_pool2d_ceil_mode() {
  function test_avg_pool2d_ceil_mode_count_include_pad (line 179) | fn test_avg_pool2d_ceil_mode_count_include_pad() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/bicubic_interpolate.rs
  function test_upsample_interpolation (line 8) | fn test_upsample_interpolation() {
  function test_downsample_interpolation (line 69) | fn test_downsample_interpolation() {
  function test_1d_bicubic (line 88) | fn test_1d_bicubic() {
  type InterpolateTestCase (line 127) | struct InterpolateTestCase {
    method assert_output (line 137) | fn assert_output(self, y: TestTensor<4>) {
    method assert_output_with_align_corners (line 141) | fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corn...
  function test_upsample_half_pixel (line 161) | fn test_upsample_half_pixel() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/bilinear_interpolate.rs
  function test_upsample_interpolation (line 8) | fn test_upsample_interpolation() {
  function test_downsample_interpolation (line 69) | fn test_downsample_interpolation() {
  function test_1d_bilinear (line 88) | fn test_1d_bilinear() {
  function test_interpolate_coord_float_precision_boundary (line 136) | fn test_interpolate_coord_float_precision_boundary() {
  function should_interpolate_cast (line 175) | fn should_interpolate_cast() {
  type InterpolateTestCase (line 213) | struct InterpolateTestCase {
    method assert_output (line 223) | fn assert_output(self, y: TestTensor<4>) {
    method assert_output_with_align_corners (line 227) | fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corn...
  function test_upsample_half_pixel (line 247) | fn test_upsample_half_pixel() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv1d.rs
  function test_conv1d_simple (line 8) | fn test_conv1d_simple() {
  function test_conv1d_dilation (line 28) | fn test_conv1d_dilation() {
  function test_conv1d_groups (line 48) | fn test_conv1d_groups() {
  function test_conv1d_complex (line 68) | fn test_conv1d_complex() {
  type Conv1dTestCase (line 90) | struct Conv1dTestCase {
    method assert_output (line 103) | fn assert_output(self, y: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv2d.rs
  function test_conv2d_simple (line 10) | fn test_conv2d_simple() {
  function test_conv2d_simple_implicit (line 45) | fn test_conv2d_simple_implicit() {
  function test_conv2d_implicit_padded_in_channels (line 164) | fn test_conv2d_implicit_padded_in_channels() {
  function test_conv2d_groups_channels_out (line 283) | fn test_conv2d_groups_channels_out() {
  function test_conv2d_groups (line 402) | fn test_conv2d_groups() {
  function test_conv2d_groups_multiple_channels (line 431) | fn test_conv2d_groups_multiple_channels() {
  function test_conv2d_complex (line 474) | fn test_conv2d_complex() {
  type Conv2dTestCase (line 508) | struct Conv2dTestCase {
    method assert_output (line 526) | fn assert_output(self, y: TestTensor<4>) {
  function conv2d_weight (line 566) | fn conv2d_weight() -> TensorData {
  function test_conv2d_binary_broadcasted (line 574) | fn test_conv2d_binary_broadcasted() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv3d.rs
  function test_conv3d_simple (line 8) | fn test_conv3d_simple() {
  function test_conv3d_groups (line 88) | fn test_conv3d_groups() {
  function test_conv3d_complex (line 150) | fn test_conv3d_complex() {
  type Conv3dTestCase (line 231) | struct Conv3dTestCase {
    method assert_output (line 254) | fn assert_output(self, y: TestTensor<5>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose1d.rs
  function test_conv_transpose1d_diff_channels (line 8) | fn test_conv_transpose1d_diff_channels() {
  function test_conv_transpose1d_stride (line 29) | fn test_conv_transpose1d_stride() {
  function test_conv_transpose1d_dilation (line 50) | fn test_conv_transpose1d_dilation() {
  function test_conv_transpose1d_groups (line 71) | fn test_conv_transpose1d_groups() {
  type ConvTranspose1dTestCase (line 91) | struct ConvTranspose1dTestCase {
    method assert_output (line 105) | fn assert_output(self, y: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose2d.rs
  function test_conv_transpose2d_simple_1 (line 8) | fn test_conv_transpose2d_simple_1() {
  function test_conv_transpose2d_simple_2 (line 32) | fn test_conv_transpose2d_simple_2() {
  function test_conv_transpose2d_simple_3 (line 75) | fn test_conv_transpose2d_simple_3() {
  function test_conv_transpose2d_stride_2 (line 103) | fn test_conv_transpose2d_stride_2() {
  function test_conv_transpose2d_dilation_2 (line 132) | fn test_conv_transpose2d_dilation_2() {
  function test_conv_transpose2d_stride2_out_padding (line 171) | fn test_conv_transpose2d_stride2_out_padding() {
  function test_conv_transpose2d_groups_2 (line 216) | fn test_conv_transpose2d_groups_2() {
  function test_conv_transpose2d_groups_different_channels (line 243) | fn test_conv_transpose2d_groups_different_channels() {
  type ConvTranspose2dTestCase (line 303) | struct ConvTranspose2dTestCase {
    method assert_output (line 323) | fn assert_output(self, y: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose3d.rs
  function test_conv_transpose3d_simple_1 (line 8) | fn test_conv_transpose3d_simple_1() {
  function test_conv_transpose3d_simple_2 (line 40) | fn test_conv_transpose3d_simple_2() {
  function test_conv_transpose3d_stride_2 (line 149) | fn test_conv_transpose3d_stride_2() {
  function test_conv_transpose3d_dilation_2 (line 204) | fn test_conv_transpose3d_dilation_2() {
  function test_conv_transpose3d_stride2_out_padding (line 309) | fn test_conv_transpose3d_stride2_out_padding() {
  function test_conv_transpose3d_groups_2 (line 456) | fn test_conv_transpose3d_groups_2() {
  function test_conv_transpose3d_groups_different_channels (line 492) | fn test_conv_transpose3d_groups_different_channels() {
  type ConvTranspose3dTestCase (line 678) | struct ConvTranspose3dTestCase {
    method assert_output (line 704) | fn assert_output(self, y: TestTensor<5>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/deform_conv2d.rs
  function test_deform_conv2d_simple (line 8) | fn test_deform_conv2d_simple() {
  function test_deform_conv2d_batched (line 37) | fn test_deform_conv2d_batched() {
  function test_deform_conv2d_weight_groups (line 75) | fn test_deform_conv2d_weight_groups() {
  function test_deform_conv2d_offset_groups (line 105) | fn test_deform_conv2d_offset_groups() {
  function test_deform_conv2d_different_kernel_size (line 135) | fn test_deform_conv2d_different_kernel_size() {
  function test_deform_conv2d_different_padding_size (line 162) | fn test_deform_conv2d_different_padding_size() {
  function test_deform_conv2d_different_stride (line 246) | fn test_deform_conv2d_different_stride() {
  function test_deform_conv2d_different_dilation (line 273) | fn test_deform_conv2d_different_dilation() {
  function test_deform_conv2d_different_width (line 300) | fn test_deform_conv2d_different_width() {
  type DeformConv2dTestCase (line 341) | struct DeformConv2dTestCase {
    method assert_output (line 360) | fn assert_output(self, y: Tensor<TestBackend, 4>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/forward.rs
  function test_embedding_forward (line 5) | fn test_embedding_forward() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/lanczos3_interpolate.rs
  function test_upsample_interpolation (line 8) | fn test_upsample_interpolation() {
  function test_downsample_interpolation (line 69) | fn test_downsample_interpolation() {
  function test_upsample_2x (line 88) | fn test_upsample_2x() {
  function test_upsample_half_pixel (line 127) | fn test_upsample_half_pixel() {
  function test_1d_lanczos3 (line 169) | fn test_1d_lanczos3() {
  type InterpolateTestCase (line 204) | struct InterpolateTestCase {
    method assert_output (line 214) | fn assert_output(self, y: TestTensor<4>) {
    method assert_output_with_align_corners (line 218) | fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corn...

FILE: crates/burn-backend-tests/tests/tensor/float/module/linear.rs
  function test_linear_1d (line 7) | fn test_linear_1d() {
  function test_linear_1d_one_element_output (line 20) | fn test_linear_1d_one_element_output() {
  function test_linear_forward_no_bias (line 33) | fn test_linear_forward_no_bias() {
  function test_linear_forward_with_bias (line 47) | fn test_linear_forward_with_bias() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/maxpool1d.rs
  function test_max_pool1d_simple (line 7) | fn test_max_pool1d_simple() {
  function test_max_pool1d_different_padding_stride_kernel (line 29) | fn test_max_pool1d_different_padding_stride_kernel() {
  function test_max_pool1d_with_neg (line 45) | fn test_max_pool1d_with_neg() {
  function test_max_pool1d_with_dilation (line 61) | fn test_max_pool1d_with_dilation() {
  function test_max_pool1d_with_indices (line 83) | fn test_max_pool1d_with_indices() {
  function test_max_pool1d_complex (line 102) | fn test_max_pool1d_complex() {
  function test_max_pool1d_ceil_mode (line 121) | fn test_max_pool1d_ceil_mode() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/maxpool2d.rs
  function test_max_pool2d_simple (line 7) | fn test_max_pool2d_simple() {
  function test_max_pool2d_different_padding_stride_kernel (line 108) | fn test_max_pool2d_different_padding_stride_kernel() {
  function test_max_pool2d_with_neg (line 149) | fn test_max_pool2d_with_neg() {
  function test_max_pool2d_with_dilation (line 191) | fn test_max_pool2d_with_dilation() {
  function test_max_pool2d_with_indices (line 230) | fn test_max_pool2d_with_indices() {
  function test_max_pool2d_complex (line 276) | fn test_max_pool2d_complex() {
  function test_max_pool2d_ceil_mode (line 324) | fn test_max_pool2d_ceil_mode() {
  function test_max_pool2d_ceil_mode_with_indices (line 404) | fn test_max_pool2d_ceil_mode_with_indices() {
  function test_max_pool2d_ceil_mode_with_indices_and_padding (line 464) | fn test_max_pool2d_ceil_mode_with_indices_and_padding() {

FILE: crates/burn-backend-tests/tests/tensor/float/module/nearest_interpolate.rs
  function test_upsample_interpolation (line 8) | fn test_upsample_interpolation() {
  function test_downsample_interpolation (line 43) | fn test_downsample_interpolation() {
  function test_1d_nearest (line 62) | fn test_1d_nearest() {
  type InterpolateTestCase (line 100) | struct InterpolateTestCase {
    method assert_output (line 110) | fn assert_output(self, y: TestTensor<4>) {

FILE: crates/burn-backend-tests/tests/tensor/float/module/unfold4d.rs
  function test_unfold4d_shape (line 8) | fn test_unfold4d_shape() {
  function test_unfold4d_simple (line 24) | fn test_unfold4d_simple() {
  function test_unfold4d_complex (line 49) | fn test_unfold4d_complex() {
  type Unfold4dTestCase (line 77) | struct Unfold4dTestCase {
    method assert_shape (line 89) | fn assert_shape(self, expected_shape: [usize; 3]) {
    method assert_output (line 111) | fn assert_output(self, expected: TestTensor<3>) {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/abs.rs
  function should_support_abs_ops_float (line 5) | fn should_support_abs_ops_float() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/add.rs
  function test_add_d2 (line 5) | fn test_add_d2() {
  function test_add_broadcast (line 18) | fn test_add_broadcast() {
  function test_add_different_strides_rhs (line 31) | fn test_add_different_strides_rhs() {
  function test_add_different_strides_lhs (line 45) | fn test_add_different_strides_lhs() {
  function test_add_different_strides_broadcast (line 59) | fn test_add_different_strides_broadcast() {
  function should_support_add_scalar_ops (line 73) | fn should_support_add_scalar_ops() {
  function add_maybe_fused_not_contiguous (line 85) | fn add_maybe_fused_not_contiguous() {
  function add_maybe_fused_not_contiguous_broadcasted (line 103) | fn add_maybe_fused_not_contiguous_broadcasted() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/aggregation.rs
  function test_should_mean (line 7) | fn test_should_mean() {
  function test_should_sum (line 19) | fn test_should_sum() {
  function test_should_sum_dim_maybe_fused (line 30) | fn test_should_sum_dim_maybe_fused() {
  function test_should_mean_last_dim (line 50) | fn test_should_mean_last_dim() {
  function test_should_sum_last_dim (line 68) | fn test_should_sum_last_dim() {
  function test_should_sum_first_dim (line 79) | fn test_should_sum_first_dim() {
  function test_should_mean_first_dim (line 90) | fn test_should_mean_first_dim() {
  function test_should_sum_mid_dim_3d_non_contiguous_1 (line 102) | fn test_should_sum_mid_dim_3d_non_contiguous_1() {
  function test_should_sum_mid_dim_3d_non_contiguous_2 (line 117) | fn test_should_sum_mid_dim_3d_non_contiguous_2() {
  function test_prod_float (line 132) | fn test_prod_float() {
  function test_prod_dim_float (line 151) | fn test_prod_dim_float() {
  function test_sum_dim_2d (line 170) | fn test_sum_dim_2d() {
  function test_sum_dims_2d (line 186) | fn test_sum_dims_2d() {
  function test_sum_and_squeeze_dims (line 210) | fn test_sum_and_squeeze_dims() {
  function test_sum_dim_1_reshape_maybe_fused (line 226) | fn test_sum_dim_1_reshape_maybe_fused() {
  function test_sum_dim_1_swap_dims_maybe_fused (line 238) | fn test_sum_dim_1_swap_dims_maybe_fused() {
  function test_sum_dim_2_reshape_maybe_fused_broadcast (line 251) | fn test_sum_dim_2_reshape_maybe_fused_broadcast() {
  function test_sum_dim_2_maybe_fused_on_write (line 263) | fn test_sum_dim_2_maybe_fused_on_write() {
  function test_sum_dim_3_maybe_fused_on_read_not_contiguous (line 278) | fn test_sum_dim_3_maybe_fused_on_read_not_contiguous() {
  function test_sum_dim_4_maybe_fused_on_read_not_contiguous_mixed (line 296) | fn test_sum_dim_4_maybe_fused_on_read_not_contiguous_mixed() {
  function test_sum_dim_5_maybe_fused_on_read_not_contiguous_mixed (line 316) | fn test_sum_dim_5_maybe_fused_on_read_not_contiguous_mixed() {
  function test_sum_dim_6_maybe_fused_on_read_not_contiguous_broadcasted (line 336) | fn test_sum_dim_6_maybe_fused_on_read_not_contiguous_broadcasted() {
  function test_sum_dim_7_maybe_fused_on_read_reshaped (line 366) | fn test_sum_dim_7_maybe_fused_on_read_reshaped() {
  function test_mean_dim_fused_on_read_on_write (line 382) | fn test_mean_dim_fused_on_read_on_write() {
  function test_mean_dim_2d (line 401) | fn test_mean_dim_2d() {
  function test_mean_dims_2d (line 421) | fn test_mean_dims_2d() {
  function test_multiple_reduce_dims_permuted (line 445) | fn test_multiple_reduce_dims_permuted() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/all.rs
  function test_all (line 5) | fn test_all() {
  function test_all_dim (line 13) | fn test_all_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/any.rs
  function test_any (line 5) | fn test_any() {
  function test_any_dim (line 41) | fn test_any_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/arg.rs
  function test_argmax_2d_dim0 (line 5) | fn test_argmax_2d_dim0() {
  function test_argmin_2d_dim0 (line 16) | fn test_argmin_2d_dim0() {
  function test_argmax_2d_dim1 (line 27) | fn test_argmax_2d_dim1() {
  function test_argmin_2d_dim1 (line 38) | fn test_argmin_2d_dim1() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/cast.rs
  function cast_float_to_bool (line 6) | fn cast_float_to_bool() {
  function cast_float_to_int (line 14) | fn cast_float_to_int() {
  function cast_int_to_float_tensor (line 22) | fn cast_int_to_float_tensor() {
  function cast_bool_to_float_tensor (line 31) | fn cast_bool_to_float_tensor() {
  function cast_float_precision (line 40) | fn cast_float_precision() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/cat.rs
  function should_support_cat_ops_2d_dim0 (line 7) | fn should_support_cat_ops_2d_dim0() {
  function should_support_cat_ops_2d_dim1 (line 21) | fn should_support_cat_ops_2d_dim1() {
  function should_support_cat_ops_3d (line 35) | fn should_support_cat_ops_3d() {
  function should_panic_when_dimensions_are_not_the_same (line 50) | fn should_panic_when_dimensions_are_not_the_same() {
  function should_panic_when_list_of_vectors_is_empty (line 60) | fn should_panic_when_list_of_vectors_is_empty() {
  function should_panic_when_cat_exceeds_dimension (line 67) | fn should_panic_when_cat_exceeds_dimension() {
  function should_support_cat_ops_cast_dtype (line 76) | fn should_support_cat_ops_cast_dtype() {
  function should_support_cat_with_empty_tensor (line 92) | fn should_support_cat_with_empty_tensor() {
  function should_support_cat_with_empty_tensor_first (line 107) | fn should_support_cat_with_empty_tensor_first() {
  function should_support_cat_with_multiple_empty_tensors (line 122) | fn should_support_cat_with_multiple_empty_tensors() {
  function should_support_cat_all_empty_tensors (line 139) | fn should_support_cat_all_empty_tensors() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/ceil.rs
  function should_support_ceil_ops (line 6) | fn should_support_ceil_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/chunk.rs
  function test_chunk_evenly_divisible (line 5) | fn test_chunk_evenly_divisible() {
  function test_chunk_not_evenly_divisible (line 26) | fn test_chunk_not_evenly_divisible() {
  function test_chunk_not_evenly_divisible_remains_several (line 47) | fn test_chunk_not_evenly_divisible_remains_several() {
  function test_chunk_not_divisible (line 61) | fn test_chunk_not_divisible() {
  function test_invalid_dim (line 83) | fn test_invalid_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/clamp.rs
  function clamp_min (line 5) | fn clamp_min() {
  function clamp_max (line 28) | fn clamp_max() {
  function clamp_min_max (line 51) | fn clamp_min_max() {
  function clamp_min_max_vec_should_compile (line 73) | fn clamp_min_max_vec_should_compile() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/close.rs
  function test_is_close (line 5) | fn test_is_close() {
  function test_all_close (line 24) | fn test_all_close() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/comparison.rs
  function test_equal_inf (line 5) | fn test_equal_inf() {
  function test_not_equal_inf (line 21) | fn test_not_equal_inf() {
  function test_equal (line 37) | fn test_equal() {
  function test_not_equal (line 50) | fn test_not_equal() {
  function test_equal_elem (line 63) | fn test_equal_elem() {
  function test_not_equal_elem (line 75) | fn test_not_equal_elem() {
  function greater_elem (line 87) | fn greater_elem() {
  function test_greater_equal_elem (line 99) | fn test_greater_equal_elem() {
  function test_greater (line 111) | fn test_greater() {
  function test_greater_equal (line 124) | fn test_greater_equal() {
  function test_lower_elem (line 137) | fn test_lower_elem() {
  function test_lower_equal_elem (line 149) | fn test_lower_equal_elem() {
  function test_lower (line 161) | fn test_lower() {
  function test_lower_equal (line 174) | fn test_lower_equal() {
  function test_greater_broadcast (line 187) | fn test_greater_broadcast() {
  function test_greater_equal_broadcast (line 212) | fn test_greater_equal_broadcast() {
  function test_lower_broadcast (line 232) | fn test_lower_broadcast() {
  function test_lower_equal_broadcast (line 253) | fn test_lower_equal_broadcast() {
  function test_equal_broadcast (line 268) | fn test_equal_broadcast() {
  function test_not_equal_broadcast (line 287) | fn test_not_equal_broadcast() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/create_like.rs
  function should_support_zeros_like (line 6) | fn should_support_zeros_like() {
  function should_support_ones_like (line 24) | fn should_support_ones_like() {
  function should_support_randoms_like (line 42) | fn should_support_randoms_like() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/cross.rs
  function test_cross_3d_last_dim (line 8) | fn test_cross_3d_last_dim() {
  function test_cross_3d_non_contiguous_last_dim (line 21) | fn test_cross_3d_non_contiguous_last_dim() {
  function test_cross_3d_dim0 (line 36) | fn test_cross_3d_dim0() {
  function test_cross_3d_broadcast (line 49) | fn test_cross_3d_broadcast() {
  function test_cross_4d_last_dim (line 62) | fn test_cross_4d_last_dim() {
  function manual_cross (line 75) | fn manual_cross(a: &[[f32; 3]], b: &[[f32; 3]]) -> Vec<[f32; 3]> {
  function forward_matches_manual_cross (line 89) | fn forward_matches_manual_cross() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/cumulative.rs
  function test_cumsum_float_dim_0 (line 5) | fn test_cumsum_float_dim_0() {
  function test_cumsum_float_dim_1 (line 16) | fn test_cumsum_float_dim_1() {
  function test_cumsum_non_contiguous (line 28) | fn test_cumsum_non_contiguous() {
  function test_cumsum_float_3d (line 39) | fn test_cumsum_float_3d() {
  function test_cumprod_float_dim_0 (line 51) | fn test_cumprod_float_dim_0() {
  function test_cumprod_float_dim_1 (line 63) | fn test_cumprod_float_dim_1() {
  function test_cumprod_float_3d (line 75) | fn test_cumprod_float_3d() {
  function test_cummin_float_dim_0 (line 87) | fn test_cummin_float_dim_0() {
  function test_cummin_float_dim_1 (line 98) | fn test_cummin_float_dim_1() {
  function test_cummin_float_3d (line 109) | fn test_cummin_float_3d() {
  function test_cummax_float_dim_0 (line 121) | fn test_cummax_float_dim_0() {
  function test_cummax_float_dim_1 (line 132) | fn test_cummax_float_dim_1() {
  function test_cummax_float_3d (line 143) | fn test_cummax_float_3d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/div.rs
  function should_support_div_ops (line 6) | fn should_support_div_ops() {
  function test_div_broadcast (line 22) | fn test_div_broadcast() {
  function should_support_div_scalar_ops (line 38) | fn should_support_div_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/dot.rs
  function test_float (line 5) | fn test_float() {
  function test_int (line 17) | fn test_int() {
  function test_panics_for_different_sizes (line 30) | fn test_panics_for_different_sizes() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/erf.rs
  function should_support_erf_ops (line 6) | fn should_support_erf_ops() {
  function should_support_erf_ops_with_negative_number (line 20) | fn should_support_erf_ops_with_negative_number() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/exp.rs
  function should_support_exp_ops (line 6) | fn should_support_exp_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/expand.rs
  function expand_2d (line 5) | fn expand_2d() {
  function expand_3d (line 24) | fn expand_3d() {
  function expand_higher_dimensions (line 37) | fn expand_higher_dimensions() {
  function expand_sum_3d (line 57) | fn expand_sum_3d() {
  function broadcast_single (line 66) | fn broadcast_single() {
  function should_fail_expand_incompatible_shapes (line 77) | fn should_fail_expand_incompatible_shapes() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/finite.rs
  function is_finite (line 4) | fn is_finite() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/flatten.rs
  function should_flatten_to_1d (line 6) | fn should_flatten_to_1d() {
  function should_flatten_middle (line 15) | fn should_flatten_middle() {
  function should_flatten_begin (line 24) | fn should_flatten_begin() {
  function should_flatten_end (line 33) | fn should_flatten_end() {
  function should_flatten_end_negative_indices (line 42) | fn should_flatten_end_negative_indices() {
  function should_flatten_panic (line 52) | fn should_flatten_panic() {
  function not_enough_destination_dimension (line 59) | fn not_enough_destination_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/flip.rs
  function flip_float (line 5) | fn flip_float() {
  function flip_duplicated_axes (line 32) | fn flip_duplicated_axes() {
  function flip_out_of_bound_axis (line 42) | fn flip_out_of_bound_axis() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/floor.rs
  function should_support_floor_ops (line 6) | fn should_support_floor_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/fmod.rs
  function should_support_fmod_ops (line 9) | fn should_support_fmod_ops() {
  function should_support_fmod_scalar (line 25) | fn should_support_fmod_scalar() {
  function should_handle_positive_dividend_positive_divisor (line 38) | fn should_handle_positive_dividend_positive_divisor() {
  function should_handle_negative_dividend (line 54) | fn should_handle_negative_dividend() {
  function should_handle_mixed_signs (line 70) | fn should_handle_mixed_signs() {
  function should_handle_infinity_dividend (line 87) | fn should_handle_infinity_dividend() {
  function should_handle_zero_divisor (line 112) | fn should_handle_zero_divisor() {
  function should_handle_infinity_divisor (line 132) | fn should_handle_infinity_divisor() {
  function should_handle_nan_arguments (line 154) | fn should_handle_nan_arguments() {
  function should_handle_negative_zero (line 173) | fn should_handle_negative_zero() {
  function should_support_fmod_broadcasting_2d (line 194) | fn should_support_fmod_broadcasting_2d() {
  function should_support_fmod_broadcasting_3d (line 215) | fn should_support_fmod_broadcasting_3d() {
  function should_support_fmod_scalar_broadcasting (line 235) | fn should_support_fmod_scalar_broadcasting() {
  function should_handle_edge_case_values (line 249) | fn should_handle_edge_case_values() {
  function should_handle_special_scalar_cases (line 266) | fn should_handle_special_scalar_cases() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/full.rs
  function test_data_full (line 5) | fn test_data_full() {
  function test_tensor_full (line 12) | fn test_tensor_full() {
  function test_tensor_full_options (line 21) | fn test_tensor_full_options() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/gather_scatter.rs
  function should_gather_1d_dim0 (line 5) | fn should_gather_1d_dim0() {
  function should_gather_2d_dim0 (line 18) | fn should_gather_2d_dim0() {
  function should_gather_2d_dim1 (line 31) | fn should_gather_2d_dim1() {
  function should_gather_3d_dim1 (line 45) | fn should_gather_3d_dim1() {
  function should_gather_2d_only_1dim (line 67) | fn should_gather_2d_only_1dim() {
  function should_scatter_add_1d (line 80) | fn should_scatter_add_1d() {
  function should_scatter_add_2d_dim0 (line 94) | fn should_scatter_add_2d_dim0() {
  function should_scatter_add_2d_dim1 (line 108) | fn should_scatter_add_2d_dim1() {
  function should_scatter_add_3d_dim1 (line 122) | fn should_scatter_add_3d_dim1() {
  function should_scatter_add_2d_dim1_diff_shape (line 151) | fn should_scatter_add_2d_dim1_diff_shape() {
  function scatter_should_panic_on_mismatch_of_shapes (line 166) | fn scatter_should_panic_on_mismatch_of_shapes() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/grid_sample.rs
  function should_grid_sample_2d_default (line 15) | fn should_grid_sample_2d_default() {
  function should_grid_sample_2d_align_corners_border (line 39) | fn should_grid_sample_2d_align_corners_border() {
  function should_pad_zeros_grid_sample_2d (line 65) | fn should_pad_zeros_grid_sample_2d() {
  function should_pad_border_grid_sample_2d (line 84) | fn should_pad_border_grid_sample_2d() {
  function should_pad_reflection_grid_sample_2d (line 106) | fn should_pad_reflection_grid_sample_2d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/inf.rs
  function is_inf (line 4) | fn is_inf() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/init.rs
  function should_support_float_empty (line 5) | fn should_support_float_empty() {
  function should_support_float_empty_options (line 12) | fn should_support_float_empty_options() {
  function should_support_float_zeros (line 19) | fn should_support_float_zeros() {
  function should_support_float_zeros_options (line 30) | fn should_support_float_zeros_options() {
  function should_support_float_ones (line 42) | fn should_support_float_ones() {
  function should_support_float_ones_options (line 53) | fn should_support_float_ones_options() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/iter_dim.rs
  function test_1d_iter_last_item (line 5) | fn test_1d_iter_last_item() {
  function test_too_high_dimension (line 19) | fn test_too_high_dimension() {
  function test_transposed (line 24) | fn test_transposed() {
  function test_2d_iter_dim (line 40) | fn test_2d_iter_dim() {
  function test_2d_iter_dim1 (line 60) | fn test_2d_iter_dim1() {
  function test_3d_iter_dim (line 85) | fn test_3d_iter_dim() {
  function test_3d_iter_dim1 (line 101) | fn test_3d_iter_dim1() {
  function test_3d_iter_dim2 (line 129) | fn test_3d_iter_dim2() {
  function test_iteration_over_low_dim (line 167) | fn test_iteration_over_low_dim() {
  function test_iter_dim_double_end (line 185) | fn test_iter_dim_double_end() {
  function test_iter_dim_single_element (line 228) | fn test_iter_dim_single_element() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/log.rs
  function should_support_log_ops (line 6) | fn should_support_log_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/log1p.rs
  function should_support_exp_log1p (line 6) | fn should_support_exp_log1p() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/mask.rs
  function should_support_mask_fill_swap_dims (line 6) | fn should_support_mask_fill_swap_dims() {
  function should_support_mask_where_ops (line 26) | fn should_support_mask_where_ops() {
  function should_support_mask_where_broadcast (line 40) | fn should_support_mask_where_broadcast() {
  function should_support_mask_where_broadcast_value_small (line 67) | fn should_support_mask_where_broadcast_value_small() {
  function should_handle_mask_where_nans (line 80) | fn should_handle_mask_where_nans() {
  function should_support_mask_fill_ops (line 116) | fn should_support_mask_fill_ops() {
  function should_support_mask_fill_broadcasted (line 129) | fn should_support_mask_fill_broadcasted() {
  function float_mask_fill_infinite (line 149) | fn float_mask_fill_infinite() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/matmul.rs
  function test_float_matmul_d2 (line 6) | fn test_float_matmul_d2() {
  function test_float_matmul_d3 (line 18) | fn test_float_matmul_d3() {
  function test_float_matmul_broadcast_1 (line 30) | fn test_float_matmul_broadcast_1() {
  function test_float_matmul_broadcast_4d (line 45) | fn test_float_matmul_broadcast_4d() {
  function test_float_matmul_simple_1 (line 69) | fn test_float_matmul_simple_1() {
  function test_float_matmul_4_3 (line 81) | fn test_float_matmul_4_3() {
  function test_float_matmul_batch_vec_mat (line 99) | fn test_float_matmul_batch_vec_mat() {
  function test_float_matmul_trivial (line 122) | fn test_float_matmul_trivial() {
  function test_float_matmul_trivial_transposed (line 143) | fn test_float_matmul_trivial_transposed() {
  function test_float_matmul_vecmat_transposed_fused (line 165) | fn test_float_matmul_vecmat_transposed_fused() {
  function test_float_matmul_4_8 (line 201) | fn test_float_matmul_4_8() {
  function test_float_matmul_simple_2 (line 222) | fn test_float_matmul_simple_2() {
  function test_float_matmul_simple_3 (line 234) | fn test_float_matmul_simple_3() {
  function float_should_panic_when_inner_dimensions_are_not_equal (line 258) | fn float_should_panic_when_inner_dimensions_are_not_equal() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/maxmin.rs
  function test_max_dim_2d (line 5) | fn test_max_dim_2d() {
  function test_max_dims_2d (line 37) | fn test_max_dims_2d() {
  function test_max_dim_with_indices_2d_with_dim_0th (line 57) | fn test_max_dim_with_indices_2d_with_dim_0th() {
  function test_max_dim_with_indices_2d (line 74) | fn test_max_dim_with_indices_2d() {
  function test_max_dim_2d_with_0th_dim (line 88) | fn test_max_dim_2d_with_0th_dim() {
  function test_max_pair (line 99) | fn test_max_pair() {
  function test_min_dim_2d (line 110) | fn test_min_dim_2d() {
  function test_min_dims_2d (line 142) | fn test_min_dims_2d() {
  function test_min_dim_with_indices_2d (line 162) | fn test_min_dim_with_indices_2d() {
  function test_min_dim_2d_with_0th_dim (line 176) | fn test_min_dim_2d_with_0th_dim() {
  function test_min_dim_with_indices_2d_with_0th_dim (line 187) | fn test_min_dim_with_indices_2d_with_0th_dim() {
  function test_min_pair (line 204) | fn test_min_pair() {
  function test_max_abs (line 215) | fn test_max_abs() {
  function test_max_abs_dim_2d_dim_0 (line 225) | fn test_max_abs_dim_2d_dim_0() {
  function test_max_abs_dims_2d (line 239) | fn test_max_abs_dims_2d() {
  function test_max_abs_dim_2d_dim_1 (line 262) | fn test_max_abs_dim_2d_dim_1() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/movedim.rs
  function movedim_float (line 5) | fn movedim_float() {
  function vec_input_float (line 32) | fn vec_input_float() {
  function different_input_types (line 59) | fn different_input_types() {
  function edge_different_sizes (line 87) | fn edge_different_sizes() {
  function edge_out_of_bound_axis (line 97) | fn edge_out_of_bound_axis() {
  function edge_vec_is_not_a_set (line 107) | fn edge_vec_is_not_a_set() {
  function edge_out_of_bound_axis_vec (line 117) | fn edge_out_of_bound_axis_vec() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/mul.rs
  function should_support_mul_ops (line 5) | fn should_support_mul_ops() {
  function test_mul_broadcast (line 19) | fn test_mul_broadcast() {
  function test_mul_broadcast_2_dims (line 33) | fn test_mul_broadcast_2_dims() {
  function should_support_mul_scalar_ops (line 45) | fn should_support_mul_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/nan.rs
  function is_nan (line 5) | fn is_nan() {
  function contains_nan (line 18) | fn contains_nan() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/narrow.rs
  function test_narrow_1 (line 6) | fn test_narrow_1() {
  function test_narrow_2 (line 22) | fn test_narrow_2() {
  function test_narrow_3 (line 37) | fn test_narrow_3() {
  function test_narrow_invalid_dim (line 58) | fn test_narrow_invalid_dim() {
  function test_narrow_invalid_start (line 69) | fn test_narrow_invalid_start() {
  function test_narrow_invalid_zero_length (line 80) | fn test_narrow_invalid_zero_length() {
  function test_narrow_invalid_length (line 91) | fn test_narrow_invalid_length() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/neg.rs
  function should_support_neg_ops (line 5) | fn should_support_neg_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/one_hot.rs
  function float_should_support_one_hot (line 5) | fn float_should_support_one_hot() {
  function float_should_support_one_hot_index (line 17) | fn float_should_support_one_hot_index() {
  function float_one_hot_should_panic_when_index_exceeds_number_of_classes (line 26) | fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() {
  function float_one_hot_should_panic_when_number_of_classes_is_zero (line 33) | fn float_one_hot_should_panic_when_number_of_classes_is_zero() {
  function one_hot_fill_with_negative_axis_and_indices (line 39) | fn one_hot_fill_with_negative_axis_and_indices() {
  function one_hot_fill_with_negative_indices (line 52) | fn one_hot_fill_with_negative_indices() {
  function one_hot_fill_should_panic_when_axis_out_range_of_rank (line 67) | fn one_hot_fill_should_panic_when_axis_out_range_of_rank() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/padding.rs
  function padding_constant_2d_test (line 5) | fn padding_constant_2d_test() {
  function padding_constant_4d_test (line 23) | fn padding_constant_4d_test() {
  function padding_constant_asymmetric_test (line 42) | fn padding_constant_asymmetric_test() {
  function padding_reflect_2d_test (line 64) | fn padding_reflect_2d_test() {
  function padding_reflect_width_only_test (line 91) | fn padding_reflect_width_only_test() {
  function padding_reflect_4d_test (line 106) | fn padding_reflect_4d_test() {
  function padding_edge_2d_test (line 123) | fn padding_edge_2d_test() {
  function padding_edge_width_only_test (line 140) | fn padding_edge_width_only_test() {
  function padding_edge_4d_test (line 155) | fn padding_edge_4d_test() {
  function padding_constant_default_test (line 171) | fn padding_constant_default_test() {
  function padding_reflect_max_valid_test (line 187) | fn padding_reflect_max_valid_test() {
  function padding_reflect_asymmetric_test (line 204) | fn padding_reflect_asymmetric_test() {
  function padding_reflect_exceeds_dimension_test (line 224) | fn padding_reflect_exceeds_dimension_test() {
  function padding_edge_asymmetric_test (line 233) | fn padding_edge_asymmetric_test() {
  function padding_zero_padding_test (line 252) | fn padding_zero_padding_test() {
  function padding_empty_tensor_constant_test (line 267) | fn padding_empty_tensor_constant_test() {
  function padding_empty_tensor_edge_panics_test (line 289) | fn padding_empty_tensor_edge_panics_test() {
  function padding_empty_tensor_reflect_panics_test (line 299) | fn padding_empty_tensor_reflect_panics_test() {
  function padding_constant_pairs_2d_test (line 310) | fn padding_constant_pairs_2d_test() {
  function padding_constant_single_dim_test (line 329) | fn padding_constant_single_dim_test() {
  function padding_constant_all_dims_4d_test (line 340) | fn padding_constant_all_dims_4d_test() {
  function padding_constant_batch_dim_only_test (line 372) | fn padding_constant_batch_dim_only_test() {
  function padding_reflect_pairs_test (line 390) | fn padding_reflect_pairs_test() {
  function padding_edge_pairs_test (line 407) | fn padding_edge_pairs_test() {
  function padding_reflect_batch_dim_3d_test (line 423) | fn padding_reflect_batch_dim_3d_test() {
  function padding_edge_batch_dim_3d_test (line 445) | fn padding_edge_batch_dim_3d_test() {
  function padding_too_many_pairs_panics_test (line 465) | fn padding_too_many_pairs_panics_test() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/permute.rs
  function permute_float_a (line 5) | fn permute_float_a() {
  function permute_float (line 32) | fn permute_float() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/powf.rs
  function should_support_powf_ops (line 6) | fn should_support_powf_ops() {
  function should_support_neg_power (line 21) | fn should_support_neg_power() {
  function should_support_neg_values_with_even_power (line 36) | fn should_support_neg_values_with_even_power() {
  function should_support_neg_values_with_odd_power (line 51) | fn should_support_neg_values_with_odd_power() {
  function should_support_powf_broadcasted (line 66) | fn should_support_powf_broadcasted() {
  function outer (line 84) | fn outer(a: TestTensor<1>, b: TestTensor<1>) -> TestTensor<2> {
  function should_support_powf_scalar_tensor (line 89) | fn should_support_powf_scalar_tensor() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/powf_scalar.rs
  function should_support_powf_ops (line 6) | fn should_support_powf_ops() {
  function should_support_neg_power (line 19) | fn should_support_neg_power() {
  function should_support_neg_values_with_even_power (line 32) | fn should_support_neg_values_with_even_power() {
  function should_support_neg_values_with_odd_power (line 45) | fn should_support_neg_values_with_odd_power() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/prod.rs
  function test_prod_float (line 5) | fn test_prod_float() {
  function test_prod_dim_2d (line 16) | fn test_prod_dim_2d() {
  function test_prod_dims_2d (line 31) | fn test_prod_dims_2d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/random.rs
  function rand_default (line 5) | fn rand_default() {
  function rand_uniform (line 20) | fn rand_uniform() {
  function rand_bernoulli (line 33) | fn rand_bernoulli() {
  function test_seed_reproducibility (line 44) | fn test_seed_reproducibility() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/recip.rs
  function should_support_recip_ops (line 6) | fn should_support_recip_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/remainder.rs
  function should_support_remainder_basic (line 7) | fn should_support_remainder_basic() {
  function should_support_remainder_basic_scalar (line 21) | fn should_support_remainder_basic_scalar() {
  function should_support_remainder_float (line 35) | fn should_support_remainder_float() {
  function should_support_remainder_float_scalar (line 57) | fn should_support_remainder_float_scalar() {
  function should_be_zero (line 71) | fn should_be_zero() {
  function should_be_zero_scalar (line 85) | fn should_be_zero_scalar() {
  function should_have_no_remainder (line 99) | fn should_have_no_remainder() {
  function should_have_no_remainder_scalar (line 122) | fn should_have_no_remainder_scalar() {
  function should_be_negative (line 136) | fn should_be_negative() {
  function should_be_negative_scalar (line 151) | fn should_be_negative_scalar() {
  function should_support_fp_dividends (line 165) | fn should_support_fp_dividends() {
  function should_support_large_divisor (line 181) | fn should_support_large_divisor() {
  function should_support_large_divisor_scalar (line 201) | fn should_support_large_divisor_scalar() {
  function should_support_remainder_op (line 215) | fn should_support_remainder_op() {
  function should_support_remainder_scalar_op (line 230) | fn should_support_remainder_scalar_op() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/repeat.rs
  function should_support_repeat_ops_one_dimension (line 5) | fn should_support_repeat_ops_one_dimension() {
  function should_support_float_repeat_repeating_on_many_dimensions (line 21) | fn should_support_float_repeat_repeating_on_many_dimensions() {
  function should_repeat_0_times_empty (line 102) | fn should_repeat_0_times_empty() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/repeat_dim.rs
  function should_support_repeat_ops (line 5) | fn should_support_repeat_ops() {
  function should_support_float_repeat_on_dims_larger_than_1 (line 21) | fn should_support_float_repeat_on_dims_larger_than_1() {
  function repeat_dim_swap_dims_1 (line 54) | fn repeat_dim_swap_dims_1() {
  function repeat_dim_swap_dims_2 (line 92) | fn repeat_dim_swap_dims_2() {
  function repeat_dim_swap_dims_3 (line 134) | fn repeat_dim_swap_dims_3() {
  function should_repeat_dim_0_times_empty (line 160) | fn should_repeat_dim_0_times_empty() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/reshape.rs
  function should_support_rank (line 5) | fn should_support_rank() {
  function should_support_reshape_1d (line 16) | fn should_support_reshape_1d() {
  function should_support_reshape_2d (line 27) | fn should_support_reshape_2d() {
  function should_support_dim_infererence (line 38) | fn should_support_dim_infererence() {
  function should_not_corrupt_after_slice (line 65) | fn should_not_corrupt_after_slice() {
  function multiple_neg_ones (line 78) | fn multiple_neg_ones() {
  function neg_value (line 86) | fn neg_value() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/round.rs
  function should_support_round_ops (line 6) | fn should_support_round_ops() {
  function should_round_ties_even (line 19) | fn should_round_ties_even() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/select.rs
  function should_select_1d (line 5) | fn should_select_1d() {
  function should_select_2d_dim0_same_num_dim (line 17) | fn should_select_2d_dim0_same_num_dim() {
  function should_select_2d_dim0_more_num_dim (line 29) | fn should_select_2d_dim0_more_num_dim() {
  function should_select_2d_dim0_vec (line 46) | fn should_select_2d_dim0_vec() {
  function should_select_2d_dim1 (line 59) | fn should_select_2d_dim1() {
  function should_select_add_1d (line 71) | fn should_select_add_1d() {
  function should_select_add_1d_int (line 84) | fn should_select_add_1d_int() {
  function should_select_add_2d_dim0 (line 97) | fn should_select_add_2d_dim0() {
  function should_select_add_2d_dim1 (line 110) | fn should_select_add_2d_dim1() {
  function should_select_3d_dim1_vec (line 123) | fn should_select_3d_dim1_vec() {
  function should_select_panic_invalid_dimension (line 145) | fn should_select_panic_invalid_dimension() {
  function should_match_default_implementation_behavior (line 154) | fn should_match_default_implementation_behavior() {
  function should_select_with_negative_dim_2d (line 178) | fn should_select_with_negative_dim_2d() {
  function should_select_add_with_negative_dim_2d (line 195) | fn should_select_add_with_negative_dim_2d() {
  function should_panic_select_negative_dim_out_of_bounds (line 216) | fn should_panic_select_negative_dim_out_of_bounds() {
  function should_panic_select_add_negative_dim_out_of_bounds (line 227) | fn should_panic_select_add_negative_dim_out_of_bounds() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/sign.rs
  function should_support_sign_ops_float (line 5) | fn should_support_sign_ops_float() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/slice.rs
  function should_support_slice_dim_1d (line 5) | fn should_support_slice_dim_1d() {
  function should_panic_when_slice_dim_1d_bad_dim (line 25) | fn should_panic_when_slice_dim_1d_bad_dim() {
  function should_support_slice_dim_2d (line 33) | fn should_support_slice_dim_2d() {
  function should_support_slice_dim_with_step (line 44) | fn should_support_slice_dim_with_step() {
  function should_support_slice_dim_with_negative_step (line 63) | fn should_support_slice_dim_with_negative_step() {
  function should_support_full_sliceing_1d (line 75) | fn should_support_full_sliceing_1d() {
  function should_support_full_sliceing_vec (line 85) | fn should_support_full_sliceing_vec() {
  function should_support_partial_sliceing_1d (line 99) | fn should_support_partial_sliceing_1d() {
  function should_support_full_sliceing_2d (line 110) | fn should_support_full_sliceing_2d() {
  function should_support_partial_sliceing_2d (line 122) | fn should_support_partial_sliceing_2d() {
  function should_support_slice_range_first_dim (line 133) | fn should_support_slice_range_first_dim() {
  function should_support_partial_sliceing_3d (line 144) | fn should_support_partial_sliceing_3d() {
  function should_support_partial_sliceing_3d_non_contiguous (line 160) | fn should_support_partial_sliceing_3d_non_contiguous() {
  function should_support_slice_fill_1d (line 176) | fn should_support_slice_fill_1d() {
  function should_support_slice_fill_vec (line 189) | fn should_support_slice_fill_vec() {
  function should_support_slice_fill_cast_f32 (line 204) | fn should_support_slice_fill_cast_f32() {
  function should_support_slice_fill_cast_f64 (line 218) | fn should_support_slice_fill_cast_f64() {
  function should_support_slice_fill_1d_neg (line 230) | fn should_support_slice_fill_1d_neg() {
  function should_support_slice_fill_2d (line 243) | fn should_support_slice_fill_2d() {
  function should_support_slice_fill_with_positive_step (line 256) | fn should_support_slice_fill_with_positive_step() {
  function should_support_slice_fill_with_negative_step (line 288) | fn should_support_slice_fill_with_negative_step() {
  function should_support_slice_fill_with_mixed_steps (line 308) | fn should_support_slice_fill_with_mixed_steps() {
  function clamp_when_slice_exceeds_dimension (line 340) | fn clamp_when_slice_exceeds_dimension() {
  function negative_dimensions (line 349) | fn negative_dimensions() {
  function missing_dimensions (line 367) | fn missing_dimensions() {
  function should_slice_aggregation_result (line 395) | fn should_slice_aggregation_result() {
  function should_panic_when_slice_with_too_many_dimensions (line 404) | fn should_panic_when_slice_with_too_many_dimensions() {
  function should_support_descending_slice_as_empty (line 411) | fn should_support_descending_slice_as_empty() {
  function should_support_empty_slice (line 423) | fn should_support_empty_slice() {
  function should_support_empty_slice_2d (line 436) | fn should_support_empty_slice_2d() {
  function test_slice_with_positive_step (line 451) | fn test_slice_with_positive_step() {
  function test_slice_with_negative_step (line 479) | fn test_slice_with_negative_step() {
  function test_slice_with_mixed_steps (line 520) | fn test_slice_with_mixed_steps() {
  function test_slice_with_steps_1d (line 543) | fn test_slice_with_steps_1d() {
  function test_slice_with_steps_3d (line 565) | fn test_slice_with_steps_3d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/slice_assign.rs
  function should_support_slice_assign_1d (line 5) | fn should_support_slice_assign_1d() {
  function should_support_slice_assign_2d (line 20) | fn should_support_slice_assign_2d() {
  function should_support_slice_assign_vec (line 35) | fn should_support_slice_assign_vec() {
  function slice_assign_now_supports_non_unit_step (line 52) | fn slice_assign_now_supports_non_unit_step() {
  function test_slice_assign_with_positive_step_1d (line 74) | fn test_slice_assign_with_positive_step_1d() {
  function test_slice_assign_with_positive_step_2d (line 87) | fn test_slice_assign_with_positive_step_2d() {
  function test_slice_assign_with_negative_step_1d (line 145) | fn test_slice_assign_with_negative_step_1d() {
  function test_slice_assign_with_negative_step_2d (line 158) | fn test_slice_assign_with_negative_step_2d() {
  function test_slice_assign_with_mixed_steps (line 205) | fn test_slice_assign_with_mixed_steps() {
  function test_slice_assign_3d_with_steps (line 252) | fn test_slice_assign_3d_with_steps() {
  function test_slice_assign_partial_with_steps (line 302) | fn test_slice_assign_partial_with_steps() {
  function should_support_slice_assign_empty_range (line 329) | fn should_support_slice_assign_empty_range() {
  function should_support_slice_assign_empty_range_1d (line 342) | fn should_support_slice_assign_empty_range_1d() {
  function should_support_slice_assign_single_dim_slice (line 355) | fn should_support_slice_assign_single_dim_slice() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/sort_argsort.rs
  function test_sort_1d_float (line 6) | fn test_sort_1d_float() {
  function test_argsort_1d_float (line 23) | fn test_argsort_1d_float() {
  function test_sort_with_indices_descending_float (line 36) | fn test_sort_with_indices_descending_float() {
  function test_sort_float (line 77) | fn test_sort_float() {
  function test_sort_with_indices_float (line 118) | fn test_sort_with_indices_float() {
  function test_argsort_float (line 167) | fn test_argsort_float() {
  function test_sort_float_nan (line 193) | fn test_sort_float_nan() {
  function test_sort_descending_1d (line 206) | fn test_sort_descending_1d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/split.rs
  function test_split_evenly_divisible (line 5) | fn test_split_evenly_divisible() {
  function test_split_not_evenly_divisible (line 25) | fn test_split_not_evenly_divisible() {
  function test_split_along_dim1 (line 44) | fn test_split_along_dim1() {
  function test_split_split_size_larger_than_tensor_size (line 62) | fn test_split_split_size_larger_than_tensor_size() {
  function test_split_with_zero_split_size_zero_tensor_size (line 77) | fn test_split_with_zero_split_size_zero_tensor_size() {
  function test_split_zero_sized_tensor (line 87) | fn test_split_zero_sized_tensor() {
  function test_split_with_zero_split_size_non_zero_tensor (line 100) | fn test_split_with_zero_split_size_non_zero_tensor() {
  function test_split_invalid_dim (line 109) | fn test_split_invalid_dim() {
  function test_split_3d_tensor_along_dim0 (line 117) | fn test_split_3d_tensor_along_dim0() {
  function test_split_3d_tensor_along_dim1 (line 143) | fn test_split_3d_tensor_along_dim1() {
  function test_split_with_sizes (line 164) | fn test_split_with_sizes() {
  function test_split_with_sizes_invalid_sum (line 186) | fn test_split_with_sizes_invalid_sum() {
  function test_split_with_sizes_zero_length (line 194) | fn test_split_with_sizes_zero_length() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/sqrt.rs
  function should_support_sqrt_ops (line 7) | fn should_support_sqrt_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/square.rs
  function should_support_sqrt_ops (line 6) | fn should_support_sqrt_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/squeeze.rs
  function should_squeeze_dim (line 6) | fn should_squeeze_dim() {
  function should_squeeze (line 14) | fn should_squeeze() {
  function should_squeeze_first (line 23) | fn should_squeeze_first() {
  function should_squeeze_last (line 31) | fn should_squeeze_last() {
  function should_squeeze_panic (line 40) | fn should_squeeze_panic() {
  function should_squeeze_dims_with_empty_slice (line 47) | fn should_squeeze_dims_with_empty_slice() {
  function should_squeeze_all_dims (line 55) | fn should_squeeze_all_dims() {
  function should_squeeze_dims_with_positive_indices (line 64) | fn should_squeeze_dims_with_positive_indices() {
  function should_squeeze_dims_with_negative_indices (line 73) | fn should_squeeze_dims_with_negative_indices() {
  function should_squeeze_dims_work_if_non_singleton (line 83) | fn should_squeeze_dims_work_if_non_singleton() {
  function should_panic_squeeze_consumes_all_singleton (line 92) | fn should_panic_squeeze_consumes_all_singleton() {
  function should_squeeze_dims_panic_on_too_many_dimensions (line 100) | fn should_squeeze_dims_panic_on_too_many_dimensions() {
  function should_squeeze_dims_dimension_mismatch_panic (line 108) | fn should_squeeze_dims_dimension_mismatch_panic() {
  function should_unsqueeze_dim (line 115) | fn should_unsqueeze_dim() {
  function should_unsqueeze_dim_first (line 124) | fn should_unsqueeze_dim_first() {
  function should_unsqueeze_dim_last (line 133) | fn should_unsqueeze_dim_last() {
  function should_unsqueeze_dim_panic (line 143) | fn should_unsqueeze_dim_panic() {
  function should_unsqueeze_dims_support_dim_inference (line 149) | fn should_unsqueeze_dims_support_dim_inference() {
  function should_unsqueeze_dims_handle_first_last (line 157) | fn should_unsqueeze_dims_handle_first_last() {
  function should_unsqueeze_dims_work_with_single_dim (line 165) | fn should_unsqueeze_dims_work_with_single_dim() {
  function should_unsqueeze_dims_multiple_trailing_negatives (line 174) | fn should_unsqueeze_dims_multiple_trailing_negatives() {
  function should_unsqueeze_dims_panic (line 183) | fn should_unsqueeze_dims_panic() {
  function squeeze_all_singleton_not_supported (line 190) | fn squeeze_all_singleton_not_supported() {
  function squeeze_dim_singleton_not_supported (line 197) | fn squeeze_dim_singleton_not_supported() {
  function squeeze_dims_all_singleton_not_supported (line 204) | fn squeeze_dims_all_singleton_not_supported() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/stack.rs
  function should_support_stack_ops_2d_dim0 (line 6) | fn should_support_stack_ops_2d_dim0() {
  function should_support_stack_ops_2d_dim1 (line 18) | fn should_support_stack_ops_2d_dim1() {
  function should_support_stack_ops_3d (line 30) | fn should_support_stack_ops_3d() {
  function should_panic_when_dimensions_are_not_the_same (line 46) | fn should_panic_when_dimensions_are_not_the_same() {
  function should_panic_when_list_of_vectors_is_empty (line 56) | fn should_panic_when_list_of_vectors_is_empty() {
  function should_panic_when_stack_exceeds_dimension (line 63) | fn should_panic_when_stack_exceeds_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/sub.rs
  function should_support_sub_ops (line 5) | fn should_support_sub_ops() {
  function test_sub_broadcast (line 19) | fn test_sub_broadcast() {
  function should_support_sub_scalar_ops (line 33) | fn should_support_sub_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/take.rs
  function should_take_1d (line 5) | fn should_take_1d() {
  function should_take_2d_dim0 (line 18) | fn should_take_2d_dim0() {
  function should_take_2d_dim1 (line 36) | fn should_take_2d_dim1() {
  function take_and_select_should_be_equivalent (line 49) | fn take_and_select_should_be_equivalent() {
  function should_take_with_2d_indices (line 72) | fn should_take_with_2d_indices() {
  function should_take_with_2d_indices_dim1 (line 98) | fn should_take_with_2d_indices_dim1() {
  function should_take_3d_tensor (line 114) | fn should_take_3d_tensor() {
  function should_take_with_3d_indices (line 139) | fn should_take_with_3d_indices() {
  function should_panic_take_invalid_dimension (line 159) | fn should_panic_take_invalid_dimension() {
  function should_take_with_single_index (line 169) | fn should_take_with_single_index() {
  function should_take_with_negative_dim_2d (line 181) | fn should_take_with_negative_dim_2d() {
  function should_panic_take_negative_dim_out_of_bounds (line 199) | fn should_panic_take_negative_dim_out_of_bounds() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/topk.rs
  function test_topk_with_indices_3d (line 6) | fn test_topk_with_indices_3d() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/transaction.rs
  function should_support_transaction (line 6) | fn should_support_transaction() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/transpose.rs
  function should_support_transpose_ops (line 6) | fn should_support_transpose_ops() {
  function should_support_transpose_maybe_fused_with_one (line 29) | fn should_support_transpose_maybe_fused_with_one() {
  function should_support_swap_dims_no_op (line 54) | fn should_support_swap_dims_no_op() {
  function should_support_swap_dims (line 75) | fn should_support_swap_dims() {
  function should_support_swap_dims_neg_index (line 97) | fn should_support_swap_dims_neg_index() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/tri.rs
  function test_triu (line 5) | fn test_triu() {
  function test_triu_positive_diagonal (line 14) | fn test_triu_positive_diagonal() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/trig.rs
  function should_support_cos_ops (line 9) | fn should_support_cos_ops() {
  function should_support_cosh_ops (line 25) | fn should_support_cosh_ops() {
  function should_support_sin_ops (line 38) | fn should_support_sin_ops() {
  function should_support_sinh_ops (line 51) | fn should_support_sinh_ops() {
  function should_support_tan_ops (line 64) | fn should_support_tan_ops() {
  function should_support_tanh_ops (line 77) | fn should_support_tanh_ops() {
  function should_support_asin_ops (line 90) | fn should_support_asin_ops() {
  function should_support_acos_ops (line 103) | fn should_support_acos_ops() {
  function should_support_atan_ops (line 119) | fn should_support_atan_ops() {
  function should_support_asinh_ops (line 132) | fn should_support_asinh_ops() {
  function should_support_acosh_ops (line 145) | fn should_support_acosh_ops() {
  function should_support_atanh_ops (line 158) | fn should_support_atanh_ops() {
  function should_support_atan2_ops (line 171) | fn should_support_atan2_ops() {
  function should_support_deg2rad_ops (line 187) | fn should_support_deg2rad_ops() {
  function should_support_rad2deg_ops (line 216) | fn should_support_rad2deg_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/trunc.rs
  function should_support_trunc_ops (line 6) | fn should_support_trunc_ops() {
  function should_truncate_positive_values_like_floor (line 19) | fn should_truncate_positive_values_like_floor() {
  function should_truncate_negative_values_like_ceil (line 32) | fn should_truncate_negative_values_like_ceil() {
  function should_handle_special_cases (line 45) | fn should_handle_special_cases() {

FILE: crates/burn-backend-tests/tests/tensor/float/ops/unfold.rs
  function test_unfold_float (line 6) | fn test_unfold_float() {

FILE: crates/burn-backend-tests/tests/tensor/float/primitive.rs
  function should_support_float_dtype (line 5) | fn should_support_float_dtype() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/calibration.rs
  function min_max_calibration_range_per_tensor (line 10) | fn min_max_calibration_range_per_tensor() {
  function min_max_calibration_range_per_block (line 27) | fn min_max_calibration_range_per_block() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/data.rs
  function should_support_per_tensor_symmetric_int8 (line 7) | fn should_support_per_tensor_symmetric_int8() {
  function should_support_per_block_symmetric_int8 (line 25) | fn should_support_per_block_symmetric_int8() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/mod.rs
  type QTensor (line 20) | pub struct QTensor<B: Backend, const D: usize> {
  function int8 (line 27) | pub fn int8<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {
  function int8_block (line 32) | pub fn int8_block<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {
  function int8_symmetric (line 41) | pub fn int8_symmetric<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/abs.rs
  function should_support_abs_ops (line 7) | fn should_support_abs_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/add.rs
  function test_add_d2 (line 7) | fn test_add_d2() {
  function test_add_broadcast (line 23) | fn test_add_broadcast() {
  function test_add_different_strides_rhs (line 39) | fn test_add_different_strides_rhs() {
  function test_add_different_strides_lhs (line 57) | fn test_add_different_strides_lhs() {
  function test_add_different_strides_broadcast (line 75) | fn test_add_different_strides_broadcast() {
  function should_support_add_scalar_ops (line 93) | fn should_support_add_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/aggregation.rs
  function test_should_mean (line 7) | fn test_should_mean() {
  function test_should_sum (line 19) | fn test_should_sum() {
  function test_should_mean_last_dim (line 31) | fn test_should_mean_last_dim() {
  function test_should_sum_last_dim (line 44) | fn test_should_sum_last_dim() {
  function test_should_sum_first_dim (line 59) | fn test_should_sum_first_dim() {
  function test_should_mean_first_dim (line 74) | fn test_should_mean_first_dim() {
  function test_should_sum_mid_dim_3d_non_contiguous_1 (line 89) | fn test_should_sum_mid_dim_3d_non_contiguous_1() {
  function test_should_sum_mid_dim_3d_non_contiguous_2 (line 107) | fn test_should_sum_mid_dim_3d_non_contiguous_2() {
  function test_prod_float (line 125) | fn test_prod_float() {
  function test_prod_dim_float (line 145) | fn test_prod_dim_float() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/all.rs
  function test_all (line 6) | fn test_all() {
  function test_all_dim (line 19) | fn test_all_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/any.rs
  function test_any (line 6) | fn test_any() {
  function test_any_dim (line 19) | fn test_any_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/arg.rs
  function test_argmax_2d_dim0 (line 6) | fn test_argmax_2d_dim0() {
  function test_argmin_2d_dim0 (line 17) | fn test_argmin_2d_dim0() {
  function test_argmax_2d_dim1 (line 28) | fn test_argmax_2d_dim1() {
  function test_argmin_2d_dim1 (line 39) | fn test_argmin_2d_dim1() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cat.rs
  function should_support_cat_ops_2d_dim0 (line 8) | fn should_support_cat_ops_2d_dim0() {
  function should_support_cat_ops_2d_dim1 (line 22) | fn should_support_cat_ops_2d_dim1() {
  function should_support_cat_ops_3d (line 36) | fn should_support_cat_ops_3d() {
  function should_panic_when_dimensions_are_not_the_same (line 51) | fn should_panic_when_dimensions_are_not_the_same() {
  function should_panic_when_cat_exceeds_dimension (line 60) | fn should_panic_when_cat_exceeds_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/ceil.rs
  function should_support_ceil_ops (line 7) | fn should_support_ceil_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/chunk.rs
  function test_chunk_evenly_divisible (line 8) | fn test_chunk_evenly_divisible() {
  function test_chunk_not_evenly_divisible (line 29) | fn test_chunk_not_evenly_divisible() {
  function test_chunk_not_divisible (line 51) | fn test_chunk_not_divisible() {
  function test_chunk_multi_dimension (line 75) | fn test_chunk_multi_dimension() {
  function test_invalid_dim (line 96) | fn test_invalid_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/clamp.rs
  function clamp_min (line 7) | fn clamp_min() {
  function clamp_max (line 22) | fn clamp_max() {
  function clamp_min_max (line 37) | fn clamp_min_max() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cos.rs
  function should_support_cos_ops (line 7) | fn should_support_cos_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cosh.rs
  function should_support_cosh_ops (line 7) | fn should_support_cosh_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/div.rs
  function should_support_div_ops (line 7) | fn should_support_div_ops() {
  function test_div_broadcast (line 21) | fn test_div_broadcast() {
  function should_support_div_scalar_ops (line 37) | fn should_support_div_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/erf.rs
  function should_support_erf_ops (line 7) | fn should_support_erf_ops() {
  function should_support_erf_ops_with_negative_number (line 20) | fn should_support_erf_ops_with_negative_number() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/exp.rs
  function should_support_exp_ops (line 7) | fn should_support_exp_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/expand.rs
  function expand_2d (line 7) | fn expand_2d() {
  function expand_3d (line 33) | fn expand_3d() {
  function expand_higher_dimensions (line 50) | fn expand_higher_dimensions() {
  function broadcast_single (line 74) | fn broadcast_single() {
  function should_fail_expand_incompatible_shapes (line 87) | fn should_fail_expand_incompatible_shapes() {
  function should_all_negative_one (line 93) | fn should_all_negative_one() {
  function should_panic_negative_one_on_non_existing_dim (line 109) | fn should_panic_negative_one_on_non_existing_dim() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/flip.rs
  function flip_float (line 7) | fn flip_float() {
  function flip_duplicated_axes (line 25) | fn flip_duplicated_axes() {
  function flip_out_of_bound_axis (line 34) | fn flip_out_of_bound_axis() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/floor.rs
  function should_support_floor_ops (line 7) | fn should_support_floor_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/gather_scatter.rs
  function should_gather_1d_dim0 (line 8) | fn should_gather_1d_dim0() {
  function should_gather_2d_dim0 (line 24) | fn should_gather_2d_dim0() {
  function should_gather_2d_dim1 (line 40) | fn should_gather_2d_dim1() {
  function should_gather_3d_dim1 (line 56) | fn should_gather_3d_dim1() {
  function should_gather_2d_only_1dim (line 79) | fn should_gather_2d_only_1dim() {
  function should_scatter_1d (line 95) | fn should_scatter_1d() {
  function should_scatter_2d_dim0 (line 112) | fn should_scatter_2d_dim0() {
  function should_scatter_2d_dim1 (line 129) | fn should_scatter_2d_dim1() {
  function should_scatter_3d_dim1 (line 146) | fn should_scatter_3d_dim1() {
  function should_scatter_2d_dim1_diff_shape (line 174) | fn should_scatter_2d_dim1_diff_shape() {
  function scatter_should_panic_on_mismatch_of_shapes (line 192) | fn scatter_should_panic_on_mismatch_of_shapes() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log.rs
  function should_support_log_ops (line 7) | fn should_support_log_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log1p.rs
  function should_support_exp_log1p (line 7) | fn should_support_exp_log1p() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/map_comparison.rs
  function test_equal (line 6) | fn test_equal() {
  function test_not_equal (line 19) | fn test_not_equal() {
  function test_equal_elem (line 33) | fn test_equal_elem() {
  function test_not_equal_elem (line 46) | fn test_not_equal_elem() {
  function test_greater_elem (line 59) | fn test_greater_elem() {
  function test_greater_equal_elem (line 71) | fn test_greater_equal_elem() {
  function test_greater (line 83) | fn test_greater() {
  function test_greater_equal (line 96) | fn test_greater_equal() {
  function test_lower_elem (line 109) | fn test_lower_elem() {
  function test_lower_equal_elem (line 122) | fn test_lower_equal_elem() {
  function test_lower (line 134) | fn test_lower() {
  function test_lower_equal (line 147) | fn test_lower_equal() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mask.rs
  function should_support_mask_where_ops (line 7) | fn should_support_mask_where_ops() {
  function should_support_mask_fill_ops (line 25) | fn should_support_mask_fill_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/maxmin.rs
  function test_max_dim_2d (line 7) | fn test_max_dim_2d() {
  function test_max_dim_with_indices_2d_with_dim_0th (line 20) | fn test_max_dim_with_indices_2d_with_dim_0th() {
  function test_max_dim_with_indices_2d (line 36) | fn test_max_dim_with_indices_2d() {
  function test_min_dim_2d (line 52) | fn test_min_dim_2d() {
  function test_min_dim_with_indices_2d (line 66) | fn test_min_dim_with_indices_2d() {
  function test_min_dim_2d_with_0th_dim (line 82) | fn test_min_dim_2d_with_0th_dim() {
  function test_max_dim_2d_with_0th_dim (line 95) | fn test_max_dim_2d_with_0th_dim() {
  function test_min_dim_with_indices_2d_with_0th_dim (line 108) | fn test_min_dim_with_indices_2d_with_0th_dim() {
  function test_maximum_pair (line 124) | fn test_maximum_pair() {
  function test_minimum_pair (line 138) | fn test_minimum_pair() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mul.rs
  function should_support_mul_ops (line 7) | fn should_support_mul_ops() {
  function test_mul_broadcast (line 21) | fn test_mul_broadcast() {
  function test_mul_broadcast_2_dims (line 35) | fn test_mul_broadcast_2_dims() {
  function should_support_mul_scalar_ops (line 49) | fn should_support_mul_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/narrow.rs
  function test_narrow (line 7) | fn test_narrow() {
  function test_narrow_invalid_dim (line 30) | fn test_narrow_invalid_dim() {
  function test_narrow_invalid_start (line 38) | fn test_narrow_invalid_start() {
  function test_narrow_invalid_zero_length (line 46) | fn test_narrow_invalid_zero_length() {
  function test_narrow_invalid_length (line 54) | fn test_narrow_invalid_length() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/neg.rs
  function should_support_neg_ops (line 7) | fn should_support_neg_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/permute.rs
  function permute_float (line 6) | fn permute_float() {
  function edge_repeated_axes (line 46) | fn edge_repeated_axes() {
  function edge_out_of_bound_axis (line 58) | fn edge_out_of_bound_axis() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf.rs
  function should_support_powf_ops (line 7) | fn should_support_powf_ops() {
  function should_support_neg_power (line 21) | fn should_support_neg_power() {
  function should_support_neg_values_with_even_power (line 35) | fn should_support_neg_values_with_even_power() {
  function should_support_neg_values_with_odd_power (line 49) | fn should_support_neg_values_with_odd_power() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf_scalar.rs
  function should_support_powf_ops (line 7) | fn should_support_powf_ops() {
  function should_support_neg_power (line 20) | fn should_support_neg_power() {
  function should_support_neg_values_with_even_power (line 33) | fn should_support_neg_values_with_even_power() {
  function should_support_neg_values_with_odd_power (line 46) | fn should_support_neg_values_with_odd_power() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/recip.rs
  function should_support_recip_ops (line 7) | fn should_support_recip_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/remainder.rs
  function should_support_remainder_basic (line 7) | fn should_support_remainder_basic() {
  function should_support_remainder_basic_scalar (line 22) | fn should_support_remainder_basic_scalar() {
  function should_support_remainder_float (line 35) | fn should_support_remainder_float() {
  function should_support_remainder_float_scalar (line 49) | fn should_support_remainder_float_scalar() {
  function should_be_zero (line 62) | fn should_be_zero() {
  function should_be_zero_scalar (line 76) | fn should_be_zero_scalar() {
  function should_have_no_remainder (line 89) | fn should_have_no_remainder() {
  function should_have_no_remainder_scalar (line 103) | fn should_have_no_remainder_scalar() {
  function should_be_negative (line 116) | fn should_be_negative() {
  function should_be_negative_scalar (line 130) | fn should_be_negative_scalar() {
  function should_support_fp_dividends (line 143) | fn should_support_fp_dividends() {
  function should_support_large_divisor (line 156) | fn should_support_large_divisor() {
  function should_support_large_divisor_scalar (line 170) | fn should_support_large_divisor_scalar() {
  function should_support_remainder_op (line 183) | fn should_support_remainder_op() {
  function should_support_remainder_scalar_op (line 198) | fn should_support_remainder_scalar_op() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/repeat_dim.rs
  function should_support_repeat_ops (line 6) | fn should_support_repeat_ops() {
  function should_support_repeat_on_dims_larger_than_1 (line 24) | fn should_support_repeat_on_dims_larger_than_1() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/reshape.rs
  function should_support_reshape_1d (line 6) | fn should_support_reshape_1d() {
  function should_support_reshape_2d (line 19) | fn should_support_reshape_2d() {
  function should_support_dim_infererence (line 32) | fn should_support_dim_infererence() {
  function should_not_corrupt_after_slice (line 56) | fn should_not_corrupt_after_slice() {
  function multiple_neg_ones (line 69) | fn multiple_neg_ones() {
  function neg_value (line 76) | fn neg_value() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/round.rs
  function should_support_round_ops (line 7) | fn should_support_round_ops() {
  function should_round_ties_even (line 20) | fn should_round_ties_even() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/select.rs
  function should_select_1d (line 8) | fn should_select_1d() {
  function should_select_2d_dim0_same_num_dim (line 22) | fn should_select_2d_dim0_same_num_dim() {
  function should_select_2d_dim0_more_num_dim (line 36) | fn should_select_2d_dim0_more_num_dim() {
  function should_select_2d_dim1 (line 55) | fn should_select_2d_dim1() {
  function should_select_assign_1d (line 69) | fn should_select_assign_1d() {
  function should_select_assign_2d_dim0 (line 84) | fn should_select_assign_2d_dim0() {
  function should_select_assign_2d_dim1 (line 99) | fn should_select_assign_2d_dim1() {
  function should_select_panic_invalid_dimension (line 115) | fn should_select_panic_invalid_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sin.rs
  function should_support_sin_ops (line 7) | fn should_support_sin_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sinh.rs
  function should_support_sinh_ops (line 7) | fn should_support_sinh_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/slice.rs
  function should_support_full_sliceing_1d (line 7) | fn should_support_full_sliceing_1d() {
  function should_support_partial_sliceing_1d (line 17) | fn should_support_partial_sliceing_1d() {
  function should_support_full_sliceing_2d (line 30) | fn should_support_full_sliceing_2d() {
  function should_support_partial_sliceing_2d (line 42) | fn should_support_partial_sliceing_2d() {
  function should_support_partial_sliceing_3d (line 55) | fn should_support_partial_sliceing_3d() {
  function should_support_partial_sliceing_3d_non_contiguous (line 71) | fn should_support_partial_sliceing_3d_non_contiguous() {
  function should_support_slice_assign_1d (line 87) | fn should_support_slice_assign_1d() {
  function should_support_slice_assign_2d (line 101) | fn should_support_slice_assign_2d() {
  function slice_should_not_corrupt_potentially_inplace_operations (line 115) | fn slice_should_not_corrupt_potentially_inplace_operations() {
  function slice_assign_should_not_corrupt_potentially_inplace_operations (line 128) | fn slice_assign_should_not_corrupt_potentially_inplace_operations() {
  function clamp_when_slice_exceeds_dimension (line 151) | fn clamp_when_slice_exceeds_dimension() {
  function negative_dimensions (line 160) | fn negative_dimensions() {
  function missing_dimensions (line 184) | fn missing_dimensions() {
  function should_panic_when_slice_with_too_many_dimensions (line 225) | fn should_panic_when_slice_with_too_many_dimensions() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sort_argsort.rs
  function test_sort_1d_float (line 7) | fn test_sort_1d_float() {
  function test_argsort_1d_float (line 27) | fn test_argsort_1d_float() {
  function test_sort_with_indices_descending_float (line 40) | fn test_sort_with_indices_descending_float() {
  function test_sort_float (line 86) | fn test_sort_float() {
  function test_sort_with_indices_float (line 133) | fn test_sort_with_indices_float() {
  function test_argsort_float (line 188) | fn test_argsort_float() {
  function test_sort_descending_1d (line 214) | fn test_sort_descending_1d() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/split.rs
  function test_split_evenly_divisible (line 8) | fn test_split_evenly_divisible() {
  function test_split_not_evenly_divisible (line 29) | fn test_split_not_evenly_divisible() {
  function test_split_along_dim1 (line 51) | fn test_split_along_dim1() {
  function test_split_split_size_larger_than_tensor_size (line 71) | fn test_split_split_size_larger_than_tensor_size() {
  function test_split_with_zero_split_size_non_zero_tensor (line 91) | fn test_split_with_zero_split_size_non_zero_tensor() {
  function test_split_invalid_dim (line 99) | fn test_split_invalid_dim() {
  function test_split_with_sizes (line 106) | fn test_split_with_sizes() {
  function test_split_with_sizes_invalid_sum (line 130) | fn test_split_with_sizes_invalid_sum() {
  function test_split_with_sizes_zero_length (line 137) | fn test_split_with_sizes_zero_length() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sqrt.rs
  function should_support_sqrt_ops (line 8) | fn should_support_sqrt_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/stack.rs
  function should_support_stack_ops_2d_dim0 (line 9) | fn should_support_stack_ops_2d_dim0() {
  function should_support_stack_ops_2d_dim1 (line 23) | fn should_support_stack_ops_2d_dim1() {
  function should_support_stack_ops_3d (line 37) | fn should_support_stack_ops_3d() {
  function should_panic_when_dimensions_are_not_the_same (line 55) | fn should_panic_when_dimensions_are_not_the_same() {
  function should_panic_when_stack_exceeds_dimension (line 64) | fn should_panic_when_stack_exceeds_dimension() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sub.rs
  function should_support_sub_ops (line 7) | fn should_support_sub_ops() {
  function test_sub_broadcast (line 21) | fn test_sub_broadcast() {
  function should_support_sub_scalar_ops (line 35) | fn should_support_sub_scalar_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tan.rs
  function should_support_tan_ops (line 7) | fn should_support_tan_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tanh.rs
  function should_support_tanh_ops (line 7) | fn should_support_tanh_ops() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/topk.rs
  function test_topk_1d (line 7) | fn test_topk_1d() {
  function test_topk (line 20) | fn test_topk() {
  function test_topk_with_indices (line 36) | fn test_topk_with_indices() {
  function test_topk_with_indices_3d (line 53) | fn test_topk_with_indices_3d() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/transpose.rs
  function should_support_transpose_ops (line 7) | fn should_support_transpose_ops() {
  function should_support_swap_dims (line 26) | fn should_support_swap_dims() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/matmul.rs
  function test_matmul_vectors (line 8) | fn test_matmul_vectors() {
  function test_matmul_2d (line 22) | fn test_matmul_2d() {
  function test_matmul_2d_aligned (line 34) | fn test_matmul_2d_aligned() {
  function test_matmul_2d_aligned_fused (line 59) | fn test_matmul_2d_aligned_fused() {
  function test_matmul_3d (line 86) | fn test_matmul_3d() {
  function test_matmul_broadcast_4d (line 100) | fn test_matmul_broadcast_4d() {
  function test_matmul_broadcast (line 120) | fn test_matmul_broadcast() {
  function should_panic_when_inner_dimensions_are_not_equal (line 135) | fn should_panic_when_inner_dimensions_are_not_equal() {
  function test_matmul_lhs_float_rhs_quantized (line 144) | fn test_matmul_lhs_float_rhs_quantized() {
  function test_matmul_mixed_block_scale (line 174) | fn test_matmul_mixed_block_scale() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/quantize.rs
  function get_q_params (line 10) | fn get_q_params(data: TensorData) -> QParams<Vec<f32>> {
  function should_support_quantize_symmetric_int8 (line 26) | fn should_support_quantize_symmetric_int8() {
  function should_support_quantize_dynamic_int8 (line 66) | fn should_support_quantize_dynamic_int8() {
  function should_quantize_dequantize_symmetric_single_with_transform (line 86) | fn should_quantize_dequantize_symmetric_single_with_transform() {
  function should_quantize_dequantize_symmetric_arange_16x16 (line 104) | fn should_quantize_dequantize_symmetric_arange_16x16() {
  function should_quantize_dequantize_symmetric_per_block_arange_16x16 (line 122) | fn should_quantize_dequantize_symmetric_per_block_arange_16x16() {
  function should_quantize_transposed (line 141) | fn should_quantize_transposed<const D: usize>(tensor: Tensor<TestBackend...
  function should_quantize_symmetric_int8_transposed_8x32 (line 153) | fn should_quantize_symmetric_int8_transposed_8x32() {
  function should_quantize_symmetric_int8_transposed_48x64 (line 164) | fn should_quantize_symmetric_int8_transposed_48x64() {
  function should_quantize_symmetric_per_block_int8_transposed_32x64 (line 175) | fn should_quantize_symmetric_per_block_int8_transposed_32x64() {
  function should_quantize_symmetric_int8_permuted_batch_dims (line 188) | fn should_quantize_symmetric_int8_permuted_batch_dims() {

FILE: crates/burn-backend-tests/tests/tensor/float/quantization/scheme.rs
  function per_tensor_symmetric_int8 (line 10) | fn per_tensor_symmetric_int8() {
  function per_block_symmetric_int8 (line 27) | fn per_block_symmetric_int8() {
  function quant_scheme_should_inhibit_by_default (line 46) | fn quant_scheme_should_inhibit_by_default() {

FILE: crates/burn-backend-tests/tests/tensor/float/stats/cov.rs
  function test_cov_1 (line 5) | fn test_cov_1() {
  function test_cov_4 (line 20) | fn test_cov_4() {
  function test_cov_2 (line 34) | fn test_cov_2() {
  function test_cov_3 (line 54) | fn test_cov_3() {

FILE: crates/burn-backend-tests/tests/tensor/float/stats/display.rs
  type FloatElem (line 5) | type FloatElem = <TestBackend as Backend>::FloatElem;
  type IntElem (line 6) | type IntElem = <TestBackend as Backend>::IntElem;
  function skip_precision_not_f32 (line 9) | fn skip_precision_not_f32() -> bool {
  function test_display_2d_int_tensor (line 14) | fn test_display_2d_int_tensor() {
  function test_display_2d_float_tensor (line 39) | fn test_display_2d_float_tensor() {
  function test_display_2d_bool_tensor (line 67) | fn test_display_2d_bool_tensor() {
  function test_display_3d_tensor (line 102) | fn test_display_3d_tensor() {
  function test_display_4d_tensor (line 133) | fn test_display_4d_tensor() {
  function test_display_tensor_summarize_1 (line 167) | fn test_display_tensor_summarize_1() {
  function test_display_tensor_summarize_2 (line 196) | fn test_display_tensor_summarize_2() {
  function test_display_tensor_summarize_3 (line 245) | fn test_display_tensor_summarize_3() {
  function test_display_precision (line 293) | fn test_display_precision() {

FILE: crates/burn-backend-tests/tests/tensor/float/stats/eye.rs
  function test_eye_float (line 4) | fn test_eye_float() {
  function test_eye_int (line 12) | fn test_eye_int() {

FILE: crates/burn-backend-tests/tests/tensor/float/stats/median.rs
  function test_median_even (line 5) | fn test_median_even() {
  function test_median_odd (line 25) | fn test_median_odd() {
  function test_median_with_indices (line 49) | fn test_median_with_indices() {
  function test_median_all_elements (line 75) | fn test_median_all_elements() {

FILE: crates/burn-backend-tests/tests/tensor/float/stats/var.rs
  function test_var (line 6) | fn test_var() {
  function test_var_mean (line 21) | fn test_var_mean() {
  function test_var_bias (line 39) | fn test_var_bias() {
  function test_var_mean_bias (line 54) | fn test_var_mean_bias() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/abs.rs
  function should_support_abs_ops_int (line 5) | fn should_support_abs_ops_int() {
  function should_support_abs_ops_int_signed_min (line 16) | fn should_support_abs_ops_int_signed_min() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/add.rs
  function test_add_d2_int (line 5) | fn test_add_d2_int() {
  function test_add_broadcast_int (line 17) | fn test_add_broadcast_int() {
  function should_support_add_scalar_ops_int (line 29) | fn should_support_add_scalar_ops_int() {
  function scalar_add_not_contiguous (line 41) | fn scalar_add_not_contiguous() {
  function scalar_add_not_contiguous_int (line 56) | fn scalar_add_not_contiguous_int() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/aggregation.rs
  function test_should_mean_int (line 5) | fn test_should_mean_int() {
  function test_should_mean_last_dim_int (line 14) | fn test_should_mean_last_dim_int() {
  function test_should_sum_last_dim_int (line 25) | fn test_should_sum_last_dim_int() {
  function test_should_sum_int (line 36) | fn test_should_sum_int() {
  function test_prod_int (line 46) | fn test_prod_int() {
  function test_prod_dim_int (line 62) | fn test_prod_dim_int() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/all.rs
  function test_all (line 5) | fn test_all() {
  function test_all_dim (line 18) | fn test_all_dim() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/any.rs
  function test_any (line 5) | fn test_any() {
  function test_any_dim (line 18) | fn test_any_dim() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/arange.rs
  function test_arange (line 6) | fn test_arange() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/arange_step.rs
  function test_arange_step (line 6) | fn test_arange_step() {
  function should_panic_when_step_is_zero (line 41) | fn should_panic_when_step_is_zero() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/arg.rs
  function test_argmax_2d_dim0_int (line 5) | fn test_argmax_2d_dim0_int() {
  function test_argmin_2d_dim0_int (line 16) | fn test_argmin_2d_dim0_int() {

FILE: crates/burn-backend-tests/tests/tensor/int/ops/bitwise.rs
  function should_apply_bitwise_and_2d (line 5) | fn should_apply_bitwise_and_2d() {
  function should_apply_bitwise_and_1d (line 17) | fn should_apply_bitwise_and_1d() {
  function should_apply_bitwise_and_scalar_2d (line 29) | fn should_apply_bitwise_and_scalar_2d() {
  function should_apply_bitwise_not_2d (line 41) | fn should_apply_bitwise_not_2d() {
  function should_apply_bitwise_or_scalar_2d (line 52) | fn should_apply_bitwise_or_scalar_2d() {
  function should_apply_bitwise_or_2d (line 64) | fn should_apply_bitwise_or_2d() {
  function should_apply_bitwise_or_1d (line 76) | fn should_apply_bitwise_or_1d() {
  function should_apply_bitwise_xor_scalar_2d (line 88) | fn should_apply_bitwise_xor_scalar_2d() {
  function should_apply_bitwise_xor_2d (line 100) | fn should_apply_bitwise_xor_2d() {
  function should_apply_bitwise_xor_1d (line 112) | fn should_apply_bitwise_xor_1d() {
  function should_apply_bitwise_left_shift_2d (line 124) | fn should_apply_bitwise_left_shift_2d() {
  function should_apply_bitwise_left_shift_scalar_2d (line 140) | fn should_apply_bitwise_left_shift_scalar_2d() {
  function should_apply_bitwise_right_shift_2d (line 152) | fn should_apply_bitwise_right_shift_2d() {
  function should_apply_bitwise_right_shift_sca
Copy disabled (too large) Download .json
Condensed preview — 1817 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (10,420K chars).
[
  {
    "path": ".cargo/audit.toml",
    "chars": 1330,
    "preview": "# Audit config file\n#\n# It may be located in the user home (`~/.cargo/audit.toml`) or in the project\n# root (`.cargo/aud"
  },
  {
    "path": ".cargo/config.toml",
    "chars": 140,
    "preview": "[alias]\nxtask = \"run --target-dir target/xtask --color always --package xtask --bin xtask --\"\nrun-checks = \"xtask -c all"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "chars": 884,
    "preview": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the b"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/doc_request.md",
    "chars": 219,
    "preview": "---\nname: Documentation request\nabout: Flag incoherent or missing documentation, including use case examples.\ntitle: ''\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "chars": 561,
    "preview": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n<!-- Please s"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/template.md",
    "chars": 411,
    "preview": "* **Please check if the PR fulfills these requirements**\n- [ ] The commit message follows our guidelines\n- [ ] Docs have"
  },
  {
    "path": ".github/dependabot.yml",
    "chars": 404,
    "preview": "version: 2\n\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"daily\"\n  "
  },
  {
    "path": ".github/pull_request_template.md",
    "chars": 385,
    "preview": "## Pull Request Template\n\n### Checklist\n\n- [ ] Confirmed that `cargo run-checks` command has been executed.\n- [ ] Made s"
  },
  {
    "path": ".github/workflows/combine-dependabot-prs.yml",
    "chars": 391,
    "preview": "name: Combine Dependabot PRs\n\non:\n  schedule:\n    - cron: '0 6 * * MON' # Monday at 6:00am UTC\n  workflow_dispatch:\n\nper"
  },
  {
    "path": ".github/workflows/dependencies.yml",
    "chars": 1964,
    "preview": "name: dependencies\n\non:\n  schedule:\n    - cron: '0 21 * * TUE' # Run every Tuesday at 21:00 (UTC)\n  push:\n    tags:\n    "
  },
  {
    "path": ".github/workflows/publish.yml",
    "chars": 13986,
    "preview": "name: publish\n\non:\n  push:\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n    inputs:\n      dry-run-only:\n        descripti"
  },
  {
    "path": ".github/workflows/stale-pr.yml",
    "chars": 1479,
    "preview": "name: Stale Pull Requests\n\non:\n  schedule:\n    - cron: '0 12 * * *' # Run every day at 12:00 (UTC)\n\n# The minimum permis"
  },
  {
    "path": ".github/workflows/test-gpu.yml",
    "chars": 5897,
    "preview": "name: CI GPU\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: \"Number of the pull request tha"
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 8560,
    "preview": "name: CI\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"Cargo.lock\"\n      - \"**.rs\"\n      - \"**.sh\"\n      -"
  },
  {
    "path": ".github/workflows/valgrind.yml",
    "chars": 1295,
    "preview": "name: valgrind\n\non:\n  schedule:\n    - cron: '0 23 * * WED' # Run every Wednesday at 23:00 (UTC)\n\nconcurrency:\n  group: $"
  },
  {
    "path": ".github/workflows/vulnerabilities.yml",
    "chars": 5122,
    "preview": "name: vulnerabilities\n\non:\n  schedule:\n    - cron: '0 21 * * WED' # Run every Wednesday at 21:00 (UTC)\n  push:\n    tags:"
  },
  {
    "path": ".gitignore",
    "chars": 241,
    "preview": "target\n# These are backup files generated by rustfmt\n**/*.rs.bk\n.DS_Store\n\n.dir-locals.el\n.idea\n.vscode\n.vs\n.fleet\n.ipyn"
  },
  {
    "path": "CITATION.cff",
    "chars": 1023,
    "preview": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n  - family-names: \"Simard\"\n   "
  },
  {
    "path": "CODE-OF-CONDUCT.md",
    "chars": 5231,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 4718,
    "preview": "# Contributing to Burn\n\nWelcome to the Burn community! We're glad you're interested in contributing.\n\n## How to Contribu"
  },
  {
    "path": "Cargo.toml",
    "chars": 6682,
    "preview": "[workspace]\n# Try\n# require version 2 to avoid \"feature\" additiveness for dev-dependencies\n# https://doc.rust-lang.org/c"
  },
  {
    "path": "LICENSE-APACHE",
    "chars": 10866,
    "preview": "                              Apache License\n                        Version 2.0, January 2004\n                     http"
  },
  {
    "path": "LICENSE-MIT",
    "chars": 1103,
    "preview": "MIT License\n\nCopyright (c) 2022 Nathaniel Simard & Burn Framework Contributors\n\nPermission is hereby granted, free of ch"
  },
  {
    "path": "NOTICES.md",
    "chars": 23076,
    "preview": "# NOTICES AND INFORMATION\n\nThis file contains notices and information required by libraries that this\nrepository copied "
  },
  {
    "path": "POEM.md",
    "chars": 1439,
    "preview": "# BURN: Burn Unstoppable Rusty Neurons\n\nIn the realm of circuits and code,  \nA fiery forge ignites to bear its load,  \nA"
  },
  {
    "path": "README.md",
    "chars": 21441,
    "preview": "<div align=\"center\">\n<img src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/logo-burn-neutral.webp\" widt"
  },
  {
    "path": "_typos.toml",
    "chars": 415,
    "preview": "[default]\nextend-ignore-identifiers-re = [\"ratatui\", \"Ratatui\", \"NdArray*\", \"ND\"]\n\n[default.extend-identifiers]\nUE4M3 = "
  },
  {
    "path": "benchmarks.toml",
    "chars": 1888,
    "preview": "[environment]\ngcp_gpu_attached = true\ngcp_image_family = \"tracel-ci-ubuntu-2404-amd64-nvidia\"\n# https://cloud.google.com"
  },
  {
    "path": "burn-book/.gitignore",
    "chars": 186,
    "preview": "target\n\n# MacOS temp file\n.DS_Store\n\nbook-test\nguide/book\n\n.vscode\ntests/burn-book/book/\nbook/\n\n# Ignore Jetbrains speci"
  },
  {
    "path": "burn-book/.prettierrc.json",
    "chars": 52,
    "preview": "{\n    \"printWidth\": 100,\n    \"proseWrap\": \"always\"\n}"
  },
  {
    "path": "burn-book/book.toml",
    "chars": 285,
    "preview": "[book]\nauthors = [\n    \"Wouter Doppenberg\",\n    \"Nathaniel Simard\",\n    \"Louis Fortier-Dubois\",\n    \"Dilshod Tadjibaev\","
  },
  {
    "path": "burn-book/src/SUMMARY.md",
    "chars": 1882,
    "preview": "- [Overview](./overview.md)\n- [Why Burn?](./motivation.md)\n- [Getting started](./getting-started.md)\n  - [Examples](./ex"
  },
  {
    "path": "burn-book/src/advanced/README.md",
    "chars": 762,
    "preview": "# Advanced\n\nIn this section, we will go into advanced topics that extend beyond basic usage. Given Burn's\nexceptional fl"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/README.md",
    "chars": 4019,
    "preview": "# Backend Extension\n\nBurn aims to be the most flexible deep learning framework. While it's crucial to maintain\ncompatibi"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md",
    "chars": 16784,
    "preview": "# Custom CubeCL Kernel\n\nIn this section, you will learn how to create your own custom operation by writing your own kern"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md",
    "chars": 19509,
    "preview": "# Custom WGPU Kernel\n\nIn this section, you will learn how to create your own custom operation by writing your own kernel"
  },
  {
    "path": "burn-book/src/advanced/no-std.md",
    "chars": 3006,
    "preview": "# No Standard Library\n\nIn this section, you will learn how to run an ONNX inference model on an embedded system, with no"
  },
  {
    "path": "burn-book/src/advanced/web-assembly.md",
    "chars": 567,
    "preview": "# WebAssembly\n\nBurn supports WebAssembly (WASM) execution using the `NdArray` and `WebGpu` backends, allowing\nmodels to "
  },
  {
    "path": "burn-book/src/basic-workflow/README.md",
    "chars": 1243,
    "preview": "# Guide\n\nThis guide will walk you through the process of creating a custom model built with Burn. We will\ntrain a simple"
  },
  {
    "path": "burn-book/src/basic-workflow/backend.md",
    "chars": 1943,
    "preview": "# Backend\n\nWe have effectively written most of the necessary code to train our model. However, we have not\nexplicitly de"
  },
  {
    "path": "burn-book/src/basic-workflow/data.md",
    "chars": 6189,
    "preview": "# Data\n\nTypically, one trains a model on some dataset. Burn provides a library of very useful dataset\nsources and transf"
  },
  {
    "path": "burn-book/src/basic-workflow/inference.md",
    "chars": 3453,
    "preview": "# Inference\n\nNow that we have trained our model, the next natural step is to use it for inference.\n\nYou need two things "
  },
  {
    "path": "burn-book/src/basic-workflow/model.md",
    "chars": 16620,
    "preview": "# Model\n\nThe first step is to create a project and add the different Burn dependencies. Start by creating a\nnew project "
  },
  {
    "path": "burn-book/src/basic-workflow/training.md",
    "chars": 12819,
    "preview": "# Training\n\nWe are now ready to write the necessary code to train our model on the MNIST dataset. We shall\ndefine the co"
  },
  {
    "path": "burn-book/src/building-blocks/README.md",
    "chars": 406,
    "preview": "# Building Blocks\n\nIn this section, we'll guide you through the core elements that make up Burn. We'll walk you through\n"
  },
  {
    "path": "burn-book/src/building-blocks/autodiff.md",
    "chars": 4009,
    "preview": "# Autodiff\n\nBurn's tensor also supports auto-differentiation, which is an essential part of any deep learning\nframework."
  },
  {
    "path": "burn-book/src/building-blocks/backend.md",
    "chars": 990,
    "preview": "# Backend\n\nNearly everything in Burn is based on the `Backend` trait, which enables you to run tensor\noperations using d"
  },
  {
    "path": "burn-book/src/building-blocks/config.md",
    "chars": 2072,
    "preview": "# Config\n\nWhen writing scientific code, you normally have a lot of values that are set, and Deep Learning is\nno exceptio"
  },
  {
    "path": "burn-book/src/building-blocks/dataset.md",
    "chars": 22194,
    "preview": "# Dataset\n\nAt its core, a dataset is a collection of data typically related to a specific analysis or\nprocessing task. T"
  },
  {
    "path": "burn-book/src/building-blocks/learner.md",
    "chars": 4548,
    "preview": "# Learner\n\nThe [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates\nmultiple u"
  },
  {
    "path": "burn-book/src/building-blocks/metric.md",
    "chars": 9308,
    "preview": "# Metric\n\nWhen working with the learner, you have the option to record metrics that will be monitored\nthroughout the tra"
  },
  {
    "path": "burn-book/src/building-blocks/module.md",
    "chars": 15345,
    "preview": "# Module\n\nThe `Module` derive allows you to create your own neural network modules, similar to PyTorch. The\nderive funct"
  },
  {
    "path": "burn-book/src/building-blocks/record.md",
    "chars": 4359,
    "preview": "# Record\n\nRecords are how states are saved with Burn. Compared to most other frameworks, Burn has its own\nadvanced savin"
  },
  {
    "path": "burn-book/src/building-blocks/tensor.md",
    "chars": 40351,
    "preview": "# Tensor\n\nAs previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3\ngeneric ar"
  },
  {
    "path": "burn-book/src/custom-training-loop.md",
    "chars": 10301,
    "preview": "# Custom Training Loops\n\nEven though Burn comes with a project dedicated to simplifying training, it doesn't mean that y"
  },
  {
    "path": "burn-book/src/distributed-computing.md",
    "chars": 24,
    "preview": "# Distributed Computing\n"
  },
  {
    "path": "burn-book/src/examples.md",
    "chars": 7767,
    "preview": "# Examples\n\nIn the [next chapter](./basic-workflow) you'll have the opportunity to implement the whole Burn\n`guide` exam"
  },
  {
    "path": "burn-book/src/getting-started.md",
    "chars": 8153,
    "preview": "# Getting Started\n\nBurn is a deep learning framework in the Rust programming language. Therefore, it goes without\nsaying"
  },
  {
    "path": "burn-book/src/models-and-pretrained-weights.md",
    "chars": 1648,
    "preview": "# Models and Pre-Trained Weights\n\n## Models Repository\n\nThe [`models`](https://github.com/tracel-ai/models) repository c"
  },
  {
    "path": "burn-book/src/motivation.md",
    "chars": 1945,
    "preview": "# Why Burn?\n\nWhy bother with the effort of creating an entirely new deep learning framework from scratch when\nPyTorch, T"
  },
  {
    "path": "burn-book/src/onnx-import.md",
    "chars": 10035,
    "preview": "# ONNX Import\n\n## Introduction\n\nAs deep learning evolves, interoperability between frameworks becomes crucial. Burn prov"
  },
  {
    "path": "burn-book/src/overview.md",
    "chars": 1921,
    "preview": "# Overview\n\nWelcome to The Burn Book 👋\n\nThis book will help you get started with the Burn deep learning framework, wheth"
  },
  {
    "path": "burn-book/src/performance/README.md",
    "chars": 122,
    "preview": "# Performance\n\nThis section covers the key concepts you need to understand to get the most out of Burn and your\nhardware"
  },
  {
    "path": "burn-book/src/performance/distributed-computing.md",
    "chars": 135,
    "preview": "# Distributed Computing\n\nDistributed computing support was introduced in Burn 0.19. Documentation and examples will be\na"
  },
  {
    "path": "burn-book/src/performance/good-practices/README.md",
    "chars": 567,
    "preview": "# Performance - Best Practices\n\nThis section provides valuable insights into the performance characteristics of Burn and"
  },
  {
    "path": "burn-book/src/performance/good-practices/asynchronous-execution.md",
    "chars": 2616,
    "preview": "# Asynchronous Execution\n\nMost Burn backends execute tensor operations in an asynchronous manner. However, the async not"
  },
  {
    "path": "burn-book/src/performance/good-practices/kernel-fusion.md",
    "chars": 4132,
    "preview": "# Kernel Fusion\n\nAn interesting property of async execution is that it allows performance optimizations like kernel\nfusi"
  },
  {
    "path": "burn-book/src/performance/good-practices/kernel-selection.md",
    "chars": 1526,
    "preview": "# Kernel Selection\n\nAs mentioned earlier, complex compute-bound operations are highly non-trivial and require many\ntrick"
  },
  {
    "path": "burn-book/src/performance/quantization.md",
    "chars": 6634,
    "preview": "# Quantization\n\nQuantization techniques perform computations and store tensors in lower precision data types like\n8-bit "
  },
  {
    "path": "burn-book/src/saving-and-loading.md",
    "chars": 14700,
    "preview": "# Saving and Loading Models\n\nSaving your trained machine learning model is quite easy, no matter the output format you c"
  },
  {
    "path": "codecov.yml",
    "chars": 276,
    "preview": "coverage:\n  status:\n    project:\n      default:\n        # https://docs.codecov.com/docs/commit-status#informational\n    "
  },
  {
    "path": "contributor-book/.gitignore",
    "chars": 186,
    "preview": "target\n\n# MacOS temp file\n.DS_Store\n\nbook-test\nguide/book\n\n.vscode\ntests/burn-book/book/\nbook/\n\n# Ignore Jetbrains speci"
  },
  {
    "path": "contributor-book/.prettierrc.json",
    "chars": 52,
    "preview": "{\n    \"printWidth\": 100,\n    \"proseWrap\": \"always\"\n}"
  },
  {
    "path": "contributor-book/book.toml",
    "chars": 302,
    "preview": "[book]\nauthors = [\n    \"Wouter Doppenberg\",\n    \"Nathaniel Simard\",\n    \"Louis Fortier-Dubois\",\n    \"Dilshod Tadjibaev\","
  },
  {
    "path": "contributor-book/src/SUMMARY.md",
    "chars": 983,
    "preview": "- [Overview](./overview.md)\n- [How to Read This Book](./how-to-read-this-book.md)\n- [Getting Started](./getting-started/"
  },
  {
    "path": "contributor-book/src/frequently-encountered-issues/README.md",
    "chars": 297,
    "preview": "# Frequently Encountered Issues\n\nThis is a collection of issues people have encountered and asked about on the\n[Discord "
  },
  {
    "path": "contributor-book/src/frequently-encountered-issues/issues-while-adding-ops.md",
    "chars": 1361,
    "preview": "# Issues encountered while adding ops\n\nBelow are some of the issues that were encountered while adding ops to the projec"
  },
  {
    "path": "contributor-book/src/getting-started/README.md",
    "chars": 323,
    "preview": "# Getting Started\n\nThis section is for setting up the environment and how to do basic development tasks such as running\n"
  },
  {
    "path": "contributor-book/src/getting-started/configuring-your-editor.md",
    "chars": 1669,
    "preview": "# Configuring your editor\n\nThese steps are not required, and most of this isn't specific to Burn, but it's definitely he"
  },
  {
    "path": "contributor-book/src/getting-started/setting-up-the-environment.md",
    "chars": 2593,
    "preview": "# Setting up the environment\n\nDepending on what part of the project you plan on contributing to, there are a couple of t"
  },
  {
    "path": "contributor-book/src/getting-started/testing.md",
    "chars": 2233,
    "preview": "# Testing\n\n## Test for Tensor Operations\n\nTest for tensor operations (generally of the form: given this input, expect it"
  },
  {
    "path": "contributor-book/src/guides/README.md",
    "chars": 144,
    "preview": "# Guides for Contributors\n\nThe following guides are meant to help contributors accomplish specific tasks, such as adding"
  },
  {
    "path": "contributor-book/src/guides/adding-a-new-operation-to-burn.md",
    "chars": 17887,
    "preview": "# Adding a New Operation to burn\n\nLet's discuss how one might go about adding new operators to Burn, using the example o"
  },
  {
    "path": "contributor-book/src/guides/submitting-examples.md",
    "chars": 3150,
    "preview": "# Submitting Examples to Burn\n\nThis guide explains how to create and submit new examples to the Burn repository. Example"
  },
  {
    "path": "contributor-book/src/how-to-read-this-book.md",
    "chars": 1297,
    "preview": "# How to read this book\n\nThroughout this book, we maintain the following structure.\n\n## Linking\n\nWhen referring to struc"
  },
  {
    "path": "contributor-book/src/overview.md",
    "chars": 1535,
    "preview": "# Overview\n\nWelcome to The Burn Contributor's Book 👋\n\nThis book will help you get acquainted with the internals of the B"
  },
  {
    "path": "contributor-book/src/project-architecture/README.md",
    "chars": 634,
    "preview": "# Project Architecture\n\nThis section documents most major architectural decisions with the reasoning behind them.\n\n**Sec"
  },
  {
    "path": "contributor-book/src/project-architecture/backend.md",
    "chars": 2144,
    "preview": "# Backend\n\nThe Backend trait abstracts multiple things:\n\n- Device type\n- Float tensor type\n- Bool tensor type\n- Int tens"
  },
  {
    "path": "contributor-book/src/project-architecture/module.md",
    "chars": 5350,
    "preview": "# Module\n\nModules are a way of creating neural network structures that can be easily optimized, saved, and\nloaded with l"
  },
  {
    "path": "contributor-book/src/project-architecture/serialization.md",
    "chars": 5863,
    "preview": "# Serialization\n\nAn important aspect of a deep learning framework is the ability to save and load models from disk.\nDesp"
  },
  {
    "path": "contributor-book/src/project-architecture/tensor.md",
    "chars": 4690,
    "preview": "# Tensor\n\nA proper deep learning framework should have a fast tensor implementation with autodiff support, and\nBurn is n"
  },
  {
    "path": "crates/burn/Cargo.toml",
    "chars": 6725,
    "preview": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \""
  },
  {
    "path": "crates/burn/src/backend.rs",
    "chars": 1396,
    "preview": "#[cfg(feature = \"ndarray\")]\npub use burn_ndarray as ndarray;\n\n#[cfg(feature = \"ndarray\")]\npub use ndarray::NdArray;\n\n#[c"
  },
  {
    "path": "crates/burn/src/collective.rs",
    "chars": 28,
    "preview": "pub use burn_collective::*;\n"
  },
  {
    "path": "crates/burn/src/lib.rs",
    "chars": 7553,
    "preview": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n\n//! # Burn\n//!\n//! Burn is a new comprehensive dynami"
  },
  {
    "path": "crates/burn-autodiff/Cargo.toml",
    "chars": 1331,
    "preview": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Automati"
  },
  {
    "path": "crates/burn-autodiff/README.md",
    "chars": 399,
    "preview": "# Burn Autodiff\n\n> [Burn](https://github.com/tracel-ai/burn) autodiff backend\n\n[![Current Crates.io Version](https://img"
  },
  {
    "path": "crates/burn-autodiff/src/backend.rs",
    "chars": 3829,
    "preview": "use crate::{\n    checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},\n    grads::Gradients,\n    tensor::Autodiff"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/base.rs",
    "chars": 2968,
    "preview": "use super::{\n    retro_forward::RetroForwards,\n    state::{BackwardStates, State},\n};\nuse crate::collections::HashMap;\nu"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/builder.rs",
    "chars": 10397,
    "preview": "use crate::{\n    collections::HashMap,\n    graph::{ComputingProperty, NodeId},\n    tensor::AutodiffTensor,\n};\nuse alloc:"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/mod.rs",
    "chars": 198,
    "preview": "/// Checkpointer module\npub mod base;\npub(crate) mod builder;\n/// RetroForward module\npub mod retro_forward;\n/// Backwar"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/retro_forward.rs",
    "chars": 3633,
    "preview": "use crate::collections::HashMap;\nuse crate::graph::NodeId;\n\nuse alloc::sync::Arc;\nuse core::fmt::Debug;\n\nuse super::stat"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/state.rs",
    "chars": 4901,
    "preview": "use core::any::Any;\n\nuse crate::collections::HashMap;\nuse crate::graph::NodeId;\nuse alloc::boxed::Box;\n\n/// In order to "
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/strategy.rs",
    "chars": 3425,
    "preview": "use core::fmt::Debug;\n\nuse burn_backend::Backend;\n\nuse crate::{graph::ComputingProperty, tensor::AutodiffTensor};\nuse al"
  },
  {
    "path": "crates/burn-autodiff/src/grads.rs",
    "chars": 2997,
    "preview": "use burn_backend::{\n    Backend, TensorMetadata, TensorPrimitive,\n    tensor::{FloatTensor, TensorContainer},\n};\n\nuse cr"
  },
  {
    "path": "crates/burn-autodiff/src/graph/base.rs",
    "chars": 645,
    "preview": "use super::NodeId;\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};\nuse alloc::boxed::Box;\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/mod.rs",
    "chars": 116,
    "preview": "mod base;\nmod node;\nmod requirement;\n\npub mod traversal;\n\npub use base::*;\npub use node::*;\npub use requirement::*;\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/node.rs",
    "chars": 2481,
    "preview": "use alloc::{sync::Arc, vec::Vec};\n\n#[cfg(target_has_atomic = \"64\")]\nuse core::sync::atomic::{AtomicU64, Ordering};\n#[cfg"
  },
  {
    "path": "crates/burn-autodiff/src/graph/requirement.rs",
    "chars": 1114,
    "preview": "use super::NodeRef;\n\n/// Requirement for each tensor in the graph.\n#[derive(Debug, Clone, Copy, PartialEq, Eq)]\npub enum"
  },
  {
    "path": "crates/burn-autodiff/src/graph/traversal.rs",
    "chars": 1727,
    "preview": "use super::{Step, StepBoxed};\nuse crate::{\n    NodeId,\n    collections::{HashMap, HashSet},\n    graph::Parent,\n};\nuse al"
  },
  {
    "path": "crates/burn-autodiff/src/lib.rs",
    "chars": 1082,
    "preview": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! # Burn Aut"
  },
  {
    "path": "crates/burn-autodiff/src/ops/activation.rs",
    "chars": 5290,
    "preview": "use core::marker::PhantomData;\n\nuse crate::{\n    Autodiff,\n    checkpoint::{\n        base::Checkpointer, retro_forward::"
  },
  {
    "path": "crates/burn-autodiff/src/ops/backward.rs",
    "chars": 2570,
    "preview": "use super::{Ops, OpsPrep};\nuse crate::{\n    checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::Che"
  },
  {
    "path": "crates/burn-autodiff/src/ops/base.rs",
    "chars": 9511,
    "preview": "use super::Backward;\nuse crate::{\n    checkpoint::{\n        base::Checkpointer,\n        builder::{ActionType, Checkpoint"
  },
  {
    "path": "crates/burn-autodiff/src/ops/bool_tensor.rs",
    "chars": 4599,
    "preview": "use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};\nuse alloc::vec::Vec;\n\nuse burn_"
  },
  {
    "path": "crates/burn-autodiff/src/ops/int_tensor.rs",
    "chars": 12085,
    "preview": "use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};\nuse alloc::vec::Vec;\n\nuse burn_"
  },
  {
    "path": "crates/burn-autodiff/src/ops/maxmin.rs",
    "chars": 813,
    "preview": "use super::{Backward, Ops, unary};\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients};\nuse burn_backend::{Bac"
  },
  {
    "path": "crates/burn-autodiff/src/ops/mod.rs",
    "chars": 211,
    "preview": "mod activation;\nmod backward;\nmod base;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transa"
  },
  {
    "path": "crates/burn-autodiff/src/ops/module.rs",
    "chars": 66642,
    "preview": "use crate::Autodiff;\nuse crate::checkpoint::base::Checkpointer;\nuse crate::checkpoint::strategy::CheckpointStrategy;\nuse"
  },
  {
    "path": "crates/burn-autodiff/src/ops/qtensor.rs",
    "chars": 2795,
    "preview": "use burn_backend::{\n    Backend, ExecutionError, TensorData,\n    ops::QTensorOps,\n    tensor::{\n        Device, FloatTen"
  },
  {
    "path": "crates/burn-autodiff/src/ops/sort.rs",
    "chars": 809,
    "preview": "use super::{Backward, Ops, unary};\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients};\nuse burn_backend::{Bac"
  },
  {
    "path": "crates/burn-autodiff/src/ops/tensor.rs",
    "chars": 122003,
    "preview": "use alloc::{boxed::Box, vec, vec::Vec};\nuse core::marker::PhantomData;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_impo"
  },
  {
    "path": "crates/burn-autodiff/src/ops/transaction.rs",
    "chars": 743,
    "preview": "use burn_backend::{\n    Backend, ExecutionError,\n    ops::{TransactionOps, TransactionPrimitive},\n};\n\nuse crate::{Autodi"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/client.rs",
    "chars": 627,
    "preview": "use crate::{\n    checkpoint::builder::CheckpointerBuilder,\n    grads::Gradients,\n    graph::StepBoxed,\n    tensor::{Auto"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/graph.rs",
    "chars": 11114,
    "preview": "use super::{AutodiffClient, server::AutodiffServer};\nuse crate::{\n    NodeId,\n    checkpoint::builder::CheckpointerBuild"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/memory_management.rs",
    "chars": 10034,
    "preview": "use crate::{\n    NodeId,\n    collections::{HashMap, HashSet},\n    graph::Parent,\n    tensor::NodeRefCount,\n};\nuse alloc:"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/mod.rs",
    "chars": 82,
    "preview": "mod client;\nmod memory_management;\nmod server;\n\npub mod graph;\npub use client::*;\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/server.rs",
    "chars": 4374,
    "preview": "use super::memory_management::GraphMemoryManagement;\nuse crate::{\n    NodeId,\n    checkpoint::{\n        base::{Checkpoin"
  },
  {
    "path": "crates/burn-autodiff/src/tensor.rs",
    "chars": 4894,
    "preview": "use crate::{\n    checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},\n    grads::Gradients,\n    graph::{Compu"
  },
  {
    "path": "crates/burn-autodiff/src/utils.rs",
    "chars": 685,
    "preview": "use alloc::vec::Vec;\n\nuse crate::graph::NodeRef;\n/// Duplicate the given object for each node that requires gradients.\n/"
  },
  {
    "path": "crates/burn-backend/Cargo.toml",
    "chars": 1635,
    "preview": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \""
  },
  {
    "path": "crates/burn-backend/README.md",
    "chars": 125,
    "preview": "# Burn Backend\n\nThis crate includes the core backend interfaces and data structures for executing tensor operations\nin B"
  },
  {
    "path": "crates/burn-backend/src/backend/base.rs",
    "chars": 13830,
    "preview": "use burn_std::DType;\npub use burn_std::backtrace::BackTrace;\n\nuse alloc::string::String;\nuse enumset::{EnumSet, EnumSetT"
  },
  {
    "path": "crates/burn-backend/src/backend/device.rs",
    "chars": 533,
    "preview": "pub use burn_std::device::*;\n\n/// Device trait for all burn backend devices.\npub trait DeviceOps: Clone + Default + Part"
  },
  {
    "path": "crates/burn-backend/src/backend/mod.rs",
    "chars": 145,
    "preview": "mod base;\nmod device;\nmod primitive;\n\npub use base::*;\npub use device::*;\npub use primitive::*;\n\n/// Backend operations "
  },
  {
    "path": "crates/burn-backend/src/backend/ops/activation.rs",
    "chars": 9249,
    "preview": "use crate::tensor::FloatTensor;\nuse crate::{Backend, Scalar, TensorMetadata};\nuse core::f64::consts::SQRT_2;\n\n/// Activa"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/argwhere.rs",
    "chars": 1895,
    "preview": "use crate::tensor::{Device, IntTensor};\nuse crate::{Backend, TensorData, element::ElementConversion};\nuse alloc::vec::Ve"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/bool_tensor.rs",
    "chars": 18060,
    "preview": "use super::{\n    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,\n};\nuse crate"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/cat.rs",
    "chars": 1327,
    "preview": "use crate::{\n    Backend, TensorMetadata,\n    tensor::{BasicOps, TensorKind},\n};\nuse alloc::vec::Vec;\nuse burn_std::Slic"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/int_tensor.rs",
    "chars": 42175,
    "preview": "use super::cat::cat_with_slice_assign;\nuse super::repeat_dim::repeat_with_slice_assign;\nuse super::sort::{argsort, sort,"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/mod.rs",
    "chars": 352,
    "preview": "mod activation;\nmod bool_tensor;\nmod int_tensor;\nmod modules;\nmod qtensor;\nmod tensor;\nmod transaction;\n\npub(crate) mod "
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/attention.rs",
    "chars": 4030,
    "preview": "use core::f32;\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\nuse burn_std::Shape;\n\nuse crate::{\n    Backend, Ten"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/base.rs",
    "chars": 37265,
    "preview": "use super::{conv, pool};\nuse crate::ops::unfold::unfold4d_using_conv2d;\nuse crate::tensor::{BoolTensor, FloatTensor, Int"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/conv.rs",
    "chars": 44527,
    "preview": "#![allow(clippy::single_range_in_vec_init)]\nuse super::{ConvOptions, ConvTransposeOptions};\nuse crate::{Backend, TensorM"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/grid_sample.rs",
    "chars": 12890,
    "preview": "use crate::{\n    Backend, TensorMetadata,\n    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},\n    tens"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/mod.rs",
    "chars": 304,
    "preview": "/// Module with convolution operations.\npub mod conv;\n\n/// Module with attention operations.\npub mod attention;\n\n/// Mod"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/pool.rs",
    "chars": 5183,
    "preview": "use crate::tensor::{FloatTensor, IntTensor};\nuse crate::{Backend, TensorMetadata};\nuse burn_std::Shape;\n\nuse super::{Max"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/unfold.rs",
    "chars": 4581,
    "preview": "use super::{ConvOptions, UnfoldOptions};\nuse crate::tensor::FloatTensor;\nuse crate::{Backend, TensorData, TensorMetadata"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/qtensor.rs",
    "chars": 42223,
    "preview": "use alloc::vec::Vec;\nuse burn_std::{\n    Shape, Slice,\n    quantization::{QuantPropagation, QuantScheme},\n};\n\nuse crate:"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/repeat_dim.rs",
    "chars": 1164,
    "preview": "use crate::{\n    Backend, TensorMetadata,\n    tensor::{BasicOps, TensorKind},\n};\nuse alloc::vec::Vec;\nuse burn_std::Slic"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/sort.rs",
    "chars": 12906,
    "preview": "use core::cmp::Ordering;\n\nuse crate::{\n    Backend, DType, TensorData,\n    element::{ElementConversion, ElementOrdered},"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/tensor.rs",
    "chars": 53513,
    "preview": "use super::cat::cat_with_slice_assign;\nuse super::grid_sample::float_grid_sample_2d_ref;\nuse super::repeat_dim::repeat_w"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/transaction.rs",
    "chars": 4925,
    "preview": "use alloc::vec::Vec;\nuse core::future::Future;\n\nuse crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}"
  },
  {
    "path": "crates/burn-backend/src/backend/primitive.rs",
    "chars": 2243,
    "preview": "use crate::Backend;\nuse burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};\nuse burn_std::{DType, Shape};"
  },
  {
    "path": "crates/burn-backend/src/data/compare.rs",
    "chars": 14606,
    "preview": "use alloc::format;\nuse alloc::string::String;\nuse burn_std::{BoolStore, DType, bf16, f16};\nuse num_traits::{Float, ToPri"
  },
  {
    "path": "crates/burn-backend/src/data/mod.rs",
    "chars": 65,
    "preview": "mod compare;\nmod tensor;\n\npub use compare::*;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-backend/src/data/tensor.rs",
    "chars": 34948,
    "preview": "use core::f32;\n\nuse alloc::boxed::Box;\nuse alloc::format;\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse bytemuck::"
  },
  {
    "path": "crates/burn-backend/src/distribution.rs",
    "chars": 3714,
    "preview": "//! Random value distributions used to initialize and populate tensor data.\n\nuse rand::{Rng, RngExt, distr::StandardUnif"
  },
  {
    "path": "crates/burn-backend/src/element/base.rs",
    "chars": 7969,
    "preview": "use core::cmp::Ordering;\nuse rand::Rng;\n\nuse crate::distribution::Distribution;\nuse burn_std::{BoolStore, DType, bf16, f"
  },
  {
    "path": "crates/burn-backend/src/element/cast.rs",
    "chars": 19649,
    "preview": "use core::mem::size_of;\n\nuse burn_std::{bf16, f16};\n\n/// A generic trait for converting a value to a number.\n/// Adapted"
  },
  {
    "path": "crates/burn-backend/src/element/mod.rs",
    "chars": 174,
    "preview": "//! Traits and helpers for working with element types and conversions.\n\nmod base;\nmod scalar;\n\n/// Tensor element castin"
  },
  {
    "path": "crates/burn-backend/src/element/scalar.rs",
    "chars": 2967,
    "preview": "use burn_std::{DType, bf16, f16};\nuse num_traits::ToPrimitive;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nus"
  },
  {
    "path": "crates/burn-backend/src/lib.rs",
    "chars": 2820,
    "preview": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! This libra"
  },
  {
    "path": "crates/burn-backend/src/tensor/alias.rs",
    "chars": 1001,
    "preview": "use crate::backend::Backend;\n\n// We provide some type aliases to improve the readability of using associated types witho"
  },
  {
    "path": "crates/burn-backend/src/tensor/container.rs",
    "chars": 2215,
    "preview": "use alloc::boxed::Box;\nuse core::any::Any;\n\n#[cfg(not(feature = \"std\"))]\nuse alloc::vec::Vec;\n#[cfg(not(feature = \"std\")"
  },
  {
    "path": "crates/burn-backend/src/tensor/kind.rs",
    "chars": 1087,
    "preview": "use crate::{Backend, TensorMetadata, TensorPrimitive};\n\n/// A type-level representation of the kind of a float tensor\n#["
  },
  {
    "path": "crates/burn-backend/src/tensor/mod.rs",
    "chars": 174,
    "preview": "mod alias;\nmod container;\nmod kind;\nmod ops;\n\npub use alias::*;\npub use container::*;\npub use kind::*;\npub use ops::*;\n\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/autodiff.rs",
    "chars": 1934,
    "preview": "use crate::{\n    AutodiffBackend,\n    tensor::{BasicOps, TensorKind},\n};\n\n/// Trait that list all operations that can be"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/base.rs",
    "chars": 32454,
    "preview": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    Backend, ExecutionError, Scalar, TensorData,"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/bool.rs",
    "chars": 6697,
    "preview": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, ExecutionError, Sc"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/float.rs",
    "chars": 25414,
    "preview": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, Distribution, Exec"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/int.rs",
    "chars": 12509,
    "preview": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, Distribution, Exec"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/mod.rs",
    "chars": 452,
    "preview": "mod autodiff;\nmod base;\nmod bool;\nmod float;\nmod int;\nmod numeric;\nmod ordered;\n\npub use autodiff::*;\npub use base::*;\np"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/numeric.rs",
    "chars": 20966,
    "preview": "use burn_std::Shape;\n\nuse crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps};\n\n/// Trait that li"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/ordered.rs",
    "chars": 28607,
    "preview": "use crate::{\n    Backend, Scalar,\n    tensor::{IntTensor, Numeric},\n};\n\n/// Trait that list all operations that can be a"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/calibration.rs",
    "chars": 185,
    "preview": "/// Calibration method used to compute the quantization range mapping.\npub enum Calibration {\n    /// Computes quantizat"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/mod.rs",
    "chars": 112,
    "preview": "mod calibration;\nmod parameters;\nmod scheme;\n\npub use calibration::*;\npub use parameters::*;\npub use scheme::*;\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/parameters.rs",
    "chars": 503,
    "preview": "use crate::Backend;\n\npub use burn_std::quantization::{QParamTensor, QParams};\n\n/// The quantization parameters primitive"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/scheme.rs",
    "chars": 2571,
    "preview": "pub use burn_std::{QPARAM_ALIGN, params_shape};\nuse burn_std::{QuantLevel, QuantMode, QuantScheme, Shape};\n\nuse super::{"
  },
  {
    "path": "crates/burn-backend-tests/.cargo/config.toml",
    "chars": 659,
    "preview": "[alias]\ntest-cpu = \"test --release --no-default-features --features cpu,std\"\ntest-cuda = \"test --release --no-default-fe"
  },
  {
    "path": "crates/burn-backend-tests/Cargo.toml",
    "chars": 3709,
    "preview": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \""
  },
  {
    "path": "crates/burn-backend-tests/README.md",
    "chars": 3659,
    "preview": "# Burn Backend Tests\n\nThis crate provides a comprehensive suite of tests for Burn backends, covering:\n\n- Tensor operatio"
  },
  {
    "path": "crates/burn-backend-tests/cubecl.toml",
    "chars": 328,
    "preview": "[profiling]\nlogger = { file = \"target/profiling.log\", level = \"disabled\" }\n\n[autotune]\nlogger = { file = \"target/autotun"
  },
  {
    "path": "crates/burn-backend-tests/src/lib.rs",
    "chars": 608,
    "preview": "extern crate alloc;\n\n#[cfg(feature = \"std\")]\npub use burn_tensor_testgen::might_panic;\n\n/// Generate a test module with "
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/abs.rs",
    "chars": 2058,
    "preview": "use super::*;\nuse burn_tensor::{TensorData, Tolerance, cast::ToElement};\n\n#[test]\nfn should_diff_abs() {\n    let data_1 "
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/adaptive_avgpool1d.rs",
    "chars": 1430,
    "preview": "use super::*;\nuse burn_tensor::module::adaptive_avg_pool1d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_po"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/adaptive_avgpool2d.rs",
    "chars": 2745,
    "preview": "use super::*;\nuse burn_tensor::module::adaptive_avg_pool2d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_po"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/add.rs",
    "chars": 2352,
    "preview": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_add() {\n    let device = Default::default();\n    let "
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/aggregation.rs",
    "chars": 4990,
    "preview": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_mean() {\n    let data_1 = TensorData::fr"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/avgpool1d.rs",
    "chars": 2614,
    "preview": "use super::*;\nuse burn_tensor::module::avg_pool1d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool1d_simp"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/avgpool2d.rs",
    "chars": 3654,
    "preview": "use super::*;\nuse burn_tensor::module::avg_pool2d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool2d_simp"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/backward.rs",
    "chars": 949,
    "preview": "use super::*;\nuse burn_tensor::{Int, Tensor, TensorData, module::embedding};\n\n#[test]\nfn test_embedding_backward() {\n   "
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/bridge.rs",
    "chars": 761,
    "preview": "use super::*;\nuse burn_tensor::{DType, Distribution, Tensor};\n\n#[test]\nfn test_full_precision() {\n    let device = Defau"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/broadcast.rs",
    "chars": 1393,
    "preview": "use super::*;\n\n#[test]\nfn mul_broadcast() {\n    test_ops_broadcast_backward(|x, y| x * y);\n}\n\n#[test]\nfn div_broadcast()"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cast.rs",
    "chars": 717,
    "preview": "// Skip on metal - F64 not supported\n#![cfg(all(feature = \"std\", not(feature = \"metal\")))]\n\nuse super::*;\nuse burn_backe"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cat.rs",
    "chars": 3849,
    "preview": "use super::*;\n\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_cat() {\n    let device = Default::default();\n    let "
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/ceil.rs",
    "chars": 589,
    "preview": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_ceil() {\n    let data = TensorData::from([\n        [-"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/checkpoint.rs",
    "chars": 9658,
    "preview": "use super::*;\nuse burn_tensor::{Bool, Tensor, TensorData};\n\n#[test]\nfn test_autodiff_checkpoint_complicated_computation("
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/complex.rs",
    "chars": 2813,
    "preview": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_full_complex_1() {\n    let data_1 = TensorData::from("
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv1d.rs",
    "chars": 7469,
    "preview": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};\n\n#[test]\nfn test_conv1d_basic() {\n "
  }
]

// ... and 1617 more files (download for full content)

About this extraction

This page contains the full source code of the tracel-ai/burn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 1817 files (9.5 MB), approximately 2.6M tokens, and a symbol index with 15549 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!