Repository: tracel-ai/burn Branch: main Commit: bcaabad860c5 Files: 1817 Total size: 9.5 MB Directory structure: gitextract_82tcaglc/ ├── .cargo/ │ ├── audit.toml │ └── config.toml ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── doc_request.md │ │ └── feature_request.md │ ├── PULL_REQUEST_TEMPLATE/ │ │ └── template.md │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ ├── combine-dependabot-prs.yml │ ├── dependencies.yml │ ├── publish.yml │ ├── stale-pr.yml │ ├── test-gpu.yml │ ├── test.yml │ ├── valgrind.yml │ └── vulnerabilities.yml ├── .gitignore ├── CITATION.cff ├── CODE-OF-CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── NOTICES.md ├── POEM.md ├── README.md ├── _typos.toml ├── benchmarks.toml ├── burn-book/ │ ├── .gitignore │ ├── .prettierrc.json │ ├── book.toml │ └── src/ │ ├── SUMMARY.md │ ├── advanced/ │ │ ├── README.md │ │ ├── backend-extension/ │ │ │ ├── README.md │ │ │ ├── custom-cubecl-kernel.md │ │ │ └── custom-wgpu-kernel.md │ │ ├── no-std.md │ │ └── web-assembly.md │ ├── basic-workflow/ │ │ ├── README.md │ │ ├── backend.md │ │ ├── data.md │ │ ├── inference.md │ │ ├── model.md │ │ └── training.md │ ├── building-blocks/ │ │ ├── README.md │ │ ├── autodiff.md │ │ ├── backend.md │ │ ├── config.md │ │ ├── dataset.md │ │ ├── learner.md │ │ ├── metric.md │ │ ├── module.md │ │ ├── record.md │ │ └── tensor.md │ ├── custom-training-loop.md │ ├── distributed-computing.md │ ├── examples.md │ ├── getting-started.md │ ├── models-and-pretrained-weights.md │ ├── motivation.md │ ├── onnx-import.md │ ├── overview.md │ ├── performance/ │ │ ├── README.md │ │ ├── distributed-computing.md │ │ ├── good-practices/ │ │ │ ├── README.md │ │ │ ├── asynchronous-execution.md │ │ │ ├── kernel-fusion.md │ │ │ └── kernel-selection.md │ │ └── quantization.md │ └── saving-and-loading.md ├── codecov.yml ├── contributor-book/ │ ├── .gitignore │ ├── .prettierrc.json │ ├── book.toml │ └── src/ │ ├── SUMMARY.md │ ├── frequently-encountered-issues/ │ │ ├── README.md │ │ └── issues-while-adding-ops.md │ ├── getting-started/ │ │ ├── README.md │ │ ├── configuring-your-editor.md │ │ ├── setting-up-the-environment.md │ │ └── testing.md │ ├── guides/ │ │ ├── README.md │ │ ├── adding-a-new-operation-to-burn.md │ │ └── submitting-examples.md │ ├── how-to-read-this-book.md │ ├── overview.md │ └── project-architecture/ │ ├── README.md │ ├── backend.md │ ├── module.md │ ├── serialization.md │ └── tensor.md ├── crates/ │ ├── burn/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── backend.rs │ │ ├── collective.rs │ │ └── lib.rs │ ├── burn-autodiff/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── checkpoint/ │ │ │ ├── base.rs │ │ │ ├── builder.rs │ │ │ ├── mod.rs │ │ │ ├── retro_forward.rs │ │ │ ├── state.rs │ │ │ └── strategy.rs │ │ ├── grads.rs │ │ ├── graph/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── node.rs │ │ │ ├── requirement.rs │ │ │ └── traversal.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── backward.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── maxmin.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── sort.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── runtime/ │ │ │ ├── client.rs │ │ │ ├── graph.rs │ │ │ ├── memory_management.rs │ │ │ ├── mod.rs │ │ │ └── server.rs │ │ ├── tensor.rs │ │ └── utils.rs │ ├── burn-backend/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend/ │ │ │ ├── base.rs │ │ │ ├── device.rs │ │ │ ├── mod.rs │ │ │ ├── ops/ │ │ │ │ ├── activation.rs │ │ │ │ ├── argwhere.rs │ │ │ │ ├── bool_tensor.rs │ │ │ │ ├── cat.rs │ │ │ │ ├── int_tensor.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── modules/ │ │ │ │ │ ├── attention.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── conv.rs │ │ │ │ │ ├── grid_sample.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── pool.rs │ │ │ │ │ └── unfold.rs │ │ │ │ ├── qtensor.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── sort.rs │ │ │ │ ├── tensor.rs │ │ │ │ └── transaction.rs │ │ │ └── primitive.rs │ │ ├── data/ │ │ │ ├── compare.rs │ │ │ ├── mod.rs │ │ │ └── tensor.rs │ │ ├── distribution.rs │ │ ├── element/ │ │ │ ├── base.rs │ │ │ ├── cast.rs │ │ │ ├── mod.rs │ │ │ └── scalar.rs │ │ ├── lib.rs │ │ └── tensor/ │ │ ├── alias.rs │ │ ├── container.rs │ │ ├── kind.rs │ │ ├── mod.rs │ │ ├── ops/ │ │ │ ├── autodiff.rs │ │ │ ├── base.rs │ │ │ ├── bool.rs │ │ │ ├── float.rs │ │ │ ├── int.rs │ │ │ ├── mod.rs │ │ │ ├── numeric.rs │ │ │ └── ordered.rs │ │ └── quantization/ │ │ ├── calibration.rs │ │ ├── mod.rs │ │ ├── parameters.rs │ │ └── scheme.rs │ ├── burn-backend-tests/ │ │ ├── .cargo/ │ │ │ └── config.toml │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── src/ │ │ │ └── lib.rs │ │ └── tests/ │ │ ├── autodiff/ │ │ │ ├── abs.rs │ │ │ ├── adaptive_avgpool1d.rs │ │ │ ├── adaptive_avgpool2d.rs │ │ │ ├── add.rs │ │ │ ├── aggregation.rs │ │ │ ├── avgpool1d.rs │ │ │ ├── avgpool2d.rs │ │ │ ├── backward.rs │ │ │ ├── bridge.rs │ │ │ ├── broadcast.rs │ │ │ ├── cast.rs │ │ │ ├── cat.rs │ │ │ ├── ceil.rs │ │ │ ├── checkpoint.rs │ │ │ ├── complex.rs │ │ │ ├── conv1d.rs │ │ │ ├── conv2d.rs │ │ │ ├── conv3d.rs │ │ │ ├── conv_transpose1d.rs │ │ │ ├── conv_transpose2d.rs │ │ │ ├── conv_transpose3d.rs │ │ │ ├── cross.rs │ │ │ ├── cross_entropy.rs │ │ │ ├── cummax.rs │ │ │ ├── cummin.rs │ │ │ ├── cumprod.rs │ │ │ ├── cumsum.rs │ │ │ ├── deform_conv2d.rs │ │ │ ├── div.rs │ │ │ ├── erf.rs │ │ │ ├── exp.rs │ │ │ ├── expand.rs │ │ │ ├── flip.rs │ │ │ ├── floor.rs │ │ │ ├── gather_scatter.rs │ │ │ ├── gelu.rs │ │ │ ├── gradients.rs │ │ │ ├── log.rs │ │ │ ├── log1p.rs │ │ │ ├── log_sigmoid.rs │ │ │ ├── mask.rs │ │ │ ├── matmul.rs │ │ │ ├── maxmin.rs │ │ │ ├── maxpool1d.rs │ │ │ ├── maxpool2d.rs │ │ │ ├── memory_management.rs │ │ │ ├── mod.rs │ │ │ ├── mul.rs │ │ │ ├── multithread.rs │ │ │ ├── nearest_interpolate.rs │ │ │ ├── neg.rs │ │ │ ├── nonzero.rs │ │ │ ├── permute.rs │ │ │ ├── pow.rs │ │ │ ├── recip.rs │ │ │ ├── relu.rs │ │ │ ├── remainder.rs │ │ │ ├── repeat_dim.rs │ │ │ ├── reshape.rs │ │ │ ├── round.rs │ │ │ ├── select.rs │ │ │ ├── sigmoid.rs │ │ │ ├── sign.rs │ │ │ ├── slice.rs │ │ │ ├── slice_assign.rs │ │ │ ├── softmax.rs │ │ │ ├── sort.rs │ │ │ ├── sqrt.rs │ │ │ ├── sub.rs │ │ │ ├── transpose.rs │ │ │ ├── trig.rs │ │ │ └── unfold.rs │ │ ├── autodiff.rs │ │ ├── common/ │ │ │ ├── autodiff.rs │ │ │ ├── backend.rs │ │ │ └── tensor.rs │ │ ├── cubecl/ │ │ │ ├── avg_pool2d.rs │ │ │ ├── bernoulli.rs │ │ │ ├── cast.rs │ │ │ ├── cat.rs │ │ │ ├── clamp.rs │ │ │ ├── contiguous.rs │ │ │ ├── conv2d.rs │ │ │ ├── conv3d.rs │ │ │ ├── conv_transpose2d.rs │ │ │ ├── conv_transpose3d.rs │ │ │ ├── cross.rs │ │ │ ├── gather.rs │ │ │ ├── mask_fill.rs │ │ │ ├── mask_where.rs │ │ │ ├── max_pool2d.rs │ │ │ ├── max_pool2d_backward.rs │ │ │ ├── mod.rs │ │ │ ├── normal.rs │ │ │ ├── quantization.rs │ │ │ ├── reduce.rs │ │ │ ├── repeat_dim.rs │ │ │ ├── scatter.rs │ │ │ ├── select.rs │ │ │ ├── select_assign.rs │ │ │ ├── slice.rs │ │ │ ├── slice_assign.rs │ │ │ ├── unary.rs │ │ │ └── uniform.rs │ │ ├── cubecl.rs │ │ ├── fused_ops/ │ │ │ ├── mod.rs │ │ │ └── reduce_broadcasted.rs │ │ ├── fusion.rs │ │ ├── tensor/ │ │ │ ├── bool/ │ │ │ │ ├── mod.rs │ │ │ │ └── ops/ │ │ │ │ ├── all.rs │ │ │ │ ├── any.rs │ │ │ │ ├── argwhere_nonzero.rs │ │ │ │ ├── cat.rs │ │ │ │ ├── comparison.rs │ │ │ │ ├── create_like.rs │ │ │ │ ├── expand.rs │ │ │ │ ├── flip.rs │ │ │ │ ├── full.rs │ │ │ │ ├── gather_scatter.rs │ │ │ │ ├── init.rs │ │ │ │ ├── logical.rs │ │ │ │ ├── mask.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── movedim.rs │ │ │ │ ├── permute.rs │ │ │ │ ├── repeat.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── reshape.rs │ │ │ │ ├── select.rs │ │ │ │ ├── stack.rs │ │ │ │ ├── take.rs │ │ │ │ ├── transpose.rs │ │ │ │ ├── tri_mask.rs │ │ │ │ └── unfold.rs │ │ │ ├── clone_invariance.rs │ │ │ ├── float/ │ │ │ │ ├── activation/ │ │ │ │ │ ├── celu.rs │ │ │ │ │ ├── elu.rs │ │ │ │ │ ├── gelu.rs │ │ │ │ │ ├── glu.rs │ │ │ │ │ ├── hard_sigmoid.rs │ │ │ │ │ ├── leaky_relu.rs │ │ │ │ │ ├── log_sigmoid.rs │ │ │ │ │ ├── mish.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── prelu.rs │ │ │ │ │ ├── quiet_softmax.rs │ │ │ │ │ ├── relu.rs │ │ │ │ │ ├── selu.rs │ │ │ │ │ ├── sigmoid.rs │ │ │ │ │ ├── silu.rs │ │ │ │ │ ├── softmax.rs │ │ │ │ │ ├── softmin.rs │ │ │ │ │ ├── softplus.rs │ │ │ │ │ ├── softsign.rs │ │ │ │ │ ├── tanh_activation.rs │ │ │ │ │ └── thresholded_relu.rs │ │ │ │ ├── grid/ │ │ │ │ │ ├── affine_grid.rs │ │ │ │ │ ├── meshgrid.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── linalg/ │ │ │ │ │ ├── cosine_similarity.rs │ │ │ │ │ ├── diag.rs │ │ │ │ │ ├── lu_decomposition.rs │ │ │ │ │ ├── matvec.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── outer.rs │ │ │ │ │ ├── trace.rs │ │ │ │ │ └── vector_norm.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── module/ │ │ │ │ │ ├── adaptive_avgpool1d.rs │ │ │ │ │ ├── adaptive_avgpool2d.rs │ │ │ │ │ ├── attention.rs │ │ │ │ │ ├── avgpool1d.rs │ │ │ │ │ ├── avgpool2d.rs │ │ │ │ │ ├── bicubic_interpolate.rs │ │ │ │ │ ├── bilinear_interpolate.rs │ │ │ │ │ ├── conv1d.rs │ │ │ │ │ ├── conv2d.rs │ │ │ │ │ ├── conv3d.rs │ │ │ │ │ ├── conv_transpose1d.rs │ │ │ │ │ ├── conv_transpose2d.rs │ │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ │ ├── deform_conv2d.rs │ │ │ │ │ ├── forward.rs │ │ │ │ │ ├── lanczos3_interpolate.rs │ │ │ │ │ ├── linear.rs │ │ │ │ │ ├── maxpool1d.rs │ │ │ │ │ ├── maxpool2d.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── nearest_interpolate.rs │ │ │ │ │ └── unfold4d.rs │ │ │ │ ├── ops/ │ │ │ │ │ ├── abs.rs │ │ │ │ │ ├── add.rs │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ ├── all.rs │ │ │ │ │ ├── any.rs │ │ │ │ │ ├── arg.rs │ │ │ │ │ ├── cast.rs │ │ │ │ │ ├── cat.rs │ │ │ │ │ ├── ceil.rs │ │ │ │ │ ├── chunk.rs │ │ │ │ │ ├── clamp.rs │ │ │ │ │ ├── close.rs │ │ │ │ │ ├── comparison.rs │ │ │ │ │ ├── create_like.rs │ │ │ │ │ ├── cross.rs │ │ │ │ │ ├── cumulative.rs │ │ │ │ │ ├── div.rs │ │ │ │ │ ├── dot.rs │ │ │ │ │ ├── erf.rs │ │ │ │ │ ├── exp.rs │ │ │ │ │ ├── expand.rs │ │ │ │ │ ├── finite.rs │ │ │ │ │ ├── flatten.rs │ │ │ │ │ ├── flip.rs │ │ │ │ │ ├── floor.rs │ │ │ │ │ ├── fmod.rs │ │ │ │ │ ├── full.rs │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ ├── grid_sample.rs │ │ │ │ │ ├── inf.rs │ │ │ │ │ ├── init.rs │ │ │ │ │ ├── iter_dim.rs │ │ │ │ │ ├── log.rs │ │ │ │ │ ├── log1p.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── matmul.rs │ │ │ │ │ ├── maxmin.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── movedim.rs │ │ │ │ │ ├── mul.rs │ │ │ │ │ ├── nan.rs │ │ │ │ │ ├── narrow.rs │ │ │ │ │ ├── neg.rs │ │ │ │ │ ├── one_hot.rs │ │ │ │ │ ├── padding.rs │ │ │ │ │ ├── permute.rs │ │ │ │ │ ├── powf.rs │ │ │ │ │ ├── powf_scalar.rs │ │ │ │ │ ├── prod.rs │ │ │ │ │ ├── random.rs │ │ │ │ │ ├── recip.rs │ │ │ │ │ ├── remainder.rs │ │ │ │ │ ├── repeat.rs │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ ├── reshape.rs │ │ │ │ │ ├── round.rs │ │ │ │ │ ├── select.rs │ │ │ │ │ ├── sign.rs │ │ │ │ │ ├── slice.rs │ │ │ │ │ ├── slice_assign.rs │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ ├── split.rs │ │ │ │ │ ├── sqrt.rs │ │ │ │ │ ├── square.rs │ │ │ │ │ ├── squeeze.rs │ │ │ │ │ ├── stack.rs │ │ │ │ │ ├── sub.rs │ │ │ │ │ ├── take.rs │ │ │ │ │ ├── topk.rs │ │ │ │ │ ├── transaction.rs │ │ │ │ │ ├── transpose.rs │ │ │ │ │ ├── tri.rs │ │ │ │ │ ├── trig.rs │ │ │ │ │ ├── trunc.rs │ │ │ │ │ └── unfold.rs │ │ │ │ ├── primitive.rs │ │ │ │ ├── quantization/ │ │ │ │ │ ├── calibration.rs │ │ │ │ │ ├── data.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── extended/ │ │ │ │ │ │ │ ├── abs.rs │ │ │ │ │ │ │ ├── add.rs │ │ │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ │ │ ├── all.rs │ │ │ │ │ │ │ ├── any.rs │ │ │ │ │ │ │ ├── arg.rs │ │ │ │ │ │ │ ├── cat.rs │ │ │ │ │ │ │ ├── ceil.rs │ │ │ │ │ │ │ ├── chunk.rs │ │ │ │ │ │ │ ├── clamp.rs │ │ │ │ │ │ │ ├── cos.rs │ │ │ │ │ │ │ ├── cosh.rs │ │ │ │ │ │ │ ├── div.rs │ │ │ │ │ │ │ ├── erf.rs │ │ │ │ │ │ │ ├── exp.rs │ │ │ │ │ │ │ ├── expand.rs │ │ │ │ │ │ │ ├── flip.rs │ │ │ │ │ │ │ ├── floor.rs │ │ │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ │ │ ├── log.rs │ │ │ │ │ │ │ ├── log1p.rs │ │ │ │ │ │ │ ├── map_comparison.rs │ │ │ │ │ │ │ ├── mask.rs │ │ │ │ │ │ │ ├── maxmin.rs │ │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ │ ├── mul.rs │ │ │ │ │ │ │ ├── narrow.rs │ │ │ │ │ │ │ ├── neg.rs │ │ │ │ │ │ │ ├── permute.rs │ │ │ │ │ │ │ ├── powf.rs │ │ │ │ │ │ │ ├── powf_scalar.rs │ │ │ │ │ │ │ ├── recip.rs │ │ │ │ │ │ │ ├── remainder.rs │ │ │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ │ │ ├── reshape.rs │ │ │ │ │ │ │ ├── round.rs │ │ │ │ │ │ │ ├── select.rs │ │ │ │ │ │ │ ├── sin.rs │ │ │ │ │ │ │ ├── sinh.rs │ │ │ │ │ │ │ ├── slice.rs │ │ │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ │ │ ├── split.rs │ │ │ │ │ │ │ ├── sqrt.rs │ │ │ │ │ │ │ ├── stack.rs │ │ │ │ │ │ │ ├── sub.rs │ │ │ │ │ │ │ ├── tan.rs │ │ │ │ │ │ │ ├── tanh.rs │ │ │ │ │ │ │ ├── topk.rs │ │ │ │ │ │ │ └── transpose.rs │ │ │ │ │ │ ├── matmul.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ └── quantize.rs │ │ │ │ │ └── scheme.rs │ │ │ │ └── stats/ │ │ │ │ ├── cov.rs │ │ │ │ ├── display.rs │ │ │ │ ├── eye.rs │ │ │ │ ├── median.rs │ │ │ │ ├── mod.rs │ │ │ │ └── var.rs │ │ │ ├── int/ │ │ │ │ ├── mod.rs │ │ │ │ ├── ops/ │ │ │ │ │ ├── abs.rs │ │ │ │ │ ├── add.rs │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ ├── all.rs │ │ │ │ │ ├── any.rs │ │ │ │ │ ├── arange.rs │ │ │ │ │ ├── arange_step.rs │ │ │ │ │ ├── arg.rs │ │ │ │ │ ├── bitwise.rs │ │ │ │ │ ├── cartesian_grid.rs │ │ │ │ │ ├── cast.rs │ │ │ │ │ ├── cat.rs │ │ │ │ │ ├── chunk.rs │ │ │ │ │ ├── comparison.rs │ │ │ │ │ ├── create_like.rs │ │ │ │ │ ├── cumulative.rs │ │ │ │ │ ├── div.rs │ │ │ │ │ ├── expand.rs │ │ │ │ │ ├── flip.rs │ │ │ │ │ ├── full.rs │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ ├── init.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── matmul.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── movedim.rs │ │ │ │ │ ├── mul.rs │ │ │ │ │ ├── one_hot.rs │ │ │ │ │ ├── permute.rs │ │ │ │ │ ├── random.rs │ │ │ │ │ ├── remainder.rs │ │ │ │ │ ├── repeat.rs │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ ├── reshape.rs │ │ │ │ │ ├── roll.rs │ │ │ │ │ ├── select.rs │ │ │ │ │ ├── sign.rs │ │ │ │ │ ├── slice.rs │ │ │ │ │ ├── slice_assign.rs │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ ├── stack.rs │ │ │ │ │ ├── sub.rs │ │ │ │ │ ├── take.rs │ │ │ │ │ ├── topk.rs │ │ │ │ │ ├── transpose.rs │ │ │ │ │ ├── tri.rs │ │ │ │ │ └── unfold.rs │ │ │ │ └── primitive.rs │ │ │ ├── mod.rs │ │ │ └── multi_threads.rs │ │ └── tensor.rs │ ├── burn-candle/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── candle_utils.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── utils.rs │ │ └── tensor.rs │ ├── burn-collective/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── multinode-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── README.md │ │ │ └── src/ │ │ │ ├── bin/ │ │ │ │ ├── global.rs │ │ │ │ ├── node.rs │ │ │ │ └── test_launcher.rs │ │ │ ├── lib.rs │ │ │ └── shared.rs │ │ └── src/ │ │ ├── api.rs │ │ ├── config.rs │ │ ├── global/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── node/ │ │ │ │ ├── base.rs │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ring.rs │ │ │ │ ├── sync.rs │ │ │ │ ├── tree.rs │ │ │ │ └── worker.rs │ │ │ ├── orchestrator/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── state.rs │ │ │ └── shared.rs │ │ ├── lib.rs │ │ ├── local/ │ │ │ ├── all_reduce/ │ │ │ │ ├── base.rs │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ ├── ring.rs │ │ │ │ └── tree.rs │ │ │ ├── broadcast/ │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ └── tree.rs │ │ │ ├── client.rs │ │ │ ├── mod.rs │ │ │ ├── reduce/ │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ └── tree.rs │ │ │ ├── server.rs │ │ │ └── tensor_map.rs │ │ └── tests/ │ │ ├── all_reduce.rs │ │ ├── broadcast.rs │ │ ├── mod.rs │ │ └── reduce.rs │ ├── burn-communication/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── base.rs │ │ ├── data_service.rs │ │ ├── lib.rs │ │ ├── util.rs │ │ └── websocket/ │ │ ├── base.rs │ │ ├── client.rs │ │ ├── mod.rs │ │ └── server.rs │ ├── burn-core/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── config.rs │ │ │ ├── data/ │ │ │ │ ├── dataloader/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── batch.rs │ │ │ │ │ ├── batcher.rs │ │ │ │ │ ├── builder.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── multithread.rs │ │ │ │ │ ├── split.rs │ │ │ │ │ └── strategy.rs │ │ │ │ └── mod.rs │ │ │ ├── lib.rs │ │ │ ├── module/ │ │ │ │ ├── base.rs │ │ │ │ ├── display.rs │ │ │ │ ├── initializer.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── param/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── constant.rs │ │ │ │ │ ├── id.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── primitive.rs │ │ │ │ │ ├── running.rs │ │ │ │ │ ├── tensor.rs │ │ │ │ │ └── visitor.rs │ │ │ │ ├── quantize.rs │ │ │ │ └── reinit.rs │ │ │ ├── record/ │ │ │ │ ├── base.rs │ │ │ │ ├── file.rs │ │ │ │ ├── memory.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── primitive.rs │ │ │ │ ├── recorder.rs │ │ │ │ ├── serde/ │ │ │ │ │ ├── adapter.rs │ │ │ │ │ ├── data.rs │ │ │ │ │ ├── de.rs │ │ │ │ │ ├── error.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── ser.rs │ │ │ │ ├── settings.rs │ │ │ │ └── tensor.rs │ │ │ ├── tensor.rs │ │ │ └── vision.rs │ │ └── tests/ │ │ ├── test_derive_config.rs │ │ ├── test_derive_module.rs │ │ ├── test_derive_record.rs │ │ └── test_record_resilience.rs │ ├── burn-cpu/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-cubecl/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── fusion.rs │ │ ├── kernel/ │ │ │ ├── attention/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune.rs │ │ │ ├── binary.rs │ │ │ ├── binary_float.rs │ │ │ ├── binary_int.rs │ │ │ ├── cast/ │ │ │ │ ├── base.rs │ │ │ │ ├── bool_cast.rs │ │ │ │ └── mod.rs │ │ │ ├── clamp.rs │ │ │ ├── comparison.rs │ │ │ ├── contiguous.rs │ │ │ ├── conv/ │ │ │ │ ├── backward_data/ │ │ │ │ │ ├── fallback.rs │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── backward_weight/ │ │ │ │ │ ├── fallback.rs │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── base.rs │ │ │ │ ├── conv_transpose2d/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── col2im.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── transpose_direct.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ ├── deform_conv2d.rs │ │ │ │ ├── deform_conv_transpose2d.rs │ │ │ │ ├── direct.rs │ │ │ │ ├── forward/ │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── im2col.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune_key.rs │ │ │ ├── cross.rs │ │ │ ├── grid_sample/ │ │ │ │ ├── base.rs │ │ │ │ ├── bilinear.rs │ │ │ │ └── mod.rs │ │ │ ├── index/ │ │ │ │ ├── flip.rs │ │ │ │ ├── gather.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── scatter.rs │ │ │ │ ├── select.rs │ │ │ │ ├── select_assign.rs │ │ │ │ ├── slice.rs │ │ │ │ └── slice_assign.rs │ │ │ ├── interpolate/ │ │ │ │ ├── base.rs │ │ │ │ ├── bicubic.rs │ │ │ │ ├── bilinear.rs │ │ │ │ ├── lanczos3.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── nearest.rs │ │ │ │ └── nearest_backward.rs │ │ │ ├── mask/ │ │ │ │ ├── base.rs │ │ │ │ ├── mask_fill.rs │ │ │ │ ├── mask_where.rs │ │ │ │ └── mod.rs │ │ │ ├── matmul/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── tune/ │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ └── utils.rs │ │ │ ├── mod.rs │ │ │ ├── pool/ │ │ │ │ ├── adaptive_avg_pool2d.rs │ │ │ │ ├── adaptive_avg_pool2d_backward.rs │ │ │ │ ├── avg_pool2d.rs │ │ │ │ ├── avg_pool2d_backward.rs │ │ │ │ ├── max_pool2d.rs │ │ │ │ ├── max_pool2d_backward.rs │ │ │ │ ├── mod.rs │ │ │ │ └── pool2d.rs │ │ │ ├── prng/ │ │ │ │ ├── bernoulli.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── normal.rs │ │ │ │ └── uniform.rs │ │ │ ├── quantization/ │ │ │ │ ├── dequantize.rs │ │ │ │ ├── mod.rs │ │ │ │ └── quantize.rs │ │ │ ├── reduce/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune.rs │ │ │ ├── unary_float.rs │ │ │ ├── unary_int.rs │ │ │ ├── unary_numeric.rs │ │ │ └── utils.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── numeric.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── template/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── source.rs │ │ ├── tensor/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── quantization.rs │ │ └── tune_key.rs │ ├── burn-cubecl-fusion/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── base.rs │ │ ├── engine/ │ │ │ ├── codegen/ │ │ │ │ ├── base.rs │ │ │ │ ├── io.rs │ │ │ │ ├── ir.rs │ │ │ │ ├── kernel.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── tensor.rs │ │ │ │ └── view.rs │ │ │ ├── fuser.rs │ │ │ ├── launch/ │ │ │ │ ├── base.rs │ │ │ │ ├── executor.rs │ │ │ │ ├── input.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── output.rs │ │ │ │ ├── plan.rs │ │ │ │ ├── runner.rs │ │ │ │ └── vectorization/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── planner.rs │ │ │ ├── mod.rs │ │ │ ├── scoring.rs │ │ │ ├── settings.rs │ │ │ └── trace/ │ │ │ ├── base.rs │ │ │ ├── block.rs │ │ │ ├── fuser.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── optim/ │ │ │ ├── base.rs │ │ │ ├── elemwise/ │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ └── optimization.rs │ │ │ ├── matmul/ │ │ │ │ ├── args.rs │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── optimization.rs │ │ │ │ └── tune.rs │ │ │ ├── mod.rs │ │ │ ├── reduce/ │ │ │ │ ├── args.rs │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── optimization.rs │ │ │ │ └── tune.rs │ │ │ └── reduce_broadcasted/ │ │ │ ├── fuser/ │ │ │ │ ├── base.rs │ │ │ │ ├── block.rs │ │ │ │ ├── full.rs │ │ │ │ ├── full_analyzer.rs │ │ │ │ └── mod.rs │ │ │ ├── launch.rs │ │ │ ├── mod.rs │ │ │ ├── optimization.rs │ │ │ ├── tune.rs │ │ │ └── unit.rs │ │ └── tune.rs │ ├── burn-cuda/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-dataset/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── hf_dataset.rs │ │ │ └── speech_commands.rs │ │ ├── src/ │ │ │ ├── audio/ │ │ │ │ ├── mod.rs │ │ │ │ └── speech_commands.rs │ │ │ ├── dataset/ │ │ │ │ ├── base.rs │ │ │ │ ├── dataframe.rs │ │ │ │ ├── fake.rs │ │ │ │ ├── in_memory.rs │ │ │ │ ├── iterator.rs │ │ │ │ ├── mod.rs │ │ │ │ └── sqlite.rs │ │ │ ├── lib.rs │ │ │ ├── nlp/ │ │ │ │ ├── ag_news.rs │ │ │ │ ├── mod.rs │ │ │ │ └── text_folder.rs │ │ │ ├── source/ │ │ │ │ ├── huggingface/ │ │ │ │ │ ├── downloader.rs │ │ │ │ │ ├── importer.py │ │ │ │ │ └── mod.rs │ │ │ │ └── mod.rs │ │ │ ├── transform/ │ │ │ │ ├── composed.rs │ │ │ │ ├── mapper.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── options.rs │ │ │ │ ├── partial.rs │ │ │ │ ├── sampler.rs │ │ │ │ ├── selection.rs │ │ │ │ ├── shuffle.rs │ │ │ │ └── window.rs │ │ │ └── vision/ │ │ │ ├── cifar.rs │ │ │ ├── image_folder.rs │ │ │ ├── mnist.rs │ │ │ └── mod.rs │ │ └── tests/ │ │ └── data/ │ │ ├── dataset-fmt.csv │ │ ├── dataset.csv │ │ ├── dataset.json │ │ ├── dataset_coco.json │ │ ├── segmask_folder/ │ │ │ └── annotations/ │ │ │ ├── mask_checkerboard.txt │ │ │ ├── mask_random_2colors.txt │ │ │ └── mask_random_3colors.txt │ │ └── text_folder/ │ │ ├── negative/ │ │ │ ├── sample1.txt │ │ │ └── sample2.txt │ │ └── positive/ │ │ ├── sample1.txt │ │ └── sample2.txt │ ├── burn-derive/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── config/ │ │ │ ├── analyzer.rs │ │ │ ├── analyzer_enum.rs │ │ │ ├── analyzer_struct.rs │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── module/ │ │ │ ├── base.rs │ │ │ ├── codegen.rs │ │ │ ├── codegen_enum.rs │ │ │ ├── codegen_struct.rs │ │ │ ├── display.rs │ │ │ ├── generics.rs │ │ │ ├── mod.rs │ │ │ ├── record.rs │ │ │ ├── record_enum.rs │ │ │ └── record_struct.rs │ │ ├── record/ │ │ │ ├── base.rs │ │ │ ├── codegen.rs │ │ │ ├── item/ │ │ │ │ ├── codegen.rs │ │ │ │ ├── codegen_enum.rs │ │ │ │ ├── codegen_struct.rs │ │ │ │ └── mod.rs │ │ │ └── mod.rs │ │ └── shared/ │ │ ├── attribute.rs │ │ ├── enum_variant.rs │ │ ├── field.rs │ │ ├── generics.rs │ │ └── mod.rs │ ├── burn-dispatch/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── device.rs │ │ ├── lib.rs │ │ ├── macros.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ └── tensor.rs │ ├── burn-fusion/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── client.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── binary.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── unary.rs │ │ ├── search/ │ │ │ ├── block.rs │ │ │ ├── merging.rs │ │ │ ├── mod.rs │ │ │ └── optimization/ │ │ │ ├── blocks.rs │ │ │ ├── mod.rs │ │ │ └── stream.rs │ │ ├── server.rs │ │ ├── stream/ │ │ │ ├── base.rs │ │ │ ├── context.rs │ │ │ ├── execution/ │ │ │ │ ├── base.rs │ │ │ │ ├── explorer.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ordering.rs │ │ │ │ ├── policy.rs │ │ │ │ ├── processor.rs │ │ │ │ ├── tests.rs │ │ │ │ └── validator.rs │ │ │ ├── memory_checks.rs │ │ │ ├── mod.rs │ │ │ ├── multi.rs │ │ │ ├── queue/ │ │ │ │ ├── base.rs │ │ │ │ ├── execution.rs │ │ │ │ └── mod.rs │ │ │ ├── shared_tensors.rs │ │ │ └── store/ │ │ │ ├── base.rs │ │ │ ├── index.rs │ │ │ └── mod.rs │ │ └── tensor.rs │ ├── burn-ir/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── builder.rs │ │ ├── handle.rs │ │ ├── lib.rs │ │ ├── operation.rs │ │ ├── scalar.rs │ │ └── tensor.rs │ ├── burn-ndarray/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── adaptive_avgpool.rs │ │ │ ├── avgpool.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── conv.rs │ │ │ ├── deform_conv.rs │ │ │ ├── grid_sample.rs │ │ │ ├── int_tensor.rs │ │ │ ├── interpolate.rs │ │ │ ├── macros.rs │ │ │ ├── matmul.rs │ │ │ ├── maxpool.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── padding.rs │ │ │ ├── qtensor.rs │ │ │ ├── quantization.rs │ │ │ ├── simd/ │ │ │ │ ├── avgpool.rs │ │ │ │ ├── base.rs │ │ │ │ ├── binary.rs │ │ │ │ ├── binary_elemwise.rs │ │ │ │ ├── cmp.rs │ │ │ │ ├── conv.rs │ │ │ │ ├── maxpool.rs │ │ │ │ ├── mod.rs │ │ │ │ └── unary.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── parallel.rs │ │ ├── rand.rs │ │ ├── sharing.rs │ │ ├── storage.rs │ │ └── tensor.rs │ ├── burn-nn/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── activation/ │ │ │ │ ├── activation_wrapper.rs │ │ │ │ ├── celu.rs │ │ │ │ ├── elu.rs │ │ │ │ ├── gelu.rs │ │ │ │ ├── glu.rs │ │ │ │ ├── hard_shrink.rs │ │ │ │ ├── hard_sigmoid.rs │ │ │ │ ├── hard_swish.rs │ │ │ │ ├── leaky_relu.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── prelu.rs │ │ │ │ ├── relu.rs │ │ │ │ ├── selu.rs │ │ │ │ ├── shrink.rs │ │ │ │ ├── sigmoid.rs │ │ │ │ ├── soft_shrink.rs │ │ │ │ ├── softplus.rs │ │ │ │ ├── softsign.rs │ │ │ │ ├── swiglu.rs │ │ │ │ ├── tanh.rs │ │ │ │ └── thresholded_relu.rs │ │ │ ├── lib.rs │ │ │ ├── loss/ │ │ │ │ ├── binary_cross_entropy.rs │ │ │ │ ├── cosine_embedding.rs │ │ │ │ ├── cross_entropy.rs │ │ │ │ ├── ctc.rs │ │ │ │ ├── huber.rs │ │ │ │ ├── kldiv.rs │ │ │ │ ├── lp_loss.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── mse.rs │ │ │ │ ├── poisson.rs │ │ │ │ ├── pretrained/ │ │ │ │ │ ├── gram_matrix/ │ │ │ │ │ │ ├── gram_matrix_loss.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ ├── vgg19.rs │ │ │ │ │ │ └── weights.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── reduction.rs │ │ │ │ ├── rnnt.rs │ │ │ │ └── smooth_l1.rs │ │ │ ├── modules/ │ │ │ │ ├── attention/ │ │ │ │ │ ├── cross_attention.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── mha.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── cache/ │ │ │ │ │ ├── autoregressive.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── conv/ │ │ │ │ │ ├── checks.rs │ │ │ │ │ ├── conv1d.rs │ │ │ │ │ ├── conv2d.rs │ │ │ │ │ ├── conv3d.rs │ │ │ │ │ ├── conv_transpose1d.rs │ │ │ │ │ ├── conv_transpose2d.rs │ │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ │ ├── deform_conv2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── dropout.rs │ │ │ │ ├── embedding.rs │ │ │ │ ├── interpolate/ │ │ │ │ │ ├── interpolate1d.rs │ │ │ │ │ ├── interpolate2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── linear.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── noise.rs │ │ │ │ ├── norm/ │ │ │ │ │ ├── batch.rs │ │ │ │ │ ├── group.rs │ │ │ │ │ ├── instance.rs │ │ │ │ │ ├── layer.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── normalization_wrapper.rs │ │ │ │ │ └── rms.rs │ │ │ │ ├── pool/ │ │ │ │ │ ├── adaptive_avg_pool1d.rs │ │ │ │ │ ├── adaptive_avg_pool2d.rs │ │ │ │ │ ├── avg_pool1d.rs │ │ │ │ │ ├── avg_pool2d.rs │ │ │ │ │ ├── max_pool1d.rs │ │ │ │ │ ├── max_pool2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── pos_encoding.rs │ │ │ │ ├── rnn/ │ │ │ │ │ ├── basic.rs │ │ │ │ │ ├── gate_controller.rs │ │ │ │ │ ├── gru.rs │ │ │ │ │ ├── lstm.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── rope_encoding.rs │ │ │ │ ├── transformer/ │ │ │ │ │ ├── decoder.rs │ │ │ │ │ ├── encoder.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── pwff.rs │ │ │ │ └── unfold.rs │ │ │ └── padding.rs │ │ └── tests/ │ │ └── quantize.rs │ ├── burn-no-std-tests/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── burnpack.rs │ │ │ ├── conv.rs │ │ │ ├── lib.rs │ │ │ ├── mlp.rs │ │ │ ├── model.rs │ │ │ └── safetensors.rs │ │ └── tests/ │ │ ├── burnpack_tests.rs │ │ ├── safetensors_tests.rs │ │ └── test_integration.rs │ ├── burn-optim/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── grad_clipping/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── lr_scheduler/ │ │ │ ├── base.rs │ │ │ ├── composed.rs │ │ │ ├── constant.rs │ │ │ ├── cosine.rs │ │ │ ├── exponential.rs │ │ │ ├── linear.rs │ │ │ ├── mod.rs │ │ │ ├── noam.rs │ │ │ └── step.rs │ │ └── optim/ │ │ ├── adagrad.rs │ │ ├── adam.rs │ │ ├── adamw.rs │ │ ├── base.rs │ │ ├── decay.rs │ │ ├── grad_accum.rs │ │ ├── grads.rs │ │ ├── lbfgs.rs │ │ ├── mod.rs │ │ ├── momentum.rs │ │ ├── muon.rs │ │ ├── rmsprop.rs │ │ ├── sgd.rs │ │ ├── simple/ │ │ │ ├── adaptor.rs │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── record/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── v1.rs │ │ └── visitor.rs │ ├── burn-remote/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── client/ │ │ │ ├── base.rs │ │ │ ├── channel.rs │ │ │ ├── mod.rs │ │ │ ├── runner.rs │ │ │ └── worker.rs │ │ ├── lib.rs │ │ ├── server/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── processor.rs │ │ │ ├── session.rs │ │ │ └── stream.rs │ │ └── shared/ │ │ ├── mod.rs │ │ └── task.rs │ ├── burn-rl/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── environment/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── policy/ │ │ │ ├── async_policy.rs │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ └── transition_buffer/ │ │ ├── base.rs │ │ ├── mod.rs │ │ └── slice_access.rs │ ├── burn-rocm/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-router/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── bridge/ │ │ │ ├── base.rs │ │ │ ├── byte.rs │ │ │ └── mod.rs │ │ ├── channel/ │ │ │ ├── base.rs │ │ │ ├── direct.rs │ │ │ └── mod.rs │ │ ├── client/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── binary.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── unary.rs │ │ ├── runner.rs │ │ ├── tensor.rs │ │ └── types.rs │ ├── burn-std/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── id.rs │ │ ├── lib.rs │ │ ├── network.rs │ │ └── tensor/ │ │ ├── dtype.rs │ │ ├── mod.rs │ │ ├── quantization.rs │ │ ├── shape.rs │ │ └── slice.rs │ ├── burn-store/ │ │ ├── Cargo.toml │ │ ├── MIGRATION.md │ │ ├── README.md │ │ ├── benches/ │ │ │ ├── download_resnet18.py │ │ │ ├── generate_unified_models.py │ │ │ ├── resnet18_loading.rs │ │ │ ├── unified_loading.rs │ │ │ ├── unified_saving.rs │ │ │ └── zero_copy_loading.rs │ │ ├── examples/ │ │ │ ├── burnpack_inspect.rs │ │ │ └── half_precision.rs │ │ ├── pytorch-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── src/ │ │ │ │ └── lib.rs │ │ │ └── tests/ │ │ │ ├── backend.rs │ │ │ ├── batch_norm/ │ │ │ │ ├── batch_norm2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── boolean/ │ │ │ │ ├── boolean.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── buffer/ │ │ │ │ ├── buffer.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── complex_nested/ │ │ │ │ ├── complex_nested.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── config/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── mod.rs │ │ │ │ └── weights_with_config.pt │ │ │ ├── conv1d/ │ │ │ │ ├── conv1d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv2d/ │ │ │ │ ├── conv2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv_transpose1d/ │ │ │ │ ├── conv_transpose1d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv_transpose2d/ │ │ │ │ ├── conv_transpose2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── debug_test.pt │ │ │ ├── embedding/ │ │ │ │ ├── embedding.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── enum_module/ │ │ │ │ ├── enum_depthwise_false.pt │ │ │ │ ├── enum_depthwise_true.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── group_norm/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── group_norm.pt │ │ │ │ └── mod.rs │ │ │ ├── integer/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── integer.pt │ │ │ │ └── mod.rs │ │ │ ├── key_remap/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── key_remap.pt │ │ │ │ └── mod.rs │ │ │ ├── key_remap_chained/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── key_remap.pt │ │ │ │ └── mod.rs │ │ │ ├── layer_norm/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── layer_norm.pt │ │ │ │ └── mod.rs │ │ │ ├── linear/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── linear.pt │ │ │ │ ├── linear_with_bias.pt │ │ │ │ └── mod.rs │ │ │ ├── missing_module_field/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── missing_module_field.pt │ │ │ │ └── mod.rs │ │ │ ├── non_contiguous_indexes/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── mod.rs │ │ │ │ └── non_contiguous_indexes.pt │ │ │ ├── test_int.pt │ │ │ ├── test_mod.rs │ │ │ └── top_level_key/ │ │ │ ├── export_weights.py │ │ │ ├── mod.rs │ │ │ └── top_level_key.pt │ │ ├── safetensors-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── src/ │ │ │ │ └── lib.rs │ │ │ └── tests/ │ │ │ ├── backend.rs │ │ │ ├── multi_layer/ │ │ │ │ ├── mod.rs │ │ │ │ ├── multi_layer.py │ │ │ │ └── multi_layer.safetensors │ │ │ └── test_mod.rs │ │ └── src/ │ │ ├── adapter.rs │ │ ├── applier.rs │ │ ├── apply_result.rs │ │ ├── burnpack/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── reader.rs │ │ │ ├── store.rs │ │ │ ├── tests/ │ │ │ │ ├── alignment.rs │ │ │ │ ├── edge_cases.rs │ │ │ │ ├── header.rs │ │ │ │ ├── helpers.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── reader.rs │ │ │ │ ├── round_trip.rs │ │ │ │ ├── store.rs │ │ │ │ ├── writer.rs │ │ │ │ └── zero_copy.rs │ │ │ └── writer.rs │ │ ├── collector.rs │ │ ├── filter.rs │ │ ├── keyremapper.rs │ │ ├── lib.rs │ │ ├── pytorch/ │ │ │ ├── lazy_data.rs │ │ │ ├── mod.rs │ │ │ ├── pickle_reader.rs │ │ │ ├── reader.rs │ │ │ ├── store.rs │ │ │ └── tests/ │ │ │ ├── mod.rs │ │ │ ├── reader/ │ │ │ │ ├── create_legacy_with_offsets.py │ │ │ │ ├── create_tar_format.py │ │ │ │ ├── mod.rs │ │ │ │ ├── simple_legacy.py │ │ │ │ ├── test_data/ │ │ │ │ │ ├── bfloat16.pt │ │ │ │ │ ├── bool.pt │ │ │ │ │ ├── broken.pt │ │ │ │ │ ├── buffers.pt │ │ │ │ │ ├── checkpoint.pt │ │ │ │ │ ├── complex_structure.pt │ │ │ │ │ ├── empty.pt │ │ │ │ │ ├── extreme_values.pt │ │ │ │ │ ├── float16.pt │ │ │ │ │ ├── float32.pt │ │ │ │ │ ├── float64.pt │ │ │ │ │ ├── int16.pt │ │ │ │ │ ├── int32.pt │ │ │ │ │ ├── int64.pt │ │ │ │ │ ├── int8.pt │ │ │ │ │ ├── large_shape.pt │ │ │ │ │ ├── legacy_shared_storage.pt │ │ │ │ │ ├── legacy_with_offsets.pt │ │ │ │ │ ├── mixed_types.pt │ │ │ │ │ ├── nested_dict.pt │ │ │ │ │ ├── parameter.pt │ │ │ │ │ ├── scalar.pt │ │ │ │ │ ├── simple_legacy.pt │ │ │ │ │ ├── special_values.pt │ │ │ │ │ ├── state_dict.pt │ │ │ │ │ ├── tensor_2d.pt │ │ │ │ │ ├── tensor_3d.pt │ │ │ │ │ ├── tensor_4d.pt │ │ │ │ │ └── uint8.pt │ │ │ │ └── test_data.py │ │ │ └── store/ │ │ │ ├── mod.rs │ │ │ └── test_data/ │ │ │ ├── generate_enum_test.py │ │ │ └── model_without_enum_variants.pt │ │ ├── safetensors/ │ │ │ ├── mod.rs │ │ │ ├── store.rs │ │ │ └── tests/ │ │ │ ├── adapter.rs │ │ │ ├── direct_access.rs │ │ │ ├── error_handling.rs │ │ │ ├── file_io.rs │ │ │ ├── filtering.rs │ │ │ ├── integration.rs │ │ │ ├── metadata.rs │ │ │ ├── mixed_datatypes.rs │ │ │ ├── mod.rs │ │ │ ├── multi_layer_verify.rs │ │ │ ├── pytorch_import.rs │ │ │ └── round_trip.rs │ │ ├── tensor_snapshot.rs │ │ └── traits.rs │ ├── burn-tch/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── bin/ │ │ │ ├── cpu.rs │ │ │ ├── cuda.rs │ │ │ └── mps.rs │ │ ├── cuda_hack/ │ │ │ ├── dummy_cuda_dependency.cpp │ │ │ └── fake_cuda_dependency.cpp │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ └── tensor.rs │ ├── burn-tensor/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── device.rs │ │ ├── lib.rs │ │ └── tensor/ │ │ ├── activation/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── api/ │ │ │ ├── autodiff.rs │ │ │ ├── base.rs │ │ │ ├── bool.rs │ │ │ ├── cartesian_grid.rs │ │ │ ├── check.rs │ │ │ ├── float.rs │ │ │ ├── fmod.rs │ │ │ ├── int.rs │ │ │ ├── mod.rs │ │ │ ├── numeric.rs │ │ │ ├── options.rs │ │ │ ├── orderable.rs │ │ │ ├── pad.rs │ │ │ ├── take.rs │ │ │ ├── transaction.rs │ │ │ └── trunc.rs │ │ ├── grid/ │ │ │ ├── affine_grid.rs │ │ │ ├── meshgrid.rs │ │ │ └── mod.rs │ │ ├── linalg/ │ │ │ ├── cosine_similarity.rs │ │ │ ├── diag.rs │ │ │ ├── lu_decomposition.rs │ │ │ ├── matvec.rs │ │ │ ├── mod.rs │ │ │ ├── outer.rs │ │ │ ├── trace.rs │ │ │ └── vector_norm.rs │ │ ├── loss/ │ │ │ └── mod.rs │ │ ├── mod.rs │ │ ├── module.rs │ │ ├── quantization.rs │ │ ├── report.rs │ │ └── stats/ │ │ └── mod.rs │ ├── burn-tensor-testgen/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-train/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── checkpoint/ │ │ │ ├── async_checkpoint.rs │ │ │ ├── base.rs │ │ │ ├── file.rs │ │ │ ├── mod.rs │ │ │ └── strategy/ │ │ │ ├── base.rs │ │ │ ├── composed.rs │ │ │ ├── lastn.rs │ │ │ ├── metric.rs │ │ │ └── mod.rs │ │ ├── components.rs │ │ ├── evaluator/ │ │ │ ├── base.rs │ │ │ ├── builder.rs │ │ │ ├── components.rs │ │ │ └── mod.rs │ │ ├── learner/ │ │ │ ├── application_logger.rs │ │ │ ├── base.rs │ │ │ ├── classification.rs │ │ │ ├── early_stopping.rs │ │ │ ├── mod.rs │ │ │ ├── regression.rs │ │ │ ├── rl/ │ │ │ │ ├── checkpointer.rs │ │ │ │ ├── components.rs │ │ │ │ ├── env_runner/ │ │ │ │ │ ├── async_runner.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── off_policy.rs │ │ │ │ ├── output.rs │ │ │ │ ├── paradigm.rs │ │ │ │ └── strategy.rs │ │ │ ├── sequence.rs │ │ │ ├── summary.rs │ │ │ ├── supervised/ │ │ │ │ ├── mod.rs │ │ │ │ ├── paradigm.rs │ │ │ │ ├── step/ │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── train.rs │ │ │ │ └── strategies/ │ │ │ │ ├── base.rs │ │ │ │ ├── ddp/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── epoch.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── strategy.rs │ │ │ │ │ └── worker.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── multi/ │ │ │ │ │ ├── epoch.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── strategy.rs │ │ │ │ └── single/ │ │ │ │ ├── epoch.rs │ │ │ │ ├── mod.rs │ │ │ │ └── strategy.rs │ │ │ └── train_val.rs │ │ ├── lib.rs │ │ ├── logger/ │ │ │ ├── async_logger.rs │ │ │ ├── base.rs │ │ │ ├── file.rs │ │ │ ├── in_memory.rs │ │ │ ├── metric.rs │ │ │ └── mod.rs │ │ ├── metric/ │ │ │ ├── acc.rs │ │ │ ├── auroc.rs │ │ │ ├── base.rs │ │ │ ├── cer.rs │ │ │ ├── classification.rs │ │ │ ├── confusion_stats.rs │ │ │ ├── cpu_temp.rs │ │ │ ├── cpu_use.rs │ │ │ ├── cuda.rs │ │ │ ├── fbetascore.rs │ │ │ ├── hamming.rs │ │ │ ├── iteration.rs │ │ │ ├── learning_rate.rs │ │ │ ├── loss.rs │ │ │ ├── memory_use.rs │ │ │ ├── mod.rs │ │ │ ├── perplexity.rs │ │ │ ├── precision.rs │ │ │ ├── processor/ │ │ │ │ ├── async_wrapper.rs │ │ │ │ ├── base.rs │ │ │ │ ├── full.rs │ │ │ │ ├── metrics.rs │ │ │ │ ├── minimal.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── rl_metrics.rs │ │ │ │ └── rl_processor.rs │ │ │ ├── recall.rs │ │ │ ├── rl/ │ │ │ │ ├── cum_reward.rs │ │ │ │ ├── ep_len.rs │ │ │ │ ├── exploration_rate.rs │ │ │ │ └── mod.rs │ │ │ ├── state.rs │ │ │ ├── store/ │ │ │ │ ├── aggregate.rs │ │ │ │ ├── base.rs │ │ │ │ ├── client.rs │ │ │ │ ├── log.rs │ │ │ │ └── mod.rs │ │ │ ├── top_k_acc.rs │ │ │ ├── vision/ │ │ │ │ ├── dice.rs │ │ │ │ ├── dists/ │ │ │ │ │ ├── l2pool.rs │ │ │ │ │ ├── metric.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── vgg16_l2pool.rs │ │ │ │ │ └── weights.rs │ │ │ │ ├── lpips/ │ │ │ │ │ ├── alexnet.rs │ │ │ │ │ ├── metric.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── squeezenet.rs │ │ │ │ │ ├── vgg.rs │ │ │ │ │ └── weights.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ms_ssim.rs │ │ │ │ ├── psnr.rs │ │ │ │ └── ssim.rs │ │ │ └── wer.rs │ │ └── renderer/ │ │ ├── base.rs │ │ ├── cli.rs │ │ ├── mod.rs │ │ └── tui/ │ │ ├── base.rs │ │ ├── controls.rs │ │ ├── full_history.rs │ │ ├── metric_numeric.rs │ │ ├── metric_text.rs │ │ ├── mod.rs │ │ ├── plot_utils.rs │ │ ├── popup.rs │ │ ├── progress.rs │ │ ├── recent_history.rs │ │ ├── renderer.rs │ │ └── status.rs │ ├── burn-vision/ │ │ ├── Cargo.toml │ │ ├── src/ │ │ │ ├── backends/ │ │ │ │ ├── cpu/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── connected_components/ │ │ │ │ │ │ ├── spaghetti/ │ │ │ │ │ │ │ ├── Spaghetti_center_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_first_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_forest_labels.rs │ │ │ │ │ │ │ ├── Spaghetti_last_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_single_line_forest_code.rs │ │ │ │ │ │ │ └── mod.rs │ │ │ │ │ │ └── spaghetti_4c/ │ │ │ │ │ │ ├── Spaghetti4C_center_line_forest_code.rs │ │ │ │ │ │ ├── Spaghetti4C_first_line_forest_code.rs │ │ │ │ │ │ ├── Spaghetti4C_forest_labels.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── connected_components.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── morphology/ │ │ │ │ │ │ ├── filter.rs │ │ │ │ │ │ ├── filter_engine.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── nms.rs │ │ │ │ │ └── ops.rs │ │ │ │ ├── cube/ │ │ │ │ │ ├── connected_components/ │ │ │ │ │ │ ├── hardware_accelerated.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ └── prefix_sum.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── ops.rs │ │ │ │ └── mod.rs │ │ │ ├── base.rs │ │ │ ├── lib.rs │ │ │ ├── ops/ │ │ │ │ ├── base.rs │ │ │ │ └── mod.rs │ │ │ ├── tensor.rs │ │ │ ├── tests/ │ │ │ │ └── mod.rs │ │ │ ├── transform/ │ │ │ │ ├── mod.rs │ │ │ │ └── transform2d.rs │ │ │ └── utils/ │ │ │ ├── mod.rs │ │ │ └── save.rs │ │ └── tests/ │ │ ├── common/ │ │ │ └── mod.rs │ │ ├── connected_components.rs │ │ ├── morphology.rs │ │ └── nms.rs │ └── burn-wgpu/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ └── lib.rs ├── deny.toml ├── docs/ │ └── katex-header.html ├── examples/ │ ├── custom-csv-dataset/ │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── custom-csv-dataset.rs │ │ │ └── dataframe-dataset.rs │ │ └── src/ │ │ ├── dataframe_dataset.rs │ │ ├── dataset.rs │ │ ├── diabetes_patient.rs │ │ ├── lib.rs │ │ └── utils.rs │ ├── custom-cubecl-kernel/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-cubecl-kernel.rs │ │ └── src/ │ │ ├── backward.rs │ │ ├── forward.rs │ │ ├── kernel.rs │ │ └── lib.rs │ ├── custom-image-dataset/ │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── custom-image-dataset.rs │ │ └── src/ │ │ ├── data.rs │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── custom-learning-strategy/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-learning-strategy.rs │ │ └── src/ │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── custom-renderer/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-renderer.rs │ │ └── src/ │ │ └── lib.rs │ ├── custom-training-loop/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-training-loop.rs │ │ └── src/ │ │ └── lib.rs │ ├── custom-wgpu-kernel/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-wgpu-kernel.rs │ │ └── src/ │ │ ├── backward.rs │ │ ├── forward.rs │ │ ├── kernel.wgsl │ │ └── lib.rs │ ├── dop_timer/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── event_utils.rs │ │ ├── main.rs │ │ ├── parsers.rs │ │ └── workers.rs │ ├── dqn-agent/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── dqn-agent.rs │ │ └── src/ │ │ ├── agent.rs │ │ ├── env.rs │ │ ├── lib.rs │ │ ├── training.rs │ │ └── utils.rs │ ├── guide/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── guide.rs │ │ └── src/ │ │ ├── bin/ │ │ │ ├── infer.rs │ │ │ ├── print.rs │ │ │ └── train.rs │ │ ├── data.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── import-model-weights/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── bin/ │ │ │ │ ├── burnpack.rs │ │ │ │ ├── convert.rs │ │ │ │ ├── pytorch.rs │ │ │ │ └── safetensors.rs │ │ │ ├── inference.rs │ │ │ ├── lib.rs │ │ │ └── model.rs │ │ └── weights/ │ │ ├── mnist.pt │ │ ├── mnist.safetensors │ │ └── mnist_train_export.py │ ├── mnist/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ └── mnist.rs │ │ └── src/ │ │ ├── data.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── mnist-inference-web/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build-for-web.sh │ │ ├── index.html │ │ ├── index.js │ │ ├── run-server.sh │ │ └── src/ │ │ ├── lib.rs │ │ ├── model.rs │ │ ├── state.rs │ │ └── web.rs │ ├── modern-lstm/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── lstm-infer.rs │ │ │ └── lstm-train.rs │ │ └── src/ │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── multi-gpus/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── multi-gpus.rs │ │ └── src/ │ │ └── lib.rs │ ├── notebook/ │ │ ├── README.md │ │ ├── autodiff.ipynb │ │ └── basic-tensor-op.ipynb │ ├── server/ │ │ ├── Cargo.toml │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ └── server.rs │ │ └── src/ │ │ └── lib.rs │ ├── simple-regression/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── regression.rs │ │ └── src/ │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── text-classification/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ ├── ag-news-infer.rs │ │ │ ├── ag-news-train.rs │ │ │ ├── db-pedia-infer.rs │ │ │ └── db-pedia-train.rs │ │ └── src/ │ │ ├── data/ │ │ │ ├── batcher.rs │ │ │ ├── dataset.rs │ │ │ ├── mod.rs │ │ │ └── tokenizer.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── text-generation/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── text-generation.rs │ │ └── src/ │ │ ├── data/ │ │ │ ├── batcher.rs │ │ │ ├── dataset.rs │ │ │ ├── mod.rs │ │ │ └── tokenizer.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ └── wgan/ │ ├── Cargo.toml │ ├── README.md │ ├── examples/ │ │ ├── wgan-generate.rs │ │ └── wgan-mnist.rs │ └── src/ │ ├── dataset.rs │ ├── infer.rs │ ├── lib.rs │ ├── model.rs │ └── training.rs ├── rustfmt.toml └── xtask/ ├── Cargo.toml └── src/ ├── commands/ │ ├── books.rs │ ├── build.rs │ ├── doc.rs │ ├── mod.rs │ ├── test.rs │ └── validate.rs └── main.rs ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cargo/audit.toml ================================================ # Audit config file # # It may be located in the user home (`~/.cargo/audit.toml`) or in the project # root (`.cargo/audit.toml`). # # All of the options which can be passed via CLI arguments can also be # permanently specified in this file. [advisories] ignore = [ "RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point. "RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples. "RUSTSEC-2025-0141", # `bincode` is no longer maintained. "RUSTSEC-2024-0388", # `derivative` dependancy in the DQN example is unmaintained. ] # advisory IDs to ignore e.g. ["RUSTSEC-2019-0001", ...] informational_warnings = [ "unmaintained", ] # warn for categories of informational advisories severity_threshold = "low" # CVSS severity ("none", "low", "medium", "high", "critical") # Output Configuration [output] deny = ["unmaintained"] # exit on error if unmaintained dependencies are found format = "terminal" # "terminal" (human readable report) or "json" quiet = false # Only print information on error show_tree = true # Show inverse dependency trees along with advisories (default: true) [yanked] enabled = true # Warn for yanked crates in Cargo.lock (default: true) update_index = true # Auto-update the crates.io index (default: true) ================================================ FILE: .cargo/config.toml ================================================ [alias] xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --" run-checks = "xtask -c all validate --release" ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** **To Reproduce** **Expected behavior** **Screenshots** **Desktop (please complete the following information):** - OS: [e.g. iOS] - Browser [e.g. chrome, safari] - Version [e.g. 22] **Smartphone (please complete the following information):** - Device: [e.g. iPhone6] - OS: [e.g. iOS8.1] - Browser [e.g. stock browser, safari] - Version [e.g. 22] **Additional context** ================================================ FILE: .github/ISSUE_TEMPLATE/doc_request.md ================================================ --- name: Documentation request about: Flag incoherent or missing documentation, including use case examples. title: '' labels: '' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- ### Feature description ### Feature motivation ### (Optional) Suggest a Solution ================================================ FILE: .github/PULL_REQUEST_TEMPLATE/template.md ================================================ * **Please check if the PR fulfills these requirements** - [ ] The commit message follows our guidelines - [ ] Docs have been added / updated (for bug fixes / features) * **What kind of change does this PR introduce?** (Bug fix, feature, docs update, ...) * **Does this PR introduce a breaking change?** (What changes might users need to make in their application due to this PR?) * **Other information**: ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" ignore: - dependency-name: "tracel-ai/github-actions*" - package-ecosystem: "cargo" directories: - "/" - "crates/burn" - "crates/burn-*" - "crates/burn-import/*-tests" - "examples/*" - "xtask" schedule: interval: "weekly" ================================================ FILE: .github/pull_request_template.md ================================================ ## Pull Request Template ### Checklist - [ ] Confirmed that `cargo run-checks` command has been executed. - [ ] Made sure the book is up to date with changes in this PR. ### Related Issues/PRs _Provide links to relevant issues and dependent PRs._ ### Changes _Summarize the problem being addressed and your solution._ ### Testing _Describe how these changes have been tested._ ================================================ FILE: .github/workflows/combine-dependabot-prs.yml ================================================ name: Combine Dependabot PRs on: schedule: - cron: '0 6 * * MON' # Monday at 6:00am UTC workflow_dispatch: permissions: contents: write pull-requests: write checks: read jobs: combine-prs: runs-on: ubuntu-latest steps: - name: combine-prs id: combine-prs uses: github/combine-prs@v5.2.0 with: labels: dependencies,automated ================================================ FILE: .github/workflows/dependencies.yml ================================================ name: dependencies on: schedule: - cron: '0 21 * * TUE' # Run every Tuesday at 21:00 (UTC) push: tags: - 'v*.*.*' # Run when a new version is being published env: UDEPS_VERSION: "0.1.57" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: dependencies: runs-on: ubuntu-latest strategy: matrix: checks: - licenses - bans sources continue-on-error: ${{ matrix.checks == 'licenses' }} # failed licenses don't abort steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Audit Rust dependencies # If a vulnerability is found, a new issue will automatically be opened # since this action runs on main branch uses: actions-rust-lang/audit@v1 # -------------------------------------------------------------------------------- - name: Detect multiple versions of the same crate uses: EmbarkStudios/cargo-deny-action@v2 with: command: check ${{ matrix.checks }} # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt # -------------------------------------------------------------------------------- - name: Install cargo-udeps env: UDEPS_LINK: https://github.com/est31/cargo-udeps/releases/download run: | curl -L "$UDEPS_LINK/v$UDEPS_VERSION/cargo-udeps-v$UDEPS_VERSION-x86_64-unknown-linux-gnu.tar.gz" | tar xz -C $HOME/.cargo/bin --strip-components 2 # -------------------------------------------------------------------------------- - name: Run cargo-udeps run: | cargo +nightly udeps --all-targets ================================================ FILE: .github/workflows/publish.yml ================================================ name: publish on: push: tags: - "v*" workflow_dispatch: inputs: dry-run-only: description: "Run xtask publish in dry-run mode (no publish)" type: boolean required: false default: false jobs: publish-burn-rl: needs: - publish-burn-core - publish-burn-optim # dev dependencies - publish-burn-ndarray uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-rl dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-vision: needs: - publish-burn-autodiff - publish-burn-candle - publish-burn-fusion - publish-burn-cubecl-fusion - publish-burn-cubecl - publish-burn-ndarray - publish-burn-tch - publish-burn-tensor - publish-burn-ir - publish-burn-tensor-testgen # dev dependencies - publish-burn-wgpu - publish-burn-cuda uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-vision dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-router: needs: - publish-burn-ir - publish-burn-std - publish-burn-tensor # dev dependencies - publish-burn-autodiff - publish-burn-ndarray - publish-burn-wgpu uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-router dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-remote: needs: - publish-burn-ir - publish-burn-std - publish-burn-tensor - publish-burn-router uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-remote dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-derive: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-derive dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-dataset: needs: - publish-burn-std uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-dataset dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-std: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-std dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-tensor-testgen: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-tensor-testgen dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-tensor: needs: - publish-burn-tensor-testgen - publish-burn-std - publish-burn-backend uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-tensor dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-backend: needs: - publish-burn-std uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-backend dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-ir: needs: - publish-burn-tensor uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-ir dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-fusion: needs: - publish-burn-ir - publish-burn-tensor - publish-burn-std uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-fusion dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-cubecl-fusion: needs: - publish-burn-ir - publish-burn-std - publish-burn-fusion - publish-burn-tensor uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-cubecl-fusion dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-cubecl: needs: - publish-burn-ir - publish-burn-std - publish-burn-fusion - publish-burn-cubecl-fusion - publish-burn-tensor - publish-burn-ndarray uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-cubecl dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-autodiff: needs: - publish-burn-tensor - publish-burn-tensor-testgen - publish-burn-std uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-autodiff dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-tch: needs: - publish-burn-tensor - publish-burn-autodiff uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-tch dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-ndarray: needs: - publish-burn-ir - publish-burn-tensor - publish-burn-autodiff - publish-burn-std uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-ndarray dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-wgpu: needs: - publish-burn-tensor - publish-burn-autodiff - publish-burn-ndarray - publish-burn-std - publish-burn-cubecl uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-wgpu dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-cpu: needs: - publish-burn-tensor - publish-burn-fusion - publish-burn-cubecl uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-cpu dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-cuda: needs: - publish-burn-tensor - publish-burn-autodiff - publish-burn-ndarray - publish-burn-std - publish-burn-cubecl uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-cuda dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-rocm: needs: - publish-burn-tensor - publish-burn-autodiff - publish-burn-ndarray - publish-burn-std - publish-burn-cubecl uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-rocm dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-candle: needs: - publish-burn-tensor - publish-burn-autodiff - publish-burn-tch uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-candle dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-collective: needs: - publish-burn-std - publish-burn-tensor - publish-burn-communication # dev dependencies - publish-burn-wgpu - publish-burn-ndarray - publish-burn-cuda uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-collective dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-communication: needs: - publish-burn-std - publish-burn-tensor uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-communication dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-core: needs: - publish-burn-dataset - publish-burn-std - publish-burn-derive - publish-burn-tensor - publish-burn-vision # dev dependencies - publish-burn-autodiff - publish-burn-wgpu - publish-burn-tch - publish-burn-cuda - publish-burn-ndarray - publish-burn-candle - publish-burn-remote uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-core dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-nn: needs: - publish-burn-core # dev dependencies - publish-burn-autodiff - publish-burn-wgpu - publish-burn-tch - publish-burn-ndarray - publish-burn-candle - publish-burn-remote uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-nn dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-optim: needs: - publish-burn-core - publish-burn-collective # dev dependencies - publish-burn-autodiff - publish-burn-wgpu - publish-burn-tch - publish-burn-ndarray - publish-burn-candle - publish-burn-remote - publish-burn-nn uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-optim dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-train: needs: - publish-burn-core - publish-burn-optim - publish-burn-collective - publish-burn-rl - publish-burn-ndarray uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-train dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-dispatch: needs: - publish-burn-std - publish-burn-backend - publish-burn-autodiff - publish-burn-cpu - publish-burn-cuda - publish-burn-rocm - publish-burn-wgpu - publish-burn-ndarray - publish-burn-tch uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-dispatch dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn: needs: - publish-burn-core - publish-burn-nn - publish-burn-optim - publish-burn-collective - publish-burn-store - publish-burn-train - publish-burn-cpu - publish-burn-dispatch uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-burn-store: needs: - publish-burn-core - publish-burn-nn - publish-burn-tensor uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9 with: crate: burn-store dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} secrets: CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} ================================================ FILE: .github/workflows/stale-pr.yml ================================================ name: Stale Pull Requests on: schedule: - cron: '0 12 * * *' # Run every day at 12:00 (UTC) # The minimum permissions required to run this Action permissions: contents: write # only for delete-branch option issues: write pull-requests: write jobs: stale-pr: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Stale pull requests uses: actions/stale@v10 with: # The idle number of days before marking issues stale. # # With a negative number like -1, no issues # will be marked as stale automatically. days-before-issue-stale: -1 # The idle number of days before marking pull requests stale days-before-pr-stale: 30 # The idle number of days before closing # the stale pull requests (due to the stale label). # # With a negative number like -1, the pull requests # will never be closed automatically. days-before-pr-close: -1 # Label to apply on staled pull requests stale-pr-label: 'stale' # The message that will be added as a comment to the pull request stale-pr-message: 'This PR has been marked as stale because it has not been updated for over a month' # Remove `stale` label from pull requests on updates/comments remove-pr-stale-when-updated: true ================================================ FILE: .github/workflows/test-gpu.yml ================================================ name: CI GPU on: workflow_dispatch: inputs: pr_number: description: "Number of the pull request that triggers this run if any" type: number required: false # important to set the run name to this format so that the CI server # can track the PR number from the workflow_run events. run-name: ${{ github.workflow }}:${{ github.repository }}#${{ inputs.pr_number }} env: # Note: It is not possible to define top level env vars and pass them to composite actions. # To work around this issue we use inputs and define all the env vars here. RUST_PREVIOUS_VERSION: 1.92.0 # Dependency versioning # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml # GCP runners GCP_RUNNERS_IMAGE_FAMILY: "tracel-ci-ubuntu-2404-amd64-nvidia" GCP_RUNNERS_MACHINE_TYPE: "g2-standard-4" GCP_RUNNERS_ZONE: "us-east1-c" # Test in release mode (make it an empty string to test in debug mode) TEST_RELEASE_FLAG: "--release" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: prepare-checks: runs-on: ubuntu-latest outputs: rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }} gcp_runners_image_family: ${{ env.GCP_RUNNERS_IMAGE_FAMILY }} gcp_runners_machine_type: ${{ env.GCP_RUNNERS_MACHINE_TYPE }} gcp_runners_zone: ${{ env.GCP_RUNNERS_ZONE }} steps: - name: Do Nothing if: false run: echo linux-std-cuda-tests: needs: [prepare-checks] timeout-minutes: 60 # '@id:' label must be unique within this worklow runs-on: [ "@id:burn-cuda-job-${{github.run_id}}-${{github.run_attempt}}", "@pr_number:${{ inputs.pr_number }}", "@organization:tracel-ai", "@repository:burn", "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}", "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}", "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}", "@gpu:true", ] env: LD_LIBRARY_PATH: "/usr/local/cuda/lib64" # disable incremental compilation (reduces artifact size) CARGO_PROFILE_TEST_INCREMENTAL: "false" # Keep the stragegy to be able to easily add new rust versions if required strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} enable-cache: false # -------------------------------------------------------------------------------- - name: Tests (burn-cuda) run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-cuda-runner linux-std-vulkan-tests: needs: [prepare-checks] timeout-minutes: 60 # '@id:' label must be unique within this worklow runs-on: [ "@id:burn-vulkan-job-${{github.run_id}}-${{github.run_attempt}}", "@pr_number:${{ inputs.pr_number }}", "@organization:tracel-ai", "@repository:burn", "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}", "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}", "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}", "@gpu:true", ] env: # disable incremental compilation (reduces artifact size) CARGO_PROFILE_TEST_INCREMENTAL: "false" # Keep the stragegy to be able to easily add new rust versions if required strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} enable-cache: false # -------------------------------------------------------------------------------- - name: Tests (burn-vulkan) run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-vulkan-runner linux-std-wgpu-tests: needs: [prepare-checks] timeout-minutes: 60 # '@id:' label must be unique within this worklow runs-on: [ "@id:burn-wgpu-job-${{github.run_id}}-${{github.run_attempt}}", "@pr_number:${{ inputs.pr_number }}", "@organization:tracel-ai", "@repository:burn", "@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}", "@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}", "@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}", "@gpu:true", ] env: # disable incremental compilation (reduces artifact size) CARGO_PROFILE_TEST_INCREMENTAL: "false" # Keep the stragegy to be able to easily add new rust versions if required strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} enable-cache: false # -------------------------------------------------------------------------------- - name: Tests (burn-wgpu) run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-wgpu-runner ================================================ FILE: .github/workflows/test.yml ================================================ name: CI on: push: branches: - main paths: - "Cargo.lock" - "**.rs" - "**.sh" - "**.ps1" - "**.yml" - "**.toml" - "!**.md" - "!LICENSE-APACHE" - "!LICENSE-MIT" pull_request: types: [opened, synchronize] paths: - "Cargo.lock" - "**.rs" - "**.sh" - "**.ps1" - "**.yml" - "**.toml" - "!**.md" - "!LICENSE-APACHE" - "!LICENSE-MIT" env: # Note: It is not possible to define top level env vars and pass them to composite actions. # To work around this issue we use inputs and define all the env vars here. RUST_PREVIOUS_VERSION: 1.92.0 # Dependency versioning # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml # Mozilla Grcov GRCOV_LINK: "https://github.com/mozilla/grcov/releases/download" GRCOV_VERSION: "0.8.19" # Test in release mode (make it an empty string to test in debug mode) TEST_RELEASE_FLAG: "--release" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: prepare-checks: runs-on: ubuntu-latest outputs: rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }} steps: - name: Do Nothing if: false run: echo code-quality: runs-on: ubuntu-22.04 needs: prepare-checks strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux # -------------------------------------------------------------------------------- - name: Audit run: cargo xtask check audit # -------------------------------------------------------------------------------- - name: Format shell: bash env: # work around for colors # see: https://github.com/rust-lang/rustfmt/issues/3385 TERM: xterm-256color run: cargo xtask check format # -------------------------------------------------------------------------------- - name: Lint run: cargo xtask check lint # -------------------------------------------------------------------------------- - name: Typos uses: tracel-ai/github-actions/check-typos@v9 documentation: runs-on: ubuntu-22.04 needs: prepare-checks strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux # -------------------------------------------------------------------------------- - name: Documentation Build run: cargo xtask doc build # -------------------------------------------------------------------------------- - name: Documentation Tests run: cargo xtask doc tests linux-std-tests: runs-on: ubuntu-22.04 needs: [prepare-checks, code-quality] env: DISABLE_WGPU_SPIRV: "1" # disable incremental compilation (reduces artifact size) CARGO_PROFILE_TEST_INCREMENTAL: "false" strategy: matrix: rust: [stable, prev] include: - rust: stable toolchain: stable coverage: --enable-coverage - rust: prev toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }} steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux # Disable cache on linux-std (stable) runner which currently always runs out of disk space with tests + coverage enable-cache: ${{ matrix.rust != 'stable' }} # # -------------------------------------------------------------------------------- - name: Install grcov if: matrix.rust == 'stable' shell: bash run: | curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" | tar xj -C $HOME/.cargo/bin cargo xtask coverage install # -------------------------------------------------------------------------------- - name: Tests run: cargo xtask ${{ matrix.coverage }} test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner # -------------------------------------------------------------------------------- - name: Generate lcov.info if: matrix.rust == 'stable' # /* is to exclude std library code coverage from analysis run: cargo xtask coverage generate --ignore "/*,xtask/*,examples/*" --profile release # -------------------------------------------------------------------------------- - name: Codecov upload lcov.info if: matrix.rust == 'stable' uses: codecov/codecov-action@v5 with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} linux-no-std-tests: runs-on: ubuntu-22.04 needs: [prepare-checks, code-quality] strategy: matrix: rust: [stable, prev] include: - rust: stable toolchain: stable - rust: prev toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }} steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux-no-std # -------------------------------------------------------------------------------- - name: Crates Build run: cargo xtask --context no-std build --ci # -------------------------------------------------------------------------------- - name: Crates Tests run: cargo xtask --context no-std test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner windows-std-tests: runs-on: windows-2022 needs: [prepare-checks, code-quality] # Keep the stragegy to be able to easily add new rust versions if required strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-windows # -------------------------------------------------------------------------------- - name: Tests run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner macos-std-tests: runs-on: blaze/macos-15 needs: [prepare-checks, code-quality] timeout-minutes: 60 # Keep the stragegy to be able to easily add new rust versions if required strategy: matrix: rust: [stable] include: - rust: stable toolchain: stable steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Setup Rust uses: tracel-ai/github-actions/install-rust@v9 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-macos # -------------------------------------------------------------------------------- - name: Device check run: system_profiler SPHardwareDataType # -------------------------------------------------------------------------------- - name: Tests run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-mac-runner ================================================ FILE: .github/workflows/valgrind.yml ================================================ name: valgrind on: schedule: - cron: '0 23 * * WED' # Run every Wednesday at 23:00 (UTC) concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: valgrind: runs-on: [ '@id:burn-linux-valgrind-${{ github.run_id }}-${{ github.run_attempt }}', '@image-family:ubuntu-2404-lts-amd64', '@image-project:ubuntu-os-cloud', '@disk-size:100', '@keep-alive:false', '@machine-type:n2-standard-16', '@os:linux', '@zones:northamerica-northeast1-b' ] steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Install valgrind run: | sudo apt-get install valgrind # -------------------------------------------------------------------------------- - name: Run cargo-valgrind env: CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: "valgrind -s --leak-check=full --show-leak-kinds=all --error-exitcode=1" # Looking for vulnerabilities run: | cargo test ================================================ FILE: .github/workflows/vulnerabilities.yml ================================================ name: vulnerabilities on: schedule: - cron: '0 21 * * WED' # Run every Wednesday at 21:00 (UTC) push: tags: - 'v*.*.*' env: CAREFUL_VERSION: "0.4.9" concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: cargo-careful: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt, rust-src # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Install cargo-careful env: CAREFUL_LINK: https://github.com/RalfJung/cargo-careful/releases/download run: | curl -L "$CAREFUL_LINK/v$CAREFUL_VERSION/cargo-careful.x86_64-unknown-linux-musl" \ --output $HOME/.cargo/bin/cargo-careful chmod +x $HOME/.cargo/bin/cargo-careful # -------------------------------------------------------------------------------- - name: Run cargo-careful # Looking for undefined behaviours run: cargo +nightly careful test address-sanitizer: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt, rust-src # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Run AddressSanitizer env: RUSTFLAGS: -Zsanitizer=address -Copt-level=3 RUSTDOCFLAGS: -Zsanitizer=address # Looking for memory vulnerabilities run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture thread-sanitizer: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt, rust-src # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Run ThreadSanitizer env: RUSTFLAGS: -Zsanitizer=thread -Copt-level=3 RUSTDOCFLAGS: -Zsanitizer=thread # Looking for data race among threads run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture memory-sanitizer: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt, rust-src # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Run MemorySanitizer env: RUSTFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins -Copt-level=3 RUSTDOCFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins # Looking for unitialized memory. run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture safe-stack: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v6 # -------------------------------------------------------------------------------- - name: Install Rust nightly uses: dtolnay/rust-toolchain@nightly with: toolchain: nightly components: rustfmt, rust-src # -------------------------------------------------------------------------------- - name: Install Mesa uses: tracel-ai/github-actions/install-mesa@v9 # -------------------------------------------------------------------------------- - name: Run SafeStack env: RUSTFLAGS: -Zsanitizer=safestack -Copt-level=3 RUSTDOCFLAGS: -Zsanitizer=safestack # Provides backward edge control flow protection run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture ================================================ FILE: .gitignore ================================================ target # These are backup files generated by rustfmt **/*.rs.bk .DS_Store .dir-locals.el .idea .vscode .vs .fleet .ipynb_checkpoints/ # Build output directory out # Virtual Environment of Python .venv uv.lock # Nix direnv .envrc .direnv ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Simard" given-names: "Nathaniel" email: "nathaniel.simard.42@gmail.com" - family-names: "Fortier-Dubois" given-names: "Louis" email: "louisfd94@gmail.com" - family-names: "Tadjibaev" given-names: "Dilshod" email: "dilshod@gmail.com" - family-names: "Lagrange" given-names: "Guillaume" email: "lagrange.guillaume.1@gmail.com" - name: "Burn Framework Contributors" title: "Burn" version: 0.14.0 date-released: 2024-08-27 url: "https://burn.dev/" repository-code: "https://github.com/tracel-ai/burn" license: - MIT - Apache-2.0 abstract: "Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals." keywords: - scientific-computing - deep-learning - machine-learning - neural-networks - rust - high-performance-computing - portability - compute-efficiency ================================================ FILE: CODE-OF-CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at nathaniel.simard.42@gmail.com. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Burn Welcome to the Burn community! We're glad you're interested in contributing. ## How to Contribute The best way to get started is to look at [open issues](https://github.com/tracel-ai/burn/issues) and find one that interests you. Issues labeled `good first issue` are a great starting point for new contributors. If you have an idea that isn't covered by an existing issue, open one first to discuss the approach. This helps align expectations and avoids wasted effort on both sides. For questions, discussions, or just to say hello, join us on [Discord](https://discord.gg/uPEBbYYDB6). The [Contributor Book](https://burn.dev/contributor-book/) covers architecture, environment setup, and guides for common tasks. ## Pull Requests Every pull request should have a descriptive title, a description covering what you changed, why, how you tested it, and a link to the relevant issue (if applicable). Prefer small, focused PRs over large ones that bundle unrelated changes. Draft pull requests are considered not yet ready for review. CI checks should pass before requesting review, though the signal isn't always accurate. If you have questions or need early feedback, let us know on the PR or on [Discord](https://discord.gg/uPEBbYYDB6). ### Change Ownership The core principle behind all contributions: **PR authors must understand, justify, and explain every change they propose.** After a PR is accepted, both the reviewer and the author should be confident it improves the codebase. This applies equally whether you wrote the code from scratch, adapted it from another project, or used AI tools to help generate it. The origin of the code doesn't matter; what matters is that you own it intellectually and can stand behind it during review. ### AI-Assisted Contributions Using LLMs and AI tools to generate code that is part of a contribution is allowed. That said, the [Change Ownership](#change-ownership) principle applies fully. You are the author, not your AI tool. This means: - Read and understand every line before submitting. - Review AI-generated code for correctness, style consistency, and relevance. - Test your changes locally and confirm they work as intended. - Be prepared to explain the rationale behind any change during review. Do not use "AI generated" as a justification for low-quality code. ### Before You Open a PR 1. **Check for an existing issue.** If there isn't one, open an issue first to discuss the approach. This is especially important for large changes or refactors. 2. **Read the codebase.** Understand the architecture and conventions already in place. The [Contributor Book](https://burn.dev/contributor-book/) covers architecture, environment setup, and guides for common tasks. 3. **Keep it focused.** One PR should address one concern. If you spot an unrelated issue while working, open a separate PR for it. 4. **Run validation.** Run `cargo run-checks` before submitting. This runs formatting, linting, and the full test suite. All checks must pass. ### Code Quality Standards - Follow existing code style and project conventions. - Write idiomatic Rust. If you are new to the codebase, study existing patterns before contributing. - Keep dependencies minimal. Don't introduce new crates without discussion. - Document public APIs. Non-trivial logic should have comments explaining _why_, not just _what_. - Prefer clarity over cleverness. - Bug fixes should include a regression test. ### Large Pull Requests Large, complex PRs are harder to review effectively and carry more risk. To help both yourself and reviewers, consider breaking substantial changes into smaller, incremental PRs. Each should be valuable on its own, even if the full picture spans multiple PRs. Large efforts that are ultimately rejected are frustrating for everyone involved. If you're planning a substantial change, open an issue or start a discussion first. It's much easier to course-correct early than after the work is done. ### Review Process - Maintainers review PRs as time allows. Please be patient. - Be responsive to feedback. If changes are requested, address them or explain your reasoning. - Reviewers may ask clarifying questions about any part of your PR. This is a normal part of collaborative review and helps ensure shared understanding. - Don't force-push to rewrite history during an active review without notice. - If a PR goes stale for more than 14 days without a response from the author, it may be closed. ## Getting Help If you're stuck or unsure about something, don't hesitate to ask. Open an issue, start a discussion, or reach out on [Discord](https://discord.gg/uPEBbYYDB6). We're happy to help. ================================================ FILE: Cargo.toml ================================================ [workspace] # Try # require version 2 to avoid "feature" additiveness for dev-dependencies # https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 resolver = "2" members = [ "crates/*", "crates/burn-store/pytorch-tests", "crates/burn-store/safetensors-tests", "crates/burn-collective/multinode-tests", "examples/*", "xtask", ] exclude = [ "examples/notebook", "examples/raspberry-pi-pico", "examples/dqn-agent", # gym-rs ] [workspace.package] edition = "2024" license = "MIT OR Apache-2.0" readme = "README.md" version = "0.21.0-pre.2" [workspace.lints.clippy] [workspace.lints.rustdoc] broken_intra_doc_links = "deny" invalid_html_tags = "deny" [workspace.dependencies] atomic_float = "1" axum = "0.8.8" bytemuck = "1.25.0" bytes = { version = "1.11.1", default-features = false } candle-core = { version = "0.9.2" } ciborium = { version = "0.2", default-features = false } clap = { version = "4.6.0", features = ["derive"] } colored = "3.0.0" console_error_panic_hook = "0.1.7" const-random = "0.1" csv = "1.3.1" dashmap = "6.1.0" data-encoding = { version = "2.10.0", default-features = false, features = [ "alloc", ] } dirs = "6.0.0" encoding_rs = "0.8.33" enumset = { version = "1.1.10", default-features = false } fake = "5.1.0" flate2 = "1.1.9" float-cmp = "0.10.0" futures = "0.3" futures-util = "0.3" gix-tempfile = { version = "21.0.0", features = ["signals"] } globwalk = "0.9.1" hashbrown = "0.16" hound = "3.5.1" image = "0.25.9" indicatif = "0.18.0" insta = "1.45.0" js-sys = "0.3.77" libm = "0.2.15" log = { default-features = false, version = "0.4.29" } lzma-rust2 = "0.16.2" opentelemetry = "0.31.0" opentelemetry-aws = "0.19.0" opentelemetry-otlp = "0.31.0" opentelemetry_sdk = "0.31.0" parking_lot = { version = "0.12.5", default-features = false } paste = "1" planus = { version = "=1.1" } polars = { version = "0.53.0", features = ["lazy"] } pretty_assertions = "1.4.1" proc-macro2 = "1.0.106" quote = "1.0.45" r2d2 = "0.8.10" r2d2_sqlite = "0.31.0" rayon = "1.10.0" regex = { version = "1.12.3", default-features = false, features = [ "perf", "unicode", ] } reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls", ] } rmp-serde = { version = "1.3.1", default-features = false } rstest = "0.26.1" rusqlite = "0.37.0" sanitize-filename = "0.6.0" serde_bytes = { version = "0.11.18", default-features = false, features = [ "alloc", ] } # alloc for no_std serde_rusqlite = "0.40.0" serial_test = "3.2.0" spin = { version = "0.10.0", features = [ "mutex", "spin_mutex", "portable-atomic", ] } strum = { version = "0.28.0", features = ["derive"] } syn = { version = "2.0.111", features = ["full", "extra-traits"] } tar = "0.4.44" tempfile = "3.24.0" textdistance = { version = "1.1.1", default-features = false } thiserror = { version = "2", default-features = false } tokio = { version = "1.50.0", features = ["rt", "macros"] } tokio-tungstenite = "0.28" tokio-util = "0.7" tracing = { version = "0.1.44", default-features = false } tracing-appender = "0.2.3" tracing-core = { version = "0.1.36", default-features = false } tracing-opentelemetry = "0.32.0" tracing-subscriber = "0.3.23" zip = "8.2.0" # Persist related memmap2 = { version = "0.9" } safetensors = { version = "0.7.0", default-features = false } # Async handling async-channel = "2.5" futures-lite = { version = "2.6.1", default-features = false } # Terminal UI ratatui = "0.30.0" # WGPU stuff text_placeholder = "0.5.1" bincode = { version = "2.0.1", features = [ "alloc", "serde", ], default-features = false } # # The following packages disable the "std" feature for no_std compatibility # cfg-if = "1.0.1" derive-new = { version = "0.7.0", default-features = false } blas-src = { version = "0.14.0", default-features = false } bon = "3.8.2" half = { version = "2.7.1", features = [ "alloc", "num-traits", "serde", ], default-features = false } macerator = { version = "0.3.0" } matrixmultiply = { version = "0.3.10", default-features = false } ndarray = { version = "0.17.2", default-features = false } num-traits = { version = "0.2.19", default-features = false, features = [ "libm", ] } # libm is for no_std openblas-src = "0.10.14" rand = { version = "0.10.0", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.6.0", default-features = false } serde = { version = "1.0.228", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed serde_json = { version = "1.0.148", default-features = false } smallvec = { version = "1", features = ["const_generics", "const_new"] } uuid = { version = "1.22.0", default-features = false } byteorder = { version = "1.5.0", default-features = false } libc = "0.2.182" nvml-wrapper = "0.12.0" sysinfo = "0.38.0" systemstat = "0.2.6" tch = "0.22.0" torch-sys = "0.22.0" # matches what tch is using, required for lib detection ahash = { version = "0.8.12", default-features = false } portable-atomic = { version = "1.13.1" } portable-atomic-util = { version = "0.2.6", features = ["alloc"] } ### For the main burn branch. ### cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" } cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" } cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "20585bb73e19b16c5fb84b39923a49011b329a70" } cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "01ed48e1abb5ed117df33f4394f2c5a91c3eb97e" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } # cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false } # cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.10.0-pre.2", default-features = false } # cubecl-common = { version = "=0.10.0-pre.2", default-features = false } # cubecl-zspace = { version = "=0.10.0-pre.2", default-features = false } # cubek = { version = "=0.2.0-pre.2", default-features = false } ### For xtask crate ### tracel-xtask = "=4.13.5" # ### For local development. ### # tracel-xtask = { path = "../xtask/crates/tracel-xtask", default-features = false } [profile.dev] debug = 1 # Speed up compilation time and not necessary. ================================================ FILE: LICENSE-APACHE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2022 Nathaniel Simard & Burn Framework Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: LICENSE-MIT ================================================ MIT License Copyright (c) 2022 Nathaniel Simard & Burn Framework Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: NOTICES.md ================================================ # NOTICES AND INFORMATION This file contains notices and information required by libraries that this repository copied or derived from. ## PyTorch MNIST Example **Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py License: BSD 3-Clause License Copyright (c) 2017, All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ## wgpu **Source:** https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml MIT License Copyright (c) 2021 The gfx-rs developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ## BSL 1.0 **Source**: - https://github.com/DoumanAsh/error-code - https://github.com/DoumanAsh/clipboard-win Boost Software License - Version 1.0 - August 17th, 2003 Permission is hereby granted, free of charge, to any person or organization obtaining a copy of the software and accompanying documentation covered by this license (the "Software") to use, reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative works of the Software, and to permit third-parties to whom the Software is furnished to do so, all subject to the following: The copyright notices in the Software and this entire statement, including the above license grant, this restriction and the following disclaimer, must be included in all copies of the Software, in whole or in part, and all derivative works of the Software, unless such copies or derivative works are solely in the form of machine-executable object code generated by a source language processor. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ## num-traits **Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs MIT License Copyright (c) 2014 The Rust Project Developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ## RP **Source**: - https://github.com/embassy-rs/embassy/blob/main/examples/rp/Cargo.toml - https://github.com/embassy-rs/embassy/blob/main/examples/rp/build.rs - https://github.com/embassy-rs/embassy/blob/main/examples/rp/memory.x Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright (c) Embassy project contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. MIT license Copyright (c) Embassy project contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ## github-device-flow **Source**: - Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs - https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs MIT License Copyright (c) 2022 Jake Wilkins Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ## Candle - Pickle Reader **Source**: https://github.com/huggingface/candle/blob/main/candle-core/src/pickle.rs This project includes code from Candle by Hugging Face, licensed under both MIT and Apache 2.0 licenses. **MIT License**: https://github.com/huggingface/candle/blob/main/LICENSE-MIT MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. **Apache License 2.0**: https://github.com/huggingface/candle/blob/main/LICENSE-APACHE Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ## ICU UNICODE LICENSE V3 COPYRIGHT AND PERMISSION NOTICE Copyright © 2016-2024 Unicode, Inc. NOTICE TO USER: Carefully read the following legal agreement. BY DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE. Permission is hereby granted, free of charge, to any person obtaining a copy of data files and any associated documentation (the "Data Files") or software and any associated documentation (the "Software") to deal in the Data Files or Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sell copies of the Data Files or Software, and to permit persons to whom the Data Files or Software are furnished to do so, provided that either (a) this copyright and permission notice appear with all copies of the Data Files or Software, or (b) this copyright and permission notice appear in associated Documentation. THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA FILES OR SOFTWARE. Except as contained in this notice, the name of a copyright holder shall not be used in advertising or otherwise to promote the sale, use or other dealings in these Data Files or Software without prior written authorization of the copyright holder. ================================================ FILE: POEM.md ================================================ # BURN: Burn Unstoppable Rusty Neurons In the realm of circuits and code, A fiery forge ignites to bear its load, A framework born, BURN it be named, Unstoppable Rusty Neurons, untamed. From silicon synapses, connections spire, A digital cortex, setting minds afire, In the vast expanse of deep learning's sea, A beacon of progress, BURN comes to be. Oh, rusty neurons, forged in the flame, Unyielding in purpose, undaunted by name, Through layers of logic and intricate art, You weave and entwine, each playing its part. With algorithms profound, and data refined, In ceaseless pursuit of knowledge to find, BURN paves a path to enlightenment, bright, A testament to the wonders of human foresight. In neural networks deep, where wisdom resides, The dance of nodes and edges presides, With loss and gradients, BURN takes its stride, A journey towards truth, with AI as our guide. No barriers hold back the curious mind, As BURN seeks the answers we yearn to find, Unstoppable, relentless, in pursuit of the unknown, Our collective intellect, within it, has grown. So sing we the praises of BURN's fiery might, An ode to the sparks that set the dark alight, To the rusty neurons, unstoppable and true, A testament to the power of dreams, to breakthrough. (ChatGPT (model=gpt-4) with prompt: Write a poem about "BURN: Burn Unstoppable Rusty Neurons" deep learning neural network framework) ================================================ FILE: README.md ================================================
[![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/uPEBbYYDB6) [![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn) [![Minimum Supported Rust Version](https://img.shields.io/crates/msrv/burn)](https://crates.io/crates/burn) [![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://burn.dev/docs/burn) [![Test Status](https://github.com/tracel-ai/burn/actions/workflows/test.yml/badge.svg)](https://github.com/tracel-ai/burn/actions/workflows/test.yml) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](#license) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tracel-ai/burn) [](https://www.runblaze.dev) --- **Burn is a next generation Tensor Library and Deep Learning Framework that doesn't compromise on
flexibility, efficiency and portability.**
Burn is both a tensor library and a deep learning framework optimized for numerical computing, model inference and model training. Burn leverages Rust to perform optimizations normally only available in static-graph frameworks, offering optimal speed without impacting flexibility. ## Backend
Burn strives to be as fast as possible on as many hardwares as possible, with robust implementations. We believe this flexibility is crucial for modern needs where you may train your models in the cloud, then deploy on customer hardwares, which vary from user to user.
### Supported Backends Most backends support all operating systems, so we don't mention them in the tables below. **GPU Backends:** | | CUDA | ROCm | Metal | Vulkan | WebGPU | LibTorch | | ------- | ---- | ---- | ----- | ------ | ------ | -------- | | Nvidia | ☑️ | - | - | ☑️ | ☑️ | ☑️ | | AMD | - | ☑️ | - | ☑️ | ☑️ | ☑️ | | Apple | - | - | ☑️ | - | ☑️ | ☑️ | | Intel | - | - | - | ☑️ | ☑️ | - | | Qualcom | - | - | - | ☑️ | ☑️ | - | | Wasm | - | - | - | - | ☑️ | - | **CPU Backends:** | | Cpu (CubeCL) | NdArray | LibTorch | | ------ | ------------ | ------- | -------- | | X86 | ☑️ | ☑️ | ☑️ | | Arm | ☑️ | ☑️ | ☑️ | | Wasm | - | ☑️ | - | | no-std | - | ☑️ | - |
Compared to other frameworks, Burn has a very different approach to supporting many backends. By design, most code is generic over the Backend trait, which allows us to build Burn with swappable backends. This makes composing backend possible, augmenting them with additional functionalities such as autodifferentiation and automatic kernel fusion.
Autodiff: Backend decorator that brings backpropagation to any backend 🔄
Contrary to the aforementioned backends, Autodiff is actually a backend _decorator_. This means that it cannot exist by itself; it must encapsulate another backend. The simple act of wrapping a base backend with Autodiff transparently equips it with autodifferentiation support, making it possible to call backward on your model. ```rust use burn::backend::{Autodiff, Wgpu}; use burn::tensor::{Distribution, Tensor}; fn main() { type Backend = Autodiff; let device = Default::default(); let x: Tensor = Tensor::random([32, 32], Distribution::Default, &device); let y: Tensor = Tensor::random([32, 32], Distribution::Default, &device).require_grad(); let tmp = x.clone() + y.clone(); let tmp = tmp.matmul(x); let tmp = tmp.exp(); let grads = tmp.backward(); let y_grad = y.grad(&grads).unwrap(); println!("{y_grad}"); } ``` Of note, it is impossible to make the mistake of calling backward on a model that runs on a backend that does not support autodiff (for inference), as this method is only offered by an Autodiff backend. See the [Autodiff Backend README](./crates/burn-autodiff/README.md) for more details.
Fusion: Backend decorator that brings kernel fusion to all first-party backends
This backend decorator enhances a backend with kernel fusion, provided that the inner backend supports it. Note that you can compose this backend with other backend decorators such as Autodiff. All first-party accelerated backends (like WGPU and CUDA) use Fusion by default (`burn/fusion` feature flag), so you typically don't need to apply it manually. ```rust #[cfg(not(feature = "fusion"))] pub type Cuda = CubeBackend; #[cfg(feature = "fusion")] pub type Cuda = burn_fusion::Fusion>; ``` Of note, we plan to implement automatic gradient checkpointing based on compute bound and memory bound operations, which will work gracefully with the fusion backend to make your code run even faster during training, see [this issue](https://github.com/tracel-ai/burn/issues/936). See the [Fusion Backend README](./crates/burn-fusion/README.md) for more details.
Router (Beta): Backend decorator that composes multiple backends into a single one
That backend simplifies hardware operability, if for instance you want to execute some operations on the CPU and other operations on the GPU. ```rust use burn::tensor::{Distribution, Tensor}; use burn::backend::{ NdArray, Router, Wgpu, ndarray::NdArrayDevice, router::duo::MultiDevice, wgpu::WgpuDevice, }; fn main() { type Backend = Router<(Wgpu, NdArray)>; let device_0 = MultiDevice::B1(WgpuDevice::DiscreteGpu(0)); let device_1 = MultiDevice::B2(NdArrayDevice::Cpu); let tensor_gpu = Tensor::::random([3, 3], burn::tensor::Distribution::Default, &device_0); let tensor_cpu = Tensor::::random([3, 3], burn::tensor::Distribution::Default, &device_1); } ```
Remote (Beta): Backend decorator for remote backend execution, useful for distributed computations
That backend has two parts, one client and one server. The client sends tensor operations over the network to a remote compute backend. You can use any first-party backend as server in a single line of code: ```rust fn main_server() { // Start a server on port 3000. burn::server::start::(Default::default(), 3000); } fn main_client() { // Create a client that communicate with the server on port 3000. use burn::backend::{Autodiff, RemoteBackend}; type Backend = Autodiff; let device = RemoteDevice::new("ws://localhost:3000"); let tensor_gpu = Tensor::::random([3, 3], Distribution::Default, &device); } ```

## Training & Inference
The whole deep learning workflow is made easy with Burn, as you can monitor your training progress with an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU clusters. Burn was built from the ground up with training and inference in mind. It's also worth noting how Burn, in comparison to frameworks like PyTorch, simplifies the transition from training to deployment, eliminating the need for code changes.

Burn Train TUI

**Click on the following sections to expand 👇**
Training Dashboard 📈
As you can see in the previous video (click on the picture!), a new terminal UI dashboard based on the [Ratatui](https://github.com/ratatui-org/ratatui) crate allows users to follow their training with ease without having to connect to any external application. You can visualize your training and validation metrics updating in real-time and analyze the lifelong progression or recent history of any registered metrics using only the arrow keys. Break from the training loop without crashing, allowing potential checkpoints to be fully written or important pieces of code to complete without interruption 🛡
ONNX Support 🐫
Burn supports importing ONNX (Open Neural Network Exchange) models through the [burn-onnx](https://github.com/tracel-ai/burn-onnx) crate, allowing you to easily port models from TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses Burn's native APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly) and benefit from all of Burn's optimizations like automatic kernel fusion. Our ONNX support is further described in [this section of the Burn Book 🔥](https://burn.dev/books/burn/onnx-import.html). > **Note**: This crate is in active development and currently supports a > [limited set of ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).
Importing PyTorch or Safetensors Models 🚚
You can load weights from PyTorch or Safetensors formats directly into your Burn-defined models. This makes it easy to reuse existing models while benefiting from Burn's performance and deployment features. Learn more in the [Saving & Loading Models](https://burn.dev/books/burn/saving-and-loading.html) section of the Burn Book.
Inference in the Browser 🌐
Several of our backends can run in WebAssembly environments: NdArray for CPU execution, and WGPU for GPU acceleration via WebGPU. This means that you can run inference directly within a browser. We provide several examples of this: - [MNIST](./examples/mnist-inference-web) where you can draw digits and a small convnet tries to find which one it is! 2️⃣ 7️⃣ 😰 - [Image Classification](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) where you can upload images and classify them! 🌄
Embedded: no_std support ⚙️
Burn's core components support [no_std](https://docs.rust-embedded.org/book/intro/no-std.html). This means it can run in bare metal environment such as embedded devices without an operating system. > As of now, only the NdArray backend can be used in a _no_std_ environment.

### Benchmarks To evaluate performance across different backends and track improvements over time, we provide a dedicated benchmarking suite. Run and compare benchmarks using [burn-bench](https://github.com/tracel-ai/burn-bench). > ⚠️ **Warning** When using one of the `wgpu` backends, you may encounter compilation errors related > to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency > chain. To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` > file: > > ```rust > #![recursion_limit = "256"] > ``` > > The default recursion limit (128) is often just below the required depth (typically 130-150) due > to deeply nested associated types and trait bounds. ## Getting Started
Just heard of Burn? You are at the right place! Just continue reading this section and we hope you can get on board really quickly.
The Burn Book 🔥
To begin working effectively with Burn, it is crucial to understand its key components and philosophy. This is why we highly recommend new users to read the first sections of [The Burn Book 🔥](https://burn.dev/books/burn/). It provides detailed examples and explanations covering every facet of the framework, including building blocks like tensors, modules, and optimizers, all the way to advanced usage, like coding your own GPU kernels. > The project is constantly evolving, and we try as much as possible to keep the book up to date > with new additions. However, we might miss some details sometimes, so if you see something weird, > let us know! We also gladly accept Pull Requests 😄
Examples 🙏
Let's start with a code snippet that shows how intuitive the framework is to use! In the following, we declare a neural network module with some parameters along with its forward pass. ```rust use burn::nn; use burn::module::Module; use burn::tensor::backend::Backend; #[derive(Module, Debug)] pub struct PositionWiseFeedForward { linear_inner: nn::Linear, linear_outer: nn::Linear, dropout: nn::Dropout, gelu: nn::Gelu, } impl PositionWiseFeedForward { pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input); let x = self.gelu.forward(x); let x = self.dropout.forward(x); self.linear_outer.forward(x) } } ``` We have a somewhat large amount of [examples](./examples) in the repository that shows how to use the framework in different scenarios. Following [the book](https://burn.dev/books/burn/): - [Basic Workflow](./examples/guide) : Creates a custom CNN `Module` to train on the MNIST dataset and use for inference. - [Custom Training Loop](./examples/custom-training-loop) : Implements a basic training loop instead of using the `Learner`. - [Custom WGPU Kernel](./examples/custom-wgpu-kernel) : Learn how to create your own custom operation with the WGPU backend. Additional examples: - [Custom CSV Dataset](./examples/custom-csv-dataset) : Implements a dataset to parse CSV data for a regression task. - [Regression](./examples/simple-regression) : Trains a simple MLP on the California Housing dataset to predict the median house value for a district. - [Custom Image Dataset](./examples/custom-image-dataset) : Trains a simple CNN on custom image dataset following a simple folder structure. - [Custom Renderer](./examples/custom-renderer) : Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress. - [Image Classification Web](./examples/image-classification-web) : Image classification web browser demo using Burn, WGPU and WebAssembly. - [MNIST Inference on Web](./examples/mnist-inference-web) : An interactive MNIST inference demo in the browser. The demo is available [online](https://burn.dev/demo/). - [MNIST Training](./examples/mnist) : Demonstrates how to train a custom `Module` (MLP) with the `Learner` configured to log metrics and keep training checkpoints. - [PyTorch Import Inference](./examples/import-model-weights) : Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. - [Text Classification](./examples/text-classification) : Trains a text classification transformer model on the AG News or DbPedia dataset. The trained model can then be used to classify a text sample. - [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the DbPedia dataset. - [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits based on MNIST. For more practical insights, you can clone the repository and run any of them directly on your computer!
Pre-trained Models 🤖
We keep an updated and curated list of models and examples built with Burn, see the [tracel-ai/models repository](https://github.com/tracel-ai/models) for more details. Don't see the model you want? Don't hesitate to open an issue, and we may prioritize it. Built a model using Burn and want to share it? You can also open a Pull Request and add your model under the community section!
Why use Rust for Deep Learning? 🦀
Deep Learning is a special form of software where you need very high level abstractions as well as extremely fast execution time. Rust is the perfect candidate for that use case since it provides zero-cost abstractions to easily create neural network modules, and fine-grained control over memory to optimize every detail. It's important that a framework be easy to use at a high level so that its users can focus on innovating in the AI field. However, since running models relies so heavily on computations, performance can't be neglected. To this day, the mainstream solution to this problem has been to offer APIs in Python, but rely on bindings to low-level languages such as C/C++. This reduces portability, increases complexity and creates frictions between researchers and engineers. We feel like Rust's approach to abstractions makes it versatile enough to tackle this two languages dichotomy. Rust also comes with the Cargo package manager, which makes it incredibly easy to build, test, and deploy from any environment, which is usually a pain in Python. Although Rust has the reputation of being a difficult language at first, we strongly believe it leads to more reliable, bug-free solutions built faster (after some practice 😅)!

> **Deprecation Note**
Since `0.14.0`, the internal structure for tensor data has changed. The > previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new > `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and > keeping the data type as a field. If you are using `Data` in your code, make sure to switch to > `TensorData`.
Loading Model Records From Previous Versions ⚠️
In the event that you are trying to load a model record saved in a version older than `0.14.0`, make sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat` feature flag. ``` features = [..., "record-backward-compat"] ``` Otherwise, the record won't be deserialized correctly and you will get an error message. This error will also point you to the backward compatible feature flag. The backward compatibility was maintained for deserialization when loading records. Therefore, as soon as you have saved the record again it will be saved according to the new structure and you can upgrade back to the current version Please note that binary formats are not backward compatible. Thus, you will need to load your record in a previous version and save it in any of the other self-describing record format (e.g., using the `NamedMpkFileRecorder`) before using a compatible version (as described) with the `record-backward-compat` feature flag.
## Community
If you are excited about the project, don't hesitate to join our [Discord](https://discord.gg/uPEBbYYDB6)! We try to be as welcoming as possible to everybody from any background. You can ask your questions and share what you built with the community!

**Contributing** Before contributing, please read the [Contributing Guidelines](./CONTRIBUTING.md) and our [Code of Conduct](./CODE-OF-CONDUCT.md). The [Contributor Book](https://burn.dev/contributor-book/) covers architecture, environment setup, and guides for common tasks. ## Status Burn is currently in active development, and there will be breaking changes. While any resulting issues are likely to be easy to fix, there are no guarantees at this stage. ## License Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull request is assumed to signal agreement with these licensing terms.
================================================ FILE: _typos.toml ================================================ [default] extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"] [default.extend-identifiers] UE4M3 = "UE4M3" UE8M0 = "UE8M0" ue8m0 = "ue8m0" [files] extend-exclude = [ "*.onnx", "*.proto", "assets/ModuleSerialization.xml", ] [default.extend-words] # Don't correct "arange" which is intentional arange = "arange" # Don't correct "convnet" (convolutional network) convnet = "convnet" ================================================ FILE: benchmarks.toml ================================================ [environment] gcp_gpu_attached = true gcp_image_family = "tracel-ci-ubuntu-2404-amd64-nvidia" # https://cloud.google.com/compute/docs/accelerator-optimized-machines # put the faster machine on first place for possibly faster 'Benchmarks Started' feedback in PRs gcp_machine_types = [ "a2-highgpu-1g", # 1 A100 40GB (listed as a2 standard) "g2-standard-4", # 1 L4 24GB ] # define the available zones for each machine type # be sure to check what machine types are available in each region # https://cloud.google.com/compute/docs/gpus/gpu-regions-zones#view-using-table gcp_zones = [ # a2-highgpu-1g [ "asia-northeast1-a", "asia-northeast1-c", "asia-northeast3-b", "asia-southeast1-b", "asia-southeast1-c", "europe-west4-a", "europe-west4-b", "us-central1-a", "us-central1-b", "us-central1-c", "us-central1-f", "us-east1-b", "us-west1-b", "us-west3-b", "us-west4-b" ], # g2-standard-4 [ "northamerica-northeast2-a", "northamerica-northeast2-b", "us-central1-a", "us-central1-b", "us-central1-c", "us-east1-b", "us-east1-c", "us-east1-d", "us-east4-a", "us-east4-c", "us-west1-a", "us-west1-b", "us-west1-c", "us-west4-a", "us-west4-c" ], ] repo_full = "tracel-ai/burn" rust_toolchain = "stable" rust_version = "stable" [burn-bench] github_organization = "tracel-ai" github_repository = "burn-bench" github_branch = "main" github_workflow = "benchmarks.yml" # vulkan autotune seems to take ages, disabling it for now # backends = ["cuda-fusion", "vulkan-fusion", "wgpu-fusion"] backends = ["cuda-fusion", "cuda"] benches = ["autodiff", "binary", "bool_select", "conv-transpose2d", "conv-transpose3d", "conv2d", "conv3d", "custom-gelu", "data", "load-record", "matmul-fused", "matmul", "max-pool2d", "random", "reduce", "softmax", "transformer-encoder", "unary" ] dtypes = ["f16"] ================================================ FILE: burn-book/.gitignore ================================================ target # MacOS temp file .DS_Store book-test guide/book .vscode tests/burn-book/book/ book/ # Ignore Jetbrains specific files. .idea/ # Ignore Vim temporary and swap files. *.sw? *~ ================================================ FILE: burn-book/.prettierrc.json ================================================ { "printWidth": 100, "proseWrap": "always" } ================================================ FILE: burn-book/book.toml ================================================ [book] authors = [ "Wouter Doppenberg", "Nathaniel Simard", "Louis Fortier-Dubois", "Dilshod Tadjibaev", "Guillaume Lagrange", "Sylvain Benner", "Bjorn Beishline" ] language = "en" src = "src" title = "The Burn Book 🔥" [output.html] mathjax-support = true ================================================ FILE: burn-book/src/SUMMARY.md ================================================ - [Overview](./overview.md) - [Why Burn?](./motivation.md) - [Getting started](./getting-started.md) - [Examples](./examples.md) - [Basic Workflow: From Training to Inference](./basic-workflow/README.md) - [Model](./basic-workflow/model.md) - [Data](./basic-workflow/data.md) - [Training](./basic-workflow/training.md) - [Backend](./basic-workflow/backend.md) - [Inference](./basic-workflow/inference.md) - [Building Blocks](./building-blocks/README.md) - [Backend](./building-blocks/backend.md) - [Tensor](./building-blocks/tensor.md) - [Autodiff](./building-blocks/autodiff.md) - [Module](./building-blocks/module.md) - [Learner](./building-blocks/learner.md) - [Metric](./building-blocks/metric.md) - [Config](./building-blocks/config.md) - [Record](./building-blocks/record.md) - [Dataset](./building-blocks/dataset.md) - [Performance](./performance/README.md) - [Good practices](./performance/good-practices/README.md) - [Asynchronous Execution](./performance/good-practices/asynchronous-execution.md) - [Kernel Fusion](./performance/good-practices/kernel-fusion.md) - [Kernel Selection](./performance/good-practices/kernel-selection.md) - [Quantization](./performance/quantization.md) - [Distributed Computing](./performance/distributed-computing.md) - [Custom Training Loop](./custom-training-loop.md) - [Saving & Loading Models](./saving-and-loading.md) - [ONNX Import](./onnx-import.md) - [Models & Pre-Trained Weights](./models-and-pretrained-weights.md) - [Advanced](./advanced/README.md) - [Backend Extension](./advanced/backend-extension/README.md) - [Custom `CubeCL` Kernel](./advanced/backend-extension/custom-cubecl-kernel.md) - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md) - [Custom Optimizer]() - [WebAssembly](./advanced/web-assembly.md) - [No-Std](./advanced/no-std.md) ================================================ FILE: burn-book/src/advanced/README.md ================================================ # Advanced In this section, we will go into advanced topics that extend beyond basic usage. Given Burn's exceptional flexibility, a lot of advanced use cases become possible. Before going through this section, we strongly recommend exploring the [basic workflow](../basic-workflow/) section and the [building blocks](../building-blocks/) section. Establishing a solid understanding of how the framework operates is crucial to comprehending the advanced concepts presented here. While you have the freedom to explore the advanced sections in any order you prefer, it's important to note that this section is not intended to be linear, contrary to preceding sections. Instead, it serves as a repository of use cases that you can refer to for guidance as needed. ================================================ FILE: burn-book/src/advanced/backend-extension/README.md ================================================ # Backend Extension Burn aims to be the most flexible deep learning framework. While it's crucial to maintain compatibility with a wide variety of backends, Burn provides the ability to extend the functionality of a backend implementation to suit your modeling requirements. This versatility is advantageous in numerous ways, such as supporting custom operations like flash attention or manually fusing operations for enhanced performance. In this section, we will go into the process of extending a backend, providing multiple examples. But before we proceed, let's establish the fundamental principles that will empower you to craft your own backend extensions. As you can observe, most types in Burn are generic over the Backend trait. This might give the impression that Burn operates at a high level over the backend layer. However, making the trait explicit instead of being chosen via a compilation flag was a thoughtful design decision. This explicitness does not imply that all backends must be identical; rather, it offers a great deal of flexibility when composing backends. The autodifferentiation backend trait (see [autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has been extended to enable gradient computation with backpropagation. Furthermore, this design allows you to create your own backend extension. To achieve this, you need to design your own backend trait specifying which functions should be supported. ```rust, ignore pub trait Backend: burn::tensor::backend::Backend { fn my_new_function(tensor: B::TensorPrimitive<2>) -> B::TensorPrimitive<2> { // You can define a basic implementation reusing the Burn Backend API. // This can be useful since all backends will now automatically support // your model. But performance can be improved for this new // operation by implementing this block in specific backends. } } ``` You can then implement your new custom backend trait for any backend that you want to support: ```rust, ignore impl Backend for burn_tch::LibTorch { fn my_new_function(tensor: TchTensor) -> TchTensor { // My Tch implementation } } impl Backend for burn_ndarray::NdArray { // No specific implementation, but the backend can still be used. } ``` You can support the backward pass using the same pattern. ```rust, ignore impl Backend for burn_autodiff::Autodiff { // No specific implementation; autodiff will work with the default // implementation. Useful if you still want to train your model, but // observe performance gains mostly during inference. } impl Backend for burn_autodiff::Autodiff { fn my_new_function(tensor: AutodiffTensor) -> AutodiffTensor { // My own backward implementation, generic over my custom Backend trait. // // You can add a new method `my_new_function_backward` to your custom backend // trait if you want to invoke a custom kernel during the backward pass. } } impl Backend for burn_autodiff::Autodiff> { fn my_new_function(tensor: AutodiffTensor) -> AutodiffTensor { // My own backward implementation, generic over a backend implementation. // // This is another way to call a custom kernel for the backward pass that // doesn't require the addition of a new `backward` function in the custom backend. // This is useful if you don't want all backends to support training, reducing // the need for extra code when you know your model will only be trained on one // specific backend. } } ``` The specifics of each implementation will be covered by the examples provided in this section. The `cubecl` compiler frontend is the recommended method of implementing custom kernels, since it supports multiple backends, including `wgpu` and `CUDA`, and is the way first-party `burn` kernels are written. ================================================ FILE: burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md ================================================ # Custom CubeCL Kernel In this section, you will learn how to create your own custom operation by writing your own kernel with the cubecl compiler frontend. We will take the example of a common workflow in the deep learning field, where we create a kernel to fuse multiple operations together. Note that `burn` does this automatically, but a manual implementation might be more efficient in some cases. We will fuse a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in various models. All the code can be found under the [examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-cubecl-kernel). > Note: CubeCL is in active development, so this section may be outdated. ## Custom Backend Trait First, we need to determine the type signature of our newly created operation by defining our custom backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid the ugly disambiguation with associated types. ```rust, ignore /// We create our own Backend trait that extends the Burn backend trait. pub trait Backend: burn::tensor::backend::Backend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor; } /// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} ``` In our project, we can use these traits instead of the `burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types. Therefore, we can encapsulate our newly defined backend traits with functions that expose new operations while maintaining a consistent API. ```rust, ignore /// We define our custom implementation using the added function on our custom backend. pub fn matmul_add_relu_custom( lhs: Tensor, rhs: Tensor, bias: Tensor, ) -> Tensor { let output = B::fused_matmul_add_relu( lhs.into_primitive().tensor(), rhs.into_primitive().tensor(), bias.into_primitive().tensor(), ); Tensor::from_primitive(TensorPrimitive::Float(output)) } /// We define a reference implementation using basic tensor operations. pub fn matmul_add_relu_reference( lhs: Tensor, rhs: Tensor, bias: Tensor, ) -> Tensor { let x = lhs.matmul(rhs) + bias; activation::relu(x) } ``` Note that we also provide a reference implementation for testing purposes, which allows us to easily validate our new implementation. While not mandatory, having a reference implementation can be valuable, especially in projects where creating a reference implementation solely using basic tensor operations is feasible. ## Forward Kernel Now, let's proceed to write the fused kernel using the `cubecl` compiler frontend. To keep things simple, we'll create a straightforward matmul kernel without employing any intricate techniques. We won't delve into the details of the `cube` macro, but if you're interested to learn more, please see [`cubecl` Book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book). The actual matmul, add and relu computations are found at the end, after an extensive prelude that serves to correctly map each compute unit to the data it is responsible for, with support for batches. ```rust, ignore use cubecl::{cube, prelude::*}; #[cube(launch)] pub fn fused_matmul_add_relu_kernel( lhs: &Tensor, rhs: &Tensor, bias: &Tensor, output: &mut Tensor, ) { let row = ABSOLUTE_POS_X; let col = ABSOLUTE_POS_Y; let batch = ABSOLUTE_POS_Z; let n_rows = output.shape(output.rank() - 2); let n_cols = output.shape(output.rank() - 1); let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { return; } let offset_output = batch * n_rows * n_cols; let mut offset_lhs = 0; let mut offset_rhs = 0; let batch_dims = output.rank() - 2; for dim in 0..batch_dims { offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim); offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim); } let mut sum = F::new(0.0); for k in 0..dim_k { let lhs_index = row * dim_k + k; let rhs_index = k * n_cols + col; sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index]; } let out_index = row * n_cols + col; let index = offset_output + out_index; output[index] = F::max(sum + bias[index], F::new(0.0)); } ``` Now, let's move on to the next step, which involves implementing the remaining code to launch the kernel. We'll go into implementing our custom backend trait for the generic JIT backend. This automatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion. ```rust, ignore /// Implement our custom backend trait for the generic `CubeBackend`. impl Backend for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor { // Define cube dim, hardcoded for simplicity. let cube_dim = CubeDim { x: 16, y: 16, z: 1 }; lhs.assert_is_on_same_device(&rhs); lhs.assert_is_on_same_device(&bias); // For simplicity, make sure each tensor is continuous. let lhs = into_contiguous(lhs); let rhs = into_contiguous(rhs); let bias = into_contiguous(bias); // Get the matmul relevant shapes. let ndims = lhs.shape.num_dims(); let num_rows = lhs.shape[ndims - 2]; let num_cols = rhs.shape[ndims - 1]; // Compute shape of output, while tracking number of batches. let mut num_batches = 1; let mut shape_out = vec![0; ndims]; for i in shape_out.clone().into_iter().take(ndims - 2) { shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]); num_batches *= shape_out[i]; } shape_out[ndims - 2] = num_rows; shape_out[ndims - 1] = num_cols; let shape_out = Shape::from(shape_out); // Create a buffer for the output tensor. let buffer = lhs .client .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. let output = CubeTensor::new_contiguous( lhs.client.clone(), lhs.device.clone(), shape_out, buffer, F::dtype(), ); // Declare the wgsl workgroup with the number of cubes in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); // Execute lazily the kernel with the launch information and the given buffers. For // simplicity, no vectorization is performed fused_matmul_add_relu_kernel::launch::( &lhs.client, cube_count, cube_dim, lhs.into_tensor_arg(), rhs.into_tensor_arg(), bias.into_tensor_arg(), output.clone().into_tensor_arg(), ); // Return the output tensor. output } } ``` In the preceding code block, we demonstrated how to launch the kernel that modifies the correct buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the capability to execute any mutable operation on any buffer. While this isn't a problem in the previous scenario where we only modify the newly created output buffer, it is wise to keep this in mind. ## Backward Now that the custom backend trait is implemented for the JIT backend, you can use it to invoke the `matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage. If your use case does not extend beyond inference, there is no need to implement any of the following code. For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is actually generic over the backend. Instead of crafting our own `cubecl` kernel for the backward pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic operations. ```rust, ignore // Implement our custom backend trait for any backend that also implements our custom backend trait. impl Backend for Autodiff { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor { // Create our zero-sized type that will implement the Backward trait. #[derive(Debug)] struct FusedMatmulAddReluBackward; // Implement the backward trait for the given backend B, the node gradient // with three other gradients to calculate (lhs, rhs, and bias). impl Backward for FusedMatmulAddReluBackward { // Our state that we must build during the forward pass to compute the backward pass. // // Note that we could improve the performance further by only keeping the state of // tensors that are tracked, improving memory management, but for simplicity, we avoid // that part. type State = (NodeId, NodeId, FloatTensor, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { // Get the nodes of each variable. let [node_lhs, node_rhs, node_bias] = ops.parents; // Fetch the gradient for the current node. let grad = grads.consume::(&ops.node); // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. let grad_output = B::relu_backward(output, grad); // Compute the lhs gradient, which is the derivative of matmul with support for // broadcasting. let grad_lhs = broadcast_shape::( B::float_matmul(grad_output.clone(), B::float_transpose(rhs)), &shape_lhs, ); // Compute the rhs gradient, which is the derivative of matmul with support for // broadcasting. let grad_rhs = broadcast_shape::( B::float_matmul(B::float_transpose(lhs), grad_output.clone()), &shape_rhs, ); // The add derivative is only 1, so we just need to support broadcasting to // compute the bias gradient. let grad_bias = broadcast_shape::(grad_output, &shape_bias); // Register the gradient for each variable based on whether they are marked as // `tracked`. if let Some(node) = node_bias { grads.register::(node.id, grad_bias); } if let Some(node) = node_lhs { grads.register::(node.id, grad_lhs); } if let Some(node) = node_rhs { grads.register::(node.id, grad_rhs); } } } // Prepare a stateful operation with each variable node and corresponding graph. // // Each node can be fetched with `ops.parents` in the same order as defined here. match FusedMatmulAddReluBackward .prepare::([lhs.node.clone(), rhs.node.clone(), bias.node.clone()]) // Marks the operation as compute bound, meaning it will save its // state instead of recomputing itself during checkpointing .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { // When at least one node is tracked, we should register our backward step. // The state consists of what will be needed for this operation's backward pass. // Since we need the parents' outputs, we must checkpoint their ids to retrieve // their node output at the beginning of the backward pass. We can also save // utility data such as the bias shape. If we also need this operation's output, // we can either save it in the state or recompute it. // during the backward pass. Here we choose to save it in the state because it's a // compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), rhs.primitive.clone(), bias.primitive, ); let state = (lhs_state, rhs_state, output.clone(), bias_shape); prep.finish(state, output) } OpsKind::UnTracked(prep) => { // When no node is tracked, we can just compute the original operation without // keeping any state. let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); prep.finish(output) } } } } ``` The previous code is self-documented to make it clearer, but here is what it does in summary: We define `fused_matmul_add_relu` within `Autodiff`, allowing any autodiff-decorated backend to benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be implemented. This is achieved using a comprehensive match statement block where computation is delegated to the inner backend, while keeping track of a state. The state comprises any information relevant to the backward pass, such as input and output tensors, along with the bias shape. When an operation isn't tracked (meaning there won't be a backward pass for this specific operation in the graph), storing a state becomes unnecessary, and we simply perform the forward computation. The backward pass uses the gradient obtained from the preceding node in the computation graph. It calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in gradients for both input tensors and the bias, which are registered for consumption by subsequent operation nodes. The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend. ```rust, ignore impl AutodiffBackend for Autodiff> { } ``` ## Conclusion In this guide, we've implemented a fused kernel using the `cubecl` compiler frontend, enabling execution on any GPU and any `cubecl` backend. By delving into the inner workings of both the JIT backend and the autodiff backend, we've gained a deeper understanding of these systems. While extending a backend may be harder than working with straightforward tensors, the benefits can be worth it. This approach enables the crafting of custom models with greater control over execution, which can potentially greatly enhance the performance of your models. As we conclude this guide, we hope that you have gained insights into Burn's world of backend extensions, and that it will help you to unleash the full potential of your projects. ================================================ FILE: burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md ================================================ # Custom WGPU Kernel In this section, you will learn how to create your own custom operation by writing your own kernel with the WGPU backend. We will take the example of a common workflow in the deep learning field, where we create a kernel to fuse multiple operations together. Note that `burn` does this automatically, but a manual implementation might be more efficient in some cases. We will fuse a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in various models. All the code can be found under the [examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-wgpu-kernel). ## Custom Backend Trait First, we need to determine the type signature of our newly created operation by defining our custom backend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid the ugly disambiguation with associated types. ```rust, ignore /// We create our own Backend trait that extends the Burn backend trait. pub trait Backend: burn::tensor::backend::Backend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor; } /// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait. pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {} ``` In our project, we can use these traits instead of the `burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs typically make use of the `Tensor` struct rather than dealing directly with primitive tensor types. Therefore, we can encapsulate our newly defined backend traits with functions that expose new operations while maintaining a consistent API. ```rust, ignore /// We define our custom implementation using the added function on our custom backend. pub fn matmul_add_relu_custom( lhs: Tensor, rhs: Tensor, bias: Tensor, ) -> Tensor { let output = B::fused_matmul_add_relu( lhs.into_primitive().tensor(), rhs.into_primitive().tensor(), bias.into_primitive().tensor(), ); Tensor::from_primitive(TensorPrimitive::Float(output)) } /// We define a reference implementation using basic tensor operations. pub fn matmul_add_relu_reference( lhs: Tensor, rhs: Tensor, bias: Tensor, ) -> Tensor { let x = lhs.matmul(rhs) + bias; activation::relu(x) } ``` Note that we also provide a reference implementation for testing purposes, which allows us to easily validate our new implementation. While not mandatory, having a reference implementation can be valuable, especially in projects where creating a reference implementation solely using basic tensor operations is feasible. ## Forward Kernel Now, let's proceed to write the fused kernel using the WGSL shading language. To keep things simple, we'll create a straightforward matmul kernel without employing any intricate techniques. Although we won't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we still provide the implementation below for readers who are curious. The actual matmul, add and relu computations are found at the end, after an extensive overhead whose use is to correctly map each compute unit to the data it is responsible of, with support for batches. ```wgsl, ignore @group(0) @binding(0) var lhs: array<{{ elem }}>; @group(0) @binding(1) var rhs: array<{{ elem }}>; @group(0) @binding(2) var bias: array<{{ elem }}>; @group(0) @binding(3) var output: array<{{ elem }}>; @group(0) @binding(4) var info: array; const BLOCK_SIZE = {{ workgroup_size_x }}u; @compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) fn main( @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_index) local_idx: u32, @builtin(workgroup_id) workgroup_id: vec3, ) { // Indices let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE); let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE); let batch = global_id.z; // Basic information let dim = info[0]; let n_rows = info[6u * dim - 1u]; let n_cols = info[6u * dim]; let K = info[5u * dim - 1u]; // Returns if outside the output dimension if row >= n_rows || col >= n_cols { return; } // Calculate the corresponding offsets with support for broadcasting. let offset_output = batch * n_rows * n_cols; var offset_lhs: u32 = 0u; var offset_rhs: u32 = 0u; let batch_dims = dim - 2u; for (var b: u32 = 1u; b <= batch_dims; b++) { let stride_lhs = info[b]; let stride_rhs = info[b + dim]; let stride_output = info[b + 2u * dim]; let shape_lhs = info[b + 3u * dim]; let shape_rhs = info[b + 4u * dim]; offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs; offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs; } // Basic matmul implementation var sum = 0.0; for (var k: u32 = 0u; k < K; k++) { let lhs_index = row * K + k; let rhs_index = k * n_cols + col; sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index]; } let output_index = row * n_cols + col; let index = offset_output + output_index; // Add and ReLU output[index] = max(sum + bias[index], 0.0); } ``` Now, let's move on to the next step, which involves implementing the remaining code to launch the kernel. The initial part entails loading the template and populating it with the appropriate variables. The `register(name, value)` method simply replaces occurrences of `{{ name }}` in the above WGSL code with some other string before it is compiled. In order to use templating utilities, you will have to activate the `template` feature of Burn in your `cargo.toml`. ```rust, ignore // Source the kernel written in WGSL. kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl"); // Define our kernel type with cube information. #[derive(new, Debug)] struct FusedMatmulAddRelu { cube_dim: CubeDim, _elem: PhantomData, } // Implement the dynamic kernel trait for our kernel type. impl KernelSource for FusedMatmulAddRelu { fn source(&self) -> SourceTemplate { // Extend our raw kernel with cube size information using the // `SourceTemplate` trait. FusedMatmulAddReluRaw::new() .source() .register("workgroup_size_x", self.cube_dim.x.to_string()) .register("workgroup_size_y", self.cube_dim.y.to_string()) .register("elem", E::type_name()) .register("int", "i32") } fn id(&self) -> cubecl::KernelId { cubecl::KernelId::new::().info(self.cube_dim) } } ``` Subsequently, we'll go into implementing our custom backend trait for the WGPU backend. Note that we won't go into supporting the `fusion` feature flag in this tutorial, so we implement the trait for the raw `WgpuBackend` type. ```rust, ignore /// Implement our custom backend trait for the existing backend `WgpuBackend`. impl Backend for CubeBackend { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor { // Define cube dim, hardcoded for simplicity. let cube_dim = CubeDim { x: 16, y: 16, z: 1 }; lhs.assert_is_on_same_device(&rhs); lhs.assert_is_on_same_device(&bias); // For simplicity, make sure each tensor is continuous. let lhs = into_contiguous(lhs); let rhs = into_contiguous(rhs); let bias = into_contiguous(bias); // Get the matmul relevant shapes. let ndims = lhs.shape.num_dims(); let num_rows = lhs.shape[ndims - 2]; let num_cols = rhs.shape[ndims - 1]; // Compute shape of output, while tracking number of batches. let mut num_batches = 1; let mut shape_out = vec![0; ndims]; for i in shape_out.clone().into_iter().take(ndims - 2) { shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]); num_batches *= shape_out[i]; } shape_out[ndims - 2] = num_rows; shape_out[ndims - 1] = num_cols; let shape_out = Shape::from(shape_out); // Create a buffer for the output tensor. let buffer = lhs .client .empty(shape_out.num_elements() * core::mem::size_of::()); // Create the output tensor primitive. let output = CubeTensor::new_contiguous( lhs.client.clone(), lhs.device.clone(), shape_out, buffer, F::dtype(), ); // Create the kernel. let kernel = FusedMatmulAddRelu::::new(cube_dim); // Build info buffer with tensor information needed by the kernel, such as shapes and strides. let info = build_info::<_, F>(&[&lhs, &rhs, &output]); let info_handle = lhs.client.create(bytemuck::cast_slice(&info)); // Declare the wgsl workgroup with the number of cubes in x, y and z. let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32; let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32; let cube_count = CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32); // Execute lazily the kernel with the launch information and the given buffers. lhs.client.execute( Box::new(SourceKernel::new(kernel, cube_dim)), cube_count, Bindings::new().with_buffers(vec![ lhs.handle.binding(), rhs.handle.binding(), bias.handle.binding(), output.handle.clone().binding(), info_handle.binding(), ]), ); // Return the output tensor. output } } ``` In the preceding code block, we demonstrated how to launch the kernel that modifies the correct buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the capability to execute any mutable operation on any buffer. While this isn't a problem in the previous scenario where we only modify the newly created output buffer, it is wise to keep this in mind. ## Backward Now that the custom backend trait is implemented for the WGPU backend, you can use it to invoke the `matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage. If your use case does not extend beyond inference, there is no need to implement any of the following code. For the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is actually generic over the backend. Instead of crafting our own WGSL kernel for the backward pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic operations. ```rust, ignore // Implement our custom backend trait for any backend that also implements our custom backend trait. // // Note that we could implement the backend trait only for the Wgpu backend instead of any backend that // also implements our own API. This would allow us to call any function only implemented for Wgpu // and potentially call a custom kernel crafted only for this task. impl Backend for Autodiff { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor, bias: FloatTensor, ) -> FloatTensor { // Create our zero-sized type that will implement the Backward trait. #[derive(Debug)] struct FusedMatmulAddReluBackward; // Implement the backward trait for the given backend B, the node gradient // with three other gradients to calculate (lhs, rhs, and bias). impl Backward for FusedMatmulAddReluBackward { // Our state that we must build during the forward pass to compute the backward pass. // // Note that we could improve the performance further by only keeping the state of // tensors that are tracked, improving memory management, but for simplicity, we avoid // that part. type State = (NodeId, NodeId, FloatTensor, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { // Get the nodes of each variable. let [node_lhs, node_rhs, node_bias] = ops.parents; // Fetch the gradient for the current node. let grad = grads.consume::(&ops.node); // Set our state. let (lhs_state, rhs_state, output, shape_bias) = ops.state; let lhs: FloatTensor = checkpointer.retrieve_node_output(lhs_state); let rhs: FloatTensor = checkpointer.retrieve_node_output(rhs_state); // Fetch shapes of our tensor to support broadcasting. let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); // Compute the gradient of the output using the already existing `relu_backward` // function in the basic Burn backend trait. let grad_output = B::relu_backward(output, grad); // Compute the lhs gradient, which is the derivative of matmul with support for // broadcasting. let grad_lhs = broadcast_shape::( B::float_matmul(grad_output.clone(), B::float_transpose(rhs)), &shape_lhs, ); // Compute the rhs gradient, which is the derivative of matmul with support for // broadcasting. let grad_rhs = broadcast_shape::( B::float_matmul(B::float_transpose(lhs), grad_output.clone()), &shape_rhs, ); // The add derivative is only 1, so we just need to support broadcasting to // compute the bias gradient. let grad_bias = broadcast_shape::(grad_output, &shape_bias); // Register the gradient for each variable based on whether they are marked as // `tracked`. if let Some(node) = node_bias { grads.register::(node.id, grad_bias); } if let Some(node) = node_lhs { grads.register::(node.id, grad_lhs); } if let Some(node) = node_rhs { grads.register::(node.id, grad_rhs); } } } // Prepare a stateful operation with each variable node and corresponding graph. // // Each node can be fetched with `ops.parents` in the same order as defined here. match FusedMatmulAddReluBackward .prepare::([lhs.node.clone(), rhs.node.clone(), bias.node.clone()]) // Marks the operation as compute bound, meaning it will save its // state instead of recomputing itself during checkpointing .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { // When at least one node is tracked, we should register our backward step. // The state consists of what will be needed for this operation's backward pass. // Since we need the parents' outputs, we must checkpoint their ids to retrieve their node // output at the beginning of the backward. We can also save utility data such as the bias shape // If we also need this operation's output, we can either save it in the state or recompute it // during the backward pass. Here we choose to save it in the state because it's a compute bound operation. let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); let bias_shape = bias.primitive.shape(); let output = B::fused_matmul_add_relu( lhs.primitive.clone(), rhs.primitive.clone(), bias.primitive, ); let state = (lhs_state, rhs_state, output.clone(), bias_shape); prep.finish(state, output) } OpsKind::UnTracked(prep) => { // When no node is tracked, we can just compute the original operation without // keeping any state. let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive); prep.finish(output) } } } } ``` The previous code is self-documented to make it clearer, but here is what it does in summary. We define `fused_matmul_add_relu` within `Autodiff`, allowing any autodiff-decorated backend to benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be implemented. This is achieved using a comprehensive match statement block where computation is delegated to the inner backend, while keeping track of a state. The state comprises any information relevant to the backward pass, such as input and output tensors, along with the bias shape. When an operation isn't tracked (meaning there won't be a backward pass for this specific operation in the graph), storing a state becomes unnecessary, and we simply perform the forward computation. The backward pass uses the gradient obtained from the preceding node in the computation graph. It calculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the derivative is one), and `matmul` (another `matmul` with transposed inputs). This results in gradients for both input tensors and the bias, which are registered for consumption by subsequent operation nodes. The only remaining part is to implement our autodiff-decorated backend trait for our WGPU Backend. ```rust, ignore impl AutodiffBackend for Autodiff> { } ``` ## Conclusion In this guide, we've implemented a fused kernel using the WGPU backend, enabling execution on any GPU. By delving into the inner workings of both the WGPU backend and the autodiff backend, we've gained a deeper understanding of these systems. While extending a backend may be harder than working with straightforward tensors, the benefits can be worth it. This approach enables the crafting of custom models with greater control over execution, which can potentially greatly enhance the performance of your models. As we conclude this guide, we hope that you have gained insights into Burn's world of backend extensions, and that it will help you to unleash the full potential of your projects. ================================================ FILE: burn-book/src/advanced/no-std.md ================================================ # No Standard Library In this section, you will learn how to run an ONNX inference model on an embedded system, with no standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other platforms. All the code can be found in the [burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico). ## Step-by-Step Guide Let's walk through the process of running an embedded ONNX model: ### Setup Follow the [embassy guide](https://embassy.dev/book/#_getting_started) for your specific environment. Once setup, you should have something similar to the following. ``` ./inference ├── Cargo.lock ├── Cargo.toml ├── build.rs ├── memory.x └── src └── main.rs ``` Some other dependencies have to be added ```toml [dependencies] embedded-alloc = "0.6.0" # Only if there is no default allocator for your chip burn = { version = "0.21", default-features = false, features = ["ndarray"] } # Backend must be ndarray burn-store = { version = "0.21", default-features = false, features = ["burnpack"] } [build-dependencies] burn-onnx = { version = "0.21" } # Used to auto generate the rust code to import the model ``` ### Import the Model Follow the directions in [ONNX Import](../onnx-import.md). Use the following ModelGen config ```rs ModelGen::new() .input(my_model) .out_dir("model/") .embed_states(true) .run_from_script(); ``` ### Global Allocator First define a global allocator (if you are on a no_std system without alloc). ```rs use embedded_alloc::LlffHeap as Heap; #[global_allocator] static HEAP: Heap = Heap::empty(); #[embassy_executor::main] async fn main(_spawner: Spawner) { { use core::mem::MaybeUninit; // Watch out for this, if it is too big or small for your model, the // program may crash. This is in u8 bytes, as such this is a total of 100kb const HEAP_SIZE: usize = 100 * 1024; static mut HEAP_MEM: [MaybeUninit; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE]; unsafe { HEAP.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) } // Initialize the heap } } ``` ### Define Backend We are using ndarray, so we just need to define the NdArray backend as usual ```rs use burn::{backend::NdArray, tensor::Tensor}; type Backend = NdArray; type BackendDevice = ::Device; ``` Then inside the `main` function add ```rs use your_model::Model; // Get a default device for the backend let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); ``` ### Running the Model To run the model, just call it as you would normally ```rs // Define the tensor let input = Tensor::::from_floats([[input]], &device); // Run the model on the input let output = model.forward(input); ``` ## Conclusion Running a model in a no_std environment is pretty much identical to a normal environment. All that is needed is a global allocator. ================================================ FILE: burn-book/src/advanced/web-assembly.md ================================================ # WebAssembly Burn supports WebAssembly (WASM) execution using the `NdArray` and `WebGpu` backends, allowing models to run directly in the browser. Check out the following examples: - [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) - [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web) When targeting WebAssembly, certain dependencies require additional configuration. In particular, the `getrandom` crate requires explicit setting when using `WebGpu`. ================================================ FILE: burn-book/src/basic-workflow/README.md ================================================ # Guide This guide will walk you through the process of creating a custom model built with Burn. We will train a simple convolutional neural network model on the MNIST dataset and prepare it for inference. For clarity, we sometimes omit imports in our code snippets. For more details, please refer to the corresponding code in the `examples/guide` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide). We reproduce this example in a step-by-step fashion, from dataset creation to modeling and training in the following sections. It is recommended to use the capabilities of your IDE or text editor to automatically add the missing imports as you add the code snippets to your code.
Be sure to checkout the git branch corresponding to the version of Burn you are using to follow this guide. The current version of Burn is `0.21` and the corresponding branch to checkout is `main`.
The code for this demo can be executed from Burn's base directory using the command: ```bash cargo run --example guide ``` ## Key Learnings - Creating a project - Creating neural network models - Importing and preparing datasets - Training models on data - Choosing a backend - Using a model for inference ================================================ FILE: burn-book/src/basic-workflow/backend.md ================================================ # Backend We have effectively written most of the necessary code to train our model. However, we have not explicitly designated the backend to be used at any point. This will be defined in the main entrypoint of our program, namely the `main` function defined in `src/main.rs`. ```rust , ignore # #![recursion_limit = "256"] # mod data; # mod model; # mod training; # use crate::{model::ModelConfig, training::TrainingConfig}; use burn::{ backend::{Autodiff, Wgpu}, # data::dataset::Dataset, optim::AdamConfig, }; fn main() { type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; let device = burn::backend::wgpu::WgpuDevice::default(); let artifact_dir = "/tmp/guide"; crate::training::train::( artifact_dir, TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); } ``` In this code snippet, we use the `Wgpu` backend which is compatible with any operating system and will use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the float type and the int type as generic arguments that will be used during the training. The autodiff backend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability to any backend. We call the `train` function defined earlier with a directory for artifacts, the configuration of the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer configuration which in our case will be the default Adam configuration, and the device which can be obtained from the backend. You can now train your freshly created model with the command: ```console cargo run --release ``` When running your project with the command above, you should see the training progression through a basic CLI dashboard: Alt text ================================================ FILE: burn-book/src/basic-workflow/data.md ================================================ # Data Typically, one trains a model on some dataset. Burn provides a library of very useful dataset sources and transformations, such as Hugging Face dataset utilities that allow to download and store data into an SQLite database for extremely efficient data streaming and storage. For this guide though, we will use the MNIST dataset from `burn::data::dataset::vision` which requires no external dependency. To iterate over a dataset efficiently, we will define a struct which will implement the `Batcher` trait. The goal of a batcher is to map individual dataset items into a batched tensor that can be used as input to our previously defined model. Let us start by defining our dataset functionalities in a file `src/data.rs`. We shall omit some of the imports for brevity, but the full code for following this guide can be found at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide). ```rust , ignore use burn::{ data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, prelude::*, }; #[derive(Clone, Default)] pub struct MnistBatcher {} ``` This batcher is pretty straightforward, as it only defines a struct that will implement the `Batcher` trait. The trait is generic over the `Backend` trait, which includes an associated type for the device, as not all backends expose the same devices. As an example, the Libtorch-based backend exposes `Cuda(gpu_index)`, `Cpu`, `Vulkan` and `Metal` devices, while the ndarray backend only exposes the `Cpu` device. Next, we need to actually implement the batching logic. ```rust , ignore # use burn::{ # data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, # prelude::*, # }; # # #[derive(Clone, Default)] # pub struct MnistBatcher {} # #[derive(Clone, Debug)] pub struct MnistBatch { pub images: Tensor, pub targets: Tensor, } impl Batcher> for MnistBatcher { fn batch(&self, items: Vec, device: &B::Device) -> MnistBatch { let images = items .iter() .map(|item| TensorData::from(item.image).convert::()) .map(|data| Tensor::::from_data(data, device)) .map(|tensor| tensor.reshape([1, 28, 28])) // Normalize: scale between [0,1] and make the mean=0 and std=1 // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) .collect(); let targets = items .iter() .map(|item| { Tensor::::from_data([(item.label as i64).elem::()], device) }) .collect(); let images = Tensor::cat(images, 0); let targets = Tensor::cat(targets, 0); MnistBatch { images, targets } } } ```
🦀 Iterators and Closures The iterator pattern allows you to perform some tasks on a sequence of items in turn. In this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the `iter` method. _Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by changing some aspect of the original iterator. Here, the `map` method is called in a chain to transform the original data before consuming the final iterator with `collect` to obtain the `images` and `targets` vectors. Both vectors are then concatenated into a single tensor for the current batch. You probably noticed that each call to `map` is different, as it defines a function to execute on the iterator items at each step. These anonymous functions are called [_closures_](https://doc.rust-lang.org/book/ch13-01-closures.html) in Rust. They're easy to recognize due to their syntax which uses vertical bars `||`. The vertical bars capture the input variables (if applicable) while the rest of the expression defines the function to execute. If we go back to the example, we can break down and comment the expression used to process the images. ```rust, ignore let images = items // take items Vec .iter() // create an iterator over it .map(|item| TensorData::from(item.image).convert::()) // for each item, convert the image to float data struct .map(|data| Tensor::::from_data(data, device)) // for each data struct, create a tensor on the device .map(|tensor| tensor.reshape([1, 28, 28])) // for each tensor, reshape to the image dimensions [C, H, W] .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) // for each image tensor, apply normalization .collect(); // consume the resulting iterator & collect the values into a new vector ``` For more information on iterators and closures, be sure to check out the [corresponding chapter](https://doc.rust-lang.org/book/ch13-00-functional-features.html) in the Rust Book.

In the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a single `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with a targets tensor that contains the indexes of the correct digit class. The first step is to parse the image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate tensor storage information without being specific for a backend. When creating a tensor from data, we often need to convert the data precision to the current backend in use. This can be done with the `.convert()` method (in this example, the data is converted backend's float element type `B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()` on a specific number to convert it to the current backend element type in use. ================================================ FILE: burn-book/src/basic-workflow/inference.md ================================================ # Inference Now that we have trained our model, the next natural step is to use it for inference. You need two things in order to load weights for a model: the model's record and the model's config. Since parameters in Burn are lazy initialized, no allocation and GPU/CPU kernels are executed by the `ModelConfig::init` function. The weights are initialized when used for the first time, therefore you can safely use `config.init(device).load_record(record)` without any meaningful performance cost. Let's create a simple `infer` method in a new file `src/inference.rs` which we will use to load our trained model. ```rust , ignore # use crate::{data::MnistBatcher, training::TrainingConfig}; # use burn::{ # data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, # prelude::*, # record::{CompactRecorder, Recorder}, # }; # pub fn infer(artifact_dir: &str, device: B::Device, item: MnistItem) { let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) .expect("Config should exist for the model; run train first"); let record = CompactRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) .expect("Trained model should exist; run train first"); let model = config.model.init::(&device).load_record(record); let label = item.label; let batcher = MnistBatcher::default(); let batch = batcher.batch(vec![item], &device); let output = model.forward(batch.images); let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar(); println!("Predicted {predicted} Expected {label}"); } ``` The first step is to load the configuration of the training to fetch the correct model configuration. Then we can fetch the record using the same recorder as we used during training. Finally we can init the model with the configuration and the record. For simplicity we can use the same batcher used during the training to pass from a MnistItem to a tensor. By running the infer function, you should see the predictions of your model! Add the call to `infer` to the `main.rs` file after the `train` function call: ```rust , ignore # #![recursion_limit = "256"] # mod data; # mod inference; # mod model; # mod training; # # use crate::{model::ModelConfig, training::TrainingConfig}; # use burn::{ # backend::{Autodiff, Wgpu}, # data::dataset::Dataset, # optim::AdamConfig, # }; # # fn main() { # type MyBackend = Wgpu; # type MyAutodiffBackend = Autodiff; # # let device = burn::backend::wgpu::WgpuDevice::default(); # let artifact_dir = "/tmp/guide"; # crate::training::train::( # artifact_dir, # TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()), # device.clone(), # ); crate::inference::infer::( artifact_dir, device, burn::data::dataset::vision::MnistDataset::test() .get(42) .unwrap(), ); # } ``` The number `42` is the index of the image in the MNIST dataset. You can explore and verify them using this [MNIST viewer](https://observablehq.com/@davidalber/mnist-viewer). --- In this short guide, we've introduced you to the fundamental building blocks for getting started with Burn. While there's still plenty to explore, our goal has been to provide you with the essential knowledge to kickstart your productivity within the framework. ================================================ FILE: burn-book/src/basic-workflow/model.md ================================================ # Model The first step is to create a project and add the different Burn dependencies. Start by creating a new project with Cargo: ```console cargo new guide ``` As [mentioned previously](../getting-started.md#creating-a-burn-application), this will initialize your `guide` project directory with a `Cargo.toml` and a `src/main.rs` file. In the `Cargo.toml` file, add the `burn` dependency with `train`, `vision` and `wgpu` features. Since we disable the default features, we also want to enable `std`, `tui` (for the dashboard) and `fusion` for wgpu. Then run `cargo build` to build the project and import all the dependencies. ```toml [package] name = "guide" version = "0.1.0" edition = "2024" [dependencies] # Disable autotune default for convolutions burn = { version = "~0.21", features = ["std", "tui", "train", "vision", "wgpu", "fusion"], default-features = false } # burn = { version = "~0.21", features = ["train", "vision", "wgpu"] } ``` Our goal will be to create a basic convolutional neural network used for image classification. We will keep the model simple by using two convolution layers followed by two linear layers, some pooling and ReLU activations. We will also use dropout to improve training performance. Let us start by defining our model struct in a new file `src/model.rs`. ```rust , ignore use burn::{ nn::{ conv::{Conv2d, Conv2dConfig}, pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, Dropout, DropoutConfig, Linear, LinearConfig, Relu, }, prelude::*, }; #[derive(Module, Debug)] pub struct Model { conv1: Conv2d, conv2: Conv2d, pool: AdaptiveAvgPool2d, dropout: Dropout, linear1: Linear, linear2: Linear, activation: Relu, } ``` There are two major things going on in this code sample. 1. You can create a deep learning module with the `#[derive(Module)]` attribute on top of a struct. This will generate the necessary code so that the struct implements the `Module` trait. This trait will make your module both trainable and (de)serializable while adding related functionalities. Like other attributes often used in Rust, such as `Clone`, `PartialEq` or `Debug`, each field within the struct must also implement the `Module` trait.
🦀 Trait Traits are a powerful and flexible Rust language feature. They provide a way to define shared behavior for a particular type, which can be shared with other types. A type's behavior consists of the methods called on that type. Since all `Module`s should implement the same functionality, it is defined as a trait. Implementing a trait on a particular type usually requires the user to implement the defined behaviors of the trait for their types, though that is not the case here as explained above with the `derive` attribute. Check out the [explainer below](#derive-attribute) to learn why. For more details on traits, take a look at the [associated chapter](https://doc.rust-lang.org/book/ch10-02-traits.html) in the Rust Book.

🦀 Derive Macro The `derive` attribute allows traits to be implemented easily by generating code that will implement a trait with its own default implementation on the type that was annotated with the `derive` syntax. This is accomplished through a feature of Rust called [procedural macros](https://doc.rust-lang.org/reference/procedural-macros.html), which allow us to run code at compile time that operates over Rust syntax, both consuming and producing Rust syntax. Using the attribute `#[my_macro]`, you can effectively extend the provided code. You will see that the derive macro is very frequently employed to recursively implement traits, where the implementation consists of the composition of all fields. In this example, we want to derive the [`Module`](../building-blocks/module.md) and `Debug` traits. ```rust, ignore #[derive(Module, Debug)] pub struct MyCustomModule { linear1: Linear, linear2: Linear, activation: Relu, } ``` The basic `Debug` implementation is provided by the compiler to format a value using the `{:?}` formatter. For ease of use, the `Module` trait implementation is automatically handled by Burn so you don't have to do anything special. It essentially acts as parameter container. For more details on derivable traits, take a look at the Rust [appendix](https://doc.rust-lang.org/book/appendix-03-derivable-traits.html), [reference](https://doc.rust-lang.org/reference/attributes/derive.html) or [example](https://doc.rust-lang.org/rust-by-example/trait/derive.html).

2. Note that the struct is generic over the [`Backend`](../building-blocks/backend.md) trait. The backend trait abstracts the underlying low level implementations of tensor operations, allowing your new model to run on any backend. Contrary to other frameworks, the backend abstraction isn't determined by a compilation flag or a device type. This is important because you can extend the functionalities of a specific backend (see [backend extension section](../advanced/backend-extension)), and it allows for an innovative [autodiff system](../building-blocks/autodiff.md). You can also change backend during runtime, for instance to compute training metrics on a cpu backend while using a gpu one only to train the model. In our example, the backend in use will be determined later on.
🦀 Trait Bounds Trait bounds provide a way for generic items to restrict which types are used as their parameters. The trait bounds stipulate what functionality a type implements. Therefore, bounding restricts the generic to types that conform to the bounds. It also allows generic instances to access the methods of traits specified in the bounds. For a simple but concrete example, check out the [Rust By Example on bounds](https://doc.rust-lang.org/rust-by-example/generics/bounds.html). In Burn, the `Backend` trait enables you to run tensor operations using different implementations as it abstracts tensor, device and element types. The [getting started example](../getting-started.md#writing-a-code-snippet) illustrates the advantage of having a simple API that works for different backend implementations. While it used the WGPU backend, you could easily swap it with any other supported backend. ```rust, ignore // Choose from any of the supported backends. // type Backend = Candle; // type Backend = LibTorch; // type Backend = NdArray; type Backend = Wgpu; // Creation of two tensors. let tensor_1 = Tensor::::from_data([[2., 3.], [4., 5.]], &device); let tensor_2 = Tensor::::ones_like(&tensor_1); // Print the element-wise addition (done with the selected backend) of the two tensors. println!("{}", tensor_1 + tensor_2); ``` For more details on trait bounds, check out the Rust [trait bound section](https://doc.rust-lang.org/book/ch10-02-traits.html#trait-bound-syntax) or [reference](https://doc.rust-lang.org/reference/items/traits.html#trait-bounds).

Note that each time you create a new file in the `src` directory you also need to explicitly add this module to the `main.rs` file. For instance after creating the `model.rs`, you need to add the following at the top of the main file: ```rust , ignore mod model; # # fn main() { # } ``` Next, we need to instantiate the model for training. ```rust , ignore # use burn::{ # nn::{ # conv::{Conv2d, Conv2dConfig}, # pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, # Dropout, DropoutConfig, Linear, LinearConfig, Relu, # }, # prelude::*, # }; # # #[derive(Module, Debug)] # pub struct Model { # conv1: Conv2d, # conv2: Conv2d, # pool: AdaptiveAvgPool2d, # dropout: Dropout, # linear1: Linear, # linear2: Linear, # activation: Relu, # } # #[derive(Config, Debug)] pub struct ModelConfig { num_classes: usize, hidden_size: usize, #[config(default = "0.5")] dropout: f64, } impl ModelConfig { /// Returns the initialized model. pub fn init(&self, device: &B::Device) -> Model { Model { conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), activation: Relu::new(), linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } } ``` At a glance, you can view the model configuration by printing the model instance: ```rust , ignore #![recursion_limit = "256"] mod model; use crate::model::ModelConfig; use burn::backend::Wgpu; fn main() { type MyBackend = Wgpu; let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); println!("{model}"); } ``` Output: ```rust , ignore Model { conv1: Conv2d {ch_in: 1, ch_out: 8, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80} conv2: Conv2d {ch_in: 8, ch_out: 16, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168} pool: AdaptiveAvgPool2d {output_size: [8, 8]} dropout: Dropout {prob: 0.5} linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800} linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130} activation: Relu params: 531178 } ```
🦀 References In the previous example, the `init()` method signature uses `&` to indicate that the parameter types are references: `&self`, a reference to the current receiver (`ModelConfig`), and `device: &B::Device`, a reference to the backend device. ```rust, ignore pub fn init(&self, device: &B::Device) -> Model { Model { // ... } } ``` References in Rust allow us to point to a resource to access its data without owning it. The idea of ownership is quite core to Rust and is worth [reading up on](https://doc.rust-lang.org/book/ch04-00-understanding-ownership.html). In a language like C, memory management is explicit and up to the programmer, which means it is easy to make mistakes. In a language like Java or Python, memory management is automatic with the help of a garbage collector. This is very safe and straightforward, but also incurs a runtime cost. In Rust, memory management is rather unique. Aside from simple types that implement [`Copy`](https://doc.rust-lang.org/std/marker/trait.Copy.html) (e.g., [primitives](https://doc.rust-lang.org/rust-by-example/primitives.html) like integers, floats, booleans and `char`), every value is _owned_ by some variable called the _owner_. Ownership can be transferred from one variable to another and sometimes a value can be _borrowed_. Once the _owner_ variable goes out of scope, the value is _dropped_, which means that any memory it allocated can be freed safely. Because the method does not own the `self` and `device` variables, the values the references point to will not be dropped when the reference stops being used (i.e., the scope of the method). For more information on references and borrowing, be sure to read the [corresponding chapter](https://doc.rust-lang.org/book/ch04-02-references-and-borrowing.html) in the Rust Book.

When creating a custom neural network module, it is often a good idea to create a config alongside the model struct. This allows you to define default values for your network, thanks to the `Config` attribute. The benefit of this attribute is that it makes the configuration serializable, enabling you to painlessly save your model hyperparameters, enhancing your experimentation process. Note that a constructor will automatically be generated for your configuration, which will take in as input values the parameters which do not have default values: `let config = ModelConfig::new(num_classes, hidden_size);`. The default values can be overridden easily with builder-like methods: (e.g `config.with_dropout(0.2);`) The first implementation block is related to the initialization method. As we can see, all fields are set using the configuration of the corresponding neural network's underlying module. In this specific case, we have chosen to expand the tensor channels from 1 to 8 with the first layer, then from 8 to 16 with the second layer, using a kernel size of 3 on all dimensions. We also use the adaptive average pooling module to reduce the dimensionality of the images to an 8 by 8 matrix, which we will flatten in the forward pass to have a 1024 (16 * 8 * 8) resulting tensor. Now let's see how the forward pass is defined. ```rust , ignore # use burn::{ # nn::{ # conv::{Conv2d, Conv2dConfig}, # pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, # Dropout, DropoutConfig, Linear, LinearConfig, Relu, # }, # prelude::*, # }; # # #[derive(Module, Debug)] # pub struct Model { # conv1: Conv2d, # conv2: Conv2d, # pool: AdaptiveAvgPool2d, # dropout: Dropout, # linear1: Linear, # linear2: Linear, # activation: Relu, # } # # #[derive(Config, Debug)] # pub struct ModelConfig { # num_classes: usize, # hidden_size: usize, # #[config(default = "0.5")] # dropout: f64, # } # # impl ModelConfig { # /// Returns the initialized model. # pub fn init(&self, device: &B::Device) -> Model { # Model { # conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), # conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), # pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), # activation: Relu::new(), # linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), # linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), # dropout: DropoutConfig::new(self.dropout).init(), # } # } # } # impl Model { /// # Shapes /// - Images [batch_size, height, width] /// - Output [batch_size, num_classes] pub fn forward(&self, images: Tensor) -> Tensor { let [batch_size, height, width] = images.dims(); // Create a channel at the second dimension. let x = images.reshape([batch_size, 1, height, width]); let x = self.conv1.forward(x); // [batch_size, 8, _, _] let x = self.dropout.forward(x); let x = self.conv2.forward(x); // [batch_size, 16, _, _] let x = self.dropout.forward(x); let x = self.activation.forward(x); let x = self.pool.forward(x); // [batch_size, 16, 8, 8] let x = x.reshape([batch_size, 16 * 8 * 8]); let x = self.linear1.forward(x); let x = self.dropout.forward(x); let x = self.activation.forward(x); self.linear2.forward(x) // [batch_size, num_classes] } } ``` For former PyTorch users, this might feel very intuitive, as each module is directly incorporated into the code using an eager API. Note that no abstraction is imposed for the forward method. You are free to define multiple forward functions with the names of your liking. Most of the neural network modules already built with Burn use the `forward` nomenclature, simply because it is standard in the field. Similar to neural network modules, the [`Tensor`](../building-blocks/tensor.md) struct given as a parameter also takes the Backend trait as a generic argument, alongside its dimensionality. Even if it is not used in this specific example, it is possible to add the kind of the tensor as a third generic argument. For example, a 3-dimensional Tensor of different data types(float, int, bool) would be defined as following: ```rust , ignore Tensor // Float tensor (default) Tensor // Float tensor (explicit) Tensor // Int tensor Tensor // Bool tensor ``` Note that the specific element type, such as `f16`, `f32` and the likes, will be defined later with the backend. ================================================ FILE: burn-book/src/basic-workflow/training.md ================================================ # Training We are now ready to write the necessary code to train our model on the MNIST dataset. We shall define the code for this training section in the file: `src/training.rs`. Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose responsibility is to apply an optimizer to the model. The output struct is used for all metrics calculated during the training. Therefore it should include all the necessary information to calculate any metric that you want for a task. Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement the necessary trait to be used with metrics. It is possible to create your own item, but it is beyond the scope of this guide. Since the MNIST task is a classification problem, we will use the `ClassificationOutput` type. ```rust , ignore # use crate::{ # data::{MnistBatch, MnistBatcher}, # model::{Model, ModelConfig}, # }; # use burn::{ # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, # nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, # prelude::*, # record::CompactRecorder, # tensor::backend::AutodiffBackend, # train::{ # ClassificationOutput, Learner, SupervisedTraining, TrainOutput, TrainStep, InferenceStep, # metric::{AccuracyMetric, LossMetric}, # }, # }; # impl Model { pub fn forward_classification( &self, images: Tensor, targets: Tensor, ) -> ClassificationOutput { let output = self.forward(images); let loss = CrossEntropyLossConfig::new() .init(&output.device()) .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } } ``` As evident from the preceding code block, we employ the cross-entropy loss module for loss calculation, without the inclusion of any padding token. We then return the classification output containing the loss, the output tensor with all logits and the targets. Please take note that tensor operations receive owned tensors as input. For reusing a tensor multiple times, you need to use the `clone()` function. There's no need to worry; this process won't involve actual copying of the tensor data. Instead, it will simply indicate that the tensor is employed in multiple instances, implying that certain operations won't be performed in place. In summary, our API has been designed with owned tensors to optimize performance. Moving forward, we will proceed with the implementation of both the training and validation steps for our model. ```rust , ignore # use crate::{ # data::{MnistBatch, MnistBatcher}, # model::{Model, ModelConfig}, # }; # use burn::{ # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, # nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, # prelude::*, # record::CompactRecorder, # tensor::backend::AutodiffBackend, # train::{ # ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep, # metric::{AccuracyMetric, LossMetric}, # }, # }; # # impl Model { # pub fn forward_classification( # &self, # images: Tensor, # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); # let loss = CrossEntropyLossConfig::new() # .init(&output.device()) # .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # } # } impl TrainStep for Model { type Input = MnistBatch; type Output = ClassificationOutput; fn step(&self, batch: MnistBatch) -> TrainOutput> { let item = self.forward_classification(batch.images, batch.targets); TrainOutput::new(self, item.loss.backward(), item) } } impl InferenceStep for Model { type Input = MnistBatch; type Output = ClassificationOutput; fn step(&self, batch: MnistBatch) -> ClassificationOutput { self.forward_classification(batch.images, batch.targets) } } ``` Here we define the input and output types as generic arguments in the `TrainStep` and `InferenceStep`. We will call them `MnistBatch` and `ClassificationOutput`. In the training step, the computation of gradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather returned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a parameter can be obtained with the grad function: `let grad = tensor.grad(&gradients);`. Although it is not necessary when using the learner struct and the optimizers, it can prove to be quite useful when debugging or writing custom training loops. One of the differences between the training and the validation steps is that the former requires the backend to implement `AutodiffBackend` and not just `Backend`. Otherwise, the `backward` function is not available, as the backend does not support autodiff. We will see later how to create a backend with autodiff support.
🦀 Generic Type Constraints in Method Definitions Although generic data types, trait and trait bounds were already introduced in previous sections of this guide, the previous code snippet might be a lot to take in at first. In the example above, we implement the `TrainStep` and `InferenceStep` trait for our `Model` struct, which is generic over the `Backend` trait as has been covered before. These traits are provided by `burn::train` and define a common `step` method that should be implemented for all structs. Since the trait is generic over the input and output types, the trait implementation must specify the concrete types used. This is where the additional type constraints appear `, ClassificationOutput>`. As we saw previously, the concrete input type for the batch is `MnistBatch`, and the output of the forward pass is `ClassificationOutput`. The `step` method signature matches the concrete input and output types. For more details specific to constraints on generic types when defining methods, take a look at [this section](https://doc.rust-lang.org/book/ch10-01-syntax.html#in-method-definitions) of the Rust Book.

Let us move on to establishing the practical training configuration. ```rust , ignore # use crate::{ # data::{MnistBatch, MnistBatcher}, # model::{Model, ModelConfig}, # }; # use burn::{ # data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, # nn::loss::CrossEntropyLossConfig, # optim::AdamConfig, # prelude::*, # record::CompactRecorder, # tensor::backend::AutodiffBackend, # train::{ # ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep, # metric::{AccuracyMetric, LossMetric}, # }, # }; # # impl Model { # pub fn forward_classification( # &self, # images: Tensor, # targets: Tensor, # ) -> ClassificationOutput { # let output = self.forward(images); # let loss = CrossEntropyLossConfig::new() # .init(&output.device()) # .forward(output.clone(), targets.clone()); # # ClassificationOutput::new(loss, output, targets) # } # } # impl TrainStep for Model { # type Input = MnistBatch; # type Output = ClassificationOutput; # # fn step(&self, batch: MnistBatch) -> TrainOutput> { # let item = self.forward_classification(batch.images, batch.targets); # # TrainOutput::new(self, item.loss.backward(), item) # } # } # # impl InferenceStep for Model { # type Input = MnistBatch; # type Output = ClassificationOutput; # # fn step(&self, batch: MnistBatch) -> ClassificationOutput { # self.forward_classification(batch.images, batch.targets) # } # } # #[derive(Config, Debug)] pub struct TrainingConfig { pub model: ModelConfig, pub optimizer: AdamConfig, #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, #[config(default = 1.0e-4)] pub learning_rate: f64, } fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts before to get an accurate learner summary std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { create_artifact_dir(artifact_dir); config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); B::seed(&device, config.seed); let batcher = MnistBatcher::default(); let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::test()); let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metrics((AccuracyMetric::new(), LossMetric::new())) .with_file_checkpointer(CompactRecorder::new()) .num_epochs(config.num_epochs) .summary(); let model = config.model.init::(&device); let result = training.launch(Learner::new( model, config.optimizer.init(), config.learning_rate, )); result .model .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) .expect("Trained model should be saved successfully"); } ``` It is a good practice to use the `Config` derive to create the experiment configuration. In the `train` function, the first thing we are doing is making sure the `artifact_dir` exists, using the standard rust library for file manipulation. All checkpoints, logging and metrics will be stored under this directory. We initialize the dataloaders using the previously created batcher. Since no automatic differentiation is needed during the validation phase, the `training.launch(...)` method defines the necessary backend bounds on the data loader for `B::InnerBackend` (see [Backend](./backend.md)). The autodiff capabilities are available through a type system, making it nearly impossible to forget to deactivate gradient calculation. Next, we create a supervised training runner with the dataloaders for training and validation and we register the accuracy and loss metric on both training and validation steps. We also configure the checkpointer using the `CompactRecorder` to indicate how weights should be stored. This struct implements the `Recorder` trait, which makes it capable of saving records for persistency. For the sake of simplicity in this example, we employ the test set as the validation set; however, we do not recommend this practice for actual usage. We create the learner containing the model, the optimizer and the learning rate. Notably, the third argument of the learner's `new` function should actually be a learning rate _scheduler_. When provided with a float as in our example, it is automatically transformed into a _constant_ learning rate scheduler. The learning rate is not part of the optimizer config as it is often done in other frameworks, but rather passed as a parameter when executing the optimizer step. This avoids having to mutate the state of the optimizer and is therefore more functional. It makes no difference when using the learner struct, but it will be an essential nuance to grasp if you implement your own training loop. Once the learner and supervised training instance are created, we can call `training.launch` and provide the learner. Finally, the trained model is returned by the `launch` method. The trained weights are then saved using the `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for floats and `i16` for integers. Other recorders are available, offering support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can load recorded data of any kind. ================================================ FILE: burn-book/src/building-blocks/README.md ================================================ # Building Blocks In this section, we'll guide you through the core elements that make up Burn. We'll walk you through the key components that serve as the building blocks of the framework and your future projects. As you explore Burn, you might notice that we occasionally draw comparisons to PyTorch. We believe it can provide a smoother learning curve and help you grasp the nuances more effectively. ================================================ FILE: burn-book/src/building-blocks/autodiff.md ================================================ # Autodiff Burn's tensor also supports auto-differentiation, which is an essential part of any deep learning framework. We introduced the `Backend` trait in the [previous section](./backend.md), but Burn also has another trait for autodiff: `AutodiffBackend`. However, not all tensors support auto-differentiation; you need a backend that implements both the `Backend` and `AutodiffBackend` traits. Fortunately, you can add auto-differentiation capabilities to any backend using a backend decorator: `type MyAutodiffBackend = Autodiff`. This decorator implements both the `AutodiffBackend` and `Backend` traits by maintaining a dynamic computational graph and utilizing the inner backend to execute tensor operations. The `AutodiffBackend` trait adds new operations on float tensors that can't be called otherwise. It also provides a new associated type, `B::Gradients`, where each calculated gradient resides. ```rust, ignore fn calculate_gradients(tensor: Tensor) -> B::Gradients { let mut gradients = tensor.clone().backward(); let tensor_grad = tensor.grad(&gradients); // get let tensor_grad = tensor.grad_remove(&mut gradients); // pop gradients } ``` Note that some functions will always be available even if the backend doesn't implement the `AutodiffBackend` trait. In such cases, those functions will do nothing. | Burn API | PyTorch Equivalent | | --------------------------------------- | ----------------------------- | | `tensor.detach()` | `tensor.detach()` | | `tensor.require_grad()` | `tensor.requires_grad()` | | `tensor.is_require_grad()` | `tensor.requires_grad` | | `tensor.set_require_grad(require_grad)` | `tensor.requires_grad(False)` | However, you're unlikely to make any mistakes since you can't call `backward` on a tensor that is on a backend that doesn't implement `AutodiffBackend`. Additionally, you can't retrieve the gradient of a tensor without an autodiff backend. ## Difference with PyTorch The way Burn handles gradients is different from PyTorch. First, when calling `backward`, each parameter doesn't have its `grad` field updated. Instead, the backward pass returns all the calculated gradients in a container. This approach offers numerous benefits, such as the ability to easily send gradients to other threads. You can also retrieve the gradient for a specific parameter using the `grad` method on a tensor. Since this method takes the gradients as input, it's hard to forget to call `backward` beforehand. Note that sometimes, using `grad_remove` can improve performance by allowing inplace operations. In PyTorch, when you don't need gradients for inference or validation, you typically need to scope your code using a block. ```python # Inference mode torch.inference(): # your code ... # Or no grad torch.no_grad(): # your code ... ``` With Burn, you don't need to wrap the backend with the `Autodiff` for inference, and you can call `inner()` to obtain the inner tensor, which is useful for validation. ```rust, ignore /// Use `B: AutodiffBackend` fn example_validation(tensor: Tensor) { let inner_tensor: Tensor = tensor.inner(); let _ = inner_tensor + 5; } /// Use `B: Backend` fn example_inference(tensor: Tensor) { let _ = tensor + 5; ... } ``` **Gradients with Optimizers** We've seen how gradients can be used with tensors, but the process is a bit different when working with optimizers from `burn-core`. To work with the `Module` trait, a translation step is required to link tensor parameters with their gradients. This step is necessary to easily support gradient accumulation and training on multiple devices, where each module can be forked and run on different devices in parallel. We'll explore deeper into this topic in the [Module](./module.md) section. ================================================ FILE: burn-book/src/building-blocks/backend.md ================================================ # Backend Nearly everything in Burn is based on the `Backend` trait, which enables you to run tensor operations using different implementations without having to modify your code. While a backend may not necessarily have autodiff capabilities, the `AutodiffBackend` trait specifies when autodiff is needed. This trait not only abstracts operations but also tensor, device, and element types, providing each backend the flexibility they need. It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs. However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code. Users are not expected to directly use the backend trait methods, as it is primarily designed with backend developers in mind rather than Burn users. Therefore, most Burn userland APIs are generic across backends. This approach helps users discover the API more organically with proper autocomplete and documentation. ================================================ FILE: burn-book/src/building-blocks/config.md ================================================ # Config When writing scientific code, you normally have a lot of values that are set, and Deep Learning is no exception. Python has the possibility to define default parameters for functions, which helps improve the developer experience. However, this has the downside of potentially breaking your code when upgrading to a new version, as the default values might change without your knowledge, making debugging very challenging. With that in mind, we came up with the Config system. It's a simple Rust derive that you can apply to your types, allowing you to define default values with ease. Additionally, all configs can be serialized, reducing potential bugs when upgrading versions and improving reproducibility. ```rust , ignore use burn::config::Config; #[derive(Config)] pub struct MyModuleConfig { d_model: usize, d_ff: usize, #[config(default = 0.1)] dropout: f64, } ``` The derive also adds useful `with_` methods for every attribute of your config, similar to a builder pattern, along with a `save` method. ```rust, ignore fn main() { let config = MyModuleConfig::new(512, 2048); println!("{}", config.d_model); // 512 println!("{}", config.d_ff); // 2048 println!("{}", config.dropout); // 0.1 let config = config.with_dropout(0.2); println!("{}", config.dropout); // 0.2 config.save("config.json").unwrap(); } ``` ## Good practices By using the config type it is easy to create new module instances. The initialization method should be implemented on the config type with the device as argument. ```rust, ignore impl MyModuleConfig { /// Create a module on the given device. pub fn init(&self, device: &B::Device) -> MyModule { MyModule { linear: LinearConfig::new(self.d_model, self.d_ff).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } } ``` Then we could add this line to the above `main`: ```rust, ignore use burn::backend::Wgpu; let device = Default::default(); let my_module = config.init::(&device); ``` ================================================ FILE: burn-book/src/building-blocks/dataset.md ================================================ # Dataset At its core, a dataset is a collection of data typically related to a specific analysis or processing task. The data modality can vary depending on the task, but most datasets primarily consist of images, texts, audio or videos. This data source represents an integral part of machine learning to successfully train a model. Thus, it is essential to provide a convenient and performant API to handle your data. Since this process varies wildly from one problem to another, it is defined as a trait that should be implemented on your type. The dataset trait is quite similar to the dataset abstract class in PyTorch: ```rust, ignore pub trait Dataset: Send + Sync { fn get(&self, index: usize) -> Option; fn len(&self) -> usize; } ``` The dataset trait assumes a fixed-length set of items that can be randomly accessed in constant time. This is a major difference from datasets that use Apache Arrow underneath to improve streaming performance. Datasets in Burn don't assume _how_ they are going to be accessed; it's just a collection of items. However, you can compose multiple dataset transformations to lazily obtain what you want with zero pre-processing, so that your training can start instantly! ## Transformation Transformations in Burn are all lazy and modify one or multiple input datasets. The goal of these transformations is to provide you with the necessary tools so that you can model complex data distributions. | Transformation | Description | | ------------------ | ------------------------------------------------------------------------------------------------------------------------ | | `SamplerDataset` | Samples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size. | | `SelectionDataset` | Selects a subset of items by index from a dataset. Can be randomly shuffled; can be re-shuffled. | | `ShuffledDataset` | Shuffles a wrapped dataset; This is a thin wrapper around `SelectionDataset`. | | `PartialDataset` | Returns a view of the input dataset with a specified range. | | `MapperDataset` | Computes a transformation lazily on the input dataset. | | `ComposedDataset` | Composes multiple datasets together to create a larger one without copying any data. | | `WindowsDataset` | Dataset designed to work with overlapping windows of data extracted from an input dataset. | Let us look at the basic usages of each dataset transform and how they can be composed together. These transforms are lazy by default except when specified, reducing the need for unnecessary intermediate allocations and improving performance. The full documentation of each transform can be found at the [API reference](https://burn.dev/docs/burn/data/dataset/transform/index.html). - **SamplerDataset**: This transform can be used to sample items from a dataset with (default) or without replacement. Transform is initialized with a sampling size which can be bigger or smaller than the input dataset size. This is particularly useful in cases where we want to checkpoint larger datasets more often during training and smaller datasets less often as the size of an epoch is now controlled by the sampling size. Sample usage: ```rust, ignore type DbPedia = SqliteDataset; let dataset: DbPedia = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset("train"). .unwrap(); let dataset = SamplerDataset::new(dataset, 10000); ``` - **SelectionDataset**: This transform can be used to select a subset of items from a dataset by index. It can be initialized with a list of indices to select from the input dataset. This is particularly useful when you want to create a smaller dataset from a larger one, for example, to create a validation set from a training set. The `SelectionDataset` can also be initialized with a random seed to shuffle the indices before selection. This is useful when you want to randomly select a subset of items from the dataset. Base dataset items may be included more than once in the selection. ```rust, ignore let explicit = SelectionDataset::from_indices_checked(dataset.clone(), vec![0, 1, 2, 0]); let shuffled = SelectionDataset::new_shuffled(dataset.clone(), &mut rng); let shuffled = SelectionDataset::new_shuffled(dataset.clone(), 42); let mut mutable = SelectionDataset::new_select_all(dataset.clone(), vec![0, 1, 2, 0]); mutable.shuffle(42); mutable.shuffle(&mut rng); ``` - **ShuffledDataset**: This transform can be used to shuffle the items of a dataset. Particularly useful before splitting the raw dataset into train/test splits. Can be initialized with a seed to ensure reproducibility. The `ShuffledDataset` is a thin wrapper around the `SelectionDataset`. ```rust, ignore let dataset = ShuffledDataset::new(dataset, &mut rng); let dataset = ShuffledDataset::new(dataset, 42); ``` - **PartialDataset**: This transform is useful to return a view of the dataset with specified start and end indices. Used to create train/val/test splits. In the example below, we show how to chain ShuffledDataset and PartialDataset to create splits. ```rust, ignore // define chained dataset type here for brevity type PartialData = PartialDataset>; let len = dataset.len(); let split = "train"; // or "val"/"test" let data_split = match split { "train" => PartialData::new(dataset, 0, len * 8 / 10), // Get first 80% dataset "test" => PartialData::new(dataset, len * 8 / 10, len), // Take remaining 20% _ => panic!("Invalid split type"), // Handle unexpected split types }; ``` - **MapperDataset**: This transform is useful to apply a transformation on each of the items of a dataset. Particularly useful for normalization of image data when channel means are known. - **ComposedDataset**: This transform is useful to compose multiple datasets downloaded from multiple sources (say different HuggingfaceDatasetLoader sources) into a single bigger dataset which can be sampled from one source. - **WindowsDataset**: This transform is useful to create overlapping windows of a dataset. Particularly useful for sequential Time series Data, for example when working with an LSTM. ## Storage There are multiple dataset storage options available for you to choose from. The choice of the dataset to use should be based on the dataset's size as well as its intended purpose. | Storage | Description | | ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------- | | `InMemDataset` | In-memory dataset that uses a vector to store items. Well-suited for smaller datasets. | | `SqliteDataset` | Dataset that uses [SQLite](https://www.sqlite.org/) to index items that can be saved in a simple SQL database file. Well-suited for larger datasets. | | `DataframeDataset` | Dataset that uses [Polars](https://www.pola.rs/) dataframe to store and manage data. Well-suited for efficient data manipulation and analysis. | ## Sources For now, there are only a couple of dataset sources available with Burn, but more to come! ### Hugging Face You can easily import any Hugging Face dataset with Burn. We use SQLite as the storage to avoid downloading the model each time or starting a Python process. You need to know the format of each item in the dataset beforehand. Here's an example with the [dbpedia dataset](https://huggingface.co/datasets/dbpedia_14). ```rust, ignore #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { pub title: String, pub content: String, pub label: usize, } fn main() { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset("train") // The training split. .unwrap(); } ``` We see that items must derive `serde::Serialize`, `serde::Deserialize`, `Clone`, and `Debug`, but those are the only requirements.
The `HuggingfaceDatasetLoader` relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download datasets. This is a Python library, so you must have an existing Python installation to use this loader.
### Images `ImageFolderDataset` is a generic vision dataset used to load images from disk. It is currently available for multi-class and multi-label classification tasks as well as semantic segmentation and object detection tasks. ```rust, ignore // Create an image classification dataset from the root folder, // where images for each class are stored in their respective folder. // // For example: // root/dog/dog1.png // root/dog/dog2.png // ... // root/cat/cat1.png let dataset = ImageFolderDataset::new_classification("path/to/dataset/root").unwrap(); ``` ```rust, ignore // Create a multi-label image classification dataset from a list of items, // where each item is a tuple `(image path, labels)`, and a list of classes // in the dataset. // // For example: let items = vec![ ("root/dog/dog1.png", vec!["animal".to_string(), "dog".to_string()]), ("root/cat/cat1.png", vec!["animal".to_string(), "cat".to_string()]), ]; let dataset = ImageFolderDataset::new_multilabel_classification_with_items( items, &["animal", "cat", "dog"], ) .unwrap(); ``` ```rust, ignore // Create a segmentation mask dataset from a list of items, where each // item is a tuple `(image path, mask path)` and a list of classes // corresponding to the integer values in the mask. let items = vec![ ( "path/to/images/image0.png", "path/to/annotations/mask0.png", ), ( "path/to/images/image1.png", "path/to/annotations/mask1.png", ), ( "path/to/images/image2.png", "path/to/annotations/mask2.png", ), ]; let dataset = ImageFolderDataset::new_segmentation_with_items( items, &[ "cat", // 0 "dog", // 1 "background", // 2 ], ) .unwrap(); ``` ```rust, ignore // Create an object detection dataset from a COCO dataset. Currently only // the import of object detection data (bounding boxes) is supported. // // COCO offers separate annotation and image archives for training and // validation, paths to the unpacked files need to be passed as parameters: let dataset = ImageFolderDataset::new_coco_detection( "/path/to/coco/instances_train2017.json", "/path/to/coco/images/train2017" ) .unwrap(); ``` ### Comma-Separated Values (CSV) Loading records from a simple CSV file in-memory is simple with the `InMemDataset`: ```rust, ignore // Build dataset from csv with tab ('\t') delimiter. // The reader can be configured for your particular file. let mut rdr = csv::ReaderBuilder::new(); let rdr = rdr.delimiter(b'\t'); let dataset = InMemDataset::from_csv("path/to/csv", rdr).unwrap(); ``` Note that this requires the `csv` crate. **What about streaming datasets?** There is no streaming dataset API with Burn, and this is by design! The learner struct will iterate multiple times over the dataset and only checkpoint when done. You can consider the length of the dataset as the number of iterations before performing checkpointing and running the validation. There is nothing stopping you from returning different items even when called with the same `index` multiple times. ## How Is The Dataset Used? During training, the dataset is used to access the data samples and, for most use cases in supervised learning, their corresponding ground-truth labels. Remember that the `Dataset` trait implementation is responsible to retrieve the data from its source, usually some sort of data storage. At this point, the dataset could be naively iterated over to provide the model a single sample to process at a time, but this is not very efficient. Instead, we collect multiple samples that the model can process as a _batch_ to fully leverage modern hardware (e.g., GPUs - which have impressive parallel processing capabilities). Since each data sample in the dataset can be collected independently, the data loading is typically done in parallel to further speed things up. In this case, we parallelize the data loading using a multi-threaded `BatchDataLoader` to obtain a sequence of items from the `Dataset` implementation. Finally, the sequence of items is combined into a batched tensor that can be used as input to a model with the `Batcher` trait implementation. Other tensor operations can be performed during this step to prepare the batch data, as is done [in the basic workflow guide](../basic-workflow/data.md). The process is illustrated in the figure below for the MNIST dataset. Burn Data Loading Pipeline Although we have conveniently implemented the [`MnistDataset`](https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/vision/mnist.rs) used in the guide, we'll go over its implementation to demonstrate how the `Dataset` and `Batcher` traits are used. The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits has a training set of 60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a \\(28 \times 28\\) pixels black-and-white image (stored as raw bytes) with its corresponding label (a digit between \\(0\\) and \\(9\\)). This is defined by the `MnistItemRaw` struct. ```rust, ignore # #[derive(Deserialize, Debug, Clone)] struct MnistItemRaw { pub image_bytes: Vec, pub label: u8, } ``` With single-channel images of such low resolution, the entire training and test sets can be loaded in memory at once. Therefore, we leverage the already existing `InMemDataset` to retrieve the raw images and labels data. At this point, the image data is still just a bunch of bytes, but we want to retrieve the _structured_ image data in its intended form. For that, we can define a `MapperDataset` that transforms the raw image bytes to a 2D array image (which we convert to float while we're at it). ```rust, ignore const WIDTH: usize = 28; const HEIGHT: usize = 28; # /// MNIST item. # #[derive(Deserialize, Serialize, Debug, Clone)] pub struct MnistItem { /// Image as a 2D array of floats. pub image: [[f32; WIDTH]; HEIGHT], /// Label of the image. pub label: u8, } struct BytesToImage; impl Mapper for BytesToImage { /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). fn map(&self, item: &MnistItemRaw) -> MnistItem { // Ensure the image dimensions are correct. debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT); // Convert the image to a 2D array of floats. let mut image_array = [[0f32; WIDTH]; HEIGHT]; for (i, pixel) in item.image_bytes.iter().enumerate() { let x = i % WIDTH; let y = i / HEIGHT; image_array[y][x] = *pixel as f32; } MnistItem { image: image_array, label: item.label, } } } type MappedDataset = MapperDataset, BytesToImage, MnistItemRaw>; # /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000 # /// images per class. There are 60,000 training images and 10,000 test images. # /// # /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist). pub struct MnistDataset { dataset: MappedDataset, } ``` To construct the `MnistDataset`, the data source must be parsed into the expected `MappedDataset` type. Since both the train and test sets use the same file format, we can separate the functionality to load the `train()` and `test()` dataset. ```rust, ignore impl MnistDataset { /// Creates a new train dataset. pub fn train() -> Self { Self::new("train") } /// Creates a new test dataset. pub fn test() -> Self { Self::new("test") } fn new(split: &str) -> Self { // Download dataset let root = MnistDataset::download(split); // Parse data as vector of images bytes and vector of labels let images: Vec> = MnistDataset::read_images(&root, split); let labels: Vec = MnistDataset::read_labels(&root, split); // Collect as vector of MnistItemRaw let items: Vec<_> = images .into_iter() .zip(labels) .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label }) .collect(); // Create the MapperDataset for InMemDataset to transform // items (MnistItemRaw -> MnistItem) let dataset = InMemDataset::new(items); let dataset = MapperDataset::new(dataset, BytesToImage); Self { dataset } } # /// Download the MNIST dataset files from the web. # /// Panics if the download cannot be completed or the content of the file cannot be written to disk. # fn download(split: &str) -> PathBuf { # // Dataset files are stored in the burn-dataset cache directory # let cache_dir = dirs::cache_dir() # .expect("Could not get cache directory") # .join("burn-dataset"); # let split_dir = cache_dir.join("mnist").join(split); # # if !split_dir.exists() { # create_dir_all(&split_dir).expect("Failed to create base directory"); # } # # // Download split files # match split { # "train" => { # MnistDataset::download_file(TRAIN_IMAGES, &split_dir); # MnistDataset::download_file(TRAIN_LABELS, &split_dir); # } # "test" => { # MnistDataset::download_file(TEST_IMAGES, &split_dir); # MnistDataset::download_file(TEST_LABELS, &split_dir); # } # _ => panic!("Invalid split specified {}", split), # }; # # split_dir # } # # /// Download a file from the MNIST dataset URL to the destination directory. # /// File download progress is reported with the help of a [progress bar](indicatif). # fn download_file>(name: &str, dest_dir: &P) -> PathBuf { # // Output file name # let file_name = dest_dir.as_ref().join(name); # # if !file_name.exists() { # // Download gzip file # let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name); # # // Create file to write the downloaded content to # let mut output_file = File::create(&file_name).unwrap(); # # // Decode gzip file content and write to disk # let mut gz_buffer = GzDecoder::new(&bytes[..]); # std::io::copy(&mut gz_buffer, &mut output_file).unwrap(); # } # # file_name # } # # /// Read images at the provided path for the specified split. # /// Each image is a vector of bytes. # fn read_images>(root: &P, split: &str) -> Vec> { # let file_name = if split == "train" { # TRAIN_IMAGES # } else { # TEST_IMAGES # }; # let file_name = root.as_ref().join(file_name); # # // Read number of images from 16-byte header metadata # let mut f = File::open(file_name).unwrap(); # let mut buf = [0u8; 4]; # let _ = f.seek(SeekFrom::Start(4)).unwrap(); # f.read_exact(&mut buf) # .expect("Should be able to read image file header"); # let size = u32::from_be_bytes(buf); # # let mut buf_images: Vec = vec![0u8; WIDTH * HEIGHT * (size as usize)]; # let _ = f.seek(SeekFrom::Start(16)).unwrap(); # f.read_exact(&mut buf_images) # .expect("Should be able to read image file header"); # # buf_images # .chunks(WIDTH * HEIGHT) # .map(|chunk| chunk.to_vec()) # .collect() # } # # /// Read labels at the provided path for the specified split. # fn read_labels>(root: &P, split: &str) -> Vec { # let file_name = if split == "train" { # TRAIN_LABELS # } else { # TEST_LABELS # }; # let file_name = root.as_ref().join(file_name); # # // Read number of labels from 8-byte header metadata # let mut f = File::open(file_name).unwrap(); # let mut buf = [0u8; 4]; # let _ = f.seek(SeekFrom::Start(4)).unwrap(); # f.read_exact(&mut buf) # .expect("Should be able to read label file header"); # let size = u32::from_be_bytes(buf); # # let mut buf_labels: Vec = vec![0u8; size as usize]; # let _ = f.seek(SeekFrom::Start(8)).unwrap(); # f.read_exact(&mut buf_labels) # .expect("Should be able to read labels from file"); # # buf_labels # } } ``` Since the `MnistDataset` simply wraps a `MapperDataset` instance with `InMemDataset`, we can easily implement the `Dataset` trait. ```rust, ignore impl Dataset for MnistDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } ``` The only thing missing now is the `Batcher`, which we already went over [in the basic workflow guide](../basic-workflow/data.md). The `Batcher` takes a list of `MnistItem` retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their targets. ================================================ FILE: burn-book/src/building-blocks/learner.md ================================================ # Learner The [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates multiple utilities for training deep learning models. The goal of the crate is to provide users with a well-crafted and flexible training loop, so that projects do not have to write such components from the ground up. Most of the interactions with `burn-train` will be with the `SupervisedTraining` struct, briefly presented in the previous [training section](../basic-workflow/training.md). This struct enables you to configure the training loop, offering support for registering metrics, enabling logging, checkpointing states, using multiple devices, and so on. There are still some assumptions in the current provided APIs, which may make them inappropriate for your learning requirements. Indeed, they assume your model will learn from a training dataset and be validated against another dataset. This is the most common paradigm, allowing users to do both supervised and unsupervised learning as well as fine-tuning. However, for more complex requirements, creating a [custom training loop](../custom-training-loop.md) might be what you need. ## Usage The `SupervisedLearning` struct must be created with the training and validation dataloaders. It provides you with numerous options when it comes to configurations. | Configuration | Description | | ---------------------- | ------------------------------------------------------------------------------ | | Training Metric | Register a training metric | | Validation Metric | Register a validation metric | | Training Metric Plot | Register a training metric with plotting (requires the metric to be numeric) | | Validation Metric Plot | Register a validation metric with plotting (requires the metric to be numeric) | | Metric Logger | Configure the metric loggers (default is saving them to files) | | Renderer | Configure how to render metrics (default is CLI) | | Grad Accumulation | Configure the number of steps before applying gradients | | File Checkpointer | Configure how the model, optimizer and scheduler states are saved | | Num Epochs | Set the number of epochs | | Devices | Set the devices to be used | | Checkpoint | Restart training from a checkpoint | | Application logging | Configure the application logging installer (default is writing to `experiment.log`) | | Training Strategy | Use a custom training strategy, allowing you to use your own training loop with all the capabilities of the `SupervisedTraining` struct | When the training is configured to your liking, you can then move forward to running the training. The `launch` method requires a learner object providing: the model, the optimizer and the learning rate scheduler. Note that the latter can be a simple float if you want it to be constant during training. The `launch` method will start the training and return the trained model once finished. Again, please refer to the [training section](../basic-workflow/training.md) for a relevant code snippet. ## Artifacts When creating a `SupervisedTraining` instance, all the collected data will be saved under the directory provided as the argument to the `new` method. Here is an example of the data layout for a model recorded using the compressed message pack format, with the accuracy and loss metrics registered: ``` ├── experiment.log ├── checkpoint │   ├── model-1.mpk.gz │   ├── optim-1.mpk.gz │   └── scheduler-1.mpk.gz │   ├── model-2.mpk.gz │   ├── optim-2.mpk.gz │   └── scheduler-2.mpk.gz ├── train │   ├── epoch-1 │   │   ├── Accuracy.log │   │   └── Loss.log │   └── epoch-2 │   ├── Accuracy.log │   └── Loss.log └── valid ├── epoch-1 │   ├── Accuracy.log │   └── Loss.log └── epoch-2 ├── Accuracy.log └── Loss.log ``` You can choose to save or synchronize that local directory with a remote file system, if desired. The file checkpointer is capable of automatically deleting old checkpoints according to a specified configuration. ================================================ FILE: burn-book/src/building-blocks/metric.md ================================================ # Metric When working with the learner, you have the option to record metrics that will be monitored throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | | ------------------- | ------------------------------------------------------------------------------------------- | | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | | Recall | Calculate recall in percentage | | FBetaScore | Calculate Fβ score in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CharErrorRate (CER) | Calculate Character Error Rate in percentage | | WordErrorRate (WER) | Calculate Word Error Rate in percentage | | HammingScore | Calculate hamming score (also known as multi-label or label-based accuracy) in percentage | | Perplexity | Calculate perplexity which is a measure of how well a probability model predicts samples | | IterationSpeed | Tracks the training iteration speed, measuring how many iterations are completed per second | | CPU Temperature | Fetch the temperature of CPUs | | CPU Usage | Fetch the CPU utilization | | CPU Memory Usage | Fetch the CPU RAM usage | | Learning Rate | Fetch the current learning rate for each optimizer step | | CUDA | Fetch general CUDA metrics such as utilization | | Vision Metric | Description | | ------------- | ---------------------------------------------------------------------------------------------------- | | Dice | Computes the Dice-Sorenson coefficient (DSC) for evaluating overlap between binary masks | | DISTS | Computes the Deep Image Structure and Texture Similarity (DISTS) metric for image quality assessment | | LPIPS | Computes the Learned Perceptual Image Patch Similarity (LPIPS) for image quality assessment | | MS-SSIM | Computes the Multi-scale Structural Similarity index measure (MS-SSIM) for image quality assessment | | PSNR | Computes the Peak Signal-to-Noise Ratio (PSNR) for image quality assessment | | SSIM | Computes the Structural Similarity index measure (SSIM) for image quality assessment | ## Using Metrics with the Learner In order to use a metric, the output of your training step must implement the `Adaptor` trait from `burn-train::metric` for each metric's corresponding input type. The `Adaptor` trait simply converts your output struct into the input type the metric expects. Burn provides four built-in output structs that cover common tasks. Each one already implements `Adaptor` for a set of metrics, so in many cases you can use them directly without writing any adaptor code yourself. - `ClassificationOutput`: - Use case: Single-label classification - Fields: `loss: Tensor`, `output: Tensor`, `targets: Tensor` - Adapted metrics: Accuracy, TopKAccuracy, Perplexity, Precision\*, Recall\*, FBetaScore\*, AUROC\*, Loss - `MultiLabelClassificationOutput`: - Use case: Multi-label classification - Fields: `loss: Tensor`, `output: Tensor`, `targets: Tensor` - Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, Loss - `RegressionOutput`: - Use case: Regression tasks - Fields: `loss: Tensor`, `output: Tensor`, `targets: Tensor` - Adapted metrics: Loss - `SequenceOutput`: - Use case: Sequence prediction - Fields: `loss: Tensor`, `logits: Tensor`, `predictions: Option>`, `targets: Tensor` - Adapted metrics: Accuracy, TopKAccuracy, Perplexity, CER, WER, Loss \* Precision, Recall, and FBetaScore all use `ConfusionStatsInput` as its input type so these three metrics are automatically (implicitly) adapted since `ConfusionStatsInput` is adapted. If your metric isn't already adapted for the appropriate output struct, you can implement `Adaptor` yourself. For example, here is how `ClassificationOutput` adapts to `AccuracyInput`: ```rust,ignore impl Adaptor> for ClassificationOutput { fn adapt(&self) -> AccuracyInput { AccuracyInput::new(self.output.clone(), self.targets.clone()) } } ``` If your task type is not covered by the built-in output structs, you can create an output struct for your data and then adapt your metric for the output struct: ```rust,ignore #[derive(new)] pub struct ClassificationOutput { /// The loss. pub loss: Tensor, /// The output. pub output: Tensor, /// The targets. pub targets: Tensor, } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> AccuracyInput { AccuracyInput::new(self.output.clone(), self.targets.clone()) } } ``` You can also open an issue on the [GitHub repository](https://github.com/tracel-ai/burn) when your task type is not covered by the built-in output structs. However, since creating an output struct for your data is simple, it is recommended to try creating your own output struct first. # Custom Metric Generating your own custom metrics is done by implementing the `Metric` trait. ```rust , ignore /// Metric trait. /// /// # Notes /// /// Implementations should define their own input type only used by the metric. /// This is important since some conflict may happen when the model output is adapted for each /// metric's input type. pub trait Metric: Send + Sync + Clone { /// The input type of the metric. type Input; /// The parameterized name of the metric. /// /// This should be unique, so avoid using short generic names, prefer using the long name. /// /// For a metric that can exist at different parameters (e.g., top-k accuracy for different /// values of k), the name should be unique for each instance. fn name(&self) -> MetricName; /// Update the metric state and returns the current metric entry. fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry; /// Clear the metric state. fn clear(&mut self); } ``` As an example, let's see how the loss metric is implemented. ```rust, ignore /// The loss metric. #[derive(Clone)] pub struct LossMetric { name: Arc, state: NumericMetricState, _b: B, } /// The [loss metric](LossMetric) input type. #[derive(new)] pub struct LossInput { tensor: Tensor, } impl Default for LossMetric { fn default() -> Self { Self::new() } } impl LossMetric { /// Create the metric. pub fn new() -> Self { Self { name: Arc::new("Loss".to_string()), state: NumericMetricState::default(), _b: Default::default(), } } } impl Metric for LossMetric { type Input = LossInput; fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let [batch_size] = loss.tensor.dims(); let loss = loss .tensor .clone() .mean() .into_data() .iter::() .next() .unwrap(); self.state.update( loss, batch_size, FormatOptions::new(self.name()).precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: false, } .into() } } ``` When the metric you are implementing is numeric in nature, you may want to also implement the `Numeric` trait. This will allow your metric to be plotted. ```rust, ignore impl Numeric for LossMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ``` ================================================ FILE: burn-book/src/building-blocks/module.md ================================================ # Module The `Module` derive allows you to create your own neural network modules, similar to PyTorch. The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared. ```rust, ignore use burn::module::Module; use burn::tensor::backend::Backend; #[derive(Module, Debug)] pub struct PositionWiseFeedForward { linear_inner: Linear, linear_outer: Linear, dropout: Dropout, gelu: Gelu, } impl PositionWiseFeedForward { /// Normal method added to a struct. pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input); let x = self.gelu.forward(x); let x = self.dropout.forward(x); self.linear_outer.forward(x) } } ``` Note that all fields declared in the struct must also implement the `Module` trait. ## Tensor If you want to create your own module that contains tensors, and not just other modules defined with the `Module` derive, you need to be careful to achieve the behavior you want. - `Param>`: If you want the tensor to be included as a parameter of your modules, you need to wrap the tensor in a `Param` struct. This will create an ID that will be used to identify this parameter. This is essential when performing module optimization and when saving states such as optimizer and module checkpoints. Note that a module's record only contains parameters. - `Param>.set_require_grad(false)`: If you want the tensor to be included as a parameter of your modules, and therefore saved with the module's weights, but you don't want it to be updated by the optimizer. - `Tensor`: If you want the tensor to act as a constant that can be recreated when instantiating a module. This can be useful when generating sinusoidal embeddings, for example. ## Methods These methods are available for all modules. | Burn API | PyTorch Equivalent | | --------------------------------------- | ---------------------------------------- | | `module.devices()` | N/A | | `module.fork(device)` | Similar to `module.to(device).detach()` | | `module.to_device(device)` | `module.to(device)` | | `module.no_grad()` | `module.require_grad_(False)` | | `module.num_params()` | N/A | | `module.visit(visitor)` | N/A | | `module.map(mapper)` | N/A | | `module.into_record()` | Similar to `state_dict` | | `module.load_record(record)` | Similar to `load_state_dict(state_dict)` | | `module.save_file(file_path, recorder)` | N/A | | `module.load_file(file_path, recorder)` | N/A | Similar to the backend trait, there is also the `AutodiffModule` trait to signify a module with autodiff support. | Burn API | PyTorch Equivalent | | ---------------- | ------------------ | | `module.valid()` | `module.eval()` | ## Visitor & Mapper As mentioned earlier, modules primarily function as parameter containers. Therefore, we naturally offer several ways to perform functions on each parameter. This is distinct from PyTorch, where extending module functionalities is not as straightforward. The `map` and `visitor` methods are quite similar but serve different purposes. Mapping is used for potentially mutable operations where each parameter of a module can be updated to a new value. In Burn, optimizers are essentially just sophisticated module mappers. Visitors, on the other hand, are used when you don't intend to modify the module but need to retrieve specific information from it, such as the number of parameters or a list of devices in use. You can implement your own mapper or visitor by implementing these simple traits: ```rust, ignore /// Module visitor trait. pub trait ModuleVisitor { /// Visit a float tensor in the module. fn visit_float(&mut self, id: ParamId, tensor: &Tensor); /// Visit an int tensor in the module. fn visit_int(&mut self, id: ParamId, tensor: &Tensor); /// Visit a bool tensor in the module. fn visit_bool(&mut self, id: ParamId, tensor: &Tensor); } /// Module mapper trait. pub trait ModuleMapper { /// Map a float tensor in the module. fn map_float(&mut self, id: ParamId, tensor: Tensor) -> Tensor; /// Map an int tensor in the module. fn map_int(&mut self, id: ParamId, tensor: Tensor) -> Tensor; /// Map a bool tensor in the module. fn map_bool(&mut self, id: ParamId, tensor: Tensor) -> Tensor; } ``` Note that the trait doesn't require all methods to be implemented as they are already defined to perform no operation. If you're only interested in float tensors (like the majority of use cases), then you can simply implement `map_float` or `visit_float`. For example, the `ModuleMapper` trait could be implemented to clamp all parameters into the range `[min, max]`. ```rust, ignore /// Clamp parameters into the range `[min, max]`. pub struct Clamp { /// Lower-bound of the range. pub min: f32, /// Upper-bound of the range. pub max: f32, } // Clamp all floating-point parameter tensors between `[min, max]`. impl ModuleMapper for Clamp { fn map_float( &mut self, _id: burn::module::ParamId, tensor: burn::prelude::Tensor, ) -> burn::prelude::Tensor { tensor.clamp(self.min, self.max) } } // Clamp module mapper into the range `[-0.5, 0.5]` let mut clamp = Clamp { min: -0.5, max: 0.5, }; let model = model.map(&mut clamp); ``` If you want to use this during training to constrain your model parameters, make sure that the parameter tensors are still tracked for autodiff. This can be done with a simple adjustment to the implementation. ```rust, ignore impl ModuleMapper for Clamp { fn map_float( &mut self, _id: burn::module::ParamId, tensor: burn::prelude::Tensor, ) -> burn::prelude::Tensor { let is_require_grad = tensor.is_require_grad(); let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); if is_require_grad { tensor = tensor.require_grad(); } tensor } } ``` ## Module Display Burn provides a simple way to display the structure of a module and its configuration at a glance. You can print the module to see its structure, which is useful for debugging and tracking changes across different versions of a module. (See the print output of the [Basic Workflow Model](../basic-workflow/model.md) example.) To customize the display of a module, you can implement the `ModuleDisplay` trait for your module. This will change the default display settings for the module and its children. Note that `ModuleDisplay` is automatically implemented for all modules, but you can override it to customize the display by annotating the module with `#[module(custom_display)]`. ```rust #[derive(Module, Debug)] #[module(custom_display)] pub struct PositionWiseFeedForward { linear_inner: Linear, linear_outer: Linear, dropout: Dropout, gelu: Gelu, } impl ModuleDisplay for PositionWiseFeedForward { /// Custom settings for the display of the module. /// If `None` is returned, the default settings will be used. fn custom_settings(&self) -> Option { DisplaySettings::new() // Will show all attributes (default is false) .with_show_all_attributes(false) // Will show each attribute on a new line (default is true) .with_new_line_after_attribute(true) // Will show the number of parameters (default is true) .with_show_num_parameters(true) // Will indent by 2 spaces (default is 2) .with_indentation_size(2) // Will show the parameter ID (default is false) .with_show_param_id(false) // Convenience method to wrap settings in Some() .optional() } /// Custom content to be displayed. /// If `None` is returned, the default content will be used /// (all attributes of the module) fn custom_content(&self, content: Content) -> Option { content .add("linear_inner", &self.linear_inner) .add("linear_outer", &self.linear_outer) .add("anything", "anything_else") .optional() } } ``` ## Built-in Modules Burn comes with built-in modules that you can use to build your own modules. ### General | Burn API | PyTorch Equivalent | | ----------------- | --------------------------------------------- | | `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | | `Celu` | `nn.CELU` | | `Dropout` | `nn.Dropout` | | `Elu` | `nn.ELU` | | `Embedding` | `nn.Embedding` | | `GaussianNoise` | _No direct equivalent_ | | `Gelu` | `nn.Gelu` | | `Glu` | `nn.Glu` | | `GroupNorm` | `nn.GroupNorm` | | `HardShrink` | `nn.Hardshrink` | | `HardSigmoid` | `nn.Hardsigmoid` | | `HardSwish` | `nn.Hardswish` | | `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. | | `LayerNorm` | `nn.LayerNorm` | | `LeakyRelu` | `nn.LeakyReLU` | | `Linear` | `nn.Linear` | | `Prelu` | `nn.PReLu` | | `Relu` | `nn.ReLU` | | `Selu` | `nn.SELU` | | `Sigmoid` | `nn.Sigmoid` | | `Softplus` | `nn.Softplus` | | `SoftShrink` | `nn.Softshrink` | | `Softsign` | `nn.Softsign` | | `Shrink` | _No direct equivalent_ | | `RmsNorm` | _No direct equivalent_ | | `SwiGlu` | _No direct equivalent_ | | `Tanh` | `nn.Tanh` | | `ThresholdedRelu` | _No direct equivalent_ | ### Convolutions | Burn API | PyTorch Equivalent | | ----------------- | ------------------------------ | | `Conv1d` | `nn.Conv1d` | | `Conv2d` | `nn.Conv2d` | | `Conv3d` | `nn.Conv3d` | | `ConvTranspose1d` | `nn.ConvTranspose1d` | | `ConvTranspose2d` | `nn.ConvTranspose2d` | | `ConvTranspose3d` | `nn.ConvTranspose3d` | | `DeformConv2d` | `torchvision.ops.DeformConv2d` | ### Pooling | Burn API | PyTorch Equivalent | | ------------------- | ---------------------- | | `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` | | `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` | | `AvgPool1d` | `nn.AvgPool1d` | | `AvgPool2d` | `nn.AvgPool2d` | | `MaxPool1d` | `nn.MaxPool1d` | | `MaxPool2d` | `nn.MaxPool2d` | ### Interpolation | Burn API | PyTorch Equivalent | | --------------- | ------------------ | | `Interpolate1d` | `nn.Upsample` | | `Interpolate2d` | `nn.Upsample` | Interpolation modules resize tensors using one of the available `InterpolateMode` options: | Mode | Description | | --------- | -------------------------------------------------------- | | `Nearest` | Nearest-neighbor interpolation | | `Linear` | Linear interpolation (bilinear for 2D) | | `Cubic` | Cubic interpolation (bicubic for 2D) | | `Lanczos` | Lanczos3 resampling (6-tap sinc-based filter, a=3) | Configuration is done via `Interpolate1dConfig` / `Interpolate2dConfig` with these options: | Option | Type | Default | Description | | --------------- |------------------------------------------| --------- | -------------------------------------------------------- | | `output_size` | `Option` / `Option<[usize; 2]>` | `None` | Target output size (takes precedence over scale_factor) | | `scale_factor` | `Option` / `Option<[f32; 2]>` | `None` | Scale factor for resizing | | `mode` | `InterpolateMode` | `Nearest` | Interpolation algorithm | | `align_corners` | `bool` | `true` | Align input/output corner pixels | ### RNNs | Burn API | PyTorch Equivalent | | ---------------- | ---------------------- | | `Gru`/`BiGru` | `nn.GRU` | | `Lstm`/`BiLstm` | `nn.LSTM` | | `GateController` | _No direct equivalent_ | ### Transformer | Burn API | PyTorch Equivalent | | -------------------- | ----------------------- | | `MultiHeadAttention` | `nn.MultiheadAttention` | | `TransformerDecoder` | `nn.TransformerDecoder` | | `TransformerEncoder` | `nn.TransformerEncoder` | | `PositionalEncoding` | _No direct equivalent_ | | `RotaryEncoding` | _No direct equivalent_ | ### Loss | Burn API | PyTorch Equivalent | | ------------------------ | ------------------------ | | `BinaryCrossEntropyLoss` | `nn.BCELoss` | | `CosineEmbeddingLoss` | `nn.CosineEmbeddingLoss` | | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `CTCLoss` | `nn.CTCLoss` | | `GramMatrixLoss` | _No direct equivalent_ | | `HuberLoss` | `nn.HuberLoss` | | `KLDivLoss` | `nn.KLDivLoss` | | `LpLoss` | _No direct equivalent_ | | `MseLoss` | `nn.MSELoss` | | `PoissonNllLoss` | `nn.PoissonNLLLoss` | | `RNNTLoss` | `torchaudio.functional.rnnt_loss` | | `SmoothL1Loss` | `nn.SmoothL1Loss` | ================================================ FILE: burn-book/src/building-blocks/record.md ================================================ # Record Records are how states are saved with Burn. Compared to most other frameworks, Burn has its own advanced saving mechanism that allows interoperability between backends with minimal possible runtime errors. There are multiple reasons why Burn decided to create its own saving formats. First, Rust has [serde](https://serde.rs/), which is an extremely well-developed serialization and deserialization library that also powers the `safetensors` format developed by Hugging Face. If used properly, all the validations are done when deserializing, which removes the need to write validation code. Since modules in Burn are created with configurations, they can't implement serialization and deserialization. That's why the record system was created: allowing you to save the state of modules independently of the backend in use extremely fast while still giving you all the flexibility possible to include any non-serializable field within your module. **Why not use safetensors?** [`safetensors`](https://github.com/huggingface/safetensors) uses serde with the JSON file format and only supports serializing and deserializing tensors. The record system in Burn gives you the possibility to serialize any type, which is very useful for optimizers that save their state, but also for any non-standard, cutting-edge modeling needs you may have. Additionally, the record system performs automatic precision conversion by using Rust types, making it more reliable with fewer manual manipulations. It is important to note that the `safetensors` format uses the word _safe_ to distinguish itself from Pickle, which is vulnerable to Python code injection. On our end, the simple fact that we use Rust already ensures that no code injection is possible. If your storage mechanism doesn't handle data corruption, you might prefer a recorder that performs checksum validation (i.e., any recorder with Gzip compression). ## Recorder Recorders are independent of the backend and serialize records with precision and a format. Note that the format can also be in-memory, allowing you to save the records directly into bytes. | Recorder | Format | Compression | | ---------------------- | ------------------------ | ----------- | | DefaultFileRecorder | File - Named MessagePack | None | | NamedMpkFileRecorder | File - Named MessagePack | None | | NamedMpkGzFileRecorder | File - Named MessagePack | Gzip | | BinFileRecorder | File - Binary | None | | BinGzFileRecorder | File - Binary | Gzip | | JsonGzFileRecorder | File - Json | Gzip | | PrettyJsonFileRecorder | File - Pretty Json | Gzip | | BinBytesRecorder | In Memory - Binary | None | Each recorder supports precision settings decoupled from the precision used for training or inference. These settings allow you to define the floating-point and integer types that will be used for serialization and deserialization. | Setting | Float Precision | Integer Precision | | ------------------------- | --------------- | ----------------- | | `DoublePrecisionSettings` | `f64` | `i64` | | `FullPrecisionSettings` | `f32` | `i32` | | `HalfPrecisionSettings` | `f16` | `i16` | Note that when loading a record into a module, the type conversion is automatically handled, so you can't encounter errors. The only crucial aspect is using the same recorder for both serialization and deserialization; otherwise, you will encounter loading errors. **Which recorder should you use?** - If you want fast serialization and deserialization, choose a recorder without compression. The one with the lowest file size without compression is the binary format; otherwise, the named MessagePack could be used. - If you want to save models for storage, you can use compression, but avoid using the binary format, as it may not be backward compatible. - If you want to debug your model's weights, you can use the pretty JSON format. - If you want to deploy with `no-std`, use the in-memory binary format and include the bytes with the compiled code. For examples on saving and loading records, take a look at [Saving and Loading Models](../saving-and-loading.md). ================================================ FILE: burn-book/src/building-blocks/tensor.md ================================================ # Tensor As previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3 generic arguments: the backend B, the dimensionality D, and the data type. ```rust, ignore Tensor // Float tensor (default) Tensor // Explicit float tensor Tensor // Int tensor Tensor // Bool tensor ``` Note that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by backend implementations. Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as below: ```rust, ignore let floats = [1.0, 2.0, 3.0, 4.0, 5.0]; // Get the default device let device = Default::default(); // correct: Tensor is 1-Dimensional with 5 elements let tensor_1 = Tensor::::from_floats(floats, &device); // incorrect: let tensor_1 = Tensor::::from_floats(floats, &device); // this will lead to an error and is for creating a 5-D tensor ``` ### Initialization Burn Tensors are primarily initialized using the `from_data()` method which takes the `TensorData` struct as input. The `TensorData` struct has two public fields: `shape` and `dtype`. The `value`, now stored as bytes, is private but can be accessed via any of the following methods: `as_slice`, `as_mut_slice`, `to_vec` and `iter`. To retrieve the data from a tensor, the method `.to_data()` should be employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is recommended for one-time use. Let's look at a couple of examples for initializing a tensor from different inputs. ```rust, ignore // Initialization from a given Backend (Wgpu) let tensor_1 = Tensor::::from_data([1.0, 2.0, 3.0], &device); // Initialization from a generic Backend let tensor_2 = Tensor::::from_data(TensorData::from([1.0, 2.0, 3.0]), &device); // Initialization using from_floats (Recommended for f32 ElementType) // Will be converted to TensorData internally. let tensor_3 = Tensor::::from_floats([1.0, 2.0, 3.0], &device); // Initialization of Int Tensor from array slices let arr: [i32; 6] = [1, 2, 3, 4, 5, 6]; let tensor_4 = Tensor::::from_data(TensorData::from(&arr[0..3]), &device); // Initialization from a custom type struct BodyMetrics { age: i8, height: i16, weight: f32 } let bmi = BodyMetrics{ age: 25, height: 180, weight: 80.0 }; let data = TensorData::from([bmi.age as f32, bmi.height as f32, bmi.weight]); let tensor_5 = Tensor::::from_data(data, &device); ``` ## Ownership and Cloning Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better. Suppose we want to do a simple min-max normalization of an input tensor. ```rust, ignore let input = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device); let min = input.min(); let max = input.max(); let input = (input - min).div(max - min); ``` With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error and prevent using the input tensor after the first `.min()` operation. The ownership of the input tensor is transferred to the variable `min` and the input tensor is no longer available for further operations. Burn Tensors like most complex primitives do not implement the `Copy` trait and therefore have to be cloned explicitly. Now let's rewrite a working example of doing min-max normalization with cloning. ```rust, ignore let input = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device); let min = input.clone().min(); let max = input.clone().max(); let input = (input.clone() - min.clone()).div(max - min); println!("{}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0] // Notice that max, min have been moved in last operation so // the below print will give an error. // If we want to use them for further operations, // they will need to be cloned in similar fashion. // println!("{:?}", min.to_data()); ``` We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied, and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used, which is very convenient for reusing tensor buffers or even fusing operations into a single kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)). For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations will always be used when available. ## Tensor Operations Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making them useful only for data preprocessing or inference-only model implementations. With Burn, you can focus more on _what_ the model should do, rather than on _how_ to do it. We take the responsibility of making your code run as fast as possible during training as well as inference. The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise. Here, we provide a list of all supported operations along with their PyTorch equivalents. Note that for the sake of simplicity, we ignore type signatures. For more details, refer to the [full documentation](https://docs.rs/burn/latest/burn/tensor/struct.Tensor.html). ### Basic Operations Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | Burn | PyTorch Equivalent | | ---------------------------------------------------- | ------------------------------------------------------------------------- | | `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | | `Tensor::empty(shape, options)` | `torch.empty(shape, device=device, dtype=dtype)` | | `Tensor::from_primitive(primitive)` | N/A | | `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | | `tensor.all()` | `tensor.all()` | | `tensor.all_dim(dim)` | `tensor.all(dim)` | | `tensor.any()` | `tensor.any()` | | `tensor.any_dim(dim)` | `tensor.any(dim)` | | `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | | `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | | `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | | `tensor.device()` | `tensor.device` | | `tensor.dtype()` | `tensor.dtype` | | `tensor.dims()` | `tensor.size()` | | `tensor.equal(other)` | `x == y` | | `tensor.equal_elem(other)` | `tensor.eq(other)` | | `tensor.expand(shape)` | `tensor.expand(shape)` | | `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | | `tensor.flip(axes)` | `tensor.flip(axes)` | | `tensor.full_like(fill_value)` | `torch.full_like(tensor, fill_value)` | | `tensor.gather(dim, indices)` | `torch.gather(tensor, dim, indices)` | | `tensor.into_data()` | N/A | | `tensor.into_primitive()` | N/A | | `tensor.into_scalar()` | `tensor.item()` | | `tensor.mask_fill(mask, value)` | `tensor.masked_fill(mask, value)` | | `tensor.mask_where(mask, value_tensor)` | `torch.where(mask, value_tensor, tensor)` | | `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | | `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | `tensor.not_equal(other)` | `x != y` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | | `tensor.permute(axes)` | `tensor.permute(axes)` | | `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | | `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | | `tensor.reshape(shape)` | `tensor.view(shape)` | | `tensor.roll(shifts, dims)` | `tensor.roll(shifts, dims)` | | `tensor.roll_dim(shift, dim)` | `tensor.roll([shift], [dim])` | | `tensor.scatter(dim, indices, values, update)` | `tensor.scatter_add(dim, indices, values)` | | `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` | | `tensor.select_assign(dim, indices, values, update)` | `tensor.index_add(dim, indices, values)` | | `tensor.shape()` | `tensor.shape` | | `tensor.slice(slices)` | `tensor[(*ranges,)]` | | `tensor.slice_assign(slices, values)` | `tensor[(*ranges,)] = values` | | `tensor.slice_fill(slices, value)` | `tensor[(*ranges,)] = value` | | `tensor.slice_dim(dim, slice)` | N/A | | `tensor.squeeze()` | `tensor.squeeze()` | | `tensor.squeeze_dim(dim)` | `tensor.squeeze(dim)` | | `tensor.squeeze_dims(dims)` | `tensor.squeeze(dims)` where `dims` is a tuple of ints | | `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | | `tensor.take(dim, indices)` | `numpy.take(tensor, indices, dim)` | | `tensor.to_data()` | N/A | | `tensor.to_device(device)` | `tensor.to(device)` | | `tensor.transpose()` | `tensor.T` | | `tensor.t()` | `tensor.T` | | `tensor.unsqueeze()` | N/A | | `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | | `tensor.unsqueeze_dims(dims)` | N/A | | `tensor.zeros_like()` | `torch.zeros_like(tensor)` | | `Tensor::full(shape, fill_value, options)` | `torch.full(shape, fill_value, device=device, dtype=dtype)` | | `Tensor::ones(shape, options)` | `torch.ones(shape, device=device, dtype=dtype)` | | `Tensor::zeros(shape, options)` | `torch.zeros(shape, device=device, dtype=dtype)` | ### Numeric Operations Those operations are available for numeric tensor kinds: `Float` and `Int`. | Burn | PyTorch Equivalent | | --------------------------------------------------------------- | --------------------------------------------- | | `tensor.abs()` | `torch.abs(tensor)` | | `tensor.add(other)` or `tensor + other` | `tensor + other` | | `tensor.add_scalar(scalar)` or `tensor + scalar` | `tensor + scalar` | | `tensor.all_close(other, atol, rtol)` | `torch.allclose(tensor, other, atol, rtol)` | | `tensor.argmax(dim)` | `tensor.argmax(dim)` | | `tensor.argmin(dim)` | `tensor.argmin(dim)` | | `tensor.argsort(dim)` | `tensor.argsort(dim)` | | `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` | | `tensor.bool()` | `tensor.bool()` | | `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` | | `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` | | `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` | | `tensor.cumsum(dim)` | `tensor.cumsum(dim)` | | `tensor.cumprod(dim)` | `tensor.cumprod(dim)` | | `tensor.cummin(dim)` | `tensor.cummin(dim)` | | `tensor.cummax(dim)` | `tensor.cummax(dim)` | | `tensor.div(other)` or `tensor / other` | `tensor / other` | | `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` | | `tensor.dot(other)` | `torch.dot(tensor, other)` | | `tensor.greater(other)` | `tensor.gt(other)` | | `tensor.greater_elem(scalar)` | `tensor.gt(scalar)` | | `tensor.greater_equal(other)` | `tensor.ge(other)` | | `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` | | `tensor.lower(other)` | `tensor.lt(other)` | | `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` | | `tensor.lower_equal(other)` | `tensor.le(other)` | | `tensor.lower_equal_elem(scalar)` | `tensor.le(scalar)` | | `tensor.max()` | `tensor.max()` | | `tensor.max_abs()` | `tensor.abs().max()` | | `tensor.max_abs_dim(dim)` | `tensor.abs().max(dim, keepdim=True)` | | `tensor.max_abs_dims(dims)` | `tensor.abs().max(dims, keepdim=True)` | | `tensor.max_dim(dim)` | `tensor.max(dim, keepdim=True)` | | `tensor.max_dims(dims)` | `tensor.max(dims, keepdim=True)` | | `tensor.max_dim_with_indices(dim)` | N/A | | `tensor.max_pair(other)` | `torch.Tensor.max(a,b)` | | `tensor.mean()` | `tensor.mean()` | | `tensor.mean_dim(dim)` | `tensor.mean(dim, keepdim=True)` | | `tensor.mean_dims(dims)` | `tensor.mean(dims, keepdim=True)` | | `tensor.min()` | `tensor.min()` | | `tensor.min_dim(dim)` | `tensor.min(dim, keepdim=True)` | | `tensor.min_dims(dims)` | `tensor.min(dims, keepdim=True)` | | `tensor.min_dim_with_indices(dim)` | N/A | | `tensor.min_pair(other)` | `torch.Tensor.min(a,b)` | | `tensor.mul(other)` or `tensor * other` | `tensor * other` | | `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` | | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | | `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, mode)` | `torch.nn.functional.pad(tensor, pads, mode)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | | `tensor.prod()` | `tensor.prod()` | | `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` | | `tensor.prod_dims(dims)` | `tensor.prod(dims, keepdim=True)` | | `tensor.rem(other)` or `tensor % other` | `tensor % other` | | `tensor.sign()` | `tensor.sign()` | | `tensor.sort(dim)` | `tensor.sort(dim).values` | | `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` | | `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` | | `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` | | `tensor.sub(other)` or `tensor - other` | `tensor - other` | | `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` | | `tensor.sum()` | `tensor.sum()` | | `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` | | `tensor.sum_dims(dims)` | `tensor.sum(dims, keepdim=True)` | | `tensor.sum_dims_squeeze(dims)` | `tensor.sum(dims, keepdim=False)` | | `tensor.topk(k, dim)` | `tensor.topk(k, dim).values` | | `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` | | `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` | | `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` | | `tensor.unfold(dim, size, step)` | `tensor.unfold(dim, size, step)` | | `Tensor::eye(size, device)` | `torch.eye(size, device=device)` | | `scalar - tensor` | `scalar - tensor` | ### Float Operations Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | -------------------------------------------- | ------------------------------------------ | | `tensor.acos()` | `tensor.acos()` | | `tensor.acosh()` | `tensor.acosh()` | | `tensor.asin()` | `tensor.asin()` | | `tensor.asinh()` | `tensor.asinh()` | | `tensor.atan()` | `tensor.atan()` | | `tensor.atanh()` | `tensor.atanh()` | | `tensor.atan2(other_tensor)` | `tensor.atan2(other_tensor)` | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.contains_nan()` | N/A | | `tensor.cos()` | `tensor.cos()` | | `tensor.cosh()` | `tensor.cosh()` | | `tensor.cross(other)` | `torch.cross(tensor, other)` | | `tensor.deg2rad()` | `torch.deg2rad()` | | `tensor.erf()` | `tensor.erf()` | | `tensor.exp()` | `tensor.exp()` | | `tensor.floor()` | `tensor.floor()` | | `tensor.fmod(other)` | `tensor.fmod(other)` | | `tensor.fmod_scalar(scalar)` | `tensor.fmod(scalar)` | | `tensor.from_floats(floats, device)` | N/A | | `tensor.int()` | Similar to `tensor.to(torch.long)` | | `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` | | `tensor.is_finite()` | `torch.isfinite(tensor)` | | `tensor.is_inf()` | `torch.isinf(tensor)` | | `tensor.is_nan()` | `torch.isnan(tensor)` | | `tensor.log()` | `tensor.log()` | | `tensor.log1p()` | `tensor.log1p()` | | `tensor.matmul(other)` | `tensor.matmul(other)` | | `tensor.rad2deg()` | `torch.rad2deg()` | | `tensor.random(shape, distribution, device)` | N/A | | `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | | `tensor.recip()` or `1.0 / tensor` | `tensor.reciprocal()` or `1.0 / tensor` | | `tensor.round()` | `tensor.round()` | | `tensor.sin()` | `tensor.sin()` | | `tensor.sinh()` | `tensor.sinh()` | | `tensor.square()` | `tensor.square()` | | `tensor.sqrt()` | `tensor.sqrt()` | | `tensor.tan()` | `tensor.tan()` | | `tensor.tanh()` | `tensor.tanh()` | | `tensor.trunc()` | `tensor.trunc()` | | `tensor.var(dim)` | `tensor.var(dim)` | | `tensor.var_bias(dim)` | N/A | | `tensor.var_mean(dim)` | N/A | | `tensor.var_mean_bias(dim)` | N/A | | `tensor.median(dim)` | `tensor.median(dim)` | | `tensor.median_with_indices(dim)` | `tensor.median(dim)` | ### Int Operations Those operations are only available for `Int` tensors. | Burn API | PyTorch Equivalent | | ------------------------------------------------ | ------------------------------------------------------- | | `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | | `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | | `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` | | `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` | | `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` | | `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` | | `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` | | `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` | | `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` | | `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` | | `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` | | `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` | | `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | ### Bool Operations Those operations are only available for `Bool` tensors. | Burn API | PyTorch Equivalent | | ------------------------------------ | ------------------------------- | | `Tensor::diag_mask(shape, diagonal)` | N/A | | `Tensor::tril_mask(shape, diagonal)` | N/A | | `Tensor::triu_mask(shape, diagonal)` | N/A | | `tensor.argwhere()` | `tensor.argwhere()` | | `tensor.bool_and()` | `tensor.logical_and()` | | `tensor.bool_not()` | `tensor.logical_not()` | | `tensor.bool_or()` | `tensor.logical_or()` | | `tensor.bool_xor()` | `tensor.logical_xor()` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.int()` | `tensor.to(torch.long)` | | `tensor.nonzero()` | `tensor.nonzero(as_tuple=True)` | ### Quantization Operations Those operations are only available for `Float` tensors on backends that implement quantization strategies. | Burn API | PyTorch Equivalent | | ---------------------------------- | ------------------ | | `tensor.quantize(scheme, qparams)` | N/A | | `tensor.dequantize()` | N/A | ## Activation Functions | Burn API | PyTorch Equivalent | | ------------------------------------------------ | -------------------------------------------------- | | `activation::celu(tensor, alpha)` | `nn.functional.celu(tensor, alpha)` | | `activation::elu(tensor, alpha)` | `nn.functional.elu(tensor, alpha)` | | `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` | | `activation::glu(tensor, dim)` | `nn.functional.glu(tensor, dim)` | | `activation::hard_shrink(tensor, lambda)` | `nn.functional.hardshrink(tensor, lambd)` | | `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` | | `activation::hard_swish(tensor)` | `nn.functional.hardswish(tensor)` | | `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` | | `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` | | `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` | | `activation::mish(tensor)` | `nn.functional.mish(tensor)` | | `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` | | `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` | | `activation::relu(tensor)` | `nn.functional.relu(tensor)` | | `activation::shrink(tensor, lambda, bias)` | _No direct equivalent_ | | `activation::soft_shrink(tensor, lambda)` | `nn.functional.softshrink(tensor, lambd)` | | `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` | | `activation::selu(tensor)` | `nn.functional.selu(tensor)` | | `activation::silu(tensor)` | `nn.functional.silu(tensor)` | | `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` | | `activation::softmin(tensor, dim)` | `nn.functional.softmin(tensor, dim)` | | `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` | | `activation::softsign(tensor)` | `nn.functional.softsign(tensor)` | | `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` | | `activation::thresholded_relu(tensor, alpha)` | `nn.functional.threshold(tensor, alpha, 0)` | ## Grid Functions | Burn API | PyTorch Equivalent | | --------------------------------------------------- | -------------------------------------------------------------------- | | `grid::affine_grid_2d(transformation_tensor, dims)` | `nn.functional.affine_grid(theta_tensor, size, align_corners)` | | `grid::meshgrid(tensors, GridIndexing::Matrix)` | `torch.meshgrid(tensors, indexing="ij")` | | `grid::meshgrid(tensors, GridIndexing::Cartesian)` | `torch.meshgrid(tensors, indexing="xy")` | | `grid::meshgrid_stack(tensors, index_pos)` | _No direct equivalent_ | ## Linalg Functions | Burn API | PyTorch Equivalent | | -------------------------------------------------- | --------------------------------------------------- | | `linalg::cosine_similarity(x1, x2, dim, eps)` | `nn.functional.cosine_similarity(x1, x2, dim, eps)` | | `linalg::diag(tensor)` | `torch.diag(tensor)` | | `linalg::l0_norm(tensor, dim)` | _No direct equivalent_ | | `linalg::l1_norm(tensor, dim)` | _No direct equivalent_ | | `linalg::l2_norm(tensor, dim)` | _No direct equivalent_ | | `linalg::lp_norm(tensor, p, dim)` | _No direct equivalent_ | | `linalg::lu_decomposition(tensor)` | `torch.linalg.lu(tensor)` | | `linalg::matvec(matrix, vector)` | `torch.matmul(matrix, vector)` / `@` operator | | `linalg::max_abs_norm(tensor, dim)` | _No direct equivalent_ | | `linalg::min_abs_norm(tensor, dim)` | _No direct equivalent_ | | `linalg::outer(lhs, rhs)` | `torch.outer(lhs, rhs)` / `einsum("bi,bj->bij", …)` | | `linalg::outer_dim(lhs, rhs, dim)` | _No direct equivalent_ | | `linalg::trace(tensor)` | `torch.trace(tensor)` | | `linalg::vector_norm(tensor, p, dim)` | `torch.linalg.vector_norm(tensor, p, dim)` | | `linalg::vector_normalize(tensor, norm, dim, eps)` | `nn.functional.normalize(tensor, p, dim, eps)` | ## Displaying Tensor Details Burn provides flexible options for displaying tensor information, allowing you to control the level of detail and formatting to suit your needs. ### Basic Display To display a detailed view of a tensor, you can simply use Rust's `println!` or `format!` macros: ```rust, ignore let tensor = Tensor::::full([2, 3], 0.123456789, &Default::default()); println!("{}", tensor); ``` This will output: ``` Tensor { data: [[0.12345679, 0.12345679, 0.12345679], [0.12345679, 0.12345679, 0.12345679]], shape: [2, 3], device: Cpu, backend: "ndarray", kind: "Float", dtype: "f32", } ``` ### Controlling Precision You can control the number of decimal places displayed using Rust's formatting syntax: ```rust println!("{:.2}", tensor); ``` Output: ``` Tensor { data: [[0.12, 0.12, 0.12], [0.12, 0.12, 0.12]], shape: [2, 3], device: Cpu, backend: "ndarray", kind: "Float", dtype: "f32", } ``` ### Global Print Options For more fine-grained control over tensor printing, Burn provides a `PrintOptions` struct and a `set_print_options` function: ```rust, ignore use burn::tensor::{set_print_options, PrintOptions}; let print_options = PrintOptions { precision: Some(2), ..Default::default() }; set_print_options(print_options); ``` Options: - `precision`: Number of decimal places for floating-point numbers (default: None) - `threshold`: Maximum number of elements to display before summarizing (default: 1000) - `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing (default: 3) ### Checking Tensor Closeness Burn provides a utility function `check_closeness` to compare two tensors and assess their similarity. This function is particularly useful for debugging and validating tensor operations, especially when working with floating-point arithmetic where small numerical differences can accumulate. It's also valuable when comparing model outputs during the process of importing models from other frameworks, helping to ensure that the imported model produces results consistent with the original. Here's an example of how to use `check_closeness`: ```rust, ignore use burn::tensor::{check_closeness, Tensor}; type B = burn::backend::NdArray; let device = Default::default(); let tensor1 = Tensor::::from_floats( [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1], &device, ); let tensor2 = Tensor::::from_floats( [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004], &device, ); check_closeness(&tensor1, &tensor2); ``` The `check_closeness` function compares the two input tensors element-wise, checking their absolute differences against a range of epsilon values. It then prints a detailed report showing the percentage of elements that are within each tolerance level. The output provides a breakdown for different epsilon values, allowing you to assess the closeness of the tensors at various precision levels. This is particularly helpful when dealing with operations that may introduce small numerical discrepancies. The function uses color-coded output to highlight the results: - Green [PASS]: All elements are within the specified tolerance. - Yellow [WARN]: Most elements (90% or more) are within tolerance. - Red [FAIL]: Significant differences are detected. This utility can be invaluable when implementing or debugging tensor operations, especially those involving complex mathematical computations or when porting algorithms from other frameworks. It's also an essential tool when verifying the accuracy of imported models, ensuring that the Burn implementation produces results that closely match those of the original model. ================================================ FILE: burn-book/src/custom-training-loop.md ================================================ # Custom Training Loops Even though Burn comes with a project dedicated to simplifying training, it doesn't mean that you have to use it. Sometimes you may have special needs for your training, and it might be faster to just reimplement the training loop yourself. Also, you may just prefer implementing your own training loop instead of using a pre-built one in general. Burn's got you covered! We will start from the same example shown in the [basic workflow](./basic-workflow) section, but without using the `Learner` struct. ```rust, ignore #[derive(Config, Debug)] pub struct MnistTrainingConfig { #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, #[config(default = 1e-4)] pub lr: f64, pub model: ModelConfig, pub optimizer: AdamConfig, } pub fn run(device: B::Device) { // Create the configuration. let config_model = ModelConfig::new(10, 1024); let config_optimizer = AdamConfig::new(); let config = MnistTrainingConfig::new(config_model, config_optimizer); B::seed(&device, config.seed); // Create the model and optimizer. let mut model = config.model.init::(&device); let mut optim = config.optimizer.init(); // Create the batcher. let batcher = MnistBatcher::default(); // Create the dataloaders. let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::test()); ... } ``` As seen with the previous example, setting up the configurations and the dataloader hasn't changed. Now, let's move forward and write our own training loop: ```rust, ignore pub fn run(device: B::Device) { ... // Iterate over our training and validation loop for X epochs. for epoch in 1..config.num_epochs + 1 { // Implement our training loop. for (iteration, batch) in dataloader_train.iter().enumerate() { let output = model.forward(batch.images); let loss = CrossEntropyLoss::new(None, &output.device()) .forward(output.clone(), batch.targets.clone()); let accuracy = accuracy(output, batch.targets); println!( "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %", epoch, iteration, loss.clone().into_scalar(), accuracy, ); // Gradients for the current backward pass let grads = loss.backward(); // Gradients linked to each parameter of the model. let grads = GradientsParams::from_grads(grads, &model); // Update the model using the optimizer. model = optim.step(config.lr, model, grads); } // Get the model without autodiff. let model_valid = model.valid(); // Implement our validation loop. for (iteration, batch) in dataloader_test.iter().enumerate() { let output = model_valid.forward(batch.images); let loss = CrossEntropyLoss::new(None, &output.device()) .forward(output.clone(), batch.targets.clone()); let accuracy = accuracy(output, batch.targets); println!( "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}", epoch, iteration, loss.clone().into_scalar(), accuracy, ); } } } ``` In the previous code snippet, we can observe that the loop starts from epoch `1` and goes up to `num_epochs`. Within each epoch, we iterate over the training dataloader. During this process, we execute the forward pass, which is necessary for computing both the loss and accuracy. To maintain simplicity, we print the results to stdout. Upon obtaining the loss, we can invoke the `backward()` function, which returns the gradients specific to each variable. It's important to note that we need to map these gradients to their corresponding parameters using the `GradientsParams` type. This step is essential because you might run multiple different autodiff graphs and accumulate gradients for each parameter id. Finally, we can perform the optimization step using the learning rate, the model, and the computed gradients. It's worth mentioning that, unlike PyTorch, there's no need to register the gradients with the optimizer, nor do you have to call `zero_grad`. The gradients are automatically consumed during the optimization step. If you're interested in gradient accumulation, you can easily achieve this by using the `GradientsAccumulator`. ```rust, ignore let mut accumulator = GradientsAccumulator::new(); let grads = model.backward(); let grads = GradientsParams::from_grads(grads, &model); accumulator.accumulate(&model, grads); ... let grads = accumulator.grads(); // Pop the accumulated gradients. ``` Note that after each epoch, we include a validation loop to assess our model's performance on previously unseen data. To disable gradient tracking during this validation step, we can invoke `model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's important to emphasize that we've declared our validation batcher to be on the inner backend, specifically `MnistBatcher`; not using `model.valid()` will result in a compilation error. You can find the code above available as an [example](https://github.com/tracel-ai/burn/tree/main/examples/custom-training-loop) for you to test. ## Multiple optimizers It's common practice to set different learning rates, optimizer parameters, or use different optimizers entirely, for different parts of a model. In Burn, each `GradientParams` can contain only a subset of gradients to actually apply with an optimizer. This allows you to flexibly mix and match optimizers! ```rust,ignore // Start with calculating all gradients let grads = loss.backward(); // Now split the gradients into various parts. let grads_conv1 = GradientParams::from_module(&mut grads, &model.conv1); let grads_conv2 = GradientParams::from_module(&mut grads, &model.conv2); // You can step the model with these gradients, using different learning // rates for each param. You could also use an entirely different optimizer here! model = optim.step(config.lr * 2.0, model, grads_conv1); model = optim.step(config.lr * 4.0, model, grads_conv2); // For even more granular control you can split off individual parameter // eg. a linear bias usually needs a smaller learning rate. if let Some(bias) == model.linear1.bias { let grads_bias = GradientParams::from_params(&mut grads, &model.linear1, &[bias.id]); model = optim.step(config.lr * 0.1, model, grads_bias); } // Note that above calls remove gradients, so we can just get all "remaining" gradients. let grads = GradientsParams::from_grads(grads, &model); model = optim.step(config.lr, model, grads); ``` ## Custom Type The explanations above demonstrate how to create a basic training loop. However, you may find it beneficial to organize your program using intermediary types. There are various ways to do this, but it requires getting comfortable with generics. If you wish to group the optimizer and the model into the same structure, you have several options. It's important to note that the optimizer trait depends on both the `AutodiffModule` trait and the `AutodiffBackend` trait, while the module only depends on the `AutodiffBackend` trait. Here's a closer look at how you can create your types: **Create a struct that is generic over the backend and the optimizer, with a predefined model.** ```rust, ignore struct Learner where B: AutodiffBackend, { model: Model, optim: O, } ``` This is quite straightforward. You can be generic over the backend since it's used with the concrete type `Model` in this case. **Create a struct that is generic over the model and the optimizer.** ```rust, ignore struct Learner { model: M, optim: O, } ``` This option is a quite intuitive way to declare the struct. You don't need to write type constraints with a `where` statement when defining a struct; you can wait until you implement the actual function. However, with this struct, you may encounter some issues when trying to implement code blocks to your struct. ```rust, ignore impl Learner where B: AutodiffBackend, M: AutodiffModule, O: Optimizer, { pub fn step(&mut self, _batch: MnistBatch) { // } } ``` This will result in the following compilation error: ```console 1. the type parameter `B` is not constrained by the impl trait, self type, or predicates unconstrained type parameter [E0207] ``` To resolve this issue, you have two options. The first one is to make your function generic over the backend and add your trait constraint within its definition: ```rust, ignore #[allow(dead_code)] impl Learner2 { pub fn step(&mut self, _batch: MnistBatch) where B: AutodiffBackend, M: AutodiffModule, O: Optimizer, { // } } ``` However, some people may prefer to have the constraints on the implementation block itself. In that case, you can make your struct generic over the backend using `PhantomData`. **Create a struct that is generic over the backend, the model, and the optimizer.** ```rust, ignore struct Learner3 { model: M, optim: O, _b: PhantomData, } ``` You might wonder why `PhantomData` is required. Each generic argument must be used as a field when declaring a struct. When you don't need the generic argument, you can use `PhantomData` to mark it as a zero sized type. These are just some suggestions on how to define your own types, but you are free to use any pattern that you prefer. ================================================ FILE: burn-book/src/distributed-computing.md ================================================ # Distributed Computing ================================================ FILE: burn-book/src/examples.md ================================================ # Examples In the [next chapter](./basic-workflow) you'll have the opportunity to implement the whole Burn `guide` example yourself in a step by step manner. Many additional Burn examples are available in the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory. Burn examples are organized as library crates with one or more examples that are executable binaries. An example can then be executed using the following cargo command line in the root of the Burn repository: ```bash cargo run --example ``` To learn more about crates and examples, read the Rust section below.
🦀 About Rust crates Each Burn example is a **package** which are subdirectories of the `examples` directory. A package is composed of one or more **crates**. A package is a bundle of one or more crates that provides a set of functionality. A package contains a `Cargo.toml` file that describes how to build those crates. A crate is a compilation unit in Rust. It could be a single file, but it is often easier to split up crates into multiple **modules**. A module lets us organize code within a crate for readability and easy reuse. Modules also allow us to control the _privacy_ of items. For instance the `pub(crate)` keyword is employed to make a module publicly available inside the crate. In the snippet below there are four modules declared, two of them are public and visible to the users of the crates, one of them is public inside the crate only and crate users cannot see it, at last one is private when there is no keyword. These modules can be single files or a directory with a `mod.rs` file inside. ```rust, ignore pub mod data; pub mod inference; pub(crate) mod model; mod training; ``` A crate can come in one of two forms: a **binary crate** or a **library crate**. When compiling a crate, the compiler first looks in the crate root file (`src/lib.rs` for a library crate and `src/main.rs` for a binary crate). Any module declared in the crate root file will be inserted in the crate for compilation. All Burn examples are library crates and they can contain one or more executable examples that uses the library. We even have some Burn examples that uses the library crate of other examples. The examples are unique files under the `examples` directory. Each file produces an executable file with the same name. Each example can then be executed with `cargo run --example `. Below is a file tree of a typical Burn example package: ``` examples/burn-example ├── Cargo.toml ├── examples │ ├── example1.rs ---> compiled to example1 binary │ ├── example2.rs ---> compiled to example2 binary │ └── ... └── src ├── lib.rs ---> this is the root file for a library ├── module1.rs ├── module2.rs └── ... ```

The following additional examples are currently available if you want to check them out: | Example | Description | | :-------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [Custom CSV Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-csv-dataset) | Implements a dataset to parse CSV data for a regression task. | | [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression) | Trains a simple MLP on the California Housing dataset to predict the median house value for a district. | | [Custom Image Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-image-dataset) | Trains a simple CNN on custom image dataset following a simple folder structure. | | [Custom Renderer](https://github.com/tracel-ai/burn/tree/main/examples/custom-renderer) | Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress. | | [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. | | [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web) | An interactive MNIST inference demo in the browser. The demo is available [online](https://burn.dev/demo/). | | [MNIST Training](https://github.com/tracel-ai/burn/tree/main/examples/mnist) | Demonstrates how to train a custom [`Module`](./building-blocks/module.md) (MLP) with the [`Learner`](./building-blocks/learner.md) configured to log metrics and keep training checkpoints. | | [ONNX Import Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. | | [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/import-model-weights) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. | | [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. | | [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. | | [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. | For more information on each example, see their respective `README.md` file. Be sure to check out the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date list.
Note that some examples use the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download the datasets required in the examples. This is a Python library, which means that you will need to install Python before running these examples. This requirement will be clearly indicated in the example's README when applicable.
================================================ FILE: burn-book/src/getting-started.md ================================================ # Getting Started Burn is a deep learning framework in the Rust programming language. Therefore, it goes without saying that one must understand the basic notions of Rust. Reading the first chapters of the [Rust Book](https://doc.rust-lang.org/book/) is recommended, but don't worry if you're just starting out. We'll try to provide as much context and reference to external resources when required. Just look out for the **🦀 Rust Note** indicators. ## Installing Rust For installation instructions, please refer to the [installation page](https://doc.rust-lang.org/book/ch01-01-installation.html). It explains in details the most convenient way for you to install Rust on your computer, which is the very first thing to do to start using Burn. ## Creating a Burn application Once Rust is correctly installed, create a new Rust application by using Rust's build system and package manager Cargo. It is automatically installed with Rust.
🦀 Cargo Cheat Sheet [Cargo](https://doc.rust-lang.org/cargo/) is a very useful tool to manage Rust projects because it handles a lot of tasks. More precisely, it is used to compile your code, download the libraries/packages your code depends on, and build said libraries. Below is a quick cheat sheet of the main `cargo` commands you might use throughout this guide. | Command | Description | | ------------------- | -------------------------------------------------------------------------------------------- | | `cargo new` _path_ | Create a new Cargo package in the given directory. | | `cargo add` _crate_ | Add dependencies to the Cargo.toml manifest file. | | `cargo build` | Compile the local package and all of its dependencies (in debug mode, use `-r` for release). | | `cargo check` | Check the local package for compilation errors (much faster). | | `cargo run` | Run the local package binary. | For more information, check out [Hello, Cargo!](https://doc.rust-lang.org/book/ch01-03-hello-cargo.html) in the Rust Book.

In the directory of your choice, run the following: ```console cargo new my_burn_app ``` This will initialize the `my_burn_app` project directory with a `Cargo.toml` file and a `src` directory with an auto-generated `main.rs` file inside. Head inside the directory to check: ```console cd my_burn_app ``` Then, add Burn as a dependency: ```console cargo add burn --features wgpu ``` Finally, compile the local package by executing the following: ```console cargo build ``` That's it, you're ready to start! You have a project configured with Burn and the WGPU backend, which allows to execute low-level operations on any platform using the GPU.
When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain. To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` file: ```rust #![recursion_limit = "256"] ``` The default recursion limit (128) is often just below the required depth (typically 130-150) due to deeply nested associated types and trait bounds.
## Writing a code snippet The `src/main.rs` was automatically generated by Cargo, so let's replace its content with the following: ```rust, ignore use burn::tensor::Tensor; use burn::backend::Wgpu; // Type alias for the backend to use. type Backend = Wgpu; fn main() { let device = Default::default(); // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first let tensor_1 = Tensor::::from_data([[2., 3.], [4., 5.]], &device); let tensor_2 = Tensor::::ones_like(&tensor_1); // Print the element-wise addition (done with the WGPU backend) of the two tensors. println!("{}", tensor_1 + tensor_2); } ```
🦀 Use Declarations To bring any of the Burn module or item into scope, a `use` declaration is added. In the example above, we wanted bring the `Tensor` struct and `Wgpu` backend into scope with the following: ```rust, ignore use burn::tensor::Tensor; use burn::backend::Wgpu; ``` This is pretty self-explanatory in this case. But, the same declaration could be written as a shortcut to simultaneously binding of multiple paths with a common prefix: ```rust, ignore use burn::{tensor::Tensor, backend::Wgpu}; ``` In this example, the common prefix is pretty short and there are only two items to bind locally. Therefore, the first usage with two `use` declarations might be preferred. But know that both examples are valid. For more details on the `use` keyword, take a look at [this section](https://doc.rust-lang.org/book/ch07-04-bringing-paths-into-scope-with-the-use-keyword.html) of the Rust Book or the [Rust reference](https://doc.rust-lang.org/reference/items/use-declarations.html).

🦀 Generic Data Types If you're new to Rust, you're probably wondering why we had to use `Tensor::::...`. That's because the `Tensor` struct is [generic](https://doc.rust-lang.org/book/ch10-01-syntax.html) over multiple concrete data types. More specifically, a `Tensor` can be defined using three generic parameters: the backend, the number of dimensions (rank) and the data type (defaults to `Float`). Here, we only specify the backend and number of dimensions since a `Float` tensor is used by default. For more details on the `Tensor` struct, take a look at [this section](./building-blocks/tensor.md). Most of the time when generics are involved, the compiler can infer the generic parameters automatically. In this case, the compiler needs a little help. This can usually be done in one of two ways: providing a type annotation or binding the generic parameter via the _turbofish_ `::<>` syntax. In the example above we used the so-called _turbofish_ syntax, but we could have used type annotations instead like this: ```rust, ignore let tensor_1: Tensor = Tensor::from_data([[2., 3.], [4., 5.]]); let tensor_2 = Tensor::ones_like(&tensor_1); ``` You probably noticed that we provided a type annotation for the first tensor only and yet this example still works. That's because the compiler (correctly) inferred that `tensor_2` had the same generic parameters. The same could have been done in the original example, but specifying the parameters for both is more explicit.

By running `cargo run`, you should now see the result of the addition: ```console Tensor { data: [[3.0, 4.0], [5.0, 6.0]], shape: [2, 2], device: DefaultDevice, backend: "wgpu", kind: "Float", dtype: "f32", } ``` While the previous example is somewhat trivial, the upcoming basic workflow section will walk you through a much more relevant example for deep learning applications. ## Using `prelude` Burn comes with a variety of things in its core library. When creating a new model or using an existing one for inference, you may need to import every single component you used, which could be a little verbose. To address it, a `prelude` module is provided, allowing you to easily import commonly used structs and macros as a group: ```rust, ignore use burn::prelude::*; ``` which is equal to: ```rust, ignore use burn::{ config::Config, module::Module, nn, tensor::{ backend::Backend, Bool, Device, ElementConversion, Float, Int, Shape, Tensor, TensorData, }, }; ```
For the sake of simplicity, the subsequent chapters of this book will all use this form of importing except in the [Building Blocks](./building-blocks) chapter, as explicit importing aids users in grasping the usage of particular structures and macros.
================================================ FILE: burn-book/src/models-and-pretrained-weights.md ================================================ # Models and Pre-Trained Weights ## Models Repository The [`models`](https://github.com/tracel-ai/models) repository contains definitions of different deep learning models with examples for different domains like computer vision and natural language processing. This includes image classification models such as [`MobileNetV2`](https://github.com/tracel-ai/models/tree/main/mobilenetv2-burn), [`SqueezeNet`](https://github.com/tracel-ai/models/tree/main/squeezenet-burn) and [`ResNet`](https://github.com/tracel-ai/models/tree/main/resnet-burn), object detection models such as [`YOLOX`](https://github.com/tracel-ai/models/tree/main/yolox-burn) and language models like [`BERT` and `RoBERTa`](https://github.com/tracel-ai/models/tree/main/bert-burn). Be sure to check out the up-to-date [collection of models](https://github.com/tracel-ai/models?tab=readme-ov-file#collection-of-official-models) to get you started. Pre-trained weights are available for every supported architecture in this collection. You will also find a spotlight of [community contributed models](https://github.com/tracel-ai/models?tab=readme-ov-file#community-contributions). ## Burn-LM (alpha) [`Burn-LM`](https://github.com/tracel-ai/burn-lm) is an LLM inference engine built on Burn. It provides access to large language models with open-source pre-trained weights and supports running, fine-tuning, and experimenting with them on any Burn backend. Unlike tools focused solely on inference, Burn-LM is designed to work in a unified way across different models and tasks, making it easier to explore both inference and training workflows within the same framework. ================================================ FILE: burn-book/src/motivation.md ================================================ # Why Burn? Why bother with the effort of creating an entirely new deep learning framework from scratch when PyTorch, TensorFlow, and other frameworks already exist? Spoiler alert: Burn isn't merely a replication of PyTorch or TensorFlow in Rust. It represents a novel approach, placing significant emphasis on making the right compromises in the right areas to facilitate exceptional flexibility, high performance, and a seamless developer experience. Burn isn’t a framework specialized for only one type of application, it is designed to serve as a versatile framework suitable for a wide range of research and production uses. The foundation of Burn's design revolves around three key user profiles: **Machine Learning Researchers** require tools to construct and execute experiments efficiently. It’s essential for them to iterate quickly on their ideas and design testable experiments which can help them discover new findings. The framework should facilitate the swift implementation of cutting-edge research while ensuring fast execution for testing. **Machine Learning Engineers** are another important demographic to keep in mind. Their focus leans less on swift implementation and more on establishing robustness, seamless deployment, and cost-effective operations. They seek dependable, economical models capable of achieving objectives without excessive expense. The whole machine learning workflow —from training to inference— must be as efficient as possible with minimal unpredictable behavior. **Low level Software Engineers** working with hardware vendors want their processing units to run models as fast as possible to gain competitive advantage. This endeavor involves harnessing hardware-specific features such as Tensor Core for Nvidia. Since they are mostly working at a system level, they want to have absolute control over how the computation will be executed. The goal of Burn is to satisfy all of those personas! ================================================ FILE: burn-book/src/onnx-import.md ================================================ # ONNX Import ## Introduction As deep learning evolves, interoperability between frameworks becomes crucial. Burn provides robust support for importing [ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models through the [`burn-onnx`](https://github.com/tracel-ai/burn-onnx) crate, enabling you to leverage pre-trained models in your Rust-based deep learning projects. ## Why Import Models? Importing pre-trained models offers several advantages: 1. **Time-saving**: Skip the resource-intensive process of training models from scratch. 2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by researchers and industry leaders. 3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from knowledge transfer. 4. **Consistency across frameworks**: Maintain consistent performance when moving between frameworks. ## Understanding ONNX ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models with these key features: - **Framework agnostic**: Provides a common format that works across various deep learning frameworks. - **Comprehensive representation**: Captures both the model architecture and trained weights. - **Wide support**: Compatible with popular frameworks like PyTorch, TensorFlow, and scikit-learn. This standardization allows seamless movement of models between different frameworks and deployment environments. ## Burn's ONNX Support Burn's approach to ONNX import offers unique advantages: 1. **Native Rust code generation**: Translates ONNX models into Rust source code for deep integration with Burn's ecosystem. 2. **Compile-time optimization**: Leverages the Rust compiler to optimize the generated code, potentially improving performance. 3. **No runtime dependency**: Eliminates the need for an ONNX runtime, unlike many other solutions. 4. **Trainability**: Allows imported models to be further trained or fine-tuned using Burn. 5. **Portability**: Enables compilation for various targets, including WebAssembly and embedded devices. 6. **Backend flexibility**: Works with any of Burn's supported backends. ## ONNX Compatibility Burn recommends ONNX models use **opset version 16 or higher** for best compatibility. While models with older opset versions may work, opset 16+ ensures access to all supported operators and their latest behavior. If you encounter issues with an older model, consider upgrading it using the ONNX version converter. ### Upgrading ONNX Models There are two simple ways to upgrade your ONNX models to the recommended opset version: Option 1: Use the provided utility script: ``` uv run --script https://raw.githubusercontent.com/tracel-ai/burn-onnx/refs/heads/main/onnx_opset_upgrade.py ``` Option 2: Use a custom Python script: ```python import onnx from onnx import version_converter, shape_inference # Load your ONNX model model = onnx.load('path/to/your/model.onnx') # Convert the model to opset version 16 upgraded_model = version_converter.convert_version(model, 16) # Apply shape inference to the upgraded model inferred_model = shape_inference.infer_shapes(upgraded_model) # Save the converted model onnx.save(inferred_model, 'upgraded_model.onnx') ``` ## Step-by-Step Guide Follow these steps to import an ONNX model into your Burn project: ### Step 1: Update `Cargo.toml` First, add the required dependencies to your `Cargo.toml`: ```toml [dependencies] burn = { version = "~0.21", features = ["ndarray"] } [build-dependencies] burn-onnx = "~0.21" ``` ### Step 2: Update `build.rs` In your `build.rs` file: ```rust, ignore use burn_onnx::ModelGen; fn main() { ModelGen::new() .input("src/model/my_model.onnx") .out_dir("model/") .run_from_script(); } ``` This generates Rust code and a `.bpk` weights file from your ONNX model during the build process. ### Step 3: Modify `mod.rs` In your `src/model/mod.rs` file, include the generated code: ```rust, ignore pub mod my_model { include!(concat!(env!("OUT_DIR"), "/model/my_model.rs")); } ``` ### Step 4: Use the Imported Model Now you can use the imported model in your code: ```rust, ignore use burn::tensor; use burn_ndarray::{NdArray, NdArrayDevice}; use model::my_model::Model; fn main() { let device = NdArrayDevice::default(); // Create model instance and load weights from target dir default device let model: Model> = Model::default(); // Create input tensor (replace with your actual input) let input = tensor::Tensor::, 4>::zeros([1, 3, 224, 224], &device); // Perform inference let output = model.forward(input); println!("Model output: {:?}", output); } ``` ## Advanced Configuration The `ModelGen` struct provides configuration options: ```rust, ignore use burn_onnx::{ModelGen, LoadStrategy}; ModelGen::new() .input("path/to/model.onnx") .out_dir("model/") .development(true) // Enable development mode for debugging .load_strategy(LoadStrategy::Embedded) // Embed weights in the binary .run_from_script(); ``` - `input`: Path to the ONNX model file - `out_dir`: Output directory for generated code and weights - `development`: When enabled, generates additional debug files (`.onnx.txt`, `.graph.txt`) - `load_strategy`: Controls which weight-loading constructors are generated on the `Model` struct (see below) Model weights are stored in Burnpack format (`.bpk`), which provides efficient serialization and loading. ### Load Strategy The `LoadStrategy` enum controls how the generated model loads its weights: | Strategy | Generated constructors | `Default` impl | Use case | |------------|------------------------------------------------|-----------------|-------------------------------------------| | `File` | `from_file()`, `from_bytes()` | Yes | Standard desktop/server (default) | | `Embedded` | `from_embedded()`, `from_bytes()` | Yes | Single binary, small models | | `Bytes` | `from_bytes()` | No | WASM, embedded, custom loaders | | `None` | (none) | No | Manual weight management | The default strategy is `File`, which keeps weights in a separate `.bpk` file and generates a `from_file()` constructor. For WebAssembly or environments without filesystem access, use `LoadStrategy::Bytes`: ```rust, ignore ModelGen::new() .input("model.onnx") .out_dir("model/") .load_strategy(LoadStrategy::Bytes) .run_from_script(); ``` Then load weights at runtime from any byte source (e.g., a network fetch): ```rust, ignore let model = Model::::from_bytes(weight_bytes, &device); ``` ## Loading and Using Models You can load models in several ways, depending on the `LoadStrategy` used during code generation: ```rust, ignore // Load from the output directory with default device (recommended for most use cases) // This automatically loads weights from the .bpk file // Available with LoadStrategy::File or LoadStrategy::Embedded let model = Model::::default(); // Create a new model instance with a specific device // (initializes weights randomly; load weights via `load_from` afterward) let model = Model::::new(&device); // Load from a specific .bpk file (LoadStrategy::File) let model = Model::::from_file("path/to/weights.bpk", &device); // Load from in-memory bytes (LoadStrategy::File, Embedded, or Bytes) let model = Model::::from_bytes(weight_bytes, &device); // Load from embedded weights (LoadStrategy::Embedded) let model = Model::::from_embedded(&device); ``` ## Troubleshooting Common issues and solutions: 1. **Unsupported ONNX operator**: Check the [list of supported ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md). You may need to simplify your model or wait for support. 2. **Build errors**: Ensure your `burn-onnx` version matches your Burn version and verify the ONNX file path in `build.rs`. 3. **Runtime errors**: Confirm that your input tensors match the expected shape and data type of your model. 4. **Performance issues**: Consider using a more performant backend or optimizing your model architecture. 5. **Viewing generated files**: Find the generated Rust code and weights in the `OUT_DIR` directory (usually `target/debug/build//out`). ## Examples and Resources For practical examples, check out the [burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples): 1. [ONNX Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) - MNIST inference example 2. [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) - SqueezeNet running in the browser via WebAssembly 3. [Raspberry Pi Pico](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico) - Embedded deployment example These demonstrate real-world usage of ONNX import in Burn projects. For contributors looking to add support for new ONNX operators: - [Development Guide](https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md) - Step-by-step guide for implementing new operators ## Conclusion Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's performance and Rust's safety features. Following this guide, you can seamlessly integrate ONNX models into your Burn projects for inference, fine-tuning, or further development. The `burn-onnx` crate is actively developed, with ongoing work to support more ONNX operators and improve performance. Visit the [burn-onnx repository](https://github.com/tracel-ai/burn-onnx) for updates and to contribute! ================================================ FILE: burn-book/src/overview.md ================================================ # Overview Welcome to The Burn Book 👋 This book will help you get started with the Burn deep learning framework, whether you are an advanced user or a beginner. We have crafted some sections for you: - [Basic Workflow: From Training to Inference](./basic-workflow): We'll start with the fundamentals, guiding you through the entire workflow, from training your models to deploying them for inference. This section lays the groundwork for your Burn expertise. - [Building Blocks](./building-blocks): Dive deeper into Burn's core components, understanding how they fit together. This knowledge forms the basis for more advanced usage and customization. - [Performance - Good Practices](./performance/good-practices/): Tips for writing models and training code that make the most of hardware resources while avoiding common pitfalls that can slow down execution. - [Custom Training Loop](./custom-training-loop.md): Gain the power to customize your training loops, fine-tuning your models to meet your specific requirements. This section empowers you to harness Burn's flexibility to its fullest. - [Saving & Loading Models](./saving-and-loading.md): Learn how to save and load your trained models, including importing weights from PyTorch and SafeTensors formats. - [ONNX Import](./onnx-import.md): Learn how to import ONNX models using the [burn-onnx](https://github.com/tracel-ai/burn-onnx) crate. - [Models & Pre-Trained Weights](./models-and-pretrained-weights.md): Get started quickly with ready-to-use models and pre-trained weights. - [Advanced](./advanced): Finally, venture into advanced topics, exploring Burn's capabilities at their peak. This section caters to those who want to push the boundaries of what's possible with Burn. Throughout the book, we assume a basic understanding of deep learning concepts, but we may refer to additional material when it seems appropriate. ================================================ FILE: burn-book/src/performance/README.md ================================================ # Performance This section covers the key concepts you need to understand to get the most out of Burn and your hardware. ================================================ FILE: burn-book/src/performance/distributed-computing.md ================================================ # Distributed Computing Distributed computing support was introduced in Burn 0.19. Documentation and examples will be available soon. ================================================ FILE: burn-book/src/performance/good-practices/README.md ================================================ # Performance - Best Practices This section provides valuable insights into the performance characteristics of Burn and guides users on how to effectively leverage them for optimal results. It includes several sections, each offering relevant details. While understanding these concepts can aid in model optimization, it’s always crucial to conduct benchmarks and profile models to accurately assess performance improvements. - [Asynchronous Execution](./asynchronous-execution.md) - [Kernel Fusion](./kernel-fusion.md) - [Kernel Selection](./kernel-selection.md) ================================================ FILE: burn-book/src/performance/good-practices/asynchronous-execution.md ================================================ # Asynchronous Execution Most Burn backends execute tensor operations in an asynchronous manner. However, the async notation is often not required for most tensor operations, privileging the simplicity of sync Rust. There are only a few operations that trigger synchronization of the backend, and it is very important to correctly handle those to optimize hardware utilization. Those operations are `into_data`, `into_scalar`, and `sync`. Some tensor operations might call `into_data` underneath, triggering a synchronization, like `to_device` for some backends. There are several ways to minimize synchronization overhead, one of which is to batch sync operations into a single transaction. Burn provides a high-level composable API to build transactions, which will only trigger a single sync on the device. For instance, it is often used when collecting metrics during training: ```rust /// All of these variables are tensors. let (output, loss, targets) = ..; /// Now output, loss, and targets will be `TensorData` stored on the CPU. let [output, loss, targets] = Transaction::default() .register(output) .register(loss) .register(targets) .execute() .try_into() .expect("Correct amount of tensor data"); ``` Another way of optimizing reads and avoiding device stalls is to read the data on a different thread. Under the hood, CubeCL-based backends assign different execution queues for different threads, meaning that syncing a thread shouldn’t impact the throughput of another thread. ## Using Different Backends for Different Tasks Tensor operations aren’t the only things that are asynchronous; dataset and dataloading are also lazily executed. This allows for efficient data augmentation and sampling without having to cache huge datasets on disk. However, this might reduce training throughput if data augmentation is performed on the same device as the training itself. So, it is normally encouraged to use a different device, maybe even a different backend, for that purpose. For optimal performance, also avoid small allocations followed by a batching procedure. Even if it doesn’t break asynchronicity, it can slow down performance. ```rust /// Items is a vector of many tensors. let items = ..; let batch = Tensor::cat(items, 1); ``` Prefer doing the concatenation of tensors on the data augmentation device and not on the training device. ```rust /// Items is a vector of many tensors. let items = ..; let device_training = ..; let axis_batch = 0; let items = Tensor::cat(items, axis_batch); let batch = Tensor::from_data(items.into_data(), device_training); ``` ================================================ FILE: burn-book/src/performance/good-practices/kernel-fusion.md ================================================ # Kernel Fusion An interesting property of async execution is that it allows performance optimizations like kernel fusion. Coupled with CubeCL and its Just-In-Time compiler, Burn can serialize tensor operations into a symbolic graph, then optimize it for improved efficiency. Kernel fusion may reorder operations to reduce global memory reads, writes, and allocations. Being aware of which operations can be fused is relevant, as it can be easy to break an execution graph. The easiest way to optimize for fusion is to avoid keeping tensors alive for too long. When fusion isn’t possible, all tensors that will be used later will trigger a global memory write. Fortunately, Rust and Clippy are quite good at detecting unnecessary clones, but special care should still be taken. View operations can also interfere with fusion. They can be included in optimized graphs, but only to a limited extent, and they reduce vectorization potential as we have fewer guarantees about memory access patterns with transformed indices. So, it is good practice to group view operations together before executing blocks of operations. ```rust let tensor4 = tensor1.unsqueeze().matmul(tensor2) + tensor3.unsqueeze(); ``` Could be improved with the following: ```rust let tensor1 = tensor1.unsqueeze(); let tensor3 = tensor3.unsqueeze(); let tensor4 = tensor1.matmul(tensor2) + tensor3; ``` This reduces the necessary reordering and may reduce a global memory write or improve vectorization. We might be able to detect these patterns in the future, but for now, it’s a good idea to order your operations using this pattern. As a reminder, view operations typically only update tensor metadata in most cases. These operations include `slice`, `slice_assign`, `select`, `gather`, `scatter`, `reshape`, `swap_dims`, `transpose`, `unsqueeze`, etc. With fusion enabled, it is often not necessary to write custom kernels, as you can rely on our system to optimize most element-wise operations. However, most compute-bound kernels require many tricks and deep knowledge of GPU memory architectures, where automatic compiler optimizations often underperform compared to human-designed algorithms. This is why Burn’s approach to fusion is centered around fuse-on-read and fuse-on-write. This means that complex compute-bound kernels that change the shapes of tensors can fuse a block of element-wise operations when reading the input tensor and when writing the output tensor. The implication is that multiple compute-bound operations in a sequence can reduce fusion potential. ```rust // This line might trigger 3 writes: tensor1, tensor2, and tensor3, if tensor1 and tensor2 are abstract tensors. let tensor3 = tensor1.clone().sum_dim(tensor2.clone(), 2); let tensor4 = tensor2.sum_dim(tensor3, 2); let tensor5 = tensor4 + (tensor1 * tensor2); ``` ```rust let tmp = tensor1.clone() + tensor2.clone(); let tensor3 = tensor1.sum_dim(tensor2, 2); let tensor4 = tensor2.sum_dim(tensor3, 2); let tensor5 = tensor4 + tmp; ``` The lesson? Whenever possible, pass only the latest value to a compute operation. Don’t clone a tensor before compute-bound operations, as it might trigger an additional write if that tensor isn’t materialized from initial fusion. It’s a bit complex, but the first code snippet is actually better if `tensor1` and `tensor2` are concrete in global memory. This would be the case if `tensor1` and `tensor2` are model parameters, so prefer this implementation style in such scenarios. The second code snippet is preferred when `tensor1` and `tensor2` are virtual tensors, meaning they were fused by earlier operations and require a global memory read to be accessed later. This happens if those tensors are part of a signal in neural networks. Reordering operations can help in such scenarios but will not create temporary values, making the previous optimization harder. We might eventually automatically optimize these cases, but the solution space is quite large, and it’s not a planned optimization. Profiling model blocks is always a good idea to identify which code block is faster when faced with ambiguous situations. ================================================ FILE: burn-book/src/performance/good-practices/kernel-selection.md ================================================ # Kernel Selection As mentioned earlier, complex compute-bound operations are highly non-trivial and require many tricks for optimal performance. However, the way these tricks are applied varies depending on the hardware and problem shapes. To select the best kernel, we use a search method with a highly configurable autotune system that performs micro-benchmarks at runtime on the current hardware. This may trigger a cold start, but the results of these benchmarks are cached on disk for subsequent executions. For deployment or training on spot instances, it’s a good idea to bundle the autotune cache with the code to mitigate cold starts. Refer to the [CubeCL configuration documentation](https://burn.dev/books/cubecl/advanced-usage/config.html) for more details on fine-grained settings . From the user’s point of view, kernel selection shouldn’t be a problem, but as usual, crafting models with even shapes, multiples of 8, can significantly improve performance. Avoid creating tensors with shapes that are multiples of 10, like `[1000, 1000]`, as these typically require bounds checking and may limit vectorization. Prefer shapes like `[1024, 1024]`, where dimensions are multiples of 32 or powers of 2, as these are generally optimal. If you have no choice but to use a suboptimal shape, prefer handling it in a single kernel, transforming it into an optimal shape. It’s better to have a slow neural network layer followed by fast ones than to propagate unevenness and end up with smaller, but slower, layers. ================================================ FILE: burn-book/src/performance/quantization.md ================================================ # Quantization Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep learning model categorized as: - Post-training quantization (PTQ) - Quantization aware training (QAT) In post-training quantization, the model is trained in floating point precision and later converted to the lower precision data type. There are two types of post-training quantization: 1. Static quantization: quantizes the weights and activations of the model. Quantizing the activations statically requires data to be calibrated (i.e., recording the activation values to compute the optimal quantization parameters with representative data). 1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the activations are dynamically at runtime. Sometimes post-training quantization is not able to achieve acceptable task accuracy. In general, this is where quantization-aware training (QAT) can be used: during training, fake-quantization modules are inserted in the forward and backward passes to simulate quantization effects, allowing the model to learn representations that are more robust to reduced precision. Burn does not currently support QAT. Only post-training quantization (PTQ) is implemented at this time.
Quantization support in Burn is currently in active development. It supports the following PTQ modes on some backends: - Per-tensor and per-block quantization to 8-bit, 4-bit and 2-bit representations No integer operations are currently supported, which means tensors are dequantized to perform the operations in floating point precision.
## Module Quantization Quantizing the weights of your model after training is quite simple. We have access to the weight tensors and can collect their statistics, such as the min and max value when using `MinMaxCalibration`, to compute the quantization parameters. ```rust , ignore # use burn::module::Quantizer; # use burn::tensor::quantization::{Calibration, QuantLevel, QuantParam, QuantScheme, QuantValue}; # // Quantization config let scheme = QuantScheme::default() .with_level(QuantLevel::Block(32)) .with_value(QuantValue::Q4F) .with_param(QuantParam::F16); let mut quantizer = Quantizer { calibration: Calibration::MinMax, scheme, }; // Quantize the weights let model = model.quantize_weights(&mut quantizer); ``` ### Calibration Calibration is the step during quantization where the range of all floating-point tensors is computed. This is pretty straightforward for weights since the actual range is known at _quantization-time_ (weights are static), but activations require more attention. To compute the quantization parameters, Burn supports the following `Calibration` methods. | Method | Description | | :------- | :------------------------------------------------------------------------------- | | `MinMax` | Computes the quantization range mapping based on the running min and max values. | ### Quantization Scheme A quantization scheme defines how an input is quantized, including the representation of quantized values, storage format, granularity, and how the values are scaled. ```rust let scheme = QuantScheme::default() .with_mode(QuantMode::Symmetric) // Quantization mode .with_level(QuantLevel::block([2, 16])) // Granularity (per-tensor or per-block) .with_value(QuantValue::Q8S) // Data type of quantized values, independent of how they're stored .with_store(QuantStore::Native) // Storage format for quantized values .with_param(QuantParam::F16); // Precision for quantization parameters ``` #### Quantization Mode | Mode | Description | | :---------- | :------------------------------------------- | | `Symmetric` | Values are scaled symmetrically around zero. | #### Quantization Level | Level | Description | | :----------------------------- | :----------------------------------------------------------------------------------------------------------- | | `Tensor` | A single quantization parameter set for the entire tensor. | | `Block(block_size: BlockSize)` | Tensor divided into blocks (1D, 2D, or higher) defined by block_size, each with its own quantization params. | #### Quantization Value | Value | Bits | Description | | :----- | :--: | :-------------------------------------------- | | `Q8F` | 8 | 8-bit full-range quantization | | `Q4F` | 4 | 4-bit full-range quantization | | `Q2F` | 2 | 2-bit full-range quantization | | `Q8S` | 8 | 8-bit symmetric quantization | | `Q4S` | 4 | 4-bit symmetric quantization | | `Q2S` | 2 | 2-bit symmetric quantization | | `E5M2` | 8 | 8-bit floating-point (5 exponent, 2 mantissa) | | `E4M3` | 8 | 8-bit floating-point (4 exponent, 3 mantissa) | | `E2M1` | 4 | 4-bit floating-point (2 exponent, 1 mantissa) | #### Quantization Store | Store | Description | | :------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------ | | `Native` | Each quantized value is stored directly in a native format, which doesn't require packing and unpacking. | | `PackedNative(dim)` | Multiple quantized values packed into a 32-bit integer. Argument is the dimension the tensor is packed on, starting from the innermost dimension. | | `PackedU32(dim)` | Multiple quantized values packed into a 32-bit integer. Argument is the dimension the tensor is packed on, starting from the innermost dimension. | Native storage is not supported for sub-byte quantization values. #### Quantization Parameters Precision | Param | Description | | :----- | :----------------------------- | | `F32` | Full floating-point precision. | | `F16` | Half-precision floating point. | | `BF16` | Brain float 16-bit precision. | ================================================ FILE: burn-book/src/saving-and-loading.md ================================================ # Saving and Loading Models Saving your trained machine learning model is quite easy, no matter the output format you choose. As mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the [MessagePack](https://msgpack.org/) binary serialization format with the help of [rmp_serde](https://docs.rs/rmp-serde/). ```rust, ignore // Save model in MessagePack format with full precision let recorder = NamedMpkFileRecorder::::new(); model .save_file(model_path, &recorder) .expect("Should be able to save the model"); ``` Note that the file extension is automatically handled by the recorder depending on the one you choose. Therefore, only the file path and base name should be provided. Now that you have a trained model saved to your disk, you can easily load it in a similar fashion. ```rust, ignore // Load model in full precision from MessagePack file let recorder = NamedMpkFileRecorder::::new(); model = model .load_file(model_path, &recorder, device) .expect("Should be able to load the model weights from the provided file"); ``` **Note:** models can be saved in different output formats, just make sure you are using the correct recorder type when loading the saved model. Type conversion between different precision settings is automatically handled, but formats are not interchangeable. A model can be loaded from one format and saved to another format, just as long as you load it back with the new recorder type afterwards. ## Initialization from Recorded Weights The most straightforward way to load weights for a module is simply by using the generated method [load_record](https://burn.dev/docs/burn/module/trait.Module.html#tymethod.load_record). Note that parameter initialization is lazy, therefore no actual tensor allocation and GPU/CPU kernels are executed before the module is used. This means that you can use `init(device)` followed by `load_record(record)` without any meaningful performance cost. ```rust, ignore // Create a dummy initialized model to save let device = Default::default(); let model = Model::::init(&device); // Save model in MessagePack format with full precision let recorder = NamedMpkFileRecorder::::new(); model .save_file(model_path, &recorder) .expect("Should be able to save the model"); ``` Afterwards, the model can just as easily be loaded from the record saved on disk. ```rust, ignore // Load model record on the backend's default device let record: ModelRecord = NamedMpkFileRecorder::::new() .load(model_path.into(), &device) .expect("Could not load model weights"); // Initialize a new model with the loaded record/weights let model = Model::init(&device).load_record(record); ``` ## Model Weight Store While the Recorder API works well for basic saving and loading, `burn-store` was introduced to address its limitations around memory efficiency and flexibility. It provides zero-copy memory-mapped loading, cross-framework interoperability (PyTorch and SafeTensors), key remapping, partial loading, and filtering. The `burn-store` crate is intended to eventually replace the Recorder API, but since it was recently released, both APIs are supported. ### Supported Formats | Format | Extension | Description | | --------------- | -------------- | ----------------------------------------------------------------------------------------- | | **Burnpack** | `.bpk` | Burn's native format with fast loading, zero-copy support, and training state persistence | | **SafeTensors** | `.safetensors` | Industry-standard format from Hugging Face for secure tensor serialization | | **PyTorch** | `.pt`, `.pth` | Direct loading of PyTorch model weights (read-only) | ### Saving a Model ```rust, ignore use burn_store::{ModuleSnapshot, BurnpackStore}; // Save to Burnpack (recommended) let mut store = BurnpackStore::from_file("model.bpk"); model.save_into(&mut store)?; // Or save to SafeTensors use burn_store::SafetensorsStore; let mut store = SafetensorsStore::from_file("model.safetensors"); model.save_into(&mut store)?; ``` ### Loading a Model ```rust, ignore use burn_store::{ModuleSnapshot, BurnpackStore}; let device = Default::default(); let mut model = MyModel::init(&device); // Load from Burnpack let mut store = BurnpackStore::from_file("model.bpk"); model.load_from(&mut store)?; ``` ### Loading from PyTorch You can load weights directly from PyTorch `.pt` files: ```rust, ignore use burn_store::{ModuleSnapshot, PytorchStore}; let mut model = MyModel::init(&device); let mut store = PytorchStore::from_file("pytorch_model.pt"); model.load_from(&mut store)?; ``` #### Exporting from PyTorch Save only the model weights (state_dict), not the entire model: ```python import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(2, 2, (2, 2)) self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False) def forward(self, x): return self.conv2(self.conv1(x)) model = Net() torch.save(model.state_dict(), "model.pt") # Correct: save state_dict # torch.save(model, "model.pt") # Wrong: saves entire model ``` #### Accessing Nested State Dicts Some PyTorch checkpoints nest the state_dict under a key: ```rust, ignore let mut store = PytorchStore::from_file("checkpoint.pt") .with_top_level_key("state_dict"); model.load_from(&mut store)?; ``` ### Loading from SafeTensors For SafeTensors files exported from PyTorch, use the adapter for proper weight transformation: ```rust, ignore use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; let mut model = MyModel::init(&device); let mut store = SafetensorsStore::from_file("model.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store)?; ``` For SafeTensors files created by Burn, no adapter is needed: ```rust, ignore let mut store = SafetensorsStore::from_file("model.safetensors"); model.load_from(&mut store)?; ``` #### Exporting from PyTorch to SafeTensors ```python from safetensors.torch import save_file model = Net() save_file(model.state_dict(), "model.safetensors") ``` ### Saving for PyTorch Compatibility Use the adapter when saving for PyTorch consumption: ```rust, ignore use burn_store::{BurnToPyTorchAdapter, SafetensorsStore}; let mut store = SafetensorsStore::from_file("for_pytorch.safetensors") .with_to_adapter(BurnToPyTorchAdapter) .skip_enum_variants(true); model.save_into(&mut store)?; ``` ### Handling Load Results The `load_from` method returns detailed information about the loading process: ```rust, ignore let result = model.load_from(&mut store)?; // Print a formatted summary with suggestions println!("{}", result); // Or inspect individual fields println!("Applied: {} tensors", result.applied.len()); println!("Missing: {:?}", result.missing); println!("Errors: {:?}", result.errors); if result.is_success() { println!("All tensors loaded successfully"); } ``` ### Adding Metadata Burnpack and SafeTensors support custom metadata: ```rust, ignore let mut store = BurnpackStore::from_file("model.bpk") .metadata("version", "1.0") .metadata("description", "My trained model") .metadata("epochs", "100"); model.save_into(&mut store)?; ``` ### Advanced Features #### Key Remapping Remap parameter names using regex patterns when model structures don't match: ```rust, ignore let mut store = PytorchStore::from_file("model.pt") // Remove prefix: "model.conv1.weight" -> "conv1.weight" .with_key_remapping(r"^model\.", "") // Rename: "layer1" -> "encoder.layer1" .with_key_remapping(r"^layer", "encoder.layer"); model.load_from(&mut store)?; ``` For complex remapping: ```rust, ignore use burn_store::KeyRemapper; let remapper = KeyRemapper::new() .add_pattern(r"^transformer\.h\.(\d+)\.", "transformer.layer$1.")? .add_pattern(r"\.attn\.", ".attention.")?; let mut store = SafetensorsStore::from_file("model.safetensors") .remap(remapper); ``` #### Partial Loading Load weights even when some tensors are missing: ```rust, ignore let mut store = PytorchStore::from_file("pretrained.pt") .allow_partial(true); let result = model.load_from(&mut store)?; println!("Missing (initialized randomly): {:?}", result.missing); ``` #### Filtering Tensors Load or save only specific layers: ```rust, ignore // Load only encoder layers let mut store = SafetensorsStore::from_file("model.safetensors") .with_regex(r"^encoder\..*") .allow_partial(true); // Save only encoder layers let mut store = SafetensorsStore::from_file("encoder.safetensors") .with_regex(r"^encoder\..*"); model.save_into(&mut store)?; // Multiple patterns (OR logic) let mut store = SafetensorsStore::from_file("model.safetensors") .with_regex(r"^encoder\..*") // encoder tensors .with_regex(r".*\.bias$") // OR any bias tensors .with_full_path("decoder.scale"); // OR specific tensor ``` #### Non-Contiguous Layer Indices PyTorch `nn.Sequential` with mixed layers creates non-contiguous indices. `PytorchStore` automatically remaps these: ``` PyTorch: fc.0.weight, fc.2.weight, fc.4.weight (gaps from ReLU layers) Burn: fc.0.weight, fc.1.weight, fc.2.weight (contiguous) ``` This is enabled by default. Disable if needed: ```rust, ignore let mut store = PytorchStore::from_file("model.pt") .map_indices_contiguous(false); ``` #### Zero-Copy Loading For embedded models or large files, use zero-copy loading to avoid memory copies: ```rust, ignore // Embedded model (compile-time) static MODEL_DATA: &[u8] = include_bytes!("model.bpk"); let mut store = BurnpackStore::from_static(MODEL_DATA); model.load_from(&mut store)?; // Large file (memory-mapped) let mut store = BurnpackStore::from_file("large_model.bpk") .zero_copy(true); model.load_from(&mut store)?; ``` #### Half-Precision Storage Save models at half precision (F16) to reduce file size by ~50%, then load back at full precision: ```rust, ignore use burn_store::{ModuleSnapshot, BurnpackStore, HalfPrecisionAdapter}; let adapter = HalfPrecisionAdapter::new(); // Save: F32 -> F16 (same adapter for both directions) let mut store = BurnpackStore::from_file("model_f16.bpk") .with_to_adapter(adapter.clone()); model.save_into(&mut store)?; // Load: F16 -> F32 let mut store = BurnpackStore::from_file("model_f16.bpk") .with_from_adapter(adapter); model.load_from(&mut store)?; ``` By default, weights in Linear, Embedding, Conv\*, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, and PRelu modules are converted. BatchNorm is excluded because its running variance can underflow in F16. Customize with `with_module()` and `without_module()`: ```rust, ignore // Keep LayerNorm at full precision let adapter = HalfPrecisionAdapter::new() .without_module("LayerNorm"); // Add a custom module to the conversion set let adapter = HalfPrecisionAdapter::new() .with_module("CustomLayer"); ``` #### Direct Tensor Access Inspect tensors without loading into a model: ```rust, ignore use burn_store::ModuleStore; let mut store = PytorchStore::from_file("model.pt"); // List all tensor names let names = store.keys()?; // Get specific tensor if let Some(snapshot) = store.get_snapshot("encoder.layer0.weight")? { println!("Shape: {:?}, DType: {:?}", snapshot.shape, snapshot.dtype); } ``` #### Model Surgery Transfer weights between models: ```rust, ignore use burn_store::{ModuleSnapshot, PathFilter}; // Transfer all weights let snapshots = model1.collect(None, None, false); model2.apply(snapshots, None, None, false); // Transfer only encoder weights let filter = PathFilter::new().with_regex(r"^encoder\..*"); let snapshots = model1.collect(Some(filter.clone()), None, false); model2.apply(snapshots, Some(filter), None, false); ``` ### API Reference #### Builder Methods | Category | Method | Description | | ------------- | ------------------------------ | ---------------------------- | | **Filtering** | `with_regex(pattern)` | Filter by regex pattern | | | `with_full_path(path)` | Include specific tensor | | | `with_predicate(fn)` | Custom filter logic | | **Remapping** | `with_key_remapping(from, to)` | Regex-based renaming | | | `remap(KeyRemapper)` | Complex remapping rules | | **Adapters** | `with_from_adapter(adapter)` | Loading transformations | | | `with_to_adapter(adapter)` | Saving transformations | | | `HalfPrecisionAdapter::new()` | F32/F16 mixed-precision | | **Config** | `allow_partial(bool)` | Continue on missing tensors | | | `with_top_level_key(key)` | Access nested dict (PyTorch) | | | `skip_enum_variants(bool)` | Skip enum variants in paths | | | `map_indices_contiguous(bool)` | Remap non-contiguous indices | | | `metadata(key, value)` | Add custom metadata | | | `zero_copy(bool)` | Enable zero-copy loading | #### Direct Access Methods | Method | Description | | --------------------- | -------------------------------- | | `keys()` | Get ordered list of tensor names | | `get_all_snapshots()` | Get all tensors as BTreeMap | | `get_snapshot(name)` | Get specific tensor by name | ### Troubleshooting #### Common Issues 1. **"Missing source values" error**: You saved the entire PyTorch model instead of the state_dict. Re-export with `torch.save(model.state_dict(), "model.pt")`. 2. **Shape mismatch**: Your Burn model doesn't match the source architecture. Verify layer configurations (channels, kernel sizes, bias settings). 3. **Key not found**: Parameter names don't match. Use `with_key_remapping()` or inspect keys: ```rust, ignore let store = PytorchStore::from_file("model.pt"); println!("Available keys: {:?}", store.keys()?); ``` #### Inspecting Files Use [Netron](https://github.com/lutzroeder/netron) to visualize `.pt` and `.safetensors` files. For Burnpack files: ```bash cargo run --example burnpack_inspect model.bpk ``` ================================================ FILE: codecov.yml ================================================ coverage: status: project: default: # https://docs.codecov.com/docs/commit-status#informational informational: true target: 80% patch: default: informational: true target: 80% github_checks: annotations: false ================================================ FILE: contributor-book/.gitignore ================================================ target # MacOS temp file .DS_Store book-test guide/book .vscode tests/burn-book/book/ book/ # Ignore Jetbrains specific files. .idea/ # Ignore Vim temporary and swap files. *.sw? *~ ================================================ FILE: contributor-book/.prettierrc.json ================================================ { "printWidth": 100, "proseWrap": "always" } ================================================ FILE: contributor-book/book.toml ================================================ [book] authors = [ "Wouter Doppenberg", "Nathaniel Simard", "Louis Fortier-Dubois", "Dilshod Tadjibaev", "Guillaume Lagrange", "Joshua Ferguson", "The Burn Community", ] language = "en" src = "src" title = "The Burn Contributor Book 🔥" [output.html] mathjax-support = true ================================================ FILE: contributor-book/src/SUMMARY.md ================================================ - [Overview](./overview.md) - [How to Read This Book](./how-to-read-this-book.md) - [Getting Started](./getting-started/README.md) - [Setting Up The Environment](./getting-started/setting-up-the-environment.md) - [Configuring Your Editor (Optional)](./getting-started/configuring-your-editor.md) - [Testing](./getting-started/testing.md) - [Architecture Overview](./project-architecture/README.md) - [Modules](./project-architecture/module.md) - [Serialization](./project-architecture/serialization.md) - [Tensor](./project-architecture/tensor.md) - [Backend](./project-architecture/backend.md) - [Guides for Contributors](./guides/README.md) - [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md) - [Submitting Examples to Burn](./guides/submitting-examples.md) - [Frequently Encountered Issues](./frequently-encountered-issues/README.md) - [Issues Related To Adding Operators](./frequently-encountered-issues/issues-while-adding-ops.md) ================================================ FILE: contributor-book/src/frequently-encountered-issues/README.md ================================================ # Frequently Encountered Issues This is a collection of issues people have encountered and asked about on the [Discord server](https://discord.gg/uPEBbYYDB6). This section is separated from the guides since it can involve lots of details that are only relevant to a small subset of contributors. ================================================ FILE: contributor-book/src/frequently-encountered-issues/issues-while-adding-ops.md ================================================ # Issues encountered while adding ops Below are some of the issues that were encountered while adding ops to the project. If you encounter an issue while adding an op that isn't listed here, and it's not obvious how to fix it, you can add it to this list or reach out on the [Discord server](https://discord.gg/uPEBbYYDB6) if you need help. ## Off by .000001 errors ```sh ---- fusion::base::tests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'fusion::base::tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/fusion/base.rs:185:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } } ---- tests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/lib.rs:49:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } } ``` If you encounter this, swap out the `assert_eq!` in the failing test for `tensor1.to_data().assert_approx_eq` with `3` as the second argument. The second arguments specifies the level of precision: `3` is equivalent to a less than 10-3 (0.001) difference between the elements of the two tensors. ================================================ FILE: contributor-book/src/getting-started/README.md ================================================ # Getting Started This section is for setting up the environment and how to do basic development tasks such as running tests and checking your code before committing. If you need help with the process or run into issues, feel free to ask on the [Discord server](https://discord.gg/uPEBbYYDB6) in the Development channels. ================================================ FILE: contributor-book/src/getting-started/configuring-your-editor.md ================================================ # Configuring your editor These steps are not required, and most of this isn't specific to Burn, but it's definitely helpful if you haven't already done it. ## VSCode Install the following extensions: - [rust-lang.rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer) for Rust syntax and semantic analysis - [tamasfe.even-better-toml](https://marketplace.visualstudio.com/items?itemName=tamasfe.even-better-toml) for TOML syntax and semantic analysis - [fill-labs.dependi](https://marketplace.visualstudio.com/items?itemName=fill-labs.dependi) for managing dependencies - [vadimcn.vscode-lldb](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) for debugging ### Setting up the Debugger To use the debugger, follow these steps: 1. Open `Command Palette` with `Ctrl+Shift+P` or `F1` and type `LLDB: Generate Launch Configurations from Cargo.toml` then select it, this will generate a file that should be saved as `.vscode/launch.json`. 2. Select the configuration from the "run and debug" side panel, then select the target from the list. Since this repo has `debug = 0` in the root `Cargo.toml` to speed up compilation, you need replace it with `debug = true` in the root `Cargo.toml` when using a debugger and breakpoints with `launch.json` settings. 3. Now you can enable breakpoints on code through IDE then start debugging the library/binary you want, like in the following example: ![debug-options](debug-options-vscode.png) If you're creating a new library or binary, keep in mind to repeat step 1 to always keep a fresh list of targets. ## Have another editor? Open a PR! ================================================ FILE: contributor-book/src/getting-started/setting-up-the-environment.md ================================================ # Setting up the environment Depending on what part of the project you plan on contributing to, there are a couple of tools to install and commands to be familiar with. This section should be up to date with current project practices (as of 2024-04-15). ## General There are a few commands you will want to run prior to any commit for a non-draft PR: 1. `cargo fmt --all` will run `rustfmt` on all files in the project. 2. `cargo clippy --fix` will run [Clippy](https://github.com/rust-lang/rust-clippy) and fix any coding issues it can. Clippy necessitates to be in a clean Git state, but this can be circumvented by adding the `--allow-dirty` flag. 3. `cargo run-checks` is a command used to test the project. It is required to run successfully prior to merging a PR. Fair warning, running these tests can take a while[^linux_mem_note]. > Want more detailed macro error diagnostics? This is especially useful for debugging tensor-related tests: > > ```bash > RUSTC_BOOTSTRAP=1 RUSTFLAGS="-Zmacro-backtrace" cargo run-checks > ``` ## Updating the burn semver version If for some reason you need to bump for the next version (though that should probably be left to the maintainers), edit the semantic version number in `burn/Cargo.toml`, and then run `cargo update` to update the lock file. ## Contributing to either the Burn Book or Contributor Book Both the Burn Book and the Contributor Book are built with mdbook. To open the book locally, run `mdbook serve ` or `cargo xtask books {burn|contributor} open` which will install and use mdbook automatically. Alternatively, if you want to install mdbook directly, run the following command[^update_note]: ```bash cargo install mdbook ``` Also instead of running `cargo run-checks`, you can run `cargo xtask check typos` to only check for misspellings. This will install [typo](https://crates.io/crates/typos-cli), and if any are encountered you should be able to run `typo -w /path/to/book` to fix them. [^linux_mem_note]: If your system is running into issues with memory and you are on linux, you may want to switch to a [virtual console](https://wiki.archlinux.org/title/Linux_console#Virtual_consoles) to run the tests. To do this, press `ctrl+alt+f3` to switch to a virtual console (and log in), and either `ctrl+alt+f1` or `ctrl+alt+f2` to switch back to your graphical session. [^update_note]: You might also want to install [cargo-update](https://github.com/nabijaczleweli/cargo-update) to easily keep your tools up to date, though it is in no way required. ================================================ FILE: contributor-book/src/getting-started/testing.md ================================================ # Testing ## Test for Tensor Operations Test for tensor operations (generally of the form: given this input, expect it match or approximate this output) are defined only in [`crates/burn-tensor/src/test/ops`](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/ops) and not in the backends (with the exception of `burn-autodiff`). The tensor operation tests are added to the `testgen_all` macro rule in [`crates/burn-tensor/src/tests/mod.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/mod.rs). This is then propagated to the existing backends without any additional work. ### Test for Autodiff Tests for autodiff go under [burn-autodiff/src/tests](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-autodiff/src/tests) and should verify backward pass correctness. For binary tensor operations, both the left and right sides need to be verified. Here's an easy way to define tests for a new operation's backward pass: 1. Use small tensors with simple values. 2. Pop open a terminal, launch `ipython` and import `numpy` then do the calculations by hand. You can also use [Google Colab](https://colab.google/) so you don't have to install the packages on your system. 3. Compare the actual outputs to the expected output for left-hand side, right-hand side. For float tensors, it is advised to use `actual_output_tensor.into_data().assert_approx_eq::>(&expected_tensor_data, Tolerance::default())` instead of `assert_eq!(...` due to occasional hiccups with floating point calculations. Other assertions should also always use `FloatElem`, and use `.elem()` to convert any literals. Backends are tested for multiple precisions, and hardcoding to a fixed type causes tests to fail with alternate floating point precisions. For convenience, it might be worth aliasing the type like `type FT = FloatElem;`. For integers, tests should use `IntElem`, and exit the test if the test values are unrepresentable (above `max_value`, below `min_value`). A minimum range of `[0..127]` (`i8`) can be assumed. ================================================ FILE: contributor-book/src/guides/README.md ================================================ # Guides for Contributors The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn. ================================================ FILE: contributor-book/src/guides/adding-a-new-operation-to-burn.md ================================================ # Adding a New Operation to burn Let's discuss how one might go about adding new operators to Burn, using the example of the pow operator added in [this PR](https://github.com/tracel-ai/burn/pull/1133/files). ## Adding the Op to burn-tensor `burn-tensor` is the crate that defines all tensor operations that need to be implemented by the various backends. The core of this lies in [crates/burn-backend/src/tensor/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/numeric.rs#L17), which is home to the numeric trait. The numeric trait is the home of all tensor operations that are numeric in nature and that are shared by `Int` and `Float` Tensor types. The numeric trait is implemented in [crates/burn-backend/src/tensor/ops/int.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/int.rs) for the int type and in [crates/burn-backend/src/tensor/ops/float.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/float.rs) for the float type. More information on the relationship between Tensor modules can be found under the section for [Tensor Architecture](../project-architecture/tensor.md#tensor-operations). Here is where pow was added to `crates/burn-tensor/src/tensor/api/numeric.rs`: 1. for the [`Tensor` struct](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L573) 2. for the [numeric trait](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L1955) 3. for the implementation of numeric for [float](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2722) and [int](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2375) Tensor is a struct that has a single member: `primitive` (defined [here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/base.rs#L27)), that is defined by its [`Kind`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/kind.rs#L16): one of `Bool`, `Float`, or `Int` (those linked in 3). These call the ops for that data type defined in the [`Backend`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/base.rs#L64) supertrait[^supertrait]. This is the trait that is then implemented by the different `burn-` backends (such as `burn-ndarray` and `burn-wgpu`) which must implement the functions if no default is provided. In this case, we don't need to worry about `Bool` Tensors. `Float` ops are implemented under [crates/burn-backend/src/backend/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/tensor.rs), and `Int` ops under [crates/burn-backend/src/backend/ops/int_tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/int_tensor.rs). The current convention is ops of each type, if not unique to that type, are prefixed with the type. So `powf` and sundry would be defined as `int_powf` for `IntTensorOps` and `float_powf` for `FloatTensorOps`. If an op is unique to a type, then it should be implemented under `burn-tensor/src/api/{type}.rs`. For example, here is an implementation for [`sin` under `crates/burn-tensor/src/api/float.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/float.rs#L82) which obviously doesn't make sense for `Int` or `Bool` tensors. The `Int` Tensor function uses the ones defined for Float with 2 extra casts (LHS to a `Float` tensor, Output to an `Int`). Given that the rest of the code will only look at the float implementations. With the addition of quantized float tensors, the `Float` tensor primitive is represented by the [`TensorPrimitive`](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/api/kind.rs#L69) enum. This allows us to handle both float and quantized float operations in the `Tensor` implementation, correctly dispatching to the corresponding op (float or quantized) based on the variant. Following the same convention, the equivalent [quantized tensor ops](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/ops/qtensor.rs#L45) are prefixed with `q_*` (e.g., `q_reshape` instead of `float_reshape`). Most ops have a default implementation that simply dequantizes the input into its floating-point representation, performs the operation on the float tensor, and quantizes the output. Backends can overwrite specific implementations when required/desired. ### Adding Tests Additional tests should be added to `burn-backend-tests` under [`crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/{op_name}.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend-tests/tests/tensor/float/ops/powf.rs), and the module name should be inserted into `crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/mod.rs`. If it makes sense for a floating point operation to support quantization, the [`QTensorOps`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/qtensor.rs#L117) counterpart is usually added at the same time with a default implementation (as mentioned in the previous section). Tests for `q_*` ops follow a similar procedure: the test is added under [`crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/{op_name}.rs`](https://github.com/tracel-ai/burn/tree/9f31281/crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended), the module name is inserted into [`crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs). If you take a look at any of the existing tests for an operation on a quantized tensor, you will see that the inputs and expected outputs are always defined with floating point values. While it assumes that the quantization and dequantization are correct, it makes the tests much more readable and easier to understand w.r.t. what is being tested. Effectively, the tests are there to ensure that a tensor operation is invariant to quantization (up to some quantization error, of course). _Note: the tests try to use tensors with floating point values which can be de/quantized without introducing too much quantization error, but the result always depends on the operation (e.g., tensor product of values can grow larger and significantly increase the output tensor range, leading to more de/quantization error on the results)._ ## Adding the Op to burn-autodiff Since this is probably the hardest and the least straightforward, we'll cover this backend separately. `burn-autodiff` enables other backends to use autodifferentiation[^autodiff]. Ops for float types are implemented in [crates/burn-autodiff/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-autodiff/src/ops/tensor.rs) and need to: 1. Define a unit struct [^absolute_units] that implements a backward (pass) function 2. Within the backward function, as this is an elementwise binary operation it implements the binary function (from `backward.rs` under the same directory), the last 2 arguments are two closures that define the left and right partial derivatives. 3. Then define what happens when a specific operation is tracked or untracked, where untracked just calls the function in the normal way, and tracked sets the execution the backward function defined above. 4. When tracked, operations are part of the autodiff graph and must save the needed information to efficiently perform their backward pass later. If the information is light (such as a shape), it should be directly saved in the state. If the operation's inputs are needed to compute the backward pass, it should be checkpointed rather than saved. This will allow the input to be provided lazily at the backward pass depending on the checkpointing strategy. 5. An operation must also be identified as _compute-bound_ (`.computeBound()`) or _memory-bound_ (`.memoryBound()`) for gradient checkpointing. _Compute-bound_ operation are heavy to compute (for instance matmul or convolution), which means that even with checkpointing they will save their output for the backward pass and not recompute it. _Memory-bound_ operations are more trivial (like `powf` which only performs one small operation per tensor entry), so it can be beneficial to recompute them during the backward pass instead of saving their whole forward output to memory. Operations registered as _memory-bound_ need to know their parents (`.parents()` method) and how to recompute their forward pass during the backward pass (with a struct that implements `RetroForward`), using their parents' outputs. The above steps are mostly boilerplate, so you can often just copy the contents of another similar op, change the name of the structs, and ensure that either both sides have the data they need (if they need to have a copy of the opposite sided tensor, clone its contents). ### Computing derivatives For those that need it, here is a quick refresher on the necessary calculus. If you are familiar with how to calculate partial derivatives, you can skip this section. Since `pow` is a binary operation, the left and right functions are the partial derivatives with respect to the left and right sided tensors. Let's define the operator as a function \\(f(x,y)=x^{y}\\) , where \\(x\\) is the left hand tensor and \\(y\\) is the right handed tensor. The two closures are defining the partial derivatives of \\(f\\) with respect to \\(x\\),\\(y\\). Treat the other variables as a constant $$\frac{\delta }{\delta x} (x^{y})= y \cdot x^{y-1}$$ is the left handed closure, and $$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$ is the right. If you aren't sure how to calculate these by hand, it is recommended to use [symbolab](), plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the variable \\(x\\)|\\(y\\) in the partial derivative to get the other side. ### Testing autodiff For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md). ## Adding the Op to other backends Most of these are fairly straightforward implementations. For reference here's pow's float implementation for torch and ndarray backends: 1. Torch implementation in [crates/burn-tch/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/tensor.rs#L467) and the Op used in [crates/burn-tch/src/ops/base.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/base.rs#L481) 2. NdArray in [crates/burn-ndarray/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-ndarray/src/ops/tensor.rs#L472) This is where any calculation happens currently. Playing a guessing game with method names and seeing what completions are suggested will take you far. If you are having trouble figuring out how to do it from the docs for that backend, [try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax). ## Adding the Op to fusion, JIT and cubecl backends Adding an operator to these backends can be fairly straightforward, though due to what these backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target backends as much as backends that enable certain functionality for other backends, in this case kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation, you'll just be describing how the generated code should look. Most of this can be copy/pasted/adjusted from other functions. Here's how powf was added to `burn-fusion`: 1. Added powf to the float ops under [crates/burn-fusion/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-fusion/src/ops/tensor.rs#L2061) 2. Added powf to the `NumericOperationIr` enum under [crates/burn-ir/src/operation.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-ir/src/operation.rs#L564) 3. Added powf to the implementations of `NumericOperationIr` enum under [crates/burn-ir/src/operation.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-ir/src/operation.rs#L1086) 4. Added powf to the implemented of `NumericOperationIr` enum under [burn/crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-fusion/src/stream/context.rs#L883) The way `cubecl` handles tensor-scalar operations is by transforming both into a sequence of vectorized scalar operations. Since powf already existed in `cubecl`, it was pretty easy to reuse the existing implementation for the situation where both sides of the operation were tensors. The `cubecl` crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual implementation is defined in `burn-cubecl`. Here is where code was added for powf in `burn-cubecl` and `cubecl`: 1. to the implementation of [`FloatTensorOps` under `burn/crates/burn-cubecl/src/ops/tensor.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl/src/ops/tensor.rs#L578) 2. the function being called was added to [`burn/crates/burn-cubecl/src/ops/numeric.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl/src/ops/numeric.rs#L211-L214) 3. the operator was defined in [`cubecl/crates/cubecl-ir/src/arithmetic.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-ir/src/arithmetic.rs#L41) 4. how the operation looks to the gpu was added to [`burn/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs#L97) 5. the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to [`cubecl/crates/cubecl-cpp/src/shared/base.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/base.rs#L1285), [`cubecl/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L869) and [`cubecl/crates/cubecl-spirv/src/arithmetic.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-spirv/src/arithmetic.rs#L491) 6. the instructions themselves were added for WGSL to [instruction op enum in `cubecl/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L124), and the actual [instruction in wgsl here](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L654), for CPP in the enum here [`cubecl/crates/cubecl-cpp/src/shared/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/instruction.rs#L187) and the actual instruction here [`cubecl/crates/cubecl-cpp/src/shared/binary.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/binary.rs#L216) We needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even power being positive. We reused as much as the existing logic as possible, and then branched at the last point based off the var type of the rhs. [See here](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L1229). For most operations, you shouldn't need to add to `cubecl-wgpu/src/compiler/wgsl/extension.rs` unless the operation isn't native to WGSL. For functions that need a complex kernel without a direct mapping to a base instruction, simply use the `cube` macro (see [the `cubecl` book](https://github.com/tracel-ai/cubecl/tree/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/cubecl-book)). And you're done! Congrats, you just fully added a new operation to burn, and we are all one step closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and it's freaking fast!". Buy yourself a coffee. [^supertrait]: for more on supertraits see [the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait) [^autodiff]: wiki link for [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) [^absolute_units]: for more information on unit structs see [the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields) ================================================ FILE: contributor-book/src/guides/submitting-examples.md ================================================ # Submitting Examples to Burn This guide explains how to create and submit new examples to the Burn repository. Examples are a great way to demonstrate Burn's capabilities and help users understand how to use the framework effectively. For a minimal working example, see the [simple-regression](https://github.com/tracel-ai/burn/blob/main/examples/simple-regression/examples/regression.rs) example in the repository. ## Repository Structure The Burn repository is set up as a workspace, with examples located in the `examples/` directory. Each example is a separate crate that can reuse workspace dependencies. ## Creating a New Example 1. Navigate to the examples directory: ```bash cd examples ``` 2. Create a new library crate: ```bash cargo new --lib ``` 3. Update the example's `Cargo.toml`: ```toml [package] name = "" version = "0.1.0" edition = "2021" readme = "README.md" # Remove this line if it exists # readme.workspace = true [dependencies] # Reuse workspace dependencies when available serde = { workspace = true } # Add example-specific dependencies burn = { path = "../../" } ``` ## Required Files and Structure ### README.md Each example must include a README.md file with: - A brief description of what the example demonstrates - A terminal command showing how to run the example - Any prerequisites or setup instructions Example README structure: ````markdown # Example Name Brief description of what this example demonstrates. ## Running the Example ```bash cargo run --example ``` ## Prerequisites List any prerequisites here. ```` ### Source Code Structure - `src/` directory: Contains the main implementation code - `examples/` directory: Contains example code - `.rs`: Example implementation ## Resource Handling - Resources (datasets, models, etc.) should be downloaded in the example code - Do not track external files in the repository - Include code to download and prepare resources when the example is run ## Best Practices 1. **Code Organization** - Keep the code modular and well-documented - Use clear, descriptive variable and function names - Include comments explaining complex operations 2. **Error Handling** - Implement proper error handling - Provide meaningful error messages - Handle resource download failures gracefully 3. **Performance** - Optimize for reasonable execution time - Include progress indicators for long-running operations - Consider adding configuration options for different hardware capabilities 4. **Documentation** - Document all public APIs - Include inline comments for complex logic - Explain any non-obvious implementation details ## Submitting Your Example 1. Ensure your example follows all the guidelines above 2. Test your example thoroughly 3. Create a pull request with: - A clear description of what the example demonstrates - Any relevant issue numbers - Screenshots or output examples (if applicable) Feel free to ask questions in the pull request if you need clarification or guidance. ================================================ FILE: contributor-book/src/how-to-read-this-book.md ================================================ # How to read this book Throughout this book, we maintain the following structure. ## Linking When referring to structures or functions within codebase, we provide permalinks to the lines in specific commits, and indicate them by the relative path of their parent file from the project root. For example this is a reference to the `Tensor` struct in [`crates/burn-tensor/src/tensor/api/base.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/base.rs#L27) When some reference information is useful but is beyond the scope of contributing to Burn, we provide that information in a footnote. To build on the previous example, the `Tensor` mentioned is what's referred to as a newtype struct[^1]. Direct hyperlinks are for tools and resources that are not part of the Burn project, but are useful for contributing to it. For example, when working on implementing an operation for autodiff, it can be useful to use [symbolab](https://www.symbolab.com/) to calculate the left and right partial derivatives. [^1]: For more information on newtype please refer to [the Advanced Types chapter of the Rust Book](https://doc.rust-lang.org/book/ch19-04-advanced-types.html#using-the-newtype-pattern-for-type-safety-and-abstraction) ================================================ FILE: contributor-book/src/overview.md ================================================ # Overview Welcome to The Burn Contributor's Book 👋 This book will help you get acquainted with the internals of the Burn deep learning framework and provide some detailed guidance on how to contribute to the project. Before opening a PR, please read the [Contributing Guidelines](https://github.com/tracel-ai/burn/blob/main/CONTRIBUTING.md). We have crafted some sections for you: - [Getting Started](./getting-started): Much like the [Burn Book](https://burn.dev/books/burn/) which targets users, we'll start with the fundamentals, guiding you through tasks like setting up the development environment, running tests, and what you should check prior to each commit. - [Project Architecture](./project-architecture): This section will give you an in-depth look at the architecture of Burn. - [Guides](./guides): We provide some guides on how to do specific tasks, such as adding a new operations to Burn. - [Frequently Encountered Issues](./frequently-encountered-issues): If you are running into an issue that has you stumped, this is the section to check out prior to asking on the [Discord](https://discord.gg/uPEBbYYDB6). It's a collection of errors encountered by contributors, what caused them, and how they were resolved. As this book is geared towards contributors and not towards users of Burn, we'll assume you have a good understanding of software development, but will make efforts to explain anything outside of that scope, or at least provide links to resources that explain it better than we can. ================================================ FILE: contributor-book/src/project-architecture/README.md ================================================ # Project Architecture This section documents most major architectural decisions with the reasoning behind them. **Sections** - [Module](./module.md) - [Optimization](./module.md#optimization) - [Constraints](./module.md#constraints) - [Solution](./module.md#solution) - [Serialization](./serialization.md) - [Constraints](./serialization.md#constraints) - [Solution](./serialization.md#solution) - [Pros](./serialization.md#pros) - [Cons](./serialization.md#cons) - [Compatibility](./serialization.md#compatibility) - [Tensor](./tensor.md) - [Backend](./backend.md) - [Autodiff](./backend.md#autodiff) ================================================ FILE: contributor-book/src/project-architecture/backend.md ================================================ # Backend The Backend trait abstracts multiple things: - Device type - Float tensor type - Bool tensor type - Int tensor type - Float element type - Int element type - Float tensor operations (kernels) - Int tensor operations (kernels) - Bool tensor operations (kernels) ## Element types > Warning: there are plans to change this architecture in the near future. Even though having one type for tensors is convenient for the tensor API, it can be cumbersome when implementing a backend. Therefore, backends can decide, through associated types, what types they want to use for their int, float, and bool tensors. Since float and int can have multiple precisions, the float and int element types are also associated types that must be declared by the backend. Note that the backend chooses the precision and not the user. Since not all backends will support the same element types, no assumptions must be made. Therefore, there are no methods on tensors to change the precision, except for the `to_full_precision` function, which ensures numerical stability on the current backend. Backend implementations can provide a way to choose the precision, which can be accomplished with a generic parameter (e.g. `NdArray`). ## Operations To be as general as possible, tensor operations are implemented as plain functions. There is no object or self, just functions that take tensors as input and often return tensors as output as well. Backend implementations are free to use their own patterns to implement these kernels. Note that Burn is a dynamic graph deep learning framework, so backends may have to implement asynchronous kernel executions for performance reasons. ## Autodiff As of now, there is only one backend decorator that supports autodiff. It follows the decorator pattern, making any backend differentiable. However, the `AutodiffBackend` trait abstracts how gradients are calculated, and other approaches to autodiff might be added later. For more information about how the current autodiff backend works, you can read this (slightly outdated) [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling). ================================================ FILE: contributor-book/src/project-architecture/module.md ================================================ # Module Modules are a way of creating neural network structures that can be easily optimized, saved, and loaded with little to no boilerplate. Unlike other frameworks, a module does not force the declaration of the forward pass, leaving it up to the implementer to decide how it should be defined. Additionally, most modules are created using a (de)serializable configuration, which defines the structure of the module and its hyperparameters. Parameters and hyperparameters are not serialized into the same file, and both are normally necessary to load a module for inference. ## Optimization Optimization is normally done with variants of gradient descent, and it is important to provide an easy API for optimizing modules. ### Constraints 1. **Users should be able to control what is optimized.** Modules can contain anything for maximum flexibility, but not everything needs to be optimized. 2. **Optimizers should have a serializable state that is updated during training.** Many optimizers keep track of previous gradients to implement some form of momentum. However, the state can be anything, not just tensors, allowing for easy implementation of any kind of optimizer. 3. **The learning rate can be updated during training.** Learning rate schedulers are often used during training and should be considered as a key aspect. ### Solution In the following, the `Module` trait is defined in [`crates/burn-core/src/module/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/module/base.rs#L83) and the `Optimizer` trait is defined in [`crates/burn-core/src/optim/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/base.rs#L8) The solution to this problem comprises multiple parts. Firstly, the `Optimizer` trait is quite similar to the `Module` trait, in terms of saving and loading the state. Please refer to the [serialization](./serialization.md) section for more details. Secondly, two traits were created. The `Optimizer` trait is general and relatively unopinionated, with a simple `step` method that takes a learning rate, a module, and the gradients. The other trait, `SimpleOptimizer`, aims to provide an easier API for implementing new optimizers. The goal is to allow implementations to avoid handling missing gradients, loading and exporting records, navigating the module parameter structure, handling tracked and untracked tensors, and other such tasks. Thirdly, each tensor that will be optimized needs to be wrapped into a `Param` struct, which gives them an ID used for (de)serialization and to associate the state of the optimizer to each parameter. The `Module` trait has two ways to navigate over parameters. The first one is the `map` function, which returns `Self` and makes it easy to implement any transformation and mutate all parameters. The second one is the `visit` function, which has a similar signature but does not mutate the parameter tensors. #### SimpleOptimizer Located in [`crates/burn-core/src/optim/simple/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/base.rs#L9), the `SimpleOptimizer` has two major assumptions: 1. The state of the optimizer is linked to each parameter. In other words, each parameter has its own optimizer state, decoupled from the other parameters. 2. The state of the optimizer implements `Record`, `Clone`, and has a `'static` lifetime. The benefits of those assumptions materialize in simplicity with little loss in flexibility. The state associative type is also generic over the dimension, making it extremely easy to include tensors in the state that share the same dimensionality as its parameter. To wrap a simple optimizer into the more general `Optimizer` trait, the `OptimizerAdaptor` struct is used. #### OptimizerAdaptor Located in in [`crates/burn-core/src/optim/simple/adaptor.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/adaptor.rs#L14), the `OptimizerAdaptor` is a simple struct composed of a `SimpleOptimizer` and a hashmap with all records associated with each parameter ID. When performing an optimization step, the adaptor handles the following: 1. Updates each parameter tensor in the given module using the `Module::map` function. 2. Checks if a gradient for the current tensor exists. 3. Makes sure that the gradient, the tensor, and the optimizer state associated with the current parameter are on the same device. The device can be different if the state is loaded from disk to restart training. 4. Performs the simple optimizer step using the inner tensor since the operations done by the optimizer should not be tracked in the autodiff graph. 5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are marked as required. Note that a parameter can still be updated by another process, as it is the case with running metrics used in batch norm. These tensors are still wrapped using the `Param` struct so that they are included in the module's state and given a proper parameter ID, but they are not registered in the autodiff graph. ================================================ FILE: contributor-book/src/project-architecture/serialization.md ================================================ # Serialization An important aspect of a deep learning framework is the ability to save and load models from disk. Despite appearing as a simple feature, it involves numerous constraints that require a proper solution. ## Constraints 1. **Users should be able to declare the precision of the model to be saved, independent of the backend in use.** The modules should not be duplicated in RAM in another precision to support this. Conversion should be done lazily during (de)serialization. 2. **Users should be able to add any field to a module, even fields that are not serializable.** This can include constants, database connections, other module references, or any other information. Only parameters should be serialized since the structure of the module itself should be encapsulated with module configurations (hyperparameters). 3. **Users should be able to declare the format in which the module should be saved.** This can involve saving to a compressed JSON file or directly to bytes in memory for `no-std` environments. 4. **Users should be able to create a module with its saved parameters without having to initialize the module first.** This will avoid unnecessary module initialization and tensor loading, resulting in reduced cold start when dealing with inference. In addition to all of these constraints, the solution should be easy to use. ## Solution In order to be able to add any field to a module without requiring it to be (de)serializable, we decouple the module type from its state. We create a new type for each module that only contains the parameters that need to be saved. To generate that type automatically, the user must either declare which field is a parameter or a constant, or we assume that each field implements the module trait. The second solution was chosen as it simplifies the code generation and reduces the size of the user API. This means that the `Module` trait should be implemented by [primitive types](https://github.com/tracel-ai/burn/blob/main/crates/burn-core/src/module/param/primitive.rs). The following diagrams highlight the main types and traits used in the solution.

Module Serialization Types

The way the types interact with each other is pretty straightforward. First, a module can be converted into a record using `into_record()`. Note that tensors can be cloned, but it won't actually copy any data; it will simply create another reference to the same data. Then, a `Recorder` instance can be used to serialize any record. The `Recorder` has the `PrecisionSettings` type as associate type, so any record will be serialized using the settings provided at the creation of the `Recorder` instance. Note that tensors implement record, and their item is just a wrapper struct that contains information about the precision in which the tensor should be saved or loaded. No actual copy of the tensor is made until this point. The tensor is converted to the `TensorData` struct and then converted into the specified precision only when `serialize()` or `deserialize()` are called, which makes the whole process lazy. To recapitulate, the `Module` trait has an associated type that implements `Record`, which only contains the parameters of the model. The `Record` trait has a generic associated type (GAT) that specifies a family of types that can be (de)serialized given any `PrecisionSettings`. Records are therefore decoupled from the backend in use, and the saved items can be loaded on any backend with any precision, since the conversion is type-safe and done when `serialize()` and `deserialize()` are called. All of the types are generated using simple derive macros without any conditional statements or complex syntax, as `Record` and `Module` are implemented for all primitive types. This makes the code simple and easy to maintain. In addition, you can extend the current system with your own `Recorder` and `PrecisionSettings` to control how your modules should be saved and loaded. ### Pros - All constraints are respected. - The code is simple and easy to maintain, with very few conditional statements. It is just recursive data structures, where all the complexity is handled by the framework in primitive implementations. - The user API is simple and small, with only two derives (`Record` and `Module`) and no additional attributes. - Users can create their own `Module` and `Record` primitive types, which gives them the flexibility to control how their data is serialized without having to fork the framework. ### Cons - There are more types, but most of them are automatically generated and single-purpose, so users don't need to interact with them for common use cases. However, they can do so if necessary. - When instantiating a new record manually, each field must be set to something, even if the type itself is `()`, which represents no value. Since the code generation step uses associative types, it doesn't know that a field type is actually nothing. Creating a record manually without using the generated function `into_record` or loading it from a file is only useful to load a set of parameters into a module from an arbitrary source. Using the record may not be the optimal solution to this problem, and another API could be created in the future. ### Compatibility Record may become incompatible with previous versions of Burn, depending on the chosen format. The more compact format (bincode) store minimal information about the type, making it significantly smaller but less resilient to type changes such adding an optional field. At some point, it might be necessary to provide a translation script that can translate a more resilient format from a previous version to a more compact one. ================================================ FILE: contributor-book/src/project-architecture/tensor.md ================================================ # Tensor A proper deep learning framework should have a fast tensor implementation with autodiff support, and Burn is no exception. The tensor API abstracts away backend implementation details and focuses on usability without compromising performance. To make it as easy as possible to use, there is only one tensor type, which is different from multiple tensor and deep learning crates in Rust. Generic parameters are used instead to specialize the tensor type. - **B: Backend:** The first argument is the backend on which the tensor implementation lies. - **const D: usize:** The second argument is the dimensionality of the tensor. - **K: TensorKind:** The third argument is the tensor kind, which can be either Float, Int or Bool. By default, the tensor kind is set to Float, so for most tensors, the kind argument is not necessary. Having one struct for tensors reduces the complexity of the tensor API, which also means less duplicated documentation to write and maintain. Tensors are thread-safe, which means that you can send a tensor to another thread, and everything will work, including auto-differentiation. Note that there are no explicit in-place tensor operations since all tensor operations take owned tensors as parameters, which make it possible to mutate them. Tensors can be shared simply by cloning them, but if there is only one reference to a tensor, the backend implementation is free to reuse the tensor's allocated data. For more information about how it is done, you can have a look at this [blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling). ## Tensor Operations Operations on Tensors (sometimes shortened to Ops) are defined in traits (generally part of the Backend Supertrait) and implemented for the Tensor struct. The appropriate parent trait of an operation depends on the type of operation: - `base` => All tensor kinds should implement these operations (reshape, into_data, etc.). The implementation is in [crates/burn-tensor/src/tensor/api/base.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/base.rs). - `numeric` => All tensors that are numeric by nature should implement these operations (Add, Sub, Div, etc.). The implementation is in [crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/numeric.rs). - `Float` => Tensor operations are only available for float tensors. The implementation is in [burn-tensor/src/tensor/api/float.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/float.rs). - `Int` => Tensor operations are only available for int tensors. The implementation is in [burn-tensor/src/tensor/api/int.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/int.rs). - `bool` => Tensor operations are only available for bool tensors. The implementation is in [burn-tensor/src/tensor/api/bool.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/bool.rs). `Numeric` is directly implemented for `Float` and `Int` tensors, and in general, The implementations for these methods are calling the corresponding `{Int|Float}` method defined in the backend supertrait. Anything that is implemented by numeric should have an implementation in the `{Int|Float}` traits, though it may be avoidable if the operation for one type requires casting to the other type. To provide an example, `powf` should be implemented for `Int` tensors, but it should not be an Int Tensor Operation. The LHS should be converted to a float, and the output should be converted back to an int. So it's possible to avoid implementing `IntTensorOp` altogether. Additionally there are some operations that should be defined as functions instead of tensor op methods. These are: `module` => These should be exported as functions instead of methods on tensors. The implementation is in [crates/burn-tensor/src/tensor/ops/module.rs](https://github.com/tracel-ai/burn/tree/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/modules). `activation` => These should also be exported as functions instead of methods on tensors. The implementation is in [crates/burn-tensor/src/tensor/ops/activation.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/activation.rs). Note that some activations are just a combination of backend operations and are not declared in there. ================================================ FILE: crates/burn/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Flexible and Comprehensive Deep Learning Framework in Rust" documentation = "https://docs.rs/burn" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn" readme.workspace = true repository = "https://github.com/tracel-ai/burn" rust-version = "1.92" version.workspace = true [lints] workspace = true [features] default = [ "std", "burn-core/default", "burn-train?/default", "burn-collective?/default", # Backends "burn-candle?/default", "burn-cpu?/default", "burn-ndarray?/default", "burn-tch?/default", "burn-wgpu?/default", "burn-router?/default", "burn-remote?/default", "burn-cuda?/default", "burn-autodiff?/default", "burn-rocm?/default", "burn-nn/default", "burn-optim/default", "burn-dispatch?/default", ] doc = [ "default", "train", "burn-core/doc", "burn-train/doc", "burn-collective/doc", "burn-store?/std", # Backends "burn-candle/doc", "burn-cpu?/doc", "burn-ndarray/doc", "burn-tch/doc", "burn-wgpu/doc", "burn-router/doc", "burn-cuda/doc", "burn-autodiff?/std", "burn-rocm/doc", "burn-nn/doc", "burn-optim/doc", "burn-dispatch?/doc", ] std = [ "burn-core/std", # Backends "burn-candle?/std", "burn-cpu?/std", "burn-ndarray?/std", "burn-wgpu?/std", "burn-router?/std", "burn-cuda?/std", "burn-autodiff?/std", "burn-rocm?/std", "burn-store?/std", "burn-tch?/std", "burn-nn/std", "burn-optim/std", "burn-dispatch?/std", ] tracing = [ "cubecl?/tracing", "burn-core/tracing", # Backends "burn-candle?/tracing", "burn-cpu?/tracing", "burn-ndarray?/tracing", "burn-wgpu?/tracing", "burn-router?/tracing", "burn-cuda?/tracing", "burn-autodiff?/tracing", "burn-rocm?/tracing", "burn-tch?/tracing", "burn-store?/tracing", "burn-nn/tracing", "burn-optim/tracing", "burn-dispatch?/tracing", ] network = ["burn-core/network"] # Training with full features train = ["burn-train", "autodiff", "dataset"] ## Includes the Text UI (progress bars, metric plots) tui = ["burn-train?/tui"] ## Includes system info metrics (CPU/GPU usage, etc) metrics = ["burn-train?/sys-metrics"] # Datasets dataset = ["burn-core/dataset"] sqlite = ["burn-core/sqlite"] sqlite-bundled = ["burn-core/sqlite-bundled"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["burn-core/record-item-custom-serde"] # Model storage and serialization (SafeTensors, PyTorch interop) store = ["burn-store"] # CubeCL re-export cubecl = ["dep:cubecl"] audio = ["burn-core/audio"] vision = ["burn-core/vision"] rl = ["dep:burn-rl", "burn-train?/rl"] # Backend ir = ["burn-ir"] autodiff = ["burn-autodiff", "burn-dispatch?/autodiff"] fusion = [ "ir", "burn-wgpu?/fusion", "burn-cuda?/fusion", "burn-rocm?/fusion", "burn-cpu?/fusion", ] ## Backend features accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] autotune = [ "burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-rocm?/autotune", "burn-cpu?/autotune", ] autotune-checks = [ "burn-wgpu?/autotune-checks", "burn-cuda?/autotune-checks", "burn-rocm?/autotune-checks", "burn-cpu?/autotune-checks", ] blas-netlib = ["burn-ndarray?/blas-netlib"] openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] remote = ["burn-remote/client", "ir"] router = ["burn-router", "ir"] server = ["burn-remote/server"] simd = ["burn-ndarray?/simd"] template = ["burn-wgpu?/template"] collective = ["burn-collective", "burn-optim/collective", "burn-train?/ddp"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] candle-metal = ["burn-candle?/metal"] cuda = ["burn-cuda", "burn-dispatch?/cuda"] rocm = ["burn-rocm", "burn-dispatch?/rocm"] ndarray = ["burn-ndarray", "burn-dispatch?/ndarray"] tch = ["burn-tch"] vulkan = ["wgpu", "burn-wgpu/vulkan", "burn-dispatch?/vulkan"] webgpu = ["wgpu", "burn-wgpu/webgpu", "burn-dispatch?/webgpu"] metal = ["wgpu", "burn-wgpu/metal", "burn-dispatch?/metal"] wgpu = ["burn-wgpu"] cpu = ["burn-cpu", "burn-dispatch?/cpu"] # Backend dispatch dispatch = ["burn-dispatch"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false } burn-train = { path = "../burn-train", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", default-features = false } burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", default-features = false } burn-rl = { path = "../burn-rl", version = "=0.21.0-pre.2", optional = true, default-features = false } # Backends burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-candle = { path = "../burn-candle", version = "=0.21.0-pre.2", optional = true } burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-dispatch = { path = "../burn-dispatch", version = "=0.21.0-pre.2", optional = true, default-features = false } cubecl = { workspace = true, default-features = false, optional = true } ================================================ FILE: crates/burn/src/backend.rs ================================================ #[cfg(feature = "ndarray")] pub use burn_ndarray as ndarray; #[cfg(feature = "ndarray")] pub use ndarray::NdArray; #[cfg(feature = "autodiff")] pub use burn_autodiff as autodiff; #[cfg(feature = "remote")] pub use burn_remote as remote; #[cfg(feature = "remote")] pub use burn_remote::RemoteBackend; #[cfg(feature = "autodiff")] pub use burn_autodiff::Autodiff; #[cfg(feature = "wgpu")] pub use burn_wgpu as wgpu; #[cfg(feature = "wgpu")] pub use burn_wgpu::Wgpu; #[cfg(feature = "webgpu")] pub use burn_wgpu::WebGpu; #[cfg(feature = "vulkan")] pub use burn_wgpu::Vulkan; #[cfg(feature = "metal")] pub use burn_wgpu::Metal; #[cfg(feature = "cuda")] pub use burn_cuda as cuda; #[cfg(feature = "cuda")] pub use burn_cuda::Cuda; #[cfg(feature = "candle")] pub use burn_candle as candle; #[cfg(feature = "candle")] pub use burn_candle::Candle; #[cfg(feature = "rocm")] pub use burn_rocm as rocm; #[cfg(feature = "rocm")] pub use burn_rocm::Rocm; #[cfg(feature = "tch")] pub use burn_tch as libtorch; #[cfg(feature = "tch")] pub use burn_tch::LibTorch; #[cfg(feature = "router")] pub use burn_router::Router; #[cfg(feature = "router")] pub use burn_router as router; #[cfg(feature = "ir")] pub use burn_ir as ir; #[cfg(feature = "collective")] pub use burn_collective as collective; #[cfg(feature = "cpu")] pub use burn_cpu as cpu; #[cfg(feature = "cpu")] pub use burn_cpu::Cpu; ================================================ FILE: crates/burn/src/collective.rs ================================================ pub use burn_collective::*; ================================================ FILE: crates/burn/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] //! # Burn //! //! Burn is a new comprehensive dynamic Deep Learning Framework built using Rust //! with extreme flexibility, compute efficiency and portability as its primary goals. //! //! ## Performance //! //! Because we believe the goal of a deep learning framework is to convert computation //! into useful intelligence, we have made performance a core pillar of Burn. //! We strive to achieve top efficiency by leveraging multiple optimization techniques: //! //! - Automatic kernel fusion //! - Asynchronous execution //! - Thread-safe building blocks //! - Intelligent memory management //! - Automatic kernel selection //! - Hardware specific features //! - Custom Backend Extension //! //! ## Training & Inference //! //! The whole deep learning workflow is made easy with Burn, as you can monitor your training progress //! with an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU clusters. //! //! Burn was built from the ground up with training and inference in mind. It's also worth noting how Burn, //! in comparison to frameworks like PyTorch, simplifies the transition from training to deployment, //! eliminating the need for code changes. //! //! ## Backends //! //! Burn strives to be as fast as possible on as many hardwares as possible, with robust implementations. //! We believe this flexibility is crucial for modern needs where you may train your models in the cloud, //! then deploy on customer hardwares, which vary from user to user. //! //! Compared to other frameworks, Burn has a very different approach to supporting many backends. //! By design, most code is generic over the Backend trait, which allows us to build Burn with swappable backends. //! This makes composing backend possible, augmenting them with additional functionalities such as //! autodifferentiation and automatic kernel fusion. //! //! - WGPU (WebGPU): Cross-Platform GPU Backend //! - Candle: Backend using the Candle bindings //! - LibTorch: Backend using the LibTorch bindings //! - NdArray: Backend using the NdArray primitive as data structure //! - Autodiff: Backend decorator that brings backpropagation to any backend //! - Fusion: Backend decorator that brings kernel fusion to backends that support it //! //! # Quantization //! //! Quantization techniques perform computations and store tensors in lower precision data types like //! 8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep //! learning model categorized as post-training quantization (PTQ) and quantization aware training (QAT). //! //! In post-training quantization, the model is trained in floating point precision and later converted //! to the lower precision data type. There are two types of post-training quantization: //! //! 1. Static quantization: quantizes the weights and activations of the model. Quantizing the //! activations statically requires data to be calibrated (i.e., recording the activation values to //! compute the optimal quantization parameters with representative data). //! 2. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the //! activations are dynamically at runtime. //! //! Sometimes post-training quantization is not able to achieve acceptable task accuracy. In general, //! this is where quantization-aware training (QAT) can be used: during training, fake-quantization //! modules are inserted in the forward and backward passes to simulate quantization effects, allowing //! the model to learn representations that are more robust to reduced precision. //! //! Burn does not currently support QAT. Only post-training quantization (PTQ) is implemented at this //! time. //! //! Quantization support in Burn is currently in active development. It supports the following PTQ modes on some backends: //! - Per-tensor and per-block quantization to 8-bit, 4-bit and 2-bit representations //! //! ## Feature Flags //! //! The following feature flags are available. //! By default, the feature `std` is activated. //! //! - Training //! - `train`: Enables features `dataset` and `autodiff` and provides a training environment //! - `tui`: Includes Text UI with progress bar and plots //! - `metrics`: Includes system info metrics (CPU/GPU usage, etc.) //! - Dataset //! - `dataset`: Includes a datasets library //! - `audio`: Enables audio datasets (SpeechCommandsDataset) //! - `sqlite`: Stores datasets in SQLite database //! - `sqlite_bundled`: Use bundled version of SQLite //! - `vision`: Enables vision datasets (MnistDataset) //! - Backends //! - `wgpu`: Makes available the WGPU backend //! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler //! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler //! - `cuda`: Makes available the CUDA backend //! - `rocm`: Makes available the ROCm backend //! - `candle`: Makes available the Candle backend //! - `tch`: Makes available the LibTorch backend //! - `ndarray`: Makes available the NdArray backend //! - Backend specifications //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use //! - `openblas-system`: If supported, Openblas installed on the system will be use //! - `autotune`: Enable running benchmarks to select the best kernel in backends that support it. //! - `fusion`: Enable operation fusion in backends that support it. //! - Backend decorators //! - `autodiff`: Makes available the Autodiff backend //! - Model Storage //! - `store`: Enables model storage with SafeTensors format and PyTorch interoperability //! - Others: //! - `std`: Activates the standard library (deactivate for no_std) //! - `server`: Enables the remote server. //! - `network`: Enables network utilities (currently, only a file downloader with progress bar) //! //! You can also check the details in sub-crates [`burn-core`](https://docs.rs/burn-core) and [`burn-train`](https://docs.rs/burn-train). pub use burn_core::*; /// Train module #[cfg(feature = "train")] pub mod train { pub use burn_train::*; } /// Module for reinforcement learning. #[cfg(feature = "rl")] pub mod rl { pub use burn_rl::*; } /// Backend module. pub mod backend; #[cfg(feature = "server")] pub use burn_remote::server; /// Module for collective operations #[cfg(feature = "collective")] pub mod collective; /// Module for model storage and serialization #[cfg(feature = "store")] pub mod store { pub use burn_store::*; } /// Neural network module. pub mod nn { pub use burn_nn::*; } /// Optimizers module. pub mod optim { pub use burn_optim::*; } // For backward compat, `burn::lr_scheduler::*` /// Learning rate scheduler module. #[cfg(feature = "std")] pub mod lr_scheduler { pub use burn_optim::lr_scheduler::*; } // For backward compat, `burn::grad_clipping::*` /// Gradient clipping module. pub mod grad_clipping { pub use burn_optim::grad_clipping::*; } #[cfg(feature = "dispatch")] pub use burn_dispatch::*; /// CubeCL module re-export. #[cfg(feature = "cubecl")] pub mod cubecl { pub use cubecl::*; } pub mod prelude { //! Structs and macros used by most projects. Add `use //! burn::prelude::*` to your code to quickly get started with //! Burn. pub use burn_core::prelude::*; pub use crate::nn; } ================================================ FILE: crates/burn-autodiff/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Automatic differentiation backend for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-autodiff" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff" documentation = "https://docs.rs/burn-autodiff" version.workspace = true [lints] workspace = true [features] default = ["std", "tracing"] std = ["dep:parking_lot"] export_tests = [] # check checkpointer is_empty in tests tracing = [ "dep:tracing", "burn-std/tracing", "burn-backend/tracing", ] [dependencies] burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } derive-new = { workspace = true } spin = { workspace = true } parking_lot = { workspace = true, optional = true } log = { workspace = true } hashbrown = { workspace = true } num-traits = { workspace = true } portable-atomic = { workspace = true } tracing = { workspace = true, optional = true, features = ["default"] } [package.metadata.docs.rs] features = ["default"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-autodiff/README.md ================================================ # Burn Autodiff > [Burn](https://github.com/tracel-ai/burn) autodiff backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-autodiff.svg)](https://crates.io/crates/burn-autodiff) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-autodiff/blob/master/README.md) For now only first order reverse mode autodiff is supported. ================================================ FILE: crates/burn-autodiff/src/backend.rs ================================================ use crate::{ checkpoint::strategy::{CheckpointStrategy, NoCheckpointing}, grads::Gradients, tensor::AutodiffTensor, }; use alloc::{format, string::String}; use burn_backend::{ backend::{AutodiffBackend, Backend, ExecutionError}, tensor::{BoolTensor, IntTensor, QuantizedTensor}, }; use core::marker::PhantomData; /// Enable auto-differentiation on a backend. /// /// This works as a backend decorator, extending the functionality of any backend with /// backpropagation. #[derive(Clone, Copy, Debug, Default)] pub struct Autodiff { _b: PhantomData, _checkpoint_strategy: PhantomData, } impl Backend for Autodiff { type Device = B::Device; type FloatTensorPrimitive = AutodiffTensor; type FloatElem = B::FloatElem; type IntTensorPrimitive = B::IntTensorPrimitive; type IntElem = B::IntElem; type BoolTensorPrimitive = B::BoolTensorPrimitive; type BoolElem = B::BoolElem; type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive; fn ad_enabled(_device: &Self::Device) -> bool { true } fn name(device: &Self::Device) -> String { format!("autodiff<{}>", B::name(device)) } fn seed(device: &B::Device, seed: u64) { B::seed(device, seed) } fn sync(device: &B::Device) -> Result<(), ExecutionError> { B::sync(device) } fn memory_persistent_allocations< Output: Send, Input: Send, Func: Fn(Input) -> Output + Send, >( device: &Self::Device, input: Input, func: Func, ) -> Output { B::memory_persistent_allocations(device, input, func) } fn memory_cleanup(device: &Self::Device) { B::memory_cleanup(device) } fn staging<'a, Iter>(data: Iter, device: &Self::Device) where Iter: Iterator, { B::staging(data, device); } fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool { B::supports_dtype(device, dtype) } fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet { B::dtype_usage(device, dtype) } } impl AutodiffBackend for Autodiff { type InnerBackend = B; type Gradients = Gradients; fn backward(tensor: AutodiffTensor) -> Gradients { tensor.backward() } fn grad(tensor: &AutodiffTensor, grads: &Gradients) -> Option { tensor.grad(grads) } fn grad_remove( tensor: &AutodiffTensor, grads: &mut Gradients, ) -> Option { tensor.grad_remove(grads) } fn inner(tensor: AutodiffTensor) -> B::FloatTensorPrimitive { tensor.primitive } fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor { AutodiffTensor::new(tensor) } fn grad_replace( tensor: &AutodiffTensor, grads: &mut Self::Gradients, grad: B::FloatTensorPrimitive, ) { tensor.grad_replace(grads, grad); } fn int_inner(tensor: IntTensor) -> IntTensor { tensor } fn bool_inner(tensor: BoolTensor) -> BoolTensor { tensor } fn int_from_inner(tensor: IntTensor) -> IntTensor { tensor } fn bool_from_inner(tensor: BoolTensor) -> BoolTensor { tensor } fn q_inner(tensor: QuantizedTensor) -> QuantizedTensor { tensor } fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor { tensor } } ================================================ FILE: crates/burn-autodiff/src/checkpoint/base.rs ================================================ use super::{ retro_forward::RetroForwards, state::{BackwardStates, State}, }; use crate::collections::HashMap; use crate::graph::NodeId; use alloc::{vec, vec::Vec}; #[derive(new, Debug)] /// Links a [NodeId] to its autodiff graph [NodeRef] pub(crate) struct NodeTree { map: HashMap>, } impl NodeTree { /// Gives the parents of the node in the autodiff graph pub(crate) fn parents(&self, node_id: &NodeId) -> Option> { self.map.get(node_id).cloned() } } #[derive(new, Debug)] /// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass pub struct Checkpointer { backward_states: BackwardStates, retro_forwards: RetroForwards, node_tree: NodeTree, } impl Checkpointer { /// Gives the output of the given node, by recursively asking parents to compute themselves /// or give their pre-computed tensors. pub fn retrieve_node_output(&mut self, node_id: NodeId) -> T where T: Clone + Send + 'static, { self.topological_sort(node_id).into_iter().for_each(|node| { self.retro_forwards .execute_retro_forward(node, &mut self.backward_states) }); self.backward_states.get_state::(&node_id) } /// Sorts the ancestors of NodeId in a way such that all parents come before their children /// Useful to avoid recursivity later when mutating the states /// /// The sort on a compute bound state or a memory bound that is already computed is trivial. /// The match on State::Computed also serves as a stopping criterion for the sort, /// we don't need to look higher than that during recursivity. fn topological_sort(&self, node_id: NodeId) -> Vec { match self.backward_states.get_state_ref(&node_id) { Some(state) => match state { State::Recompute { n_required: _ } => { let mut sorted = Vec::new(); let parents = self.node_tree.parents(&node_id).unwrap(); for parent_node in parents { let parent_sorted = self.topological_sort(parent_node); for ps in parent_sorted { if !sorted.contains(&ps) { sorted.push(ps) } } } sorted.push(node_id); sorted } State::Computed { state_content: _, n_required: _, } => vec![node_id], }, None => panic!("Node {node_id:?} is not in the backward_states. "), } } /// Checks if checkpointer has been drained adequately. Useful for testing pub fn is_empty(&self) -> bool { self.backward_states.is_empty() && self.retro_forwards.is_empty() } } ================================================ FILE: crates/burn-autodiff/src/checkpoint/builder.rs ================================================ use crate::{ collections::HashMap, graph::{ComputingProperty, NodeId}, tensor::AutodiffTensor, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; use burn_backend::Backend; use core::any::Any; use super::{ base::{Checkpointer, NodeTree}, retro_forward::{RetroForward, RetroForwards}, state::{BackwardStates, State}, }; #[derive(Debug)] /// Determines if a node should checkpoint its computed output or its retro_forward for recomputation /// The action is normally created by the child of the node, once the node is determined to be needed pub enum CheckpointingAction { /// The node's already computed output should be saved Computed { /// The node node_id: NodeId, /// The node's output state_content: Box, }, /// The node should recompute itself when asked Recompute { /// The node node_id: NodeId, /// How the node should recompute itself retro_forward: Arc, }, } // TODO: Remove that when proper client server. unsafe impl Send for CheckpointingAction {} impl CheckpointingAction { /// Utility function to access the id of the node of the checkpointing action pub fn id(&self) -> NodeId { match self { CheckpointingAction::Computed { node_id: node_ref, state_content: _, } => *node_ref, CheckpointingAction::Recompute { node_id: node_ref, retro_forward: _, } => *node_ref, } } } #[derive(new, Debug, Default)] /// Accumulates checkpoints as checkpointing actions during the forward pass, /// and builds a checkpointer right before the backward pass pub struct CheckpointerBuilder { explicit_actions: Vec, backup_actions: Vec, } /// Determines if a checkpoint should impact the n_required values (Main) /// or if it should just keep the state in case it's required (Backup) /// pub(crate) enum ActionType { /// Explicit actions have been explicitly requested by some operation to retrieve their state Explicit, /// Backup actions are not always needed. They exist to save the output of an operation /// whose child is memory bound, in case the state is indirectly needed when computing /// the child's retro_forward. If no explicit action ever asks for the child's output, then /// the backup output will go out of scope when the checkpointer is built. Backup, } impl CheckpointerBuilder { pub(crate) fn checkpoint( &mut self, tensor: &AutodiffTensor, action_type: ActionType, ) { let action_list = match action_type { ActionType::Explicit => &mut self.explicit_actions, ActionType::Backup => &mut self.backup_actions, }; match &tensor.node.properties { ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => { action_list.push(CheckpointingAction::Computed { node_id: tensor.node.id, state_content: Box::new(tensor.primitive.clone()), }) } ComputingProperty::MemoryBound { retro_forward } => { action_list.push(CheckpointingAction::Recompute { node_id: tensor.node.id, retro_forward: retro_forward.clone(), }) } } } pub(crate) fn extend(&mut self, other: CheckpointerBuilder) { for other_action in other.explicit_actions { self.explicit_actions.push(other_action) } for other_unsure in other.backup_actions { self.backup_actions.push(other_unsure) } } pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer { let mut backward_states_map = HashMap::new(); let mut retro_forwards_map = HashMap::new(); // Find recursion stopping points let stop_nodes: Vec = self.find_stop_nodes(); // We start by identifying how many times each node will be required. let n_required_map = self.build_n_required_map(&node_tree, stop_nodes); // Then we checkpoint the nodes with the corresponding n_required value self.insert_checkpoints( &mut backward_states_map, &mut retro_forwards_map, n_required_map, ); Checkpointer::new( BackwardStates::new(backward_states_map), RetroForwards::new(retro_forwards_map), node_tree, ) } fn find_stop_nodes(&self) -> Vec { let mut stop_nodes = Vec::default(); for action in self .explicit_actions .iter() .chain(self.backup_actions.iter()) { match action { CheckpointingAction::Computed { node_id: node_ref, state_content: _, } => stop_nodes.push(*node_ref), CheckpointingAction::Recompute { node_id: _, retro_forward: _, } => {} } } stop_nodes } fn build_n_required_map( &self, node_tree: &NodeTree, stop_nodes: Vec, ) -> HashMap { let mut n_required_map = HashMap::::default(); for action in self.explicit_actions.iter() { match action { CheckpointingAction::Computed { node_id: node_ref, state_content: _, } => { let id = *node_ref; match n_required_map.remove(&id) { Some(n) => { n_required_map.insert(id, n + 1); } None => { n_required_map.insert(id, 1); } }; } CheckpointingAction::Recompute { node_id: node_ref, retro_forward: _, } => { let id = *node_ref; Self::update_n_required_of_parents( id, &mut n_required_map, node_tree, &stop_nodes, ); } } } n_required_map } fn insert_checkpoints( mut self, backward_states_map: &mut HashMap, retro_forward_map: &mut HashMap>, n_required_map: HashMap, ) { // We do not loop over checkpointing actions anymore because they can contain // duplicates or miss some that are in backup. We loop over the n_required_map // from which we use the ids to find them again in the checkpointing actions for (node_id, n_required) in n_required_map { // We find the checkpointing action for node_id. It's likely in checkpointing_actions // so we check there first, otherwise it will be in backup. // Technically it can be there several times but can never be of both types, so we can assume the first we find is fine let action = match self .explicit_actions .iter() .position(|action| action.id() == node_id) { Some(pos) => self.explicit_actions.remove(pos), None => { let pos = self .backup_actions .iter() .position(|action| action.id() == node_id); self.backup_actions.remove(pos.unwrap_or_else(|| { panic!("Node {:?} is needed but never checkpointed", &node_id) })) } }; match action { CheckpointingAction::Computed { node_id: _, state_content, } => { self.checkpoint_compute(backward_states_map, node_id, state_content, n_required) } CheckpointingAction::Recompute { node_id: _, retro_forward, } => self.checkpoint_lazy( backward_states_map, retro_forward_map, node_id, retro_forward, n_required, ), }; } } fn update_n_required_of_parents( id: NodeId, n_required_map: &mut HashMap, node_tree: &NodeTree, stop_nodes: &Vec, ) { match n_required_map.remove(&id) { Some(n) => { n_required_map.insert(id, n + 1); } None => { n_required_map.insert(id, 1); if !stop_nodes.contains(&id) && let Some(parents) = node_tree.parents(&id) { for p in parents { Self::update_n_required_of_parents( p, n_required_map, node_tree, stop_nodes, ); } } } } } fn checkpoint_compute( &self, backward_states_map: &mut HashMap, node_id: NodeId, state_content: Box, n_required: usize, ) { backward_states_map.insert( node_id, State::Computed { state_content, n_required, }, ); } fn checkpoint_lazy( &self, backward_states_map: &mut HashMap, retro_forward_map: &mut HashMap>, node_id: NodeId, retro_forward: Arc, n_required: usize, ) { retro_forward_map.insert(node_id, retro_forward); backward_states_map.insert(node_id, State::Recompute { n_required }); } } ================================================ FILE: crates/burn-autodiff/src/checkpoint/mod.rs ================================================ /// Checkpointer module pub mod base; pub(crate) mod builder; /// RetroForward module pub mod retro_forward; /// BackwardStates module pub mod state; /// CheckpointStrategy module pub mod strategy; ================================================ FILE: crates/burn-autodiff/src/checkpoint/retro_forward.rs ================================================ use crate::collections::HashMap; use crate::graph::NodeId; use alloc::sync::Arc; use core::fmt::Debug; use super::state::{BackwardStates, State}; /// Definition of the forward function of a node, called during retropropagation only. /// This is different from the normal forward function because it reads and writes from /// the [BackwardStates] map instead of having a clear function signature. pub trait RetroForward: Debug + Send + 'static { /// Applies the forward pass for retropropagation. fn forward(&self, states: &mut BackwardStates, out_node: NodeId); } #[derive(new, Debug)] /// Links [NodeId]s to their corresponding [RetroForward] pub(crate) struct RetroForwards { map: HashMap>, } impl RetroForwards { /// Executes the [RetroForward] for a given [NodeId] if the node's /// [State] is [State::Recompute], otherwise does nothing. pub(crate) fn execute_retro_forward( &mut self, node_id: NodeId, backward_states: &mut BackwardStates, ) { if let State::Recompute { n_required: _ } = backward_states .get_state_ref(&node_id) .unwrap_or_else(|| panic!("Should find node {node_id:?}")) { // Retro forwards are always used only once because afterwards their state is computed let retro_forward = self.map.remove(&node_id).unwrap(); retro_forward.forward(backward_states, node_id); } } pub(crate) fn is_empty(&self) -> bool { self.map.is_empty() } } #[macro_export] /// Creates a RetroForward struct for unary scalar operations macro_rules! retro_unary_scalar { ( $name:ident, $ops:expr ) => { #[derive(new, Debug, Clone)] struct $name { lhs_id: NodeId, rhs: Scalar, _backend: PhantomData, } impl RetroForward for $name { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let lhs = states.get_state::(&self.lhs_id); let out = $ops(lhs, self.rhs); states.save(out_node, out) } } }; } #[macro_export] /// Creates a RetroForward struct for unary scalar operations macro_rules! retro_unary { ( $name:ident, $ops:expr ) => { #[derive(new, Debug, Clone)] struct $name { input_id: NodeId, _backend: PhantomData, } impl RetroForward for $name { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = $ops(input); states.save(out_node, out) } } }; } #[macro_export] /// Creates a RetroForward struct for binary operations macro_rules! retro_binary { ( $name:ident, $ops:expr ) => { #[derive(new, Debug, Clone)] struct $name { lhs_id: NodeId, rhs_id: NodeId, _backend: PhantomData, } impl RetroForward for $name { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let lhs = states.get_state::(&self.lhs_id); let rhs = states.get_state::(&self.rhs_id); let out = $ops(lhs, rhs); states.save(out_node, out) } } }; } ================================================ FILE: crates/burn-autodiff/src/checkpoint/state.rs ================================================ use core::any::Any; use crate::collections::HashMap; use crate::graph::NodeId; use alloc::boxed::Box; /// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any. pub(crate) type StateContent = Box; #[derive(Debug)] /// The state contained at one node. Encapsulates the node output if precomputed, /// or clearly asks that it needs to be recomputed from the parents. /// Also keeps track of the number of times the state is required so it can be removed /// from the map of states on its last use. pub(crate) enum State { /// The state was not checkpointed, will need to recompute it from the node's parents Recompute { n_required: usize }, /// The state was checkpointed or computed during retropropagation and can be directly accessed Computed { state_content: StateContent, n_required: usize, }, } impl State { /// Returns a reference to the (not yet) downcasted node output, if checkpointed pub(crate) fn to_state_content(&self) -> &StateContent { match self { State::Recompute { n_required: _ } => { unreachable!( "Can't get state content of recompute state. A child has likely been accessed before its parents." ) } State::Computed { state_content, n_required: _, } => state_content, } } /// Returns a (not yet) downcasted node output, if checkpointed pub(crate) fn into_state_content(self) -> StateContent { match self { State::Recompute { n_required: _ } => { unreachable!( "Can't get state content of recompute state. A child has likely been accessed before its parents." ) } State::Computed { state_content, n_required: _, } => state_content, } } /// Returns the number of time the state is required pub(crate) fn n_required(&self) -> usize { match self { State::Recompute { n_required } => *n_required, State::Computed { state_content: _, n_required, } => *n_required, } } } #[derive(new, Default, Debug)] /// Links [NodeId]s to their current state pub struct BackwardStates { map: HashMap, } impl BackwardStates { /// Returns the output in the state of the given [NodeId], /// and decrements the number of times this state is required. /// This function always gives ownership of the output, but will clone it if needed for further uses. pub fn get_state(&mut self, node_id: &NodeId) -> T where T: Clone + Send + 'static, { // Fetch the state and decrement its number of required let state = self.map.remove(node_id).unwrap(); let remaining_n_required = state.n_required() - 1; // Downcast the state to whatever it is supposed to be // If still needed after giving ownership, we copy it back to the hashmap if remaining_n_required > 0 { let new_stored_state = match state { State::Recompute { n_required: _ } => unreachable!(), State::Computed { state_content, n_required: _, } => State::Computed { state_content, n_required: remaining_n_required, }, }; let downcasted = new_stored_state .to_state_content() .downcast_ref::() .unwrap() .clone(); self.insert_state(*node_id, new_stored_state); downcasted } else { let downcasted = state.into_state_content().downcast::().unwrap(); *downcasted } } /// Returns a reference to the [State] of the given node /// Useful when we need [State] information without needing the underlying tensor pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> { self.map.get(node_id) } /// Associates a [State] to its [NodeId] pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) { self.map.insert(node_id, state); } /// Saves the output to the state of the given [NodeId]. pub fn save(&mut self, node_id: NodeId, saved_output: T) where T: Clone + Send + 'static, { let n_required = self.get_state_ref(&node_id).unwrap().n_required(); self.insert_state( node_id, State::Computed { state_content: Box::new(saved_output), n_required, }, ); } pub(crate) fn is_empty(&self) -> bool { self.map.is_empty() } } ================================================ FILE: crates/burn-autodiff/src/checkpoint/strategy.rs ================================================ use core::fmt::Debug; use burn_backend::Backend; use crate::{graph::ComputingProperty, tensor::AutodiffTensor}; use alloc::sync::Arc; use super::{ builder::{ActionType, CheckpointerBuilder}, retro_forward::RetroForward, }; /// Strategy for the amount of checkpointing to do during autodiff pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static { /// May modify the compute property depending on the strategy fn compute_property(retro_forward: R) -> ComputingProperty; /// Checkpoints parents if necessary in the strategy fn checkpoint_parents<'a, B2, A>( parents: A, builder: &mut CheckpointerBuilder, ) -> Result<(), CheckpointingError> where B2: Backend, A: IntoIterator>; } #[derive(Debug)] /// Error that can happen when trying to checkpoint a tensor. pub enum CheckpointingError { /// When a parent is untracked, we can't easily checkpoint its state, since we don't know the /// requirements in advanced. UntrackedParent, } #[derive(Clone, Copy, Debug, Default)] /// All operations are considered compute bound, notwithstanding how they are marked pub struct NoCheckpointing {} impl CheckpointStrategy for NoCheckpointing { /// An operation marked as memory bound is actually compute bound. fn compute_property(_retro_forward: R) -> ComputingProperty { ComputingProperty::ComputeBound } /// An operation marked as memory bound is actually compute bound. /// It's therefore useless to checkpoint the parents fn checkpoint_parents<'a, B2, A>( _parents: A, _builder: &mut CheckpointerBuilder, ) -> Result<(), CheckpointingError> where B2: Backend, A: IntoIterator>, { // Nothing to do here Ok(()) } } #[derive(Clone, Copy, Debug, Default)] /// Operation properties are as they are marked (compute or memory bound) pub struct BalancedCheckpointing {} impl CheckpointStrategy for BalancedCheckpointing { /// An operation marked as memory bound is memory bound. /// When memory bound, an operation needs to save its RetroForward fn compute_property(retro_forward: R) -> ComputingProperty { ComputingProperty::MemoryBound { retro_forward: Arc::new(retro_forward), } } /// An operation marked as memory bound is really memory bound. /// Since the operation may not checkpoint its parents but may need them indirectly /// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them fn checkpoint_parents<'a, B2, A>( parents: A, builder: &mut CheckpointerBuilder, ) -> Result<(), CheckpointingError> where B2: Backend, A: IntoIterator>, { let mut can_checkpoint = true; for tensor in parents.into_iter() { if let crate::graph::Requirement::None = tensor.node.requirement { can_checkpoint = false; } else { builder.checkpoint(tensor, ActionType::Backup); } } if !can_checkpoint { *builder = CheckpointerBuilder::default(); return Err(CheckpointingError::UntrackedParent); } Ok(()) } } ================================================ FILE: crates/burn-autodiff/src/grads.rs ================================================ use burn_backend::{ Backend, TensorMetadata, TensorPrimitive, tensor::{FloatTensor, TensorContainer}, }; use crate::{ NodeId, graph::{NodeRef, Requirement}, tensor::AutodiffTensor, }; /// Gradient identifier. pub type GradID = u64; /// Gradients container used during the backward pass. pub struct Gradients { container: TensorContainer, } impl Gradients { /// Creates a new gradients container. pub fn new(root_node: NodeRef, root_tensor: FloatTensor) -> Self { let mut gradients = Self { container: TensorContainer::new(), }; gradients.register::( root_node.id, B::float_ones( root_tensor.shape(), &B::float_device(&root_tensor), root_tensor.dtype().into(), ), ); gradients } /// Consumes the gradients for a given tensor. /// /// Each tensor should be consumed exactly 1 time if its gradients are only required during the /// backward pass, otherwise, it may be consume multiple times. pub fn consume(&mut self, node: &NodeRef) -> FloatTensor { match node.requirement { Requirement::Grad => self .container .get::(&node.id.value) .map(|tensor| tensor.tensor()) .expect("Can't consume the gradients before they are registered at least once."), Requirement::GradInBackward => self .container .remove::(&node.id.value) .map(|tensor| tensor.tensor()) .expect("Can't consume the gradients before they are registered at least once."), Requirement::None => panic!("Trying to consume the gradients for an untracked tensor"), } } /// Removes a grad tensor from the container. pub fn remove(&mut self, tensor: &AutodiffTensor) -> Option> { self.container .remove::(&tensor.node.id.value) .map(|tensor| tensor.tensor()) } /// Gets a grad tensor from the container. pub fn get(&self, tensor: &AutodiffTensor) -> Option> { self.container .get::(&tensor.node.id.value) .map(|tensor| tensor.tensor()) } /// Register a grad tensor in the container. /// /// If the tensor already exists, add both tensors together before saving the result. pub fn register(&mut self, node_id: NodeId, value: FloatTensor) { if let Some(tensor_old) = self.container.remove::(&node_id.value) { self.container.register::( node_id.value, TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())), ); } else { self.container .register::(node_id.value, TensorPrimitive::Float(value)); } } } ================================================ FILE: crates/burn-autodiff/src/graph/base.rs ================================================ use super::NodeId; use crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent}; use alloc::boxed::Box; /// Backward step for reverse mode autodiff. pub trait Step: Send + core::fmt::Debug { /// Executes the step and consumes it. fn step(self: Box, grads: &mut Gradients, checkpointer: &mut Checkpointer); /// Depth of the operation relative to the first node added to a graph. fn depth(&self) -> usize; /// The node associated to the step. fn node(&self) -> NodeId; /// The parents of the node associated to the step. fn parents(&self) -> &[Parent]; } pub type StepBoxed = Box; ================================================ FILE: crates/burn-autodiff/src/graph/mod.rs ================================================ mod base; mod node; mod requirement; pub mod traversal; pub use base::*; pub use node::*; pub use requirement::*; ================================================ FILE: crates/burn-autodiff/src/graph/node.rs ================================================ use alloc::{sync::Arc, vec::Vec}; #[cfg(target_has_atomic = "64")] use core::sync::atomic::{AtomicU64, Ordering}; #[cfg(not(target_has_atomic = "64"))] use portable_atomic::{AtomicU64, Ordering}; use crate::checkpoint::retro_forward::RetroForward; use crate::runtime::AutodiffClientImpl; use super::Requirement; #[derive(Debug, Clone)] pub enum ComputingProperty { ComputeBound, MemoryBound { retro_forward: Arc, }, Ambiguous, // Maybe autotune someday } /// This is safe only because we only call RetroForward on the autodiff server. /// Therefore, the trait will never be used by multiple threads at the same time. /// /// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the /// Arc, which will make (dyn RetroForward) safely implement Send. unsafe impl Send for ComputingProperty {} /// unsafe Sync is required because Send is only implemented for Arc, not Arc. unsafe impl Sync for ComputingProperty {} /// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning. #[derive(new, Debug)] pub struct Node { pub parents: Vec, pub order: usize, pub id: NodeId, pub requirement: Requirement, pub properties: ComputingProperty, pub client: AutodiffClientImpl, } pub type NodeRef = Arc; #[derive(new, Debug, Clone, PartialEq, Eq)] pub struct Parent { pub id: NodeId, } impl Node { /// Returns the [node](Node) only if gradients are required. pub fn clone_if_require_grad(self: &Arc) -> Option { match self.requirement.is_none() { true => None, false => Some(self.clone()), } } } /// Unique identifier generated for each node. #[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)] pub struct NodeId { /// The integer representation of the id pub value: u64, } impl core::fmt::Display for NodeId { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!("NodeId({})", self.value)) } } impl NodeId { /// Create a unique [node id](NodeId). pub fn new() -> Self { static COUNTER: AtomicU64 = AtomicU64::new(0); let value = COUNTER.fetch_add(1, Ordering::Relaxed); if value == u64::MAX { panic!("NodeId overflowed"); } Self { value } } } impl Default for NodeId { fn default() -> Self { Self::new() } } ================================================ FILE: crates/burn-autodiff/src/graph/requirement.rs ================================================ use super::NodeRef; /// Requirement for each tensor in the graph. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Requirement { /// Operations that require gradients. Grad, /// Operations that require gradients only for backprop. GradInBackward, /// Operations that don't need gradients, therefore not to be included in the graph. None, } impl Requirement { /// Returns true if gradients are not required. pub fn is_none(&self) -> bool { matches!(self, Self::None) } /// Returns the right requirement from a list of nodes. pub fn from_nodes(nodes: &[NodeRef]) -> Self { if nodes.len() == 1 { return nodes[0].requirement.infer(&Requirement::None); } nodes .iter() .map(|node| node.requirement) .reduce(|acc, requirement| requirement.infer(&acc)) .unwrap_or(Requirement::None) } fn infer(&self, other: &Self) -> Self { match self.is_none() && other.is_none() { true => Self::None, false => Self::GradInBackward, } } } ================================================ FILE: crates/burn-autodiff/src/graph/traversal.rs ================================================ use super::{Step, StepBoxed}; use crate::{ NodeId, collections::{HashMap, HashSet}, graph::Parent, }; use alloc::vec::Vec; /// Breadth for search algorithm. pub struct BreadthFirstSearch; pub trait TraversalItem { fn id(&self) -> NodeId; fn parents(&self) -> &[Parent]; fn parent_nodes(&self) -> Vec { self.parents().iter().map(|p| p.id).collect() } } impl BreadthFirstSearch { /// Traverse the graph of backward steps from a root node. pub fn traverse( &self, root_id: NodeId, root_step: I, steps: &mut HashMap, mut callback: F, ) where F: FnMut(NodeId, I), I: TraversalItem, { let mut visited = HashSet::new(); let mut parents = Vec::new(); visited.insert(root_id); parents.append(&mut root_step.parent_nodes()); callback(root_id, root_step); while let Some(id) = parents.pop() { let step = match steps.remove(&id) { Some(step) => step, None => continue, }; let step_node = step.id(); let step_parents = step.parent_nodes(); if visited.contains(&step_node) { continue; } visited.insert(step_node); for id in step_parents.iter() { if !visited.contains(id) { parents.push(*id); } } callback(step_node, step); } } } impl TraversalItem for StepBoxed { fn id(&self) -> NodeId { Step::node(self.as_ref()) } fn parents(&self) -> &[Parent] { Step::parents(self.as_ref()) } } ================================================ FILE: crates/burn-autodiff/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! # Burn Autodiff //! //! This autodiff library is a part of the Burn project. It is a standalone crate //! that can be used to perform automatic differentiation on tensors. It is //! designed to be used with the Burn Tensor crate, but it can be used with any //! tensor library that implements the `Backend` trait. #[macro_use] extern crate derive_new; extern crate alloc; /// Checkpoint module. pub mod checkpoint; /// Gradients module. pub mod grads; /// Operation module. pub mod ops; pub(crate) mod graph; // Exported for backend extension pub use graph::NodeId; pub(crate) mod tensor; pub(crate) mod utils; mod backend; pub(crate) mod runtime; pub use backend::*; /// A facade around for HashMap and HashSet. /// This avoids elaborate import wrangling having to happen in every module. mod collections { #[cfg(not(feature = "std"))] pub use hashbrown::{HashMap, HashSet}; #[cfg(feature = "std")] pub use std::collections::{HashMap, HashSet}; } ================================================ FILE: crates/burn-autodiff/src/ops/activation.rs ================================================ use core::marker::PhantomData; use crate::{ Autodiff, checkpoint::{ base::Checkpointer, retro_forward::RetroForward, state::BackwardStates, strategy::CheckpointStrategy, }, grads::Gradients, graph::NodeId, ops::{Backward, Ops, OpsKind, unary}, retro_unary, }; use burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor}; impl ActivationOps> for Autodiff { fn gelu(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Gelu; retro_unary!(RetroGelu, B::gelu); impl Backward for Gelu { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { B::gelu_backward(input, grad) }); } } match Gelu .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroGelu::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::gelu(tensor.primitive.clone())) } OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)), } } fn relu(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Relu; retro_unary!(RetroRelu, B::relu); impl Backward for Relu { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let state = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { B::relu_backward(state, grad) }); } } match Relu .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroRelu::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::relu(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)), } } fn sigmoid(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sigmoid; retro_unary!(RetroSigmoid, B::sigmoid); impl Backward for Sigmoid { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); let output = B::sigmoid(input); unary::(ops.parents, ops.node, grads, |grad| { B::sigmoid_backward(output, grad) }); } } match Sigmoid .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSigmoid::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::sigmoid(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)), } } fn log_sigmoid(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct LogSigmoid; retro_unary!(RetroLogSigmoid, B::log_sigmoid); impl Backward for LogSigmoid { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { B::log_sigmoid_backward(input, grad) }); } } match LogSigmoid .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroLogSigmoid::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::log_sigmoid(tensor.primitive.clone())) } OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)), } } } ================================================ FILE: crates/burn-autodiff/src/ops/backward.rs ================================================ use super::{Ops, OpsPrep}; use crate::{ checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy}, grads::Gradients, graph::{ComputingProperty, NodeRef, Requirement}, utils::duplicate, }; use burn_backend::Backend; /// Trait for all operations. /// /// # Notes /// /// Concrete types implementing this trait should not have any state. /// If a state is necessary during the backward pass, /// they should be declared with the associated type 'State'. pub trait Backward: Send + core::fmt::Debug where Self: Sized + 'static, B: Backend, { /// Associated type to compute the backward pass. type State: Clone + Send + core::fmt::Debug + 'static; /// The backward pass. fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ); /// Prepare the backward ops. fn prepare( self, nodes: [NodeRef; N], ) -> OpsPrep { let requirement = Requirement::from_nodes(&nodes); OpsPrep::new( nodes, requirement, self, ComputingProperty::Ambiguous, // If not specified we start with ambiguous CheckpointerBuilder::default(), ) } } /// Execute a binary operation during the backward step. pub fn binary( parents: [Option; 2], node: NodeRef, grads: &mut Gradients, func_lhs: FLhs, func_rhs: FRhs, ) where B: Backend, FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive, FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive, { let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::(&node))); let [node_lhs, node_rhs] = parents; if let Some(node) = node_lhs { let grad = func_lhs(grad_4lhs.unwrap()); grads.register::(node.id, grad) } if let Some(node) = node_rhs { let grad = func_rhs(grad_4rhs.unwrap()); grads.register::(node.id, grad) } } /// Execute a unary operation during the backward step. pub fn unary(parents: [Option; 1], node: NodeRef, grads: &mut Gradients, func: F) where B: Backend, F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive, { let [parent_node] = parents; let grad = grads.consume::(&node); if let Some(node) = parent_node { let grad = func(grad); grads.register::(node.id, grad) } } ================================================ FILE: crates/burn-autodiff/src/ops/base.rs ================================================ use super::Backward; use crate::{ checkpoint::{ base::Checkpointer, builder::{ActionType, CheckpointerBuilder}, retro_forward::RetroForward, strategy::CheckpointStrategy, }, grads::Gradients, graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step}, tensor::AutodiffTensor, }; use alloc::boxed::Box; use burn_backend::{Backend, TensorMetadata, tensor::FloatTensor}; use burn_std::Shape; use core::marker::PhantomData; /// Operation in preparation. /// /// Each mode has its own set of functions to minimize cloning for unused backward states. #[derive(new)] pub struct OpsPrep { nodes: [NodeRef; N], requirement: Requirement, backward: Backward, compute_property: ComputingProperty, checkpointer_builder: CheckpointerBuilder, checkpoint_strategy: PhantomData, phantom_backend: PhantomData, phantom_state: PhantomData, marker: PhantomData, } /// Operation is initialized pub struct Init; /// Operation has been tagged as memory bound pub struct MemoryBound; /// Memory bound operation has received its RetroForward pub struct MemoryBoundRetroForward; /// Operation's compute property is fixed pub struct ComputePropertyDone; /// Tracked operation tag. pub struct Tracked; /// Untracked operation tag. pub struct UnTracked; impl OpsPrep where B: Backend, BO: Backward, { /// Indicates that the operation is compute bound, meaning its computation /// is heavy and should not be recomputed pub fn compute_bound(self) -> OpsPrep { OpsPrep::new( self.nodes, self.requirement, self.backward, ComputingProperty::ComputeBound, self.checkpointer_builder, ) } /// Indicates that the operation is memory bound, meaning its computation /// is light and can be recomputed pub fn memory_bound(self) -> OpsPrep { OpsPrep::new( self.nodes, self.requirement, self.backward, self.compute_property, self.checkpointer_builder, ) } } impl OpsPrep where B: Backend, BO: Backward, C: CheckpointStrategy, { /// Registers the retro forward, if needed pub fn retro_forward( self, retro_forward: R, ) -> OpsPrep { OpsPrep::new( self.nodes, self.requirement, self.backward, C::compute_property(retro_forward), self.checkpointer_builder, ) } } impl OpsPrep where B: Backend, BO: Backward, C: CheckpointStrategy, { /// Checkpoints the parents, if needed pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep where B2: Backend, A: IntoIterator>, { let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder) { Ok(..) => self.compute_property, Err(..) => ComputingProperty::ComputeBound, }; OpsPrep::new( self.nodes, self.requirement, self.backward, compute_property, self.checkpointer_builder, ) } } impl OpsPrep where B: Backend, BO: Backward, { /// Prepare a stateless operation. pub fn stateless(self, output: FloatTensor) -> AutodiffTensor { match self.stateful() { OpsKind::Tracked(prep) => prep.finish((), output), OpsKind::UnTracked(prep) => prep.finish(output), } } } impl OpsPrep where B: Backend, S: Clone + Send + core::fmt::Debug + 'static, BO: Backward, { /// Prepare an operation that requires a state during the backward pass. pub fn stateful(self) -> OpsKind { match self.requirement.is_none() { false => OpsKind::Tracked(OpsPrep::new( self.nodes, self.requirement, self.backward, self.compute_property, self.checkpointer_builder, )), true => OpsKind::UnTracked(OpsPrep::new( self.nodes, self.requirement, self.backward, self.compute_property, self.checkpointer_builder, )), } } } impl OpsPrep where B: Backend, S: Clone + Send + core::fmt::Debug + 'static, BO: Backward, { /// Finish the preparation of an untracked operation and returns the output tensor. pub fn finish(self, output: FloatTensor) -> AutodiffTensor { let output = AutodiffTensor::from_parents( output, &self.nodes, self.requirement, self.compute_property, ); let parents = self.nodes.map(|node| node.clone_if_require_grad()); let ops = Ops::new(parents, output.node.clone(), ()); // We register the ops in the graph even if untracked, otherwise memory bound operations // that have an untracked parent would not be able to retrieve it output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder) } } impl OpsPrep where B: Backend, S: Clone + Send + core::fmt::Debug + 'static, BO: Backward, { /// Finish the preparation of a tracked operation and returns the output tensor. pub fn finish(self, state: S, output: FloatTensor) -> AutodiffTensor { let output = AutodiffTensor::from_parents( output, &self.nodes, self.requirement, self.compute_property, ); let parents = self.nodes.map(|node| node.clone_if_require_grad()); let ops = Ops::new(parents, output.node.clone(), state); output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder) } /// Checkpoints the tensor pub fn checkpoint(&mut self, tensor: &AutodiffTensor) -> NodeId { self.checkpointer_builder .checkpoint(tensor, ActionType::Explicit); tensor.node.id } } /// Enum used before finishing tracked and untracked operations. pub enum OpsKind { /// Tracked operation preparation. Tracked(OpsPrep), /// Untracked operation preparation. UnTracked(OpsPrep), } /// Operation containing its parent nodes, its own node and the backward step state. #[derive(new, Debug)] pub struct Ops { /// Parents nodes. pub parents: [Option; N], /// The node. pub node: NodeRef, /// The state. pub state: S, } /// Operation implementing backward [step](Step) with type erasing. #[derive(new, Debug)] struct OpsStep where B: Backend, T: Backward, SB: Clone + Send + core::fmt::Debug + 'static, { ops: Ops, backward: T, phantom: PhantomData, } impl Step for OpsStep where B: Backend, T: Backward, SB: Clone + Send + core::fmt::Debug + 'static, { fn step(self: Box, grads: &mut Gradients, checkpointer: &mut Checkpointer) { self.backward.backward(self.ops, grads, checkpointer); } fn node(&self) -> NodeId { self.ops.node.id } fn parents(&self) -> &[Parent] { &self.ops.node.parents } fn depth(&self) -> usize { self.ops.node.order } } #[derive(new, Debug)] struct UntrackedOpsStep { ops: Ops<(), N>, } impl Step for UntrackedOpsStep { fn step(self: Box, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) { // Nothing to do } fn node(&self) -> NodeId { self.ops.node.id } fn parents(&self) -> &[Parent] { &self.ops.node.parents } fn depth(&self) -> usize { self.ops.node.order } } /// Make sure the grad tensor has the given shape. /// /// If broadcasting happened during the forward pass, the gradients will be sum along the /// broadcasted dimension. pub fn broadcast_shape(mut grad: FloatTensor, shape: &Shape) -> FloatTensor { let shape_grad = grad.shape(); let ndims = shape_grad.num_dims(); for i in 0..ndims { if shape_grad[i] != shape[i] { if shape[i] != 1 { panic!( "Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}", shape, shape_grad, "Expected the shape of the next grad to be 1." ); } grad = B::float_sum_dim(grad, i); } } grad } ================================================ FILE: crates/burn-autodiff/src/ops/bool_tensor.rs ================================================ use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor}; use alloc::vec::Vec; use burn_backend::{ Backend, ExecutionError, Scalar, TensorData, ops::BoolTensorOps, tensor::{BoolTensor, Device, IntTensor}, }; use burn_std::Shape; impl BoolTensorOps for Autodiff { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { B::bool_from_data(data, device) } async fn bool_into_data(tensor: BoolTensor) -> Result { B::bool_into_data(tensor).await } fn bool_into_int(tensor: BoolTensor) -> IntTensor { B::bool_into_int(tensor) } fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { B::bool_to_device(tensor, device) } fn bool_device(tensor: &BoolTensor) -> Device { B::bool_device(tensor) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { B::bool_reshape(tensor, shape) } fn bool_slice(tensor: BoolTensor, slices: &[burn_std::Slice]) -> BoolTensor { B::bool_slice(tensor, slices) } fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { B::bool_empty(shape, device) } fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor { B::bool_zeros(shape, device) } fn bool_ones(shape: Shape, device: &Device) -> BoolTensor { B::bool_ones(shape, device) } fn bool_slice_assign( tensor: BoolTensor, slices: &[burn_std::Slice], value: BoolTensor, ) -> BoolTensor { B::bool_slice_assign(tensor, slices, value) } fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { B::bool_cat(tensors, dim) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { B::bool_equal(lhs, rhs) } fn bool_not(tensor: BoolTensor) -> BoolTensor { B::bool_not(tensor) } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { B::bool_and(lhs, rhs) } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { B::bool_or(lhs, rhs) } fn bool_xor(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { B::bool_xor(lhs, rhs) } fn bool_into_float(tensor: BoolTensor) -> as Backend>::FloatTensorPrimitive { AutodiffTensor::new(B::bool_into_float(tensor)) } fn bool_swap_dims( tensor: as Backend>::BoolTensorPrimitive, dim1: usize, dim2: usize, ) -> as Backend>::BoolTensorPrimitive { B::bool_swap_dims(tensor, dim1, dim2) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { B::bool_permute(tensor, axes) } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { B::bool_flip(tensor, axes) } async fn bool_argwhere(tensor: BoolTensor) -> IntTensor { B::bool_argwhere(tensor).await } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { B::bool_expand(tensor, shape) } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { B::bool_repeat_dim(tensor, dim, times) } fn bool_unfold( tensor: BoolTensor, dim: usize, size: usize, step: usize, ) -> BoolTensor { B::bool_unfold(tensor, dim, size, step) } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, source: BoolTensor, ) -> BoolTensor { B::bool_mask_where(tensor, mask, source) } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { B::bool_mask_fill(tensor, mask, value) } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { B::bool_gather(dim, tensor, indices) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { B::bool_scatter_or(dim, tensor, indices, value) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { B::bool_equal_elem(lhs, rhs) } } ================================================ FILE: crates/burn-autodiff/src/ops/int_tensor.rs ================================================ use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor}; use alloc::vec::Vec; use burn_backend::{ Backend, Distribution, ExecutionError, Scalar, TensorData, ops::IntTensorOps, tensor::{BoolTensor, Device, IntTensor}, }; use burn_std::{IntDType, Shape}; impl IntTensorOps for Autodiff { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { B::int_from_data(data, device) } async fn int_into_data(tensor: IntTensor) -> Result { B::int_into_data(tensor).await } fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { B::int_to_device(tensor, device) } fn int_device(tensor: &IntTensor) -> Device { B::int_device(tensor) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { B::int_reshape(tensor, shape) } fn int_slice(tensor: IntTensor, slices: &[burn_std::Slice]) -> IntTensor { B::int_slice(tensor, slices) } fn int_empty( shape: Shape, device: & as Backend>::Device, dtype: IntDType, ) -> IntTensor { B::int_empty(shape, device, dtype) } fn int_slice_assign( tensor: IntTensor, slices: &[burn_std::Slice], value: IntTensor, ) -> IntTensor { B::int_slice_assign(tensor, slices, value) } fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { B::int_cat(tensors, dim) } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { B::int_equal(lhs, rhs) } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { B::int_equal_elem(lhs, rhs) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_add(lhs, rhs) } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::int_add_scalar(lhs, rhs) } fn int_clamp_min(tensor: IntTensor, min: Scalar) -> IntTensor { B::int_clamp_min(tensor, min) } fn int_clamp_max(tensor: IntTensor, max: Scalar) -> IntTensor { B::int_clamp_max(tensor, max) } fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { B::int_clamp(tensor, min, max) } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_sub(lhs, rhs) } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::int_sub_scalar(lhs, rhs) } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_mul(lhs, rhs) } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::int_mul_scalar(lhs, rhs) } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_div(lhs, rhs) } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::int_div_scalar(lhs, rhs) } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_remainder(lhs, rhs) } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::int_remainder_scalar(lhs, rhs) } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::int_matmul(lhs, rhs) } fn int_neg(tensor: IntTensor) -> IntTensor { B::int_neg(tensor) } fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { B::int_zeros(shape, device, dtype) } fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { B::int_ones(shape, device, dtype) } fn int_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: IntDType, ) -> IntTensor { B::int_full(shape, fill_value, device, dtype) } fn int_sum(tensor: IntTensor) -> IntTensor { B::int_sum(tensor) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { B::int_sum_dim(tensor, dim) } fn int_mean(tensor: IntTensor) -> IntTensor { B::int_mean(tensor) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { B::int_mean_dim(tensor, dim) } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { B::int_cumsum(tensor, dim) } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { B::int_cumprod(tensor, dim) } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { B::int_cummin(tensor, dim) } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { B::int_cummax(tensor, dim) } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { B::int_repeat_dim(tensor, dim, times) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { B::int_greater(lhs, rhs) } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { B::int_greater_elem(lhs, rhs) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { B::int_greater_equal(lhs, rhs) } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { B::int_greater_equal_elem(lhs, rhs) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { B::int_lower(lhs, rhs) } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { B::int_lower_elem(lhs, rhs) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { B::int_lower_equal(lhs, rhs) } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { B::int_lower_equal_elem(lhs, rhs) } fn int_gather(dim: usize, tensor: IntTensor, indices: IntTensor) -> IntTensor { B::int_gather(dim, tensor, indices) } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { B::int_scatter_add(dim, tensor, indices, value) } fn int_select(tensor: IntTensor, dim: usize, indices: IntTensor) -> IntTensor { B::int_select(tensor, dim, indices) } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { B::int_select_add(tensor, dim, indices, value) } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> as Backend>::IntTensorPrimitive { B::int_mask_where(tensor, mask, value) } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> as Backend>::IntTensorPrimitive { B::int_mask_fill(tensor, mask, value) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { B::int_argmax(tensor, dim) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { B::int_argmin(tensor, dim) } fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { B::int_max(tensor) } fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive { B::int_max_dim(tensor, dim) } fn int_max_dim_with_indices( tensor: B::IntTensorPrimitive, dim: usize, ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { B::int_max_dim_with_indices(tensor, dim) } fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { B::int_min(tensor) } fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive { B::int_min_dim(tensor, dim) } fn int_min_dim_with_indices( tensor: B::IntTensorPrimitive, dim: usize, ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { B::int_min_dim_with_indices(tensor, dim) } fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive { B::int_abs(tensor) } fn int_into_float( tensor: as Backend>::IntTensorPrimitive, ) -> as Backend>::FloatTensorPrimitive { AutodiffTensor::new(B::int_into_float(tensor)) } fn int_swap_dims( tensor: as Backend>::IntTensorPrimitive, dim1: usize, dim2: usize, ) -> as Backend>::IntTensorPrimitive { B::int_swap_dims(tensor, dim1, dim2) } fn int_random( shape: Shape, distribution: Distribution, device: &Device, ) -> IntTensor { B::int_random(shape, distribution, device) } fn int_arange(range: core::ops::Range, device: &Device) -> IntTensor { B::int_arange(range, device) } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { B::int_permute(tensor, axes) } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { B::int_flip(tensor, axes) } fn int_sign(tensor: IntTensor) -> IntTensor { B::int_sign(tensor) } fn int_prod(tensor: IntTensor) -> IntTensor { B::int_prod(tensor) } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { B::int_prod_dim(tensor, dim) } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { B::int_expand(tensor, shape) } fn int_sort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_sort(tensor, dim, descending) } fn int_sort_with_indices( tensor: IntTensor, dim: usize, descending: bool, ) -> (IntTensor, IntTensor) { B::int_sort_with_indices(tensor, dim, descending) } fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::bitwise_and(lhs, rhs) } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::bitwise_and_scalar(lhs, rhs) } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::bitwise_or(lhs, rhs) } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::bitwise_or_scalar(lhs, rhs) } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::bitwise_xor(lhs, rhs) } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::bitwise_xor_scalar(lhs, rhs) } fn bitwise_not(tensor: IntTensor) -> IntTensor { B::bitwise_not(tensor) } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::bitwise_left_shift(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::bitwise_left_shift_scalar(lhs, rhs) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::bitwise_right_shift(lhs, rhs) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::bitwise_right_shift_scalar(lhs, rhs) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { B::int_cast(tensor, dtype) } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { B::int_unfold(tensor, dim, size, step) } } ================================================ FILE: crates/burn-autodiff/src/ops/maxmin.rs ================================================ use super::{Backward, Ops, unary}; use crate::{checkpoint::base::Checkpointer, grads::Gradients}; use burn_backend::{Backend, TensorMetadata}; use burn_std::Shape; #[derive(Debug)] pub(crate) struct MaxMinDim; impl Backward for MaxMinDim { type State = (B::IntTensorPrimitive, Shape, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let (indices, shape, dim) = ops.state; let device = B::float_device(&grad); let dtype = grad.dtype(); let zeros = B::float_zeros(shape, &device, dtype.into()); B::float_scatter_add(dim, zeros, indices, grad) }); } } ================================================ FILE: crates/burn-autodiff/src/ops/mod.rs ================================================ mod activation; mod backward; mod base; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; pub(crate) mod maxmin; pub(crate) mod sort; pub use backward::*; pub use base::*; ================================================ FILE: crates/burn-autodiff/src/ops/module.rs ================================================ use crate::Autodiff; use crate::checkpoint::base::Checkpointer; use crate::checkpoint::strategy::CheckpointStrategy; use crate::grads::Gradients; use crate::graph::NodeId; use crate::ops::{Backward, Ops, unary}; use crate::tensor::AutodiffTensor; use burn_backend::Backend; use burn_backend::ops::attention::attention_fallback; use burn_backend::ops::*; use burn_backend::tensor::{FloatTensor, IntTensor}; use super::OpsKind; impl ModuleOps> for Autodiff { fn embedding(weights: AutodiffTensor, indices: IntTensor) -> AutodiffTensor { #[derive(Debug)] struct Embedding; impl Backward for Embedding { type State = (B::FloatTensorPrimitive, IntTensor); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (weights, indices) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::embedding_backward(weights, grad, indices) }); } } match Embedding .prepare::([weights.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (weights.primitive.clone(), indices.clone()), B::embedding(weights.primitive, indices), ), OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)), } } fn embedding_backward( _weights: AutodiffTensor, _output: AutodiffTensor, _indices: IntTensor, ) -> AutodiffTensor { panic!("Can't differentiate embedding backward."); } fn conv1d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvOptions<1>, ) -> AutodiffTensor { #[derive(Debug)] struct Conv1DWithBias; #[derive(Debug)] struct Conv1DNoBias; impl Backward for Conv1DWithBias { type State = (NodeId, NodeId, NodeId, ConvOptions<1>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv1d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv1d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for Conv1DNoBias { type State = (NodeId, NodeId, ConvOptions<1>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv1d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv1d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match Conv1DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv1d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match Conv1DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv1d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => { prep.finish(B::conv1d(x.primitive, weight.primitive, None, options)) } }, } } fn conv_transpose1d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> AutodiffTensor { #[derive(Debug)] struct ConvTranspose1DWithBias; #[derive(Debug)] struct ConvTranspose1DNoBias; impl Backward for ConvTranspose1DWithBias { type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<1>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv_transpose1d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose1d_weight_backward( x.clone(), weight, grad.clone(), options, ); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv_transpose1d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for ConvTranspose1DNoBias { type State = (NodeId, NodeId, ConvTransposeOptions<1>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv_transpose1d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match ConvTranspose1DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv_transpose1d( x.primitive, weight.primitive, Some(bias.primitive), options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match ConvTranspose1DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv_transpose1d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d( x.primitive, weight.primitive, None, options, )), }, } } fn conv2d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvOptions<2>, ) -> AutodiffTensor { #[derive(Debug)] struct Conv2DWithBias; #[derive(Debug)] struct Conv2DNoBias; impl Backward for Conv2DWithBias { type State = (NodeId, NodeId, NodeId, ConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv2d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv2d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for Conv2DNoBias { type State = (NodeId, NodeId, ConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv2d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv2d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match Conv2DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv2d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match Conv2DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv2d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => { prep.finish(B::conv2d(x.primitive, weight.primitive, None, options)) } }, } } fn deform_conv2d( x: AutodiffTensor, offset: AutodiffTensor, weight: AutodiffTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> AutodiffTensor { #[derive(Debug)] struct DeformConv2DWithMaskWithBias; #[derive(Debug)] struct DeformConv2DWithMaskNoBias; #[derive(Debug)] struct DeformConv2DNoMaskWithBias; #[derive(Debug)] struct DeformConv2DNoMaskNoBias; impl Backward for DeformConv2DWithMaskWithBias { type State = (NodeId, NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, offset_state, weight_state, mask_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output(x_state); let offset = checkpointer.retrieve_node_output(offset_state); let weight = checkpointer.retrieve_node_output(weight_state); let mask = Some(checkpointer.retrieve_node_output(mask_state)); let bias = Some(checkpointer.retrieve_node_output(bias_state)); let backward = B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options); if let Some(node) = node_x { grads.register::(node.id, backward.x_grad) } if let Some(node) = node_offset { grads.register::(node.id, backward.offset_grad) } if let Some(node) = node_weight { grads.register::(node.id, backward.weight_grad) } if let Some(node) = node_mask { grads.register::(node.id, backward.mask_grad.unwrap()) } if let Some(node) = node_bias { grads.register::(node.id, backward.bias_grad.unwrap()) } } } impl Backward for DeformConv2DWithMaskNoBias { type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_offset, node_weight, node_mask] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, offset_state, weight_state, mask_state, options) = ops.state; let x = checkpointer.retrieve_node_output(x_state); let offset = checkpointer.retrieve_node_output(offset_state); let weight = checkpointer.retrieve_node_output(weight_state); let mask = Some(checkpointer.retrieve_node_output(mask_state)); let backward = B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options); if let Some(node) = node_x { grads.register::(node.id, backward.x_grad) } if let Some(node) = node_offset { grads.register::(node.id, backward.offset_grad) } if let Some(node) = node_weight { grads.register::(node.id, backward.weight_grad) } if let Some(node) = node_mask { grads.register::(node.id, backward.mask_grad.unwrap()) } } } impl Backward for DeformConv2DNoMaskWithBias { type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_offset, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, offset_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output(x_state); let offset = checkpointer.retrieve_node_output(offset_state); let weight = checkpointer.retrieve_node_output(weight_state); let bias = Some(checkpointer.retrieve_node_output(bias_state)); let backward = B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options); if let Some(node) = node_x { grads.register::(node.id, backward.x_grad) } if let Some(node) = node_offset { grads.register::(node.id, backward.offset_grad) } if let Some(node) = node_weight { grads.register::(node.id, backward.weight_grad) } if let Some(node) = node_bias { grads.register::(node.id, backward.bias_grad.unwrap()) } } } impl Backward for DeformConv2DNoMaskNoBias { type State = (NodeId, NodeId, NodeId, DeformConvOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_offset, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, offset_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output(x_state); let offset = checkpointer.retrieve_node_output(offset_state); let weight = checkpointer.retrieve_node_output(weight_state); let backward = B::deform_conv2d_backward(x, offset, weight, None, None, grad, options); if let Some(node) = node_x { grads.register::(node.id, backward.x_grad) } if let Some(node) = node_offset { grads.register::(node.id, backward.offset_grad) } if let Some(node) = node_weight { grads.register::(node.id, backward.weight_grad) } } } match (mask, bias) { (Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias .prepare::([ x.node.clone(), offset.node.clone(), weight.node.clone(), mask.node.clone(), bias.node.clone(), ]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let offset_state = prep.checkpoint(&offset); let weight_state = prep.checkpoint(&weight); let mask_state = prep.checkpoint(&mask); let bias_state = prep.checkpoint(&bias); prep.finish( ( x_state, offset_state, weight_state, mask_state, bias_state, options.clone(), ), B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, Some(mask.primitive), Some(bias.primitive), options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, Some(mask.primitive), Some(bias.primitive), options, )), }, (Some(mask), None) => match DeformConv2DWithMaskNoBias .prepare::([ x.node.clone(), offset.node.clone(), weight.node.clone(), mask.node.clone(), ]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let offset_state = prep.checkpoint(&offset); let weight_state = prep.checkpoint(&weight); let mask_state = prep.checkpoint(&mask); prep.finish( ( x_state, offset_state, weight_state, mask_state, options.clone(), ), B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, Some(mask.primitive), None, options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, Some(mask.primitive), None, options, )), }, (None, Some(bias)) => match DeformConv2DNoMaskWithBias .prepare::([ x.node.clone(), offset.node.clone(), weight.node.clone(), bias.node.clone(), ]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let offset_state = prep.checkpoint(&offset); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( ( x_state, offset_state, weight_state, bias_state, options.clone(), ), B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, None, Some(bias.primitive), options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, None, Some(bias.primitive), options, )), }, (None, None) => match DeformConv2DNoMaskNoBias .prepare::([x.node.clone(), offset.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let offset_state = prep.checkpoint(&offset); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, offset_state, weight_state, options.clone()), B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, None, None, options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( x.primitive, offset.primitive, weight.primitive, None, None, options, )), }, } } fn deform_conv2d_backward( _x: AutodiffTensor, _offset: AutodiffTensor, _weight: AutodiffTensor, _mask: Option>, _bias: Option>, _output_grad: AutodiffTensor, _options: DeformConvOptions<2>, ) -> DeformConv2dBackward { panic!("Can't differentiate deform conv 2d backward."); } fn conv_transpose2d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> AutodiffTensor { #[derive(Debug)] struct ConvTranspose2DWithBias; #[derive(Debug)] struct ConvTranspose2DNoBias; impl Backward for ConvTranspose2DWithBias { type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv_transpose2d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose2d_weight_backward( x.clone(), weight, grad.clone(), options, ); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv_transpose2d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for ConvTranspose2DNoBias { type State = (NodeId, NodeId, ConvTransposeOptions<2>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv_transpose2d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match ConvTranspose2DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv_transpose2d( x.primitive, weight.primitive, Some(bias.primitive), options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match ConvTranspose2DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv_transpose2d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d( x.primitive, weight.primitive, None, options, )), }, } } fn conv3d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvOptions<3>, ) -> AutodiffTensor { #[derive(Debug)] struct Conv3DWithBias; #[derive(Debug)] struct Conv3DNoBias; impl Backward for Conv3DWithBias { type State = (NodeId, NodeId, NodeId, ConvOptions<3>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv3d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv3d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for Conv3DNoBias { type State = (NodeId, NodeId, ConvOptions<3>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv3d_x_backward( x.clone(), weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv3d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match Conv3DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv3d(x.primitive, weight.primitive, Some(bias.primitive), options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv3d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match Conv3DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv3d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => { prep.finish(B::conv3d(x.primitive, weight.primitive, None, options)) } }, } } fn conv_transpose3d( x: AutodiffTensor, weight: AutodiffTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> AutodiffTensor { #[derive(Debug)] struct ConvTranspose3DWithBias; #[derive(Debug)] struct ConvTranspose3DNoBias; impl Backward for ConvTranspose3DWithBias { type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<3>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight, node_bias] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, bias_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); let bias = checkpointer.retrieve_node_output::(bias_state); if let Some(node) = node_x { let grad = B::conv_transpose3d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose3d_weight_backward( x.clone(), weight, grad.clone(), options, ); grads.register::(node.id, grad) } if let Some(node) = node_bias { let grad = B::conv_transpose3d_bias_backward(x, bias, grad); grads.register::(node.id, grad) } } } impl Backward for ConvTranspose3DNoBias { type State = (NodeId, NodeId, ConvTransposeOptions<3>); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_x, node_weight] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, weight_state, options) = ops.state; let x = checkpointer.retrieve_node_output::(x_state); let weight = checkpointer.retrieve_node_output::(weight_state); if let Some(node) = node_x { let grad = B::conv_transpose3d_x_backward( weight.clone(), grad.clone(), options.clone(), ); grads.register::(node.id, grad) } if let Some(node) = node_weight { let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options); grads.register::(node.id, grad) } } } match bias { Some(bias) => match ConvTranspose3DWithBias .prepare::([x.node.clone(), weight.node.clone(), bias.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); let bias_state = prep.checkpoint(&bias); prep.finish( (x_state, weight_state, bias_state, options.clone()), B::conv_transpose3d( x.primitive, weight.primitive, Some(bias.primitive), options, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d( x.primitive, weight.primitive, Some(bias.primitive), options, )), }, None => match ConvTranspose3DNoBias .prepare::([x.node.clone(), weight.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let weight_state = prep.checkpoint(&weight); prep.finish( (x_state, weight_state, options.clone()), B::conv_transpose3d(x.primitive, weight.primitive, None, options), ) } OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d( x.primitive, weight.primitive, None, options, )), }, } } // TODO: Support a custom unfold4d operation by overriding the default implementation. // // We don't override it now because the fold operation isn't available for the backward pass. // This implies that when autodiff is enabled, custom unfold operations defined by backends // won't be used. Instead, the conv2d operation with custom weights matrix will be used. // Therefore, the conv2d backward pass will be used for the unfold4d backward pass. // // fn unfold4d( // x:AutodiffTensor, // kernel_size: [usize; 2], // options: UnfoldOptions, // ) -> AutodiffTensor { // todo!() // } fn avg_pool1d( x: AutodiffTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> AutodiffTensor { #[derive(Debug)] struct AvgPool1D; impl Backward for AvgPool1D { type State = (NodeId, usize, usize, usize, bool, bool); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) = ops.state; let x = checkpointer.retrieve_node_output(x_state); if let Some(node) = node_parent { let grad = B::avg_pool1d_backward( x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode, ); grads.register::(node.id, grad); } } } match AvgPool1D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); prep.finish( ( x_state, kernel_size, stride, padding, count_include_pad, ceil_mode, ), B::avg_pool1d( x.primitive.clone(), kernel_size, stride, padding, count_include_pad, ceil_mode, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d( x.primitive, kernel_size, stride, padding, count_include_pad, ceil_mode, )), } } fn avg_pool2d( x: AutodiffTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> AutodiffTensor { #[derive(Debug)] struct AvgPool2D; impl Backward for AvgPool2D { type State = (NodeId, [usize; 2], [usize; 2], [usize; 2], bool, bool); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) = ops.state; let x = checkpointer.retrieve_node_output(x_state); if let Some(node) = node_parent { let grad = B::avg_pool2d_backward( x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode, ); grads.register::(node.id, grad); } } } match AvgPool2D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); prep.finish( ( x_state, kernel_size, stride, padding, count_include_pad, ceil_mode, ), B::avg_pool2d( x.primitive.clone(), kernel_size, stride, padding, count_include_pad, ceil_mode, ), ) } OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d( x.primitive, kernel_size, stride, padding, count_include_pad, ceil_mode, )), } } fn avg_pool2d_backward( _x: AutodiffTensor, _grad: AutodiffTensor, _kernel_size: [usize; 2], _stride: [usize; 2], _padding: [usize; 2], _count_include_pad: bool, _ceil_mode: bool, ) -> AutodiffTensor { panic!("Can't differentiate avg pool 2d backward."); } fn max_pool1d( x: AutodiffTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> AutodiffTensor { match MaxPool1D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let output = B::max_pool1d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); prep.finish( ( x_state, output.indices, kernel_size, stride, padding, dilation, ceil_mode, ), output.output, ) } OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, )), } } fn max_pool1d_with_indices( x: AutodiffTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { match MaxPool1D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let output = B::max_pool1d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); let output_tensor = prep.finish( ( x_state, output.indices.clone(), kernel_size, stride, padding, dilation, ceil_mode, ), output.output, ); MaxPool1dWithIndices::new(output_tensor, output.indices) } OpsKind::UnTracked(prep) => { let output = B::max_pool1d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); let output_tensor = prep.finish(output.output); MaxPool1dWithIndices::new(output_tensor, output.indices) } } } fn max_pool1d_with_indices_backward( x: AutodiffTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: AutodiffTensor, indices: IntTensor, ) -> MaxPool1dBackward { let output = B::max_pool1d_with_indices_backward( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, output_grad.primitive, indices, ); MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad)) } fn max_pool2d( x: AutodiffTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> AutodiffTensor { match MaxPool2D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let output = B::max_pool2d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); prep.finish( ( x_state, output.indices, kernel_size, stride, padding, dilation, ceil_mode, ), output.output, ) } OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, )), } } fn max_pool2d_with_indices( x: AutodiffTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { match MaxPool2D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let output = B::max_pool2d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); let output_tensor = prep.finish( ( x_state, output.indices.clone(), kernel_size, stride, padding, dilation, ceil_mode, ), output.output, ); MaxPool2dWithIndices::new(output_tensor, output.indices) } OpsKind::UnTracked(prep) => { let output = B::max_pool2d_with_indices( x.primitive, kernel_size, stride, padding, dilation, ceil_mode, ); let output_tensor = prep.finish(output.output); MaxPool2dWithIndices::new(output_tensor, output.indices) } } } fn max_pool2d_with_indices_backward( _x: AutodiffTensor, _kernel_size: [usize; 2], _stride: [usize; 2], _padding: [usize; 2], _dilation: [usize; 2], _ceil_mode: bool, _output_grad: AutodiffTensor, _indices: IntTensor, ) -> MaxPool2dBackward { panic!("Can't differentiate max pool2d with indices backward."); } fn adaptive_avg_pool1d(x: AutodiffTensor, output_size: usize) -> AutodiffTensor { #[derive(Debug)] struct AdaptiveAvgPool1D; impl Backward for AdaptiveAvgPool1D { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let state = checkpointer.retrieve_node_output(ops.state); if let Some(node) = node_parent { let grad = B::adaptive_avg_pool1d_backward(state, grad); grads.register::(node.id, grad); } } } match AdaptiveAvgPool1D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); prep.finish(x_state, B::adaptive_avg_pool1d(x.primitive, output_size)) } OpsKind::UnTracked(prep) => { prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size)) } } } fn adaptive_avg_pool2d(x: AutodiffTensor, output_size: [usize; 2]) -> AutodiffTensor { #[derive(Debug)] struct AdaptiveAvgPool2D; impl Backward for AdaptiveAvgPool2D { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let state = checkpointer.retrieve_node_output(ops.state); if let Some(node) = node_parent { let grad = B::adaptive_avg_pool2d_backward(state, grad); grads.register::(node.id, grad); } } } match AdaptiveAvgPool2D .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); prep.finish(x_state, B::adaptive_avg_pool2d(x.primitive, output_size)) } OpsKind::UnTracked(prep) => { prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size)) } } } fn adaptive_avg_pool2d_backward( _x: AutodiffTensor, _grad: AutodiffTensor, ) -> as Backend>::FloatTensorPrimitive { panic!("Can't differentiate adaptive avg pool2d backward."); } fn interpolate( x: AutodiffTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> AutodiffTensor { #[derive(Debug)] struct Interpolate; impl Backward for Interpolate { type State = (NodeId, [usize; 2], InterpolateOptions); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, output_size, options) = ops.state; let state = checkpointer.retrieve_node_output(x_state); if let Some(node) = node_parent { let grad = B::interpolate_backward(state, grad, output_size, options); grads.register::(node.id, grad); } } } match Interpolate .prepare::([x.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let x_state = prep.checkpoint(&x); let output = B::interpolate(x.primitive.clone(), output_size, options.clone()); prep.finish((x_state, output_size, options), output) } OpsKind::UnTracked(prep) => { prep.finish(B::interpolate(x.primitive, output_size, options)) } } } fn interpolate_backward( _x: FloatTensor>, _grad: FloatTensor>, _output_size: [usize; 2], _options: InterpolateOptions, ) -> as Backend>::FloatTensorPrimitive { panic!("Can't differentiate interpolate backward."); } fn attention( query: FloatTensor>, key: FloatTensor>, value: FloatTensor>, mask: Option>>, attn_bias: Option>>, options: AttentionModuleOptions, ) -> FloatTensor> { attention_fallback::(query, key, value, mask, attn_bias, options) } } #[derive(Debug)] struct MaxPool1D; impl Backward for MaxPool1D { type State = (NodeId, IntTensor, usize, usize, usize, usize, bool); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state; let x = checkpointer.retrieve_node_output(x_state); if let Some(node) = node_parent { let grad = B::max_pool1d_with_indices_backward( x, kernel_size, stride, padding, dilation, ceil_mode, grad, indices, ); grads.register::(node.id, grad.x_grad); } } } #[derive(Debug)] struct MaxPool2D; impl Backward for MaxPool2D { type State = ( NodeId, IntTensor, [usize; 2], [usize; 2], [usize; 2], [usize; 2], bool, ); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let [node_parent] = ops.parents; let grad = grads.consume::(&ops.node); let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state; let x = checkpointer.retrieve_node_output(x_state); if let Some(node) = node_parent { let grad = B::max_pool2d_with_indices_backward( x, kernel_size, stride, padding, dilation, ceil_mode, grad, indices, ); grads.register::(node.id, grad.x_grad); } } } ================================================ FILE: crates/burn-autodiff/src/ops/qtensor.rs ================================================ use burn_backend::{ Backend, ExecutionError, TensorData, ops::QTensorOps, tensor::{ Device, FloatTensor, IntTensor, QuantizedTensor, quantization::QuantizationParametersPrimitive, }, }; use burn_std::{QuantScheme, Shape}; use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy}; impl QTensorOps for Autodiff { fn q_from_data(_data: TensorData, _device: &Device) -> QuantizedTensor { todo!() } fn quantize( _tensor: FloatTensor, _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { todo!() // required for QAT } fn quantize_dynamic( _tensor: FloatTensor, _scheme: &QuantScheme, ) -> QuantizedTensor { todo!() } fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { todo!() } fn q_device(tensor: &QuantizedTensor) -> Device { B::q_device(tensor) } fn q_to_device( _tensor: QuantizedTensor, _device: &Device, ) -> QuantizedTensor { unimplemented!() } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { B::q_reshape(tensor, shape) } async fn q_into_data(tensor: QuantizedTensor) -> Result { B::q_into_data(tensor).await } fn q_swap_dims( _tensor: QuantizedTensor, _dim1: usize, _dim2: usize, ) -> QuantizedTensor { unimplemented!() } fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_gather( _dim: usize, _tensor: QuantizedTensor, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_select( _tensor: QuantizedTensor, _dim: usize, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_slice( _tensor: QuantizedTensor, _slices: &[burn_std::Slice], ) -> QuantizedTensor { unimplemented!() } fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor { B::q_argmax(tensor, dim) } fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor { B::q_argmin(tensor, dim) } fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } } ================================================ FILE: crates/burn-autodiff/src/ops/sort.rs ================================================ use super::{Backward, Ops, unary}; use crate::{checkpoint::base::Checkpointer, grads::Gradients}; use burn_backend::{Backend, TensorMetadata}; use burn_std::Shape; #[derive(Debug)] pub(crate) struct SortDim; impl Backward for SortDim { type State = (B::IntTensorPrimitive, Shape, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let (indices, shape, dim) = ops.state; let device = B::float_device(&grad); let dtype = grad.dtype(); let zeros = B::float_zeros(shape, &device, dtype.into()); B::float_scatter_add(dim, zeros, indices, grad) }); } } ================================================ FILE: crates/burn-autodiff/src/ops/tensor.rs ================================================ use alloc::{boxed::Box, vec, vec::Vec}; use core::marker::PhantomData; #[cfg(not(feature = "std"))] #[allow(unused_imports, reason = "required on aarch64, unused on x86_64")] use num_traits::float::Float; use crate::{ Autodiff, checkpoint::{ base::Checkpointer, builder::CheckpointerBuilder, retro_forward::RetroForward, state::BackwardStates, strategy::CheckpointStrategy, }, grads::Gradients, graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step}, ops::{Backward, Ops, OpsKind, binary, broadcast_shape, unary}, retro_binary, retro_unary, retro_unary_scalar, tensor::AutodiffTensor, utils::duplicate, }; use burn_backend::{ Backend, ExecutionError, TensorData, TensorMetadata, ops::FloatTensorOps, tensor::{BoolTensor, Device, FloatTensor, IntTensor}, }; use burn_backend::{Scalar, ops::unfold::calculate_unfold_windows}; use burn_std::{FloatDType, Shape, Slice}; use super::maxmin::MaxMinDim; // Unsqueeze op on primitive. fn unsqueeze_like( tensor: B::FloatTensorPrimitive, shape: Shape, ) -> B::FloatTensorPrimitive { let ndims_out = shape.num_dims(); let shape = tensor.shape(); let ndims_in = shape.num_dims(); let mut dims = vec![1; ndims_out]; let num_ones = ndims_out - ndims_in; dims[num_ones..(ndims_in + num_ones)].copy_from_slice(&shape[..ndims_in]); B::float_reshape(tensor, Shape::from(dims)) } impl FloatTensorOps for Autodiff { #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(data), fields(?data.shape, ?data.dtype) ))] fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { AutodiffTensor::new(B::float_from_data(data, device)) } fn float_random( shape: Shape, distribution: burn_backend::Distribution, device: &Device, ) -> FloatTensor { AutodiffTensor::new(B::float_random(shape, distribution, device)) } fn float_zeros(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { AutodiffTensor::new(B::float_zeros(shape, device, dtype)) } fn float_ones(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { AutodiffTensor::new(B::float_ones(shape, device, dtype)) } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields( from = ?tensor.node, shape = ?tensor.shape(), dtype = ?tensor.dtype(), ) ))] async fn float_into_data(tensor: FloatTensor) -> Result { B::float_into_data(tensor.primitive).await } fn float_device(tensor: &FloatTensor) -> Device { B::float_device(&tensor.primitive) } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields( from = ?tensor.node, shape = ?tensor.shape(), dtype = ?tensor.dtype(), ) ))] fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { #[derive(Debug)] struct ToDevice; impl Backward for ToDevice { type State = B::Device; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { B::float_to_device(grad, &ops.state) }); } } match ToDevice .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let device_old = B::float_device(&tensor.primitive); prep.finish(device_old, B::float_to_device(tensor.primitive, device)) } OpsKind::UnTracked(prep) => prep.finish(B::float_to_device(tensor.primitive, device)), } } fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { AutodiffTensor::new(B::float_empty(shape, device, dtype)) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Add; retro_binary!(RetroAdd, B::float_add); impl Backward for Add { type State = (Shape, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape_lhs, shape_rhs) = ops.state; binary::( ops.parents, ops.node, grads, |grad| broadcast_shape::(grad, &shape_lhs), |grad| broadcast_shape::(grad, &shape_rhs), ); } } match Add .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroAdd::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (lhs.primitive.shape(), rhs.primitive.shape()), B::float_add(lhs.primitive, rhs.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_add(lhs.primitive, rhs.primitive)), } } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { #[derive(Debug)] struct AddScalar; retro_unary_scalar!(RetroAddScalar, B::float_add_scalar); impl Backward for AddScalar { type State = (); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| grad); } } AddScalar .prepare::([lhs.node.clone()]) .memory_bound() .retro_forward(RetroAddScalar::::new(lhs.node.id, rhs)) .parents([&lhs]) .stateless(B::float_add_scalar(lhs.primitive, rhs)) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sub; retro_binary!(RetroSub, B::float_sub); impl Backward for Sub { type State = (Shape, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape_lhs, shape_rhs) = ops.state; binary::( ops.parents, ops.node, grads, |grad| broadcast_shape::(grad, &shape_lhs), |grad| broadcast_shape::(B::float_neg(grad), &shape_rhs), ); } } match Sub .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroSub::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (lhs.primitive.shape(), rhs.primitive.shape()), B::float_sub(lhs.primitive, rhs.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_sub(lhs.primitive, rhs.primitive)), } } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { #[derive(Debug)] struct SubScalar; retro_unary_scalar!(RetroSubScalar, B::float_sub_scalar); impl Backward for SubScalar { type State = (); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| grad); } } SubScalar .prepare::([lhs.node.clone()]) .memory_bound() .retro_forward(RetroSubScalar::::new(lhs.node.id, rhs)) .parents([&lhs]) .stateless(B::float_sub_scalar(lhs.primitive, rhs)) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Mul; retro_binary!(RetroMul, B::float_mul); impl Backward for Mul { type State = (Option, Option, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs, rhs, broadcast) = ops.state; let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs)); let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs)); binary::( ops.parents, ops.node, grads, |grad| { let grad = B::float_mul(grad, rhs.unwrap()); broadcast.backward_lhs::(grad) }, |grad| { let grad = B::float_mul(grad, lhs.unwrap()); broadcast.backward_rhs::(grad) }, ); } } let lhs_tracked = lhs.is_tracked(); let rhs_tracked = rhs.is_tracked(); let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); match Mul .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroMul::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs)); let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs)); prep.finish( (lhs_state, rhs_state, broadcast), B::float_mul(lhs.primitive, rhs.primitive), ) } OpsKind::UnTracked(prep) => prep.finish(B::float_mul(lhs.primitive, rhs.primitive)), } } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { #[derive(Debug)] struct MulScalar; retro_unary_scalar!(RetroMulScalar, B::float_mul_scalar); impl Backward for MulScalar { type State = Scalar; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { B::float_mul_scalar(grad, ops.state) }); } } match MulScalar .prepare::([lhs.node.clone()]) .memory_bound() .retro_forward(RetroMulScalar::::new(lhs.node.id, rhs)) .parents([&lhs]) .stateful() { OpsKind::Tracked(prep) => prep.finish(rhs, B::float_mul_scalar(lhs.primitive, rhs)), OpsKind::UnTracked(prep) => prep.finish(B::float_mul_scalar(lhs.primitive, rhs)), } } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Div; retro_binary!(RetroDiv, B::float_div); impl Backward for Div { type State = (Option, Option, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs, rhs, broadcast) = ops.state; let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs)); let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs)); let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs); binary::( ops.parents, ops.node, grads, |grad| { let rhs = rhs_4lhs.unwrap(); let value = B::float_recip(rhs); let grad = B::float_mul(grad, value); broadcast.backward_lhs::(grad) }, |grad| { let rhs = rhs_4rhs.unwrap(); let lhs = lhs.unwrap(); let value = B::float_div(B::float_neg(lhs), B::float_powi_scalar(rhs, 2.into())); let grad = B::float_mul(grad, value); broadcast.backward_rhs::(grad) }, ); } } let lhs_tracked = lhs.is_tracked(); let rhs_tracked = rhs.is_tracked(); let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); match Div .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroDiv::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs)); let rhs_state = (lhs_tracked || rhs_tracked).then(|| prep.checkpoint(&rhs)); prep.finish( (lhs_state, rhs_state, broadcast), B::float_div(lhs.primitive, rhs.primitive), ) } OpsKind::UnTracked(prep) => prep.finish(B::float_div(lhs.primitive, rhs.primitive)), } } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { #[derive(Debug)] struct DivScalar; retro_unary_scalar!(RetroDivScalar, B::float_div_scalar); impl Backward for DivScalar { type State = Scalar; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let tmp = 1.0 / ops.state.elem::(); B::float_mul_scalar(grad, tmp.into()) }); } } match DivScalar .prepare::([lhs.node.clone()]) .memory_bound() .retro_forward(RetroDivScalar::::new(lhs.node.id, rhs)) .parents([&lhs]) .stateful() { OpsKind::Tracked(prep) => prep.finish(rhs, B::float_div_scalar(lhs.primitive, rhs)), OpsKind::UnTracked(prep) => prep.finish(B::float_div_scalar(lhs.primitive, rhs)), } } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Rem; retro_binary!(RetroRem, B::float_remainder); impl Backward for Rem { type State = (Option, Option, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs, rhs, broadcast) = ops.state; let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs)); let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs)); binary::( ops.parents, ops.node, grads, |grad| { // remainder(x, y) = x - floor(x / y) * y // partial(x - floor(x / y) * y, x) = 1 broadcast.backward_lhs::(grad) }, |grad| { // partial(x - floor(x / y) * y, y) = - floor(x / y) let rhs = rhs.unwrap(); let lhs = lhs.unwrap(); let value = B::float_neg(B::float_floor(B::float_div(lhs, rhs))); let grad = B::float_mul(grad, value); broadcast.backward_rhs::(grad) }, ); } } let lhs_tracked = lhs.is_tracked(); let rhs_tracked = rhs.is_tracked(); let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); match Rem .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroRem::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs)); let rhs_state = (lhs_tracked || rhs_tracked).then(|| prep.checkpoint(&rhs)); prep.finish( (lhs_state, rhs_state, broadcast), B::float_remainder(lhs.primitive, rhs.primitive), ) } OpsKind::UnTracked(prep) => { prep.finish(B::float_remainder(lhs.primitive, rhs.primitive)) } } } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { #[derive(Debug)] struct RemainderScalar; retro_unary_scalar!(RetroRemainderScalar, B::float_remainder_scalar); impl Backward for RemainderScalar { type State = (); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| grad); } } RemainderScalar .prepare::([lhs.node.clone()]) .memory_bound() .retro_forward(RetroRemainderScalar::::new(lhs.node.id, rhs)) .parents([&lhs]) .stateless(B::float_remainder_scalar(lhs.primitive, rhs)) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Matmul; impl Backward for Matmul { type State = (Option, Option, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs, rhs, broadcast) = ops.state; let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs)); let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs)); binary::( ops.parents, ops.node, grads, |grad| { let rhs = B::float_transpose(rhs.unwrap()); let grad = B::float_matmul(grad, rhs); broadcast.backward_lhs::(grad) }, |grad| { let lhs = B::float_transpose(lhs.unwrap()); let grad = B::float_matmul(lhs, grad); broadcast.backward_rhs::(grad) }, ); } } let lhs_tracked = lhs.is_tracked(); let rhs_tracked = rhs.is_tracked(); let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); match Matmul .prepare::([lhs.node.clone(), rhs.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs)); let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs)); prep.finish( (lhs_state, rhs_state, broadcast), B::float_matmul(lhs.primitive, rhs.primitive), ) } OpsKind::UnTracked(prep) => prep.finish(B::float_matmul(lhs.primitive, rhs.primitive)), } } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { #[derive(Debug)] struct Cross; impl Backward for Cross { type State = (Option, Option, usize); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs_id, rhs_id, dim) = ops.state; let lhs = lhs_id.map(|id| checkpointer.retrieve_node_output(id)); let rhs = rhs_id.map(|id| checkpointer.retrieve_node_output(id)); binary::( ops.parents, ops.node, grads, |grad| B::float_cross(rhs.unwrap(), grad, dim), |grad| B::float_cross(grad, lhs.unwrap(), dim), ); } } let lhs_tracked = lhs.is_tracked(); let rhs_tracked = rhs.is_tracked(); match Cross .prepare::([lhs.node.clone(), rhs.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs)); let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs)); prep.finish( (lhs_state, rhs_state, dim), B::float_cross(lhs.primitive, rhs.primitive, dim), ) } OpsKind::UnTracked(prep) => { prep.finish(B::float_cross(lhs.primitive, rhs.primitive, dim)) } } } fn float_neg(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Neg; retro_unary!(RetroNeg, B::float_neg); impl Backward for Neg { type State = (); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| B::float_neg(grad)); } } Neg.prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroNeg::::new(tensor.node.id)) .parents([&tensor]) .stateless(B::float_neg(tensor.primitive)) } fn float_recip(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Recip; retro_unary!(RetroRecip, B::float_recip); impl Backward for Recip { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let tensor = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let tmp = B::float_powi_scalar(tensor, (-2).into()); let value = B::float_neg(tmp); B::float_mul(grad, value) }); } } match Recip .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroRecip::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_recip(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_recip(tensor.primitive)), } } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { #[derive(Debug)] struct SwapDim; #[derive(new, Debug)] struct RetroSwapDims { input_id: NodeId, dim1: usize, dim2: usize, _backend: PhantomData, } impl RetroForward for RetroSwapDims { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_swap_dims(input, self.dim1, self.dim2); states.save(out_node, out) } } impl Backward for SwapDim { type State = (usize, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim1, dim2) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_swap_dims(grad, dim2, dim1) }); } } match SwapDim .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSwapDims::::new(tensor.node.id, dim1, dim2)) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( (dim1, dim2), B::float_swap_dims(tensor.primitive, dim1, dim2), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_swap_dims(tensor.primitive, dim1, dim2)) } } } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { #[derive(Debug)] struct PermuteDim; #[derive(new, Debug)] struct RetroPermuteDims { input_id: NodeId, axes: Vec, _backend: PhantomData, } impl RetroForward for RetroPermuteDims { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_permute(input, &self.axes); states.save(out_node, out) } } impl Backward for PermuteDim { type State = Vec; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let axes = ops.state; let mut inverse = vec![0usize; axes.len()]; axes.iter() .enumerate() .for_each(|(i, &axis)| inverse[axis] = i); unary::(ops.parents, ops.node, grads, |grad| { B::float_permute(grad, &inverse) }); } } match PermuteDim .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroPermuteDims::::new(tensor.node.id, axes.to_vec())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => { prep.finish(axes.to_vec(), B::float_permute(tensor.primitive, axes)) } OpsKind::UnTracked(prep) => prep.finish(B::float_permute(tensor.primitive, axes)), } } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { #[derive(Debug)] struct FlipDim; #[derive(new, Debug)] struct RetroFlipDims { input_id: NodeId, axes: Vec, _backend: PhantomData, } impl RetroForward for RetroFlipDims { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_flip(input, &self.axes); states.save(out_node, out) } } impl Backward for FlipDim { type State = Vec; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let axes = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_flip(grad, &axes) }); } } match FlipDim .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroFlipDims::::new(tensor.node.id, axes.to_vec())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => { prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes)) } OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)), } } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { #[derive(Debug)] struct ReshapeDim; #[derive(new, Debug)] struct RetroReshape { input_id: NodeId, shape: Shape, _backend: PhantomData, } impl RetroForward for RetroReshape { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_reshape(input, self.shape.clone()); states.save(out_node, out) } } impl Backward for ReshapeDim { type State = (Shape, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape_original, shape) = ops.state; let ndims_out = shape.num_dims(); unary::(ops.parents, ops.node, grads, |grad| { let shape_grad = grad.shape(); let mut grad = grad; for i in 0..ndims_out { if shape[i] == 1 && shape_grad[i] != 1 { grad = B::float_sum_dim(grad, i); } } B::float_reshape(grad, shape_original) }); } } match ReshapeDim .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroReshape::::new(tensor.node.id, shape.clone())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), shape.clone()), B::float_reshape(tensor.primitive, shape), ), OpsKind::UnTracked(prep) => prep.finish(B::float_reshape(tensor.primitive, shape)), } } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { #[derive(Debug)] struct Gather; impl Backward for Gather { type State = (usize, IntTensor, Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim, indices, shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let zeros = B::float_zeros(shape, &device, grad.dtype().into()); B::float_scatter_add(dim, zeros, indices, grad) }); } } match Gather .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( ( dim, indices.clone(), tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_gather(dim, tensor.primitive, indices), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_gather(dim, tensor.primitive, indices)) } } } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { #[derive(Debug)] struct Scatter; impl Backward for Scatter { type State = (usize, IntTensor); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim, indices) = ops.state; let [_, indices_4rhs] = duplicate(&ops.parents, Some(indices)); binary::( ops.parents, ops.node, grads, |grad| grad, |grad| B::float_gather(dim, grad, indices_4rhs.unwrap()), ); } } match Scatter .prepare::([tensor.node, value.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (dim, indices.clone()), B::float_scatter_add(dim, tensor.primitive, indices, value.primitive), ), OpsKind::UnTracked(prep) => prep.finish(B::float_scatter_add( dim, tensor.primitive, indices, value.primitive, )), } } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { #[derive(Debug)] struct Select; #[derive(new, Debug)] struct RetroSelect { input_id: NodeId, dim: usize, indices: IntTensor, } impl RetroForward for RetroSelect { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_select(input, self.dim, self.indices.clone()); states.save(out_node, out) } } impl Backward for Select { type State = (usize, IntTensor, Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim, indices, shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let zeros = B::float_zeros(shape, &device, grad.dtype().into()); B::float_select_add(zeros, dim, indices, grad) }); } } match Select .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSelect::::new(tensor.node.id, dim, indices.clone())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( ( dim, indices.clone(), tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_select(tensor.primitive, dim, indices), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_select(tensor.primitive, dim, indices)) } } } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { #[derive(Debug)] struct IndexSelectDimAssign; #[derive(new, Debug)] struct RetroSelectAssign { tensor_id: NodeId, dim: usize, indices: IntTensor, value_id: NodeId, } impl RetroForward for RetroSelectAssign { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let tensor = states.get_state::(&self.tensor_id); let value = states.get_state::(&self.value_id); let out = B::float_select_add(tensor, self.dim, self.indices.clone(), value); states.save(out_node, out) } } impl Backward for IndexSelectDimAssign { type State = (usize, IntTensor); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim, indices) = ops.state; binary::( ops.parents, ops.node, grads, |grad| grad, |grad| B::float_select(grad, dim, indices), ); } } match IndexSelectDimAssign .prepare::([tensor.node.clone(), value.node.clone()]) .memory_bound() .retro_forward(RetroSelectAssign::::new( tensor.node.id, dim, indices.clone(), value.node.id, )) .parents([&tensor, &value]) .stateful() { OpsKind::Tracked(prep) => prep.finish( (dim, indices.clone()), B::float_select_add(tensor.primitive, dim, indices, value.primitive), ), OpsKind::UnTracked(prep) => prep.finish(B::float_select_add( tensor.primitive, dim, indices, value.primitive, )), } } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { #[derive(Debug)] struct Index; #[derive(new, Debug)] struct RetroSlice { tensor_id: NodeId, slices: Vec, _backend: PhantomData, } impl RetroForward for RetroSlice { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let tensor = states.get_state::(&self.tensor_id); let out = B::float_slice(tensor, &self.slices); states.save(out_node, out) } } impl Backward for Index { type State = (Vec, Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (slices, shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let zeros = B::float_zeros(shape, &device, grad.dtype().into()); B::float_slice_assign(zeros, &slices, grad) }); } } match Index .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSlice::::new(tensor.node.id, slices.to_vec())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( ( slices.to_vec(), tensor.primitive.shape(), B::float_device(&tensor.primitive), ), B::float_slice(tensor.primitive, slices), ), OpsKind::UnTracked(prep) => prep.finish(B::float_slice(tensor.primitive, slices)), } } fn float_slice_assign( tensor: FloatTensor, slices: &[Slice], value: FloatTensor, ) -> FloatTensor { #[derive(Debug)] struct SliceAssign; #[derive(new, Debug)] struct RetroSliceAssign { tensor_id: NodeId, slices: Vec, value_id: NodeId, _backend: PhantomData, } impl RetroForward for RetroSliceAssign { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let tensor = states.get_state::(&self.tensor_id); let value = states.get_state::(&self.value_id); let out = B::float_slice_assign(tensor, &self.slices, value); states.save(out_node, out) } } impl Backward for SliceAssign { type State = (Vec, Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (slices, shape_rhs, device) = ops.state; let [slices_4lhs, slices_4rhs] = duplicate(&ops.parents, Some(slices)); binary::( ops.parents, ops.node, grads, |grad| { let zeros = B::float_zeros(shape_rhs, &device, grad.dtype().into()); B::float_slice_assign(grad, &slices_4lhs.unwrap(), zeros) }, |grad| B::float_slice(grad, &slices_4rhs.unwrap()), ); } } match SliceAssign .prepare::([tensor.node.clone(), value.node.clone()]) .memory_bound() .retro_forward(RetroSliceAssign::::new( tensor.node.id, slices.to_vec(), value.node.id, )) .parents([&tensor, &value]) .stateful() { OpsKind::Tracked(prep) => prep.finish( ( slices.to_vec(), value.primitive.shape(), B::float_device(&value.primitive), ), B::float_slice_assign(tensor.primitive, slices, value.primitive), ), OpsKind::UnTracked(prep) => prep.finish(B::float_slice_assign( tensor.primitive, slices, value.primitive, )), } } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, source: FloatTensor, ) -> FloatTensor { #[derive(Debug)] struct MaskWhere; impl Backward for MaskWhere { type State = (BoolTensor, Shape, Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (mask, shape_lhs, shape_rhs, device) = ops.state; let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask)); binary::( ops.parents, ops.node, grads, |grad| { let zeros = B::float_zeros(shape_lhs.clone(), &device, grad.dtype().into()); let grad = B::float_mask_where(grad, mask_4lhs.unwrap(), zeros); broadcast_shape::(grad, &shape_lhs) }, |grad| { let zeros = B::float_zeros(shape_rhs.clone(), &device, grad.dtype().into()); let grad = B::float_mask_where(zeros, mask_4rhs.unwrap(), grad); broadcast_shape::(grad, &shape_rhs) }, ); } } match MaskWhere .prepare::([tensor.node, source.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( ( mask.clone(), tensor.primitive.shape(), source.primitive.shape(), B::float_device(&source.primitive), ), B::float_mask_where(tensor.primitive, mask, source.primitive), ), OpsKind::UnTracked(prep) => prep.finish(B::float_mask_where( tensor.primitive, mask, source.primitive, )), } } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { #[derive(Debug)] struct MaskFill; impl Backward for MaskFill { type State = BoolTensor; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { B::float_mask_fill(grad, ops.state, 0f32.into()) }); } } match MaskFill .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( mask.clone(), B::float_mask_fill(tensor.primitive, mask, value), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_mask_fill(tensor.primitive, mask, value)) } } } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { B::float_equal(lhs.primitive, rhs.primitive) } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { B::float_equal_elem(lhs.primitive, rhs) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { B::float_greater(lhs.primitive, rhs.primitive) } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { B::float_greater_elem(lhs.primitive, rhs) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { B::float_greater_equal(lhs.primitive, rhs.primitive) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { B::float_greater_equal_elem(lhs.primitive, rhs) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { B::float_lower(lhs.primitive, rhs.primitive) } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { B::float_lower_elem(lhs.primitive, rhs) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { B::float_lower_equal(lhs.primitive, rhs.primitive) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { B::float_lower_equal_elem(lhs.primitive, rhs) } fn float_is_nan(tensor: FloatTensor) -> BoolTensor { B::float_is_nan(tensor.primitive) } fn float_is_inf(tensor: FloatTensor) -> BoolTensor { B::float_is_inf(tensor.primitive) } fn float_detach(tensor: FloatTensor) -> FloatTensor { // When we detach a tensor, we remove it from the graph, but we still want to keep the // `require_grad` setting. let is_require_grad = Self::float_is_require_grad(&tensor); let tensor = AutodiffTensor::new(tensor.primitive); match is_require_grad { true => tensor.require_grad(), false => tensor, } } fn float_set_require_grad(tensor: FloatTensor, require_grad: bool) -> FloatTensor { if require_grad { return tensor.require_grad(); } AutodiffTensor::new(tensor.primitive) } fn float_is_require_grad(tensor: &FloatTensor) -> bool { matches!(tensor.node.requirement, Requirement::Grad) } fn float_mean(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Mean; impl Backward for Mean { type State = Shape; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let shape = ops.state; let val = 1_f64 / shape.num_elements() as f64; let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into()); let val = B::float_mul_scalar(ones, val.into()); let grad = unsqueeze_like::(grad, val.shape()); B::float_mul(val, grad) }); } } match Mean.prepare::([tensor.node]).compute_bound().stateful() { OpsKind::Tracked(prep) => { prep.finish(tensor.primitive.shape(), B::float_mean(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_mean(tensor.primitive)), } } fn float_sum(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sum; impl Backward for Sum { type State = Shape; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let val = B::float_ones(ops.state, &B::float_device(&grad), grad.dtype().into()); let grad = unsqueeze_like::(grad, val.shape()); B::float_mul(val, grad) }); } } match Sum.prepare::([tensor.node]).compute_bound().stateful() { OpsKind::Tracked(prep) => { prep.finish(tensor.primitive.shape(), B::float_sum(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_sum(tensor.primitive)), } } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct MeanDim; impl Backward for MeanDim { type State = (Shape, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, dim) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let val = 1_f64 / shape[dim] as f64; let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into()); let val = B::float_mul_scalar(ones, val.into()); let grad = B::float_sum_dim(grad, dim); B::float_mul(val, grad) }); } } match MeanDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), dim), B::float_mean_dim(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_mean_dim(tensor.primitive, dim)), } } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct SumDim; impl Backward for SumDim { type State = (Shape, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, dim) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into()); let grad = B::float_sum_dim(grad, dim); B::float_mul(ones, grad) }); } } match SumDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), dim), B::float_sum_dim(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_sum_dim(tensor.primitive, dim)), } } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct CumSum; impl Backward for CumSum { type State = (Shape, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (_shape, dim) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { // Gradient of cumsum is cumsum of gradient in reverse let grad_reversed = B::float_flip(grad.clone(), &[dim]); let grad_cumsum = B::float_cumsum(grad_reversed, dim); B::float_flip(grad_cumsum, &[dim]) }); } } match CumSum .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), dim), B::float_cumsum(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)), } } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct CumProd; impl Backward for CumProd { type State = (B::FloatTensorPrimitive, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (input, dim) = ops.state; let output = B::float_cumprod(input.clone(), dim); unary::(ops.parents, ops.node, grads, |grad| { // Gradient of cumprod using negative step slicing // Formula: grad_input[i] = sum_{j>=i}(grad_output[j] * output[j] / input[i]) // = (1 / input[i]) * sum_{j>=i}(grad_output[j] * output[j]) // = (1 / input) * reverse_cumsum(grad * output) // // LIMITATION: This produces NaN when input contains zeros. // A proper zero-safe implementation requires more sophisticated algorithms // (see PyTorch's cumprod_backward or JAX's associative_scan approach). // TODO: Implement zero-safe gradient computation. // See: https://github.com/tracel-ai/burn/issues/3864 let grad_times_output = B::float_mul(grad, output.clone()); // Create slices to reverse along the specified dimension let shape = grad_times_output.shape(); let mut slices = vec![Slice::full(); shape.num_dims()]; slices[dim] = Slice::with_step(0, None, -1); // Reverse, cumsum, reverse back using negative step slicing let grad_reversed = B::float_slice(grad_times_output, &slices); let grad_cumsum = B::float_cumsum(grad_reversed, dim); let grad_result = B::float_slice(grad_cumsum, &slices); B::float_div(grad_result, input) }); } } match CumProd .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.clone(), dim), B::float_cumprod(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_cumprod(tensor.primitive, dim)), } } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct CumMin; impl Backward for CumMin { type State = (B::FloatTensorPrimitive, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (input, dim) = ops.state; let output = B::float_cummin(input.clone(), dim); unary::(ops.parents, ops.node, grads, |grad| { // Gradient flows to the input positions that produced each output // Use scatter to accumulate gradients (scatter does sum reduction) let shape = input.shape(); let device = B::float_device(&input); let dim_size = shape[dim] as i64; // Create indices [0, 1, 2, ...] along the dimension let arange_1d = B::int_arange(0..dim_size, &device); // Reshape to broadcast along the specified dimension let mut arange_shape = vec![1; shape.num_dims()]; arange_shape[dim] = dim_size as usize; let arange = B::int_reshape(arange_1d, Shape::from(arange_shape)); // Expand to match input shape let arange = B::int_expand(arange, shape.clone()); // Find where cummin[i] == input[i] (these are source positions) let is_source = B::float_equal(output.clone(), input.clone()); let is_source_int = B::bool_into_int(is_source); // Mask: where is_source, use index; else 0 let masked_indices = B::int_mul(arange, is_source_int); // Cummax propagates the last valid (non-zero) index forward let source_indices = B::int_cummax(masked_indices, dim); // Scatter gradients to source positions (sum reduction) let zeros = B::float_zeros(shape, &device, grad.dtype().into()); B::float_scatter_add(dim, zeros, source_indices, grad) }); } } match CumMin .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.clone(), dim), B::float_cummin(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_cummin(tensor.primitive, dim)), } } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(Debug)] struct CumMax; impl Backward for CumMax { type State = (B::FloatTensorPrimitive, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (input, dim) = ops.state; let output = B::float_cummax(input.clone(), dim); unary::(ops.parents, ops.node, grads, |grad| { // Gradient flows to the input positions that produced each output // Use scatter to accumulate gradients (scatter does sum reduction) let shape = input.shape(); let device = B::float_device(&input); let dim_size = shape[dim] as i64; // Create indices [0, 1, 2, ...] along the dimension let arange_1d = B::int_arange(0..dim_size, &device); // Reshape to broadcast along the specified dimension let mut arange_shape = vec![1; shape.num_dims()]; arange_shape[dim] = dim_size as usize; let arange = B::int_reshape(arange_1d, Shape::from(arange_shape)); // Expand to match input shape let arange = B::int_expand(arange, shape.clone()); // Find where cummax[i] == input[i] (these are source positions) let is_source = B::float_equal(output.clone(), input.clone()); let is_source_int = B::bool_into_int(is_source); // Mask: where is_source, use index; else 0 let masked_indices = B::int_mul(arange, is_source_int); // Cummax propagates the last valid (non-zero) index forward let source_indices = B::int_cummax(masked_indices, dim); // Scatter gradients to source positions (sum reduction) let zeros = B::float_zeros(shape, &device, grad.dtype().into()); B::float_scatter_add(dim, zeros, source_indices, grad) }); } } match CumMax .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.clone(), dim), B::float_cummax(tensor.primitive, dim), ), OpsKind::UnTracked(prep) => prep.finish(B::float_cummax(tensor.primitive, dim)), } } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmax(tensor.primitive, dim) } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmin(tensor.primitive, dim) } fn float_exp(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Exp; retro_unary!(RetroExp, B::float_exp); impl Backward for Exp { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); let output = B::float_exp(input); unary::(ops.parents, ops.node, grads, |grad| { B::float_mul(grad, output) }); } } match Exp .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroExp::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_exp(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_exp(tensor.primitive)), } } fn float_log(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Log; retro_unary!(RetroLog, B::float_log); impl Backward for Log { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_recip(input); B::float_mul(grad, value) }); } } match Log .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroLog::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_log(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_log(tensor.primitive)), } } fn float_log1p(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Log1P; retro_unary!(RetroLog1P, B::float_log1p); impl Backward for Log1P { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_add_scalar(input, 1f32.into()); let value = B::float_recip(value); B::float_mul(grad, value) }); } } match Log1P .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroLog1P::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_log1p(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_log1p(tensor.primitive)), } } fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor { #[derive(Debug)] struct PowfScalar; #[derive(new, Debug)] struct RetroPowfScalar { lhs_id: NodeId, rhs: f64, _backend: PhantomData, } impl RetroForward for RetroPowfScalar { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let lhs = states.get_state::(&self.lhs_id); let out = B::float_powf_scalar(lhs, self.rhs.into()); states.save(out_node, out) } } impl Backward for PowfScalar { type State = (NodeId, f64); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (tensor_id, value) = ops.state; let tensor = checkpointer.retrieve_node_output(tensor_id); unary::(ops.parents, ops.node, grads, |grad| { let tmp = B::float_powf_scalar(tensor, (value - 1.).into()); let value = B::float_mul_scalar(tmp, value.into()); B::float_mul(grad, value) }); } } match PowfScalar .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroPowfScalar::::new(tensor.node.id, value.elem())) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = (prep.checkpoint(&tensor), value.elem()); prep.finish(state, B::float_powf_scalar(tensor.primitive, value)) } OpsKind::UnTracked(prep) => prep.finish(B::float_powf_scalar(tensor.primitive, value)), } } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sqrt; retro_unary!(RetroSqrt, B::float_sqrt); impl Backward for Sqrt { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_div_scalar( B::float_powf_scalar(input, (-0.5).into()), 2f32.into(), ); B::float_mul(grad, value) }); } } match Sqrt .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSqrt::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_sqrt(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_sqrt(tensor.primitive)), } } fn float_abs(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Abs; retro_unary!(RetroAbs, B::float_abs); impl Backward for Abs { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let tensor: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(ops.state); let state = B::float_sign(tensor); unary::(ops.parents, ops.node, grads, |grad| { B::float_mul(grad, state) }); } } match Abs .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAbs::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_abs(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_abs(tensor.primitive)), } } fn float_cos(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Cos; retro_unary!(RetroCos, B::float_cos); impl Backward for Cos { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_neg(B::float_sin(input)); B::float_mul(grad, value) }); } } match Cos .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroCos::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_cos(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_cos(tensor.primitive)), } } fn float_sin(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sin; retro_unary!(RetroSin, B::float_sin); impl Backward for Sin { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let state = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_cos(state); B::float_mul(grad, value) }); } } match Sin .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSin::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_sin(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_sin(tensor.primitive)), } } fn float_tanh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Tanh; retro_unary!(RetroTanh, B::float_tanh); impl Backward for Tanh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); let state = B::float_tanh(input); unary::(ops.parents, ops.node, grads, |grad| { let value = B::float_add_scalar( B::float_neg(B::float_powi_scalar(state, 2.into())), 1f32.into(), ); B::float_mul(grad, value) }); } } match Tanh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroTanh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_tanh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_tanh(tensor.primitive)), } } fn float_cosh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Cosh; retro_unary!(RetroCosh, B::float_cosh); impl Backward for Cosh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { B::float_mul(grad, B::float_sinh(input)) }); } } match Cosh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroCosh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_cosh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_cosh(tensor.primitive)), } } fn float_sinh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sinh; retro_unary!(RetroSinh, B::float_sinh); impl Backward for Sinh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { B::float_mul(grad, B::float_cosh(input)) }); } } match Sinh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSinh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_sinh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_sinh(tensor.primitive)), } } fn float_tan(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Tan; retro_unary!(RetroTan, B::float_tan); impl Backward for Tan { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); let tan_x = B::float_tan(input); unary::(ops.parents, ops.node, grads, |grad| { // d/dx tan(x) = 1 + tan^2(x) let tan_sq = B::float_powi_scalar(tan_x, 2.into()); B::float_mul(grad, B::float_add_scalar(tan_sq, 1f32.into())) }); } } match Tan .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroTan::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_tan(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_tan(tensor.primitive)), } } fn float_asin(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Asin; retro_unary!(RetroAsin, B::float_asin); impl Backward for Asin { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx asin(x) = 1/sqrt(1 - x^2) let x_sq = B::float_powi_scalar(input, 2.into()); let denom = B::float_sqrt(B::float_add_scalar(B::float_neg(x_sq), 1f32.into())); B::float_mul(grad, B::float_recip(denom)) }); } } match Asin .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAsin::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_asin(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_asin(tensor.primitive)), } } fn float_acos(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Acos; retro_unary!(RetroAcos, B::float_acos); impl Backward for Acos { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx acos(x) = -1/sqrt(1 - x^2) let x_sq = B::float_powi_scalar(input, 2.into()); let denom = B::float_sqrt(B::float_add_scalar(B::float_neg(x_sq), 1f32.into())); let value = B::float_neg(B::float_recip(denom)); B::float_mul(grad, value) }); } } match Acos .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAcos::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_acos(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_acos(tensor.primitive)), } } fn float_atan(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Atan; retro_unary!(RetroAtan, B::float_atan); impl Backward for Atan { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx atan(x) = 1/(1 + x^2) let x_sq = B::float_powi_scalar(input, 2.into()); let value = B::float_recip(B::float_add_scalar(x_sq, 1f32.into())); B::float_mul(grad, value) }); } } match Atan .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAtan::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_atan(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_atan(tensor.primitive)), } } fn float_asinh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Asinh; retro_unary!(RetroAsinh, B::float_asinh); impl Backward for Asinh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx asinh(x) = 1/sqrt(x^2 + 1) let x_sq = B::float_powi_scalar(input, 2.into()); let value = B::float_recip(B::float_sqrt(B::float_add_scalar(x_sq, 1f32.into()))); B::float_mul(grad, value) }); } } match Asinh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAsinh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_asinh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_asinh(tensor.primitive)), } } fn float_acosh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Acosh; retro_unary!(RetroAcosh, B::float_acosh); impl Backward for Acosh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx acosh(x) = 1/sqrt(x^2 - 1) let x_sq = B::float_powi_scalar(input, 2.into()); let value = B::float_recip(B::float_sqrt(B::float_sub_scalar(x_sq, 1f32.into()))); B::float_mul(grad, value) }); } } match Acosh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAcosh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_acosh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_acosh(tensor.primitive)), } } fn float_atanh(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Atanh; retro_unary!(RetroAtanh, B::float_atanh); impl Backward for Atanh { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let input = checkpointer.retrieve_node_output(ops.state); unary::(ops.parents, ops.node, grads, |grad| { // d/dx atanh(x) = 1/(1 - x^2) let x_sq = B::float_powi_scalar(input, 2.into()); let value = B::float_recip(B::float_add_scalar(B::float_neg(x_sq), 1f32.into())); B::float_mul(grad, value) }); } } match Atanh .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroAtanh::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_atanh(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_atanh(tensor.primitive)), } } fn float_atan2(y: FloatTensor, x: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Atan2; retro_binary!(RetroAtan2, B::float_atan2); impl Backward for Atan2 { type State = (Option, Option, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (y_id, x_id, broadcast) = ops.state; let y = y_id.map(|id| checkpointer.retrieve_node_output(id)); let x = x_id.map(|id| checkpointer.retrieve_node_output(id)); let [y_4y, y_4x] = duplicate(&ops.parents, y); let [x_4y, x_4x]: [Option>; 2] = duplicate(&ops.parents, x); binary::( ops.parents, ops.node, grads, |grad| { // d/dy atan2(y, x) = x/(x^2 + y^2) let y = y_4y.unwrap(); let x = x_4y.unwrap(); let x_sq = B::float_powi_scalar(x.clone(), 2.into()); let y_sq = B::float_powi_scalar(y, 2.into()); let denom = B::float_add(x_sq, y_sq); let value = B::float_div(x, denom); let grad = B::float_mul(grad, value); broadcast.backward_lhs::(grad) }, |grad| { // d/dx atan2(y, x) = -y/(x^2 + y^2) let y = y_4x.unwrap(); let x = x_4x.unwrap(); let x_sq = B::float_powi_scalar(x, 2.into()); let y_sq = B::float_powi_scalar(y.clone(), 2.into()); let denom = B::float_add(x_sq, y_sq); let value = B::float_neg(B::float_div(y, denom)); let grad = B::float_mul(grad, value); broadcast.backward_rhs::(grad) }, ); } } let y_tracked = y.is_tracked(); let x_tracked = x.is_tracked(); let broadcast = BinaryOpsBroadcast::new::(&y.primitive, &x.primitive); match Atan2 .prepare::([y.node.clone(), x.node.clone()]) .memory_bound() .retro_forward(RetroAtan2::::new(y.node.id, x.node.id)) .parents([&y, &x]) .stateful() { OpsKind::Tracked(mut prep) => { let is_tracked = y_tracked || x_tracked; let y_state = is_tracked.then(|| prep.checkpoint(&y)); let x_state = is_tracked.then(|| prep.checkpoint(&x)); prep.finish( (y_state, x_state, broadcast), B::float_atan2(y.primitive, x.primitive), ) } OpsKind::UnTracked(prep) => prep.finish(B::float_atan2(y.primitive, x.primitive)), } } fn float_round(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Round; retro_unary!(RetroRound, B::float_round); impl Backward for Round { type State = (Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_zeros(shape, &device, grad.dtype().into()) }) } } match Round .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroRound::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_round(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)), } } fn float_floor(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Floor; retro_unary!(RetroFloor, B::float_floor); impl Backward for Floor { type State = (Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_zeros(shape, &device, grad.dtype().into()) }) } } match Floor .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroFloor::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_floor(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)), } } fn float_ceil(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Ceil; retro_unary!(RetroCeil, B::float_ceil); impl Backward for Ceil { type State = (Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_zeros(shape, &device, grad.dtype().into()) }) } } match Ceil .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroCeil::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_ceil(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_ceil(tensor.primitive)), } } fn float_trunc(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Trunc; retro_unary!(RetroTrunc, B::float_trunc); impl Backward for Trunc { type State = (Shape, B::Device); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_zeros(shape, &device, grad.dtype().into()) }) } } match Trunc .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroTrunc::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(preps) => preps.finish( (tensor.primitive.shape(), B::float_device(&tensor.primitive)), B::float_trunc(tensor.primitive), ), OpsKind::UnTracked(preps) => preps.finish(B::float_trunc(tensor.primitive)), } } fn float_erf(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Erf; retro_unary!(RetroErf, B::float_erf); impl Backward for Erf { type State = NodeId; fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| { let ops = checkpointer.retrieve_node_output(ops.state); let exponent = B::float_neg(B::float_powi_scalar(ops, 2.into())); let numerator = B::float_mul_scalar(B::float_exp(exponent), 2.0.into()); let denominator = core::f64::consts::PI.sqrt().into(); let value = B::float_div_scalar(numerator, denominator); B::float_mul(grad, value) }); } } match Erf .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroErf::::new(tensor.node.id)) .parents([&tensor]) .stateful() { OpsKind::Tracked(mut prep) => { let state = prep.checkpoint(&tensor); prep.finish(state, B::float_erf(tensor.primitive)) } OpsKind::UnTracked(prep) => prep.finish(B::float_erf(tensor.primitive)), } } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CatStep { nodes: Vec>, // The dimension of each tensor along the dim dimension. // This indicates the number of dimension concatenated for each tensor. dim_sizes: Vec, output: NodeRef, phantom: PhantomData, dim: usize, parents: Vec, } impl Step for CatStep { fn step(self: Box, grads: &mut Gradients, _checkpointer: &mut Checkpointer) { let grad = grads.consume::(&self.output); let ranges_template: Vec<_> = grad.shape().iter().map(|&v| 0..v).collect(); self.nodes .into_iter() .zip(self.dim_sizes) .scan(0, |offset, (node_opt, dim_size)| { let start = *offset; let end = start + dim_size; *offset = end; Some(node_opt.map(|node| (node, start, end))) }) .flatten() .for_each(|(node, start, end)| { let mut ranges = ranges_template.clone(); ranges[self.dim] = start..end; let slices: Vec = ranges .iter() .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1)) .collect(); grads.register::(node.id, B::float_slice(grad.clone(), &slices)); }); } fn node(&self) -> NodeId { self.output.id } fn parents(&self) -> &[Parent] { &self.parents } fn depth(&self) -> usize { self.output.order } } let mut nodes = Vec::with_capacity(tensors.len()); let mut primitives = Vec::with_capacity(tensors.len()); let mut dim_sizes = Vec::with_capacity(tensors.len()); tensors.into_iter().for_each(|tensor| { dim_sizes.push(tensor.primitive.shape()[dim]); nodes.push(tensor.node); primitives.push(tensor.primitive); }); let requirement = Requirement::from_nodes(&nodes); // For simplicity, this operation does not checkpoint anything let cat_computing_property = ComputingProperty::Ambiguous; let checkpointer_builder = CheckpointerBuilder::default(); let output = B::float_cat(primitives, dim); if requirement.is_none() { return AutodiffTensor::from_parents( output, &nodes, requirement, cat_computing_property, ); } let output = AutodiffTensor::from_parents(output, &nodes, requirement, cat_computing_property); let mut parents = Vec::new(); let nodes = nodes .into_iter() .map(|node| node.clone_if_require_grad()) .collect::>(); for node in nodes.iter().flatten() { parents.push(Parent { id: node.id }); } let ops = CatStep::::new(nodes, dim_sizes, output.node.clone(), dim, parents); output.register_step(ops, checkpointer_builder) } fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { match MaxMinDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim); prep.finish((index, shape, dim), tensor) } OpsKind::UnTracked(prep) => prep.finish(B::float_max_dim(tensor.primitive, dim)), } } fn float_max_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { match MaxMinDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish((index.clone(), shape, dim), tensor); (tensor, index) } OpsKind::UnTracked(prep) => { let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish(tensor); (tensor, index) } } } fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { match MaxMinDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim); prep.finish((index, shape, dim), tensor) } OpsKind::UnTracked(prep) => prep.finish(B::float_min_dim(tensor.primitive, dim)), } } fn float_min_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { match MaxMinDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish((index.clone(), shape, dim), tensor); (tensor, index) } OpsKind::UnTracked(prep) => { let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim); let tensor = prep.finish(tensor); (tensor, index) } } } fn float_into_int(tensor: FloatTensor) -> as Backend>::IntTensorPrimitive { B::float_into_int(tensor.primitive) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { #[derive(Debug)] struct PowF; retro_binary!(RetroPowf, B::float_powf); impl Backward for PowF { type State = (NodeId, NodeId, BinaryOpsBroadcast); fn backward( self, ops: Ops, grads: &mut Gradients, checkpointer: &mut Checkpointer, ) { let (lhs_id, rhs_id, broadcast) = ops.state; let lhs: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(lhs_id); let rhs: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(rhs_id); // Both lhs and rhs are needed for both lhs and rhs gradients, but we clone them // the number of times required by the parents specification. let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs)); let [lhs_4lhs, lhs_4rhs] = duplicate(&ops.parents, Some(lhs)); binary::( ops.parents, ops.node, grads, |grad| { //rhs*(lhs.val**(rhs-1))*grad let rhs1 = rhs_4lhs.unwrap(); let rhs2 = rhs1.clone(); let lhs = lhs_4lhs.unwrap(); let tmp = B::float_powf(lhs, B::float_sub_scalar(rhs1, 1.0.into())); let value = B::float_mul(tmp, rhs2); let grad = B::float_mul(grad, value); broadcast.backward_lhs::(grad) }, |grad| { //lhs**rhs * ln(lhs) * grad let rhs = rhs_4rhs.unwrap(); let lhs1 = lhs_4rhs.unwrap(); let lhs2 = lhs1.clone(); let tmp = B::float_powf(lhs1, rhs); let value = B::float_mul(tmp, B::float_log(lhs2)); let grad = B::float_mul(grad, value); broadcast.backward_rhs::(grad) }, ); } } let broadcast = BinaryOpsBroadcast::new::(&lhs.primitive, &rhs.primitive); match PowF .prepare::([lhs.node.clone(), rhs.node.clone()]) .memory_bound() .retro_forward(RetroPowf::::new(lhs.node.id, rhs.node.id)) .parents([&lhs, &rhs]) .stateful() { OpsKind::Tracked(mut prep) => { let lhs_state = prep.checkpoint(&lhs); let rhs_state = prep.checkpoint(&rhs); prep.finish( (lhs_state, rhs_state, broadcast), B::float_powf(lhs.primitive, rhs.primitive), ) } OpsKind::UnTracked(prep) => prep.finish(B::float_powf(lhs.primitive, rhs.primitive)), } } fn float_sign(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Sign; retro_unary!(RetroSign, B::float_sign); impl Backward for Sign { type State = (); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { unary::(ops.parents, ops.node, grads, |grad| // Always return 0 because the derivative of the sign function // does not contribute to gradient updates in a meaningful way. B::float_mul_scalar(grad, 0f32.into())); } } Sign.prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroSign::::new(tensor.node.id)) .parents([&tensor]) .stateless(B::float_sign(tensor.primitive)) } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { // D1: tensor, D2: shape #[derive(Debug)] struct ExpandDim; #[derive(new, Debug)] struct RetroExpand { input_id: NodeId, shape: Shape, _backend: PhantomData, } impl RetroForward for RetroExpand { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let input = states.get_state::(&self.input_id); let out = B::float_expand(input, self.shape.clone()); states.save(out_node, out) } } impl Backward for ExpandDim { type State = (Shape, Shape); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape_in, shape_out) = ops.state; let ndims_in = shape_in.num_dims(); let ndims_out = shape_out.num_dims(); let mut shape_expanded = vec![1; ndims_out]; debug_assert!(ndims_out >= ndims_in); for i in 0..ndims_in { shape_expanded[i + (ndims_out - ndims_in)] = shape_in[i]; } unary::(ops.parents, ops.node, grads, |grad| { let shape_grad = grad.shape(); let mut grad = grad; #[allow(clippy::needless_range_loop)] for i in 0..ndims_out { if shape_expanded[i] == 1 && shape_grad[i] != 1 { grad = B::float_sum_dim(grad, i); } } B::float_reshape(grad, shape_in) }); } } match ExpandDim .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroExpand::::new(tensor.node.id, shape.clone())) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), shape.clone()), B::float_expand(tensor.primitive, shape), ), OpsKind::UnTracked(prep) => prep.finish(B::float_expand(tensor.primitive, shape)), } } fn float_sort(tensor: FloatTensor, dim: usize, descending: bool) -> FloatTensor { match super::sort::SortDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, indices) = B::float_sort_with_indices(tensor.primitive, dim, descending); prep.finish((indices, shape, dim), tensor) } OpsKind::UnTracked(prep) => { prep.finish(B::float_sort(tensor.primitive, dim, descending)) } } } fn float_sort_with_indices( tensor: FloatTensor, dim: usize, descending: bool, ) -> (FloatTensor, IntTensor) { match super::sort::SortDim .prepare::([tensor.node]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => { let shape = tensor.primitive.shape(); let (tensor, indices) = B::float_sort_with_indices(tensor.primitive, dim, descending); let tensor = prep.finish((indices.clone(), shape, dim), tensor); (tensor, indices) } OpsKind::UnTracked(prep) => { let (tensor, indices) = B::float_sort_with_indices(tensor.primitive, dim, descending); let tensor = prep.finish(tensor); (tensor, indices) } } } fn float_argsort(tensor: FloatTensor, dim: usize, descending: bool) -> IntTensor { B::float_argsort(tensor.primitive, dim, descending) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { #[derive(Debug)] struct Repeat; #[derive(new, Debug)] struct RetroRepeat { tensor_id: NodeId, dim: usize, times: usize, _backend: PhantomData, } impl RetroForward for RetroRepeat { fn forward(&self, states: &mut BackwardStates, out_node: NodeId) { let tensor = states.get_state::(&self.tensor_id); let out = B::float_repeat_dim(tensor, self.dim, self.times); states.save(out_node, out) } } impl Backward for Repeat { type State = (usize, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (dim, times) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let mut dims = grad.shape(); let orig_dim_size = dims[dim] / times; if orig_dim_size > 1 { dims[dim] = orig_dim_size; let orig_dims = dims.clone(); dims.insert(dim + 1, times); // shape [..., orig_dim_size, times, ...] let grad = B::float_reshape(grad, dims); let grad = B::float_sum_dim(grad, dim + 1); // sum over repeat times B::float_reshape(grad, orig_dims) } else { B::float_sum_dim(grad, dim) } }); } } match Repeat .prepare::([tensor.node.clone()]) .memory_bound() .retro_forward(RetroRepeat::::new(tensor.node.id, dim, times)) .parents([&tensor]) .stateful() { OpsKind::Tracked(prep) => prep.finish( (dim, times), B::float_repeat_dim(tensor.primitive, dim, times), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_repeat_dim(tensor.primitive, dim, times)) } } } fn float_cast(tensor: FloatTensor, dtype: burn_std::FloatDType) -> FloatTensor { #[derive(Debug)] struct Cast; impl Backward for Cast { type State = FloatDType; fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let dtype = ops.state; unary::(ops.parents, ops.node, grads, |grad| { B::float_cast(grad, dtype) }); } } match Cast .prepare::([tensor.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( tensor.dtype().into(), B::float_cast(tensor.primitive, dtype), ), OpsKind::UnTracked(prep) => prep.finish(B::float_cast(tensor.primitive, dtype)), } } // TODO: Implement float_prod and float_sum // https://github.com/tracel-ai/burn/issues/1458 fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { #[derive(Debug)] struct Unfold; impl Backward for Unfold { type State = (Shape, usize, usize, usize); fn backward( self, ops: Ops, grads: &mut Gradients, _checkpointer: &mut Checkpointer, ) { let (shape_in, dim, size, step) = ops.state; let windows = calculate_unfold_windows(shape_in[dim], size, step); unary::(ops.parents, ops.node, grads, |grad| { let device = B::float_device(&grad); let mut grad_input = B::float_zeros(shape_in.clone(), &device, grad.dtype().into()); if windows == 0 { return grad_input; } let ndims_in = shape_in.num_dims(); let ndims_out = grad.shape().num_dims(); let mut target_shape = shape_in.clone(); target_shape[dim] = size; for window_idx in 0..windows { let mut slices_out = vec![Slice::new(0, None, 1); ndims_out]; let start = window_idx * step; let end = start + size; slices_out[dim] = Slice::new(window_idx as isize, Some((window_idx + 1) as isize), 1); let window_grad = B::float_slice(grad.clone(), &slices_out); let last_axis = ndims_out - 1; let mut permutation: Vec = (0..dim).collect(); permutation.push(last_axis); permutation.extend(dim + 1..last_axis); permutation.push(dim); let window_grad = B::float_permute(window_grad, &permutation); let window_grad = B::float_reshape(window_grad, target_shape.clone()); let mut slices_in = vec![Slice::new(0, None, 1); ndims_in]; slices_in[dim] = Slice::new(start as isize, Some(end as isize), 1); let current = B::float_slice(grad_input.clone(), &slices_in); let updated = B::float_add(current, window_grad); grad_input = B::float_slice_assign(grad_input, &slices_in, updated); } grad_input }); } } match Unfold .prepare::([tensor.node.clone()]) .compute_bound() .stateful() { OpsKind::Tracked(prep) => prep.finish( (tensor.primitive.shape(), dim, size, step), B::float_unfold(tensor.primitive, dim, size, step), ), OpsKind::UnTracked(prep) => { prep.finish(B::float_unfold(tensor.primitive, dim, size, step)) } } } } #[derive(Debug, Clone)] enum BinaryOpsBroadcast { Broadcasted(Shape, Shape), None, } impl BinaryOpsBroadcast { fn new(lhs: &B::FloatTensorPrimitive, rhs: &B::FloatTensorPrimitive) -> Self { let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); let ndims = shape_lhs.num_dims(); for i in 0..ndims { if shape_rhs[i] != shape_lhs[i] { return Self::Broadcasted(shape_lhs, shape_rhs); } } Self::None } fn backward_lhs(&self, grad: B::FloatTensorPrimitive) -> B::FloatTensorPrimitive { match self { BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::(grad, lhs), BinaryOpsBroadcast::None => grad, } } fn backward_rhs(&self, grad: B::FloatTensorPrimitive) -> B::FloatTensorPrimitive { match self { BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::(grad, rhs), BinaryOpsBroadcast::None => grad, } } } ================================================ FILE: crates/burn-autodiff/src/ops/transaction.rs ================================================ use burn_backend::{ Backend, ExecutionError, ops::{TransactionOps, TransactionPrimitive}, }; use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy}; impl TransactionOps for Autodiff { async fn tr_execute( transaction: TransactionPrimitive, ) -> Result { B::tr_execute(TransactionPrimitive::new( transaction .read_floats .into_iter() .map(|t| t.primitive) .collect(), transaction.read_qfloats, transaction.read_ints, transaction.read_bools, )) .await } } ================================================ FILE: crates/burn-autodiff/src/runtime/client.rs ================================================ use crate::{ checkpoint::builder::CheckpointerBuilder, grads::Gradients, graph::StepBoxed, tensor::{AutodiffTensor, NodeRefCount}, }; use burn_backend::Backend; /// Client used to communicate with the autodiff server. pub trait AutodiffClient: Send + Clone { /// Register a new step. fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder); /// Call backpropagation from the given tensor. fn backward(&self, tensor: AutodiffTensor) -> Gradients; } /// Client implementation in used. pub type AutodiffClientImpl = super::graph::GraphMutexClient; ================================================ FILE: crates/burn-autodiff/src/runtime/graph.rs ================================================ use super::{AutodiffClient, server::AutodiffServer}; use crate::{ NodeId, checkpoint::builder::CheckpointerBuilder, grads::Gradients, graph::{Parent, StepBoxed}, runtime::server::NodeCleaner, tensor::{AutodiffTensor, NodeRefCount}, }; use alloc::sync::Arc; use alloc::vec::Vec; use burn_backend::Backend; use hashbrown::{HashMap, HashSet}; #[cfg(feature = "std")] use parking_lot::{Mutex, MutexGuard}; #[cfg(not(feature = "std"))] use spin::{Mutex, MutexGuard}; /// A client for managing multiple graphs using mutex-based synchronization. /// /// The biggest benefit of using this client implementation is that each graph can modify its own /// data without blocking other graphs, which is essential for multi-device training. /// /// # Notes /// /// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of /// graphs will be stored under a single mutex-protected graph by the client, limiting /// parallelisation. #[derive(Clone, new, Debug)] pub struct GraphMutexClient; /// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph. /// /// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent /// dependencies, ensuring proper synchronization and server allocation. /// /// # Notes /// /// Multiple node ids can point to the same graph, where the autodiff graph is stored. #[derive(Default)] pub struct GraphLocator { graphs: HashMap>, /// We keep a mapping of each original node id (graph id) => all nodes that point to that graph. /// This is to ensure that when merging graphs, we correctly move all previous graphs to /// the new merged one. keys: HashMap>, } /// Represents a single computation graph with a mutex-protected server. /// /// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was /// first created. pub(crate) struct Graph { origin: NodeId, state: Mutex, } #[derive(Default)] struct GraphState { server: AutodiffServer, } impl core::fmt::Debug for Graph { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Graph") .field("origin", &self.origin) .finish() } } static STATE: Mutex> = Mutex::new(None); impl GraphMutexClient { /// Retrieves or creates a graph for the given [NodeId] and parent dependencies. /// /// # Parameters /// - `node`: The unique identifier for the stream. /// - `parents`: A slice of parent nodes that the stream depends on. /// /// # Returns /// An `Arc` representing the selected or newly created stream. fn graph(node: NodeId, parents: &[Parent]) -> Arc { let mut state = STATE.lock(); match state.as_mut() { Some(locator) => locator.select(node, parents), None => { let mut locator = GraphLocator::default(); let stream = locator.select(node, parents); *state = Some(locator); stream } } } } impl AutodiffClient for GraphMutexClient { fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) { let node_id = *node_id_ref; let graph = GraphMutexClient::graph(node_id, step.parents()); let mut state = graph.state.lock(); state.server.register(node_id_ref, step, actions); } fn backward(&self, root: AutodiffTensor) -> Gradients { let node_id = root.node.id; let graph = GraphMutexClient::graph(root.node.id, &[]); let grads = Gradients::new::(root.node, root.primitive); let grads = { let mut state = graph.state.lock(); state.server.backward::(grads, node_id) }; // lock released GraphCleaner::cleanup_orphaned_entries(); grads } } struct GraphCleaner<'a> { guard: MutexGuard<'a, Option>, } impl<'a> GraphCleaner<'a> { fn cleanup_orphaned_entries() { let graphs = { // Get the available graphs and release the lock match STATE.lock().as_ref() { Some(state) => state.graphs.clone(), None => return, } }; let mut should_remove = Vec::new(); for graph in graphs.values() { { let mut guard = graph.state.lock(); // Double safety: in case it was marked as no longer useful, but other // nodes are still relevant, we only check which nodes can safely be removed. if !guard.server.maybe_useful() { guard .server .free_unused_roots(|node| should_remove.push(*node)); } } } if !should_remove.is_empty() { let mut state = STATE.lock(); if let Some(state) = state.as_mut() { for node in should_remove { state.remove_entry(&node); } } } } } impl<'a> NodeCleaner for GraphCleaner<'a> { fn init() -> Self { let guard = STATE.lock(); Self { guard } } fn clean(&mut self, node: &NodeId) { if let Some(state) = self.guard.as_mut() { state.remove_entry(node); } } } impl GraphLocator { /// Selects a single graph for the given [NodeId], considering parent dependencies. /// /// If multiple graphs are found, they are merged into a single one. /// /// # Parameters /// - `node`: The node ID of the graph to select. /// - `parents`: A slice of parent nodes that the graph depends on. /// /// # Returns /// /// An `Arc` representing the selected or merged graph. pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc { match self.analyse(node, parents) { GraphAnalysis::NoCollision(graph) => { if graph.origin != node { self.graphs.insert(node, graph.clone()); self.register_key(graph.origin, node); } graph } GraphAnalysis::Collisions(graphs) => self.merge(node, graphs), } } /// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`. fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis { // If no parents, there is no collision, therefore a single graph is ok. if parents.is_empty() { let graph = match self.graphs.get(&node) { Some(val) => val.clone(), None => self.new_graph(node), }; return GraphAnalysis::NoCollision(graph); }; // We collect all graphs of parents and of the current node based on their origin node id. let mut graphs = HashMap::>::new(); if let Some(val) = self.graphs.get(&node) { graphs.insert(val.origin, val.clone()); } for parent in parents { match self.graphs.get(&parent.id) { Some(graph) => graphs.insert(graph.origin, graph.clone()), None => continue, }; } if graphs.is_empty() { return match self.graphs.get(&node) { Some(old) => GraphAnalysis::NoCollision(old.clone()), None => GraphAnalysis::NoCollision(self.new_graph(node)), }; } if graphs.len() == 1 { return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1); } GraphAnalysis::Collisions(graphs) } /// Merges multiple graphs associated with a node into a single graph. fn merge(&mut self, node: NodeId, mut graphs: HashMap>) -> Arc { let mut graphs = graphs.drain().map(|g| g.1); let main = graphs.next().expect("At least one graph"); self.register_key(main.origin, node); let mut state = main.state.lock(); for graph in graphs { self.merge_two(&mut state, &main, graph); } self.graphs.insert(main.origin, main.clone()); self.graphs.insert(node, main.clone()); core::mem::drop(state); main } /// Registers a key for a given origin node. fn register_key(&mut self, origin: NodeId, key: NodeId) { if !self.keys.contains_key(&origin) { // Ensure an entry exists for this origin self.keys.insert(origin, HashSet::new()); } if origin != key { // Register this node to point to the origin graph self.keys.get_mut(&origin).unwrap().insert(key); } } /// Merges two graphs by combining their states and updating graph mappings. fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc, merged: Arc) { let mut locked = merged.state.lock(); let mut state_old = GraphState::default(); core::mem::swap(&mut state_old, &mut locked); main_state.server.extend(state_old.server); // Re-map merged origin to the main graph self.graphs.insert(merged.origin, main.clone()); // Move all keys (node IDs) from the merged graph to the main graph if let Some(locator_keys) = self.keys.remove(&merged.origin) { for k in locator_keys.iter() { self.graphs.insert(*k, main.clone()); } let locator_keys_main = self .keys .get_mut(&main.origin) .expect("Should be init before the merge."); locator_keys_main.extend(locator_keys); } } /// Creates a new graph for a given node. fn new_graph(&mut self, origin: NodeId) -> Arc { let graph = Arc::new(Graph { origin, state: Mutex::new(GraphState::default()), }); self.graphs.insert(origin, graph.clone()); self.keys.insert(origin, HashSet::new()); graph } fn remove_entry(&mut self, node: &NodeId) { if let Some(graph) = self.graphs.remove(node) { let mut remove = false; if let Some(entry) = self.keys.get_mut(&graph.origin) { entry.remove(node); if entry.is_empty() { remove = true; } } if remove { self.keys.remove(&graph.origin); } } } } /// Represents the analysis result of graph operations for a given node and its parents. #[derive(Debug)] enum GraphAnalysis { /// No collision detected, contains the graph associated with the node. NoCollision(Arc), /// Collision detected, contains a map of node IDs to their associated graphs. Collisions(HashMap>), } ================================================ FILE: crates/burn-autodiff/src/runtime/memory_management.rs ================================================ use crate::{ NodeId, collections::{HashMap, HashSet}, graph::Parent, tensor::NodeRefCount, }; use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec}; use core::mem; #[derive(Default, Debug)] pub struct GraphMemoryManagement { nodes: HashMap>, leaves: HashSet, statuses: HashMap, } #[derive(Debug, Clone, PartialEq)] enum NodeMemoryStatus { Useful, Unavailable, Unknown, } impl GraphMemoryManagement { pub fn extend(&mut self, other: Self) { self.nodes.extend(other.nodes); self.leaves.extend(other.leaves); self.statuses.extend(other.statuses); } /// Register a new node with its parent. pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) { let node_id = *node.as_ref(); for parent in parents.iter() { self.leaves.remove(&parent.id); } self.leaves.insert(node_id); self.nodes .insert(node, parents.iter().map(|p| p.id).collect()); } /// Free the node from the state. pub fn consume_node(&mut self, node_id: NodeId) { if !self.is_referenced(node_id) { self.leaves.remove(&node_id); self.nodes.remove(&node_id); } } /// Free all nodes whose backward call has become impossible /// /// This function goes into three steps, which must happen for all leaves /// before going into the next step. Then it deletes what can be safely deleted pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) { let leaves = self.leaves.clone(); let mut new_leaves = HashSet::new(); let mut deletables = Vec::new(); // When consuming nodes with a backward pass, some other backward passes become // unavailable because some of their parents have been consumed. They are // identified here. for leaf in leaves.clone() { self.unavailable_propagation(leaf); } // Among the available nodes that remain, some may be useless if no // available node with a tensor reference exist in their descendance. // But some may seem useless from some leaf but be useful from another one, // hence the need to iterate on all leaves. self.useful_propagation(leaves.clone()); // New leaves are the roots of a useful backward sub-tree. // Deletables are everything not marked as useful. for leaf in leaves { self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables); } // Replace leaves by the new ones and delete everything not useful anymore mem::swap(&mut self.leaves, &mut new_leaves); self.clear_unused_roots(&mut deletables); self.statuses.clear(); for node_to_delete in deletables { self.nodes.remove(&node_to_delete); on_free_graph(&node_to_delete) } } pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) { let mut deletables = Vec::new(); self.clear_unused_roots(&mut deletables); for node_id in deletables { self.nodes.remove(&node_id); on_free_graph(&node_id); } } fn clear_unused_roots(&self, to_delete: &mut Vec) { for (id, parents) in self.nodes.iter() { let is_useful = matches!( self.statuses.get(id.as_ref()), Some(NodeMemoryStatus::Useful) ); // Check if parents are either empty or absent from self.nodes let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p)); if !is_useful && Arc::strong_count(id) == 1 && parents_absent { to_delete.push(*id.as_ref()) } } } fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus { // If already visited if let Some(status) = self.statuses.get(&node_id) { return status.clone(); } match self.nodes.get(&node_id).cloned() { // If node exists and any of its parents is unavailable, it is unavailable as well // If node exists but the parents vec is empty, it is a tensor that never had parents; // the status remains unknown Some(parents) => { let mut node_status = NodeMemoryStatus::Unknown; for parent in parents { let parent_status = self.unavailable_propagation(parent); if let NodeMemoryStatus::Unavailable = parent_status { node_status = NodeMemoryStatus::Unavailable; } } self.statuses.insert(node_id, node_status.clone()); node_status } // If node does not exist, it was // deleted, so this and all its descendants are unavailable None => { self.statuses.insert(node_id, NodeMemoryStatus::Unavailable); NodeMemoryStatus::Unavailable } } } fn useful_propagation(&mut self, leaves: HashSet) { // Accumulate visited nodes let mut explored = HashSet::new(); let mut tagged_useful = HashSet::new(); // Queue of nodes to visit let mut to_tag_useful = PopNodeSet::default(); let mut to_explore = PopNodeSet::new(leaves); // Utility function to iterate over a node's parents let parents = |node_id| { self.nodes .get(&node_id) .cloned() .unwrap_or_default() .into_iter() }; loop { // Pop a node id, greedily looking at tag_useful ones first let (node_id, status) = match to_tag_useful.pop() { Some(node_id) => (node_id, NodeMemoryStatus::Useful), None => match to_explore.pop() { Some(node_id) => { let node_status = self .statuses .get(&node_id) .expect("All nodes should have received a status during unavailable_propagation") .to_owned(); if let NodeMemoryStatus::Unknown = node_status { match self.is_referenced(node_id) { true => (node_id, NodeMemoryStatus::Useful), false => (node_id, NodeMemoryStatus::Unknown), } } else { (node_id, node_status) } } None => { // There are no nodes in the queues anymore break; } }, }; match status { NodeMemoryStatus::Useful => { tagged_useful.insert(node_id); for parent in parents(node_id) { // The node can be explored, as long as it's not already tagged useful if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) { to_tag_useful.insert(parent); } } } _ => { explored.insert(node_id); for parent in parents(node_id) { if !(explored.contains(&parent) || to_explore.contains(&parent)) { to_explore.insert(parent); } } } } self.statuses.insert(node_id, status); } } fn identify_leaves_and_deletables( &self, leaf_id: NodeId, new_leaves: &mut HashSet, to_delete: &mut Vec, ) { let mut visited = HashSet::new(); let mut to_visit = vec![leaf_id]; while let Some(node_id) = to_visit.pop() { visited.insert(node_id); match self .statuses .get(&node_id) .expect("Node should have status") { NodeMemoryStatus::Useful => { new_leaves.insert(node_id); } _ => { to_delete.push(node_id); for parent in self .nodes .get(&node_id) .cloned() .unwrap_or_default() .into_iter() { if !visited.contains(&parent) { to_visit.push(parent); } } } }; } } fn is_referenced(&self, node_id: NodeId) -> bool { match self.nodes.get_key_value(&node_id) { Some((key, _value)) => Arc::strong_count(key) > 1, None => panic!("Node should be in the nodes map"), } } pub(crate) fn maybe_useful(&self) -> bool { self.nodes.keys().any(|node| Arc::strong_count(node) > 1) } } /// Wrapper over hash set for fast popping of any node #[derive(new, Default)] struct PopNodeSet { hash_set: HashSet, } impl PopNodeSet { #[inline(always)] fn pop(&mut self) -> Option { self.hash_set .iter() .next() .copied() .and_then(|node_id| self.hash_set.take(&node_id)) } #[inline(always)] fn contains(&self, node_id: &NodeId) -> bool { self.hash_set.contains(node_id) } #[inline(always)] fn insert(&mut self, node_id: NodeId) { self.hash_set.insert(node_id); } } ================================================ FILE: crates/burn-autodiff/src/runtime/mod.rs ================================================ mod client; mod memory_management; mod server; pub mod graph; pub use client::*; ================================================ FILE: crates/burn-autodiff/src/runtime/server.rs ================================================ use super::memory_management::GraphMemoryManagement; use crate::{ NodeId, checkpoint::{ base::{Checkpointer, NodeTree}, builder::CheckpointerBuilder, }, collections::HashMap, grads::Gradients, graph::{StepBoxed, traversal::BreadthFirstSearch}, tensor::NodeRefCount, }; use alloc::vec::Vec; #[derive(Default)] pub struct AutodiffServer { steps: HashMap, actions_builder: HashMap, memory_management: GraphMemoryManagement, } /// Defines how nodes are clean. pub trait NodeCleaner { /// Initialize a new cleaner. fn init() -> Self; /// Cleans a single [node](NodeId). fn clean(&mut self, node: &NodeId); } impl AutodiffServer { pub fn extend(&mut self, other: AutodiffServer) { self.steps.extend(other.steps); self.actions_builder.extend(other.actions_builder); self.memory_management.extend(other.memory_management); } pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) { let parents = step.parents(); let node_id = *rc.as_ref(); self.memory_management.register(rc, parents); self.steps.insert(node_id, step); self.actions_builder.insert(node_id, actions); } pub fn backward(&mut self, grads: Gradients, node_id: NodeId) -> Gradients { let step = self.steps.remove(&node_id).expect( "Node should have a step registered, did you forget to call \ `Tensor::register_grad` on the tensor where you need gradients?", ); let builder = self.actions_builder.remove(&node_id).unwrap(); let mut consumed = Vec::new(); let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed); let gradients = Self::execute_steps(tape, grads, checkpointer); // Cleanup let mut cleaner = NC::init(); self.memory_management .free_unavailable_nodes(|node_id: &NodeId| { self.steps.remove(node_id); self.actions_builder.remove(node_id); NC::clean(&mut cleaner, node_id); }); for node_id in consumed { cleaner.clean(&node_id) } gradients } pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) { self.memory_management.free_unused_roots(|node_id| { self.steps.remove(node_id); self.actions_builder.remove(node_id); on_free_graph(node_id); }); } fn build_tape( &mut self, node: NodeId, node_step: StepBoxed, mut builder: CheckpointerBuilder, consumed: &mut Vec, ) -> (Vec>, Checkpointer) { let mut tape = (0..node_step.depth()) .map(|_| Vec::with_capacity(1)) .collect::>(); let mut tree = HashMap::default(); BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| { self.memory_management.consume_node(id); // Clean up consumed node consumed.push(id); let depth = step.depth(); if depth == 0 { return; } if let Some(steps) = tape.get_mut(depth - 1) { let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id); tree.insert(id, parents.collect()); steps.push(step); } if let Some(node_builder) = self.actions_builder.remove(&id) { builder.extend(node_builder); } }); let checkpointer = builder.build(NodeTree::new(tree)); (tape, checkpointer) } fn execute_steps( tape: Vec>, mut grads: Gradients, mut checkpointer: Checkpointer, ) -> Gradients { tape.into_iter().rev().for_each(|steps| { steps .into_iter() .for_each(|step| step.step(&mut grads, &mut checkpointer)) }); // For checkpointing tests #[cfg(feature = "export_tests")] assert!(checkpointer.is_empty()); grads } pub(crate) fn maybe_useful(&self) -> bool { self.memory_management.maybe_useful() } } ================================================ FILE: crates/burn-autodiff/src/tensor.rs ================================================ use crate::{ checkpoint::{base::Checkpointer, builder::CheckpointerBuilder}, grads::Gradients, graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step}, runtime::{AutodiffClient, AutodiffClientImpl}, }; use alloc::{boxed::Box, sync::Arc, vec}; use burn_backend::{Backend, TensorMetadata}; #[derive(Debug, Clone)] pub struct AutodiffTensor { pub primitive: B::FloatTensorPrimitive, pub node: NodeRef, pub rc: NodeRefCount, } impl TensorMetadata for AutodiffTensor { fn dtype(&self) -> burn_std::DType { self.primitive.dtype() } fn shape(&self) -> burn_std::Shape { self.primitive.shape() } fn rank(&self) -> usize { self.primitive.rank() } } pub type NodeRefCount = Arc; #[derive(new, Debug)] pub(crate) struct RootStep { node: NodeRef, } impl Step for RootStep { fn step(self: Box, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) { // Nothing to do } fn node(&self) -> NodeId { self.node.id } fn parents(&self) -> &[Parent] { &self.node.parents } fn depth(&self) -> usize { self.node.order } } impl AutodiffTensor { /// Create a new leaf tensor. pub fn new(primitive: B::FloatTensorPrimitive) -> Self { let id = NodeId::new(); let node: NodeRef = Node::new( vec![], 0, id, Requirement::None, ComputingProperty::Ambiguous, AutodiffClientImpl::new(), ) .into(); Self { rc: Arc::new(node.id), primitive, node: node.clone(), } } pub fn is_tracked(&self) -> bool { !self.node.requirement.is_none() } /// Mark the tensor as requiring gradients. /// /// # Panics /// /// It panics if the tensor is not a leaf. pub fn require_grad(mut self) -> Self { match self.node.requirement { Requirement::Grad => self, Requirement::GradInBackward => { panic!("Can't convert a non leaf tensor into a tracked tensor") } Requirement::None => { self.node = Node::new( vec![], 0, self.node.id, Requirement::Grad, self.node.properties.clone(), self.node.client.clone(), ) .into(); let step = RootStep::new(self.node.clone()); self.register_step(step, CheckpointerBuilder::default()) } } } /// Create a tensor from parent infos. pub fn from_parents( primitive: B::FloatTensorPrimitive, parent_nodes: &[NodeRef], requirement: Requirement, computing_properties: ComputingProperty, ) -> Self { let order = parent_nodes .iter() .map(|node| node.order) .reduce(usize::max) .unwrap_or(0) + 1; let client = parent_nodes .first() .map(|node| node.client.clone()) .unwrap_or_else(AutodiffClientImpl::new); let node: NodeRef = Node::new( parent_nodes .iter() .filter_map(|node| node.clone_if_require_grad()) .map(|node| Parent::new(node.id)) .collect(), order, NodeId::new(), requirement, computing_properties, client, ) .into(); Self { rc: Arc::new(node.id), primitive, node, } } /// Register a step into a graph for that tensor. /// /// # Warning /// /// This should be called only once per tensor. pub fn register_step( self, step_that_created_the_tensor: S, actions: CheckpointerBuilder, ) -> Self { self.node.client.register( self.rc.clone(), Box::new(step_that_created_the_tensor), actions, ); self } pub fn into_primitive(self) -> B::FloatTensorPrimitive { self.primitive } pub fn backward(self) -> Gradients { let client = self.node.client.clone(); AutodiffClient::backward::(&client, self) } pub fn grad(&self, grads: &Gradients) -> Option { grads.get::(self) } pub fn grad_remove(&self, grads: &mut Gradients) -> Option { grads.remove::(self) } pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) { grads.remove::(self); grads.register::(self.node.id, grad); } } ================================================ FILE: crates/burn-autodiff/src/utils.rs ================================================ use alloc::vec::Vec; use crate::graph::NodeRef; /// Duplicate the given object for each node that requires gradients. /// /// # Notes /// /// This is useful since you don't have to keep N cloned references alive event if just 1 node /// will be updated. /// /// If the object is a tensor and if one reference exists, it can be updated inplace. pub fn duplicate( nodes: &[Option; N], obj: Option, ) -> [Option; N] { nodes .iter() .map(|node| match node { Some(_) => obj.clone(), None => None, }) .collect::>() .try_into() .unwrap() } ================================================ FILE: crates/burn-backend/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Core backend interfaces and data structures for executing tensor operations in Burn." documentation = "https://docs.rs/burn-backend" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-backend" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend" version.workspace = true [lints] workspace = true [features] default = ["std"] doc = ["default"] std = ["rand/std", "num-traits/std", "burn-std/std", "cubecl?/std"] tracing = ["burn-std/tracing", "cubecl/tracing"] cubecl = ["dep:cubecl", "burn-std/cubecl"] cubecl-cuda = ["cubecl", "cubecl/cuda"] cubecl-hip = ["cubecl", "cubecl/hip"] cubecl-wgpu = ["cubecl", "cubecl/wgpu"] cubecl-cpu = ["cubecl", "cubecl/cpu"] [dependencies] burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } cubecl = { workspace = true, optional = true, default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } derive-new = { workspace = true } enumset = { workspace = true } hashbrown = { workspace = true } num-traits = { workspace = true } rand = { workspace = true, default-features = false } rand_distr = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } [dev-dependencies] rand = { workspace = true, features = ["thread_rng"] } paste = { workspace = true } serde_json = { workspace = true, features = ["alloc"]} ================================================ FILE: crates/burn-backend/README.md ================================================ # Burn Backend This crate includes the core backend interfaces and data structures for executing tensor operations in Burn. ================================================ FILE: crates/burn-backend/src/backend/base.rs ================================================ use burn_std::DType; pub use burn_std::backtrace::BackTrace; use alloc::string::String; use enumset::{EnumSet, EnumSetType}; use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::element::Element; use crate::ops::*; use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; use crate::{QTensorPrimitive, TensorData, TensorMetadata}; use super::DeviceOps; /// This trait defines all types and functions needed for a backend to be used with burn. /// /// ## Design /// /// This trait aims to be as unopinionated as possible and allows implementations to define /// their own types and patterns. Therefore, there are few pre-defined abstractions baked /// into this trait. /// /// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`. /// Since we minimize assumptions, we chose to separate these types, as they are used in /// different contexts. However, some backends may have a generic tensor type that is used /// for all data types. /// /// ### Eager Mode /// /// Because burn supports dynamic graphs, the backend trait is designed around kernel /// implementations that can be called without any mutable context or graph. This may not be /// ideal for backends that want to configure their computational graphs and execute them /// multiple times. /// /// To implement this kind of backend, channels could be used to communicate with a backend /// server thread to build the computation graphs and re-execute the ones that are repeated, /// with some form of cache. Once that pattern has matured, a graph mode backend trait could /// be extracted from it, allowing other backends of the same kind to be quickly integrated /// with burn. This pattern could also be used to create an operation fusion trait, which /// allows backends to define what kind of graph structures can be fused into one operation. /// /// ### Multi-Threaded /// /// Backend tensor types are all `Clone` + `Send`, which allows them to be safely /// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc), /// which avoids copying the tensor's buffer. Note that it is still possible to mutate and /// reuse tensors' buffer without locking; see the next section on the Mutable API. /// /// ### Mutable API /// /// There is no mutable or inplace operation API to implement, but that does not mean that /// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and /// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable /// reference to their tensor buffer data structure if the tensor is not shared. In that case, /// backends can dispatch to their owned inplace operations for better performance. /// /// ## Documentation /// /// Most of the documentation for each function can be found on the user API #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// struct in the `burn-tensor` crate. /// For modules, public functions are often created, which can be used by `burn-core` modules. pub trait Backend: FloatTensorOps + BoolTensorOps + IntTensorOps + ModuleOps + ActivationOps + QTensorOps + TransactionOps + Clone + Default + Sized + Send + Sync + core::fmt::Debug + 'static { /// Device type. type Device: DeviceOps; /// Tensor primitive to be used for all float operations. type FloatTensorPrimitive: TensorMetadata + 'static; /// Default float element type. type FloatElem: Element; /// Tensor primitive to be used for all int operations. type IntTensorPrimitive: TensorMetadata + 'static; /// Int element type. type IntElem: Element; /// Tensor primitive to be used for all bool operations. type BoolTensorPrimitive: TensorMetadata + 'static; /// Tensor primitive to be used for all bool operations. type BoolElem: Element; /// Tensor primitive to be used for all quantized operations. type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static; /// If autodiff is enabled. fn ad_enabled(_device: &Self::Device) -> bool { false } /// Sets the current allocation mode to persistent. #[allow(unused_variables)] fn memory_persistent_allocations< Output: Send, Input: Send, Func: Fn(Input) -> Output + Send, >( device: &Self::Device, input: Input, func: Func, ) -> Output { func(input) } /// Manually triggers a memory cleanup on the given device. #[allow(unused_variables)] fn memory_cleanup(device: &Self::Device) {} /// Name of the backend. fn name(device: &Self::Device) -> String; /// Seeds the backend on the specified device. /// /// There is no guarantee that only the specified device will be seeded, but it is guaranteed /// that at least the specified device will be seeded. /// /// In all cases, this should ensure deterministic execution for a single-threaded program. fn seed(device: &Self::Device, seed: u64); /// Sync the backend, ensure that all computation are finished. fn sync(_device: &Self::Device) -> Result<(), ExecutionError> { Ok(()) } /// Marks the given data as being used as a staging buffer for transfer between CPU and /// accelerators like GPUs. /// /// The given data might be transferred to pinned memory or another format to improve data transfer /// speed. fn staging<'a, Iter>(_data: Iter, _device: &Self::Device) where Iter: Iterator, { } /// Whether the type is fully supported by the specified device for general operations. /// /// A type is considered supported if it can be used for the full suite of tensor /// operations, including storage, conversion, and basic arithmetic. /// /// Returning `false` does not necessarily mean the device cannot handle the type at all. /// For instance, a device might support a type only for specialized hardware /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such /// types should return `false` here as they are not globally supported. fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general()) } /// Returns the [DTypeUsageSet] for the given [DType] on the specified device. fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet; } /// An error that can happen when syncing a device. #[derive(Error, Serialize, Deserialize)] pub enum ExecutionError { /// A generic error happened during execution. /// /// The backtrace and context information should be included in the reason string. #[error("An error happened during execution\nCaused by:\n {reason}")] WithContext { /// The reason of the error. reason: String, }, /// A generic error happened during execution thrown in the Burn project. /// /// The full context isn't captured by the string alone. #[error("An error happened during execution\nCaused by:\n {reason}")] Generic { /// The reason of the error. reason: String, /// The backtrace. #[serde(skip)] backtrace: BackTrace, }, } impl core::fmt::Debug for ExecutionError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!("{self}")) } } /// Trait that allows a backend to support autodiff. pub trait AutodiffBackend: Backend { /// The inner backend type. type InnerBackend: Backend; /// Gradients type. type Gradients: Send; /// Backward pass. /// /// # Arguments /// /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. /// /// # Returns /// /// The gradients. fn backward(tensor: FloatTensor) -> Self::Gradients; /// Returns the gradients of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to extract the gradients from. /// /// # Returns /// /// An optional tensor containing the gradient. fn grad( tensor: &FloatTensor, grads: &Self::Gradients, ) -> Option>; /// Pops the gradients of a tensor and returns them. /// /// # Arguments /// /// * `tensor` - The tensor to pop the gradients from. /// * `grads` - The gradients. /// /// # Returns /// /// An optional tensor containing the given gradients. fn grad_remove( tensor: &FloatTensor, grads: &mut Self::Gradients, ) -> Option>; /// Replace the gradients of a tensor with the one provided. /// /// If no gradient existed for the provided tensor, register it. /// /// # Arguments /// /// * `tensor` - The tensor to pop the gradients from. /// * `grads` - The gradients. /// * `grad` - The updated grad tensor. fn grad_replace( tensor: &FloatTensor, grads: &mut Self::Gradients, grad: FloatTensor, ); /// Returns the tensor with inner backend type. /// /// # Arguments /// /// * `tensor` - The tensor to get the inner backend tensor for. /// /// # Returns /// /// The inner backend tensor. fn inner(tensor: FloatTensor) -> FloatTensor; /// Returns the tensor with inner backend type. /// /// # Arguments /// /// * `tensor` - The tensor to get the inner backend tensor for. /// /// # Returns /// /// The inner backend tensor. fn int_inner(tensor: IntTensor) -> IntTensor; /// Returns the tensor with inner backend type. /// /// # Arguments /// /// * `tensor` - The tensor to get the inner backend tensor for. /// /// # Returns /// /// The inner backend tensor. fn bool_inner(tensor: BoolTensor) -> BoolTensor; /// Returns the tensor with inner backend type. /// /// # Arguments /// /// * `tensor` - The tensor to get the inner backend tensor for. /// /// # Returns /// /// The inner backend tensor. fn q_inner(tensor: QuantizedTensor) -> QuantizedTensor; /// Converts the inner backend tensor to the autodiff backend tensor. /// /// # Arguments /// /// * `tensor` - The inner backend tensor to convert. /// /// /// # Returns /// /// The autodiff backend tensor. fn from_inner(tensor: FloatTensor) -> FloatTensor; /// Converts the inner backend tensor to the autodiff backend tensor. /// /// # Arguments /// /// * `tensor` - The inner backend tensor to convert. /// /// /// # Returns /// /// The autodiff backend tensor. fn int_from_inner(tensor: IntTensor) -> IntTensor; /// Converts the inner backend tensor to the autodiff backend tensor. /// /// # Arguments /// /// * `tensor` - The inner backend tensor to convert. /// /// /// # Returns /// /// The autodiff backend tensor. fn bool_from_inner(tensor: BoolTensor) -> BoolTensor; /// Converts the inner backend tensor to the autodiff backend tensor. /// /// # Arguments /// /// * `tensor` - The inner backend tensor to convert. /// /// /// # Returns /// /// The autodiff backend tensor. fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor; } /// Describes how a data type can be used on a given device. /// /// A data type may be supported for different classes of operations. Not all /// data types that appear in hardware or kernel implementations are suitable /// for general-purpose tensor operations. #[derive(Debug, EnumSetType)] pub enum DTypeUsage { /// The type can be stored in device memory and converted to and from /// other supported data types. Storage, /// The type supports general-purpose arithmetic and common tensor /// operations (e.g. elementwise ops, reductions, etc.). Arithmetic, /// The type is supported by hardware-accelerated execution paths. /// /// This typically indicates support for accelerator-backed compute units (e.g., tensor /// cores executing MMA instructions) for high-performance operations such as matrix /// multiplication and operations that lower to it. /// /// # Notes /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and /// [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations /// *and* accelerated paths. /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not /// suitable for general-purpose tensor operations and may only be used /// in specific accelerated operations. /// /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which /// operations are accelerated or which accelerator features are available. Accelerated, } /// A set of [DTypeUsage] representing the total capabilities of a data type on a device. pub type DTypeUsageSet = EnumSet; impl DTypeUsage { /// Returns the usage set required for general-purpose tensor support. pub fn general() -> DTypeUsageSet { DTypeUsage::Storage | DTypeUsage::Arithmetic } } ================================================ FILE: crates/burn-backend/src/backend/device.rs ================================================ pub use burn_std::device::*; /// Device trait for all burn backend devices. pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + Device { /// Returns the [device id](DeviceId). fn id(&self) -> DeviceId { self.to_id() } /// Returns the inner device without autodiff enabled. /// /// For most devices this is a no-op that returns `self`. For autodiff-enabled /// devices, this returns the underlying inner device. fn inner(&self) -> &Self { self } } ================================================ FILE: crates/burn-backend/src/backend/mod.rs ================================================ mod base; mod device; mod primitive; pub use base::*; pub use device::*; pub use primitive::*; /// Backend operations on tensors. pub mod ops; ================================================ FILE: crates/burn-backend/src/backend/ops/activation.rs ================================================ use crate::tensor::FloatTensor; use crate::{Backend, Scalar, TensorMetadata}; use core::f64::consts::SQRT_2; /// Activation function operations. /// /// This trait let backend implementations override activation functions for better performance. pub trait ActivationOps { /// Applies the LeakyReLU activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with. /// /// # Returns /// /// The output tensor. fn leaky_relu(tensor: FloatTensor, negative_slope: Scalar) -> FloatTensor { let mask = B::float_lower_elem(tensor.clone(), 0f32.into()); let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope); // Update the tensor where the values are `< 0` by `tensor * negative_slope`. B::float_mask_where(tensor, mask, scaled_tensor) } /// Applies the ReLU activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The output tensor. fn relu(tensor: FloatTensor) -> FloatTensor { let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into()); B::float_mask_fill(tensor, mask, 0f32.into()) } /// Applies the ReLU activation function backward. /// /// # Arguments /// /// * `output` - The output tensor. /// /// # Returns /// /// The gradient. fn relu_backward(output: FloatTensor, grad: FloatTensor) -> FloatTensor { let mask = B::float_lower_equal_elem(output, 0f32.into()); B::float_mask_fill(grad, mask, 0.into()) } /// Applies the Gelu activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The output tensor. fn gelu(tensor: FloatTensor) -> FloatTensor { let x = B::float_div_scalar(tensor.clone(), SQRT_2.into()); let x = B::float_erf(x); let x = B::float_add_scalar(x, 1f32.into()); let x = B::float_mul(tensor, x); B::float_div_scalar(x, 2f32.into()) } /// Applies the PReLu activation function. /// # Arguments /// * `tensor` - The input tensor /// * `alpha` - The weight tensor fn prelu(tensor: FloatTensor, alpha: FloatTensor) -> FloatTensor { let mask = B::float_lower_elem(tensor.clone(), 0f32.into()); let scaled_tensor = B::float_mul(tensor.clone(), alpha); B::float_mask_where(tensor, mask, scaled_tensor) } /// Applies the Gelu activation function backward. /// /// # Arguments /// /// * `x` - The tensor. /// * `grad` - The gradient. /// /// # Returns /// /// The output tensor. fn gelu_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor { // Derivative of the approximate gelu implementation based on tanh. let constant_1 = 0.0356774; let constant_2 = 0.797885; let constant_3 = 0.0535161; let constant_4 = 0.398942; let x3 = B::float_powi_scalar(x.clone(), 3.into()); let c1 = B::float_mul_scalar(x3.clone(), constant_1.into()); let c2 = B::float_mul_scalar(x.clone(), constant_2.into()); let c3 = B::float_mul_scalar(x3, constant_3.into()); let c4 = B::float_mul_scalar(x, constant_4.into()); let inner1 = B::float_add(c1, c2); let inner2 = B::float_add(c3, c4); let tanh = B::float_tanh(inner1); let sech = B::float_powi_scalar(tanh.clone(), 2.into()); let sech = B::float_neg(sech); let sech = B::float_add_scalar(sech, 1.into()); let y1 = B::float_mul_scalar(tanh, 0.5.into()); let y2 = B::float_mul(inner2, sech); let y2 = B::float_add_scalar(y2, 0.5.into()); let y = B::float_add(y1, y2); B::float_mul(y, grad) } /// Applies the Sigmoid activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The output tensor. fn sigmoid(tensor: FloatTensor) -> FloatTensor { let dtype = tensor.dtype(); let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32); let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar( B::float_exp(B::float_neg(tensor_full)), 1.0.into(), )))); B::float_cast(tensor_tmp, dtype.into()) } /// Applies the Sigmoid activation function backward. /// /// # Arguments /// /// * `output` - The output tensor of the sigmoid function. /// * `grad` - The gradient. /// /// # Returns /// /// The output tensor. fn sigmoid_backward(output: FloatTensor, grad: FloatTensor) -> FloatTensor { let value = B::float_mul( output.clone(), B::float_add_scalar(B::float_neg(output), 1.0.into()), ); B::float_mul(value, grad) } /// Applies the hard Sigmoid activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `alpha` - The alpha value that the tensor is multiplied with. /// * `beta` - The beta value that is added to the tensor /// /// # Returns /// /// The output tensor. fn hard_sigmoid(tensor: FloatTensor, alpha: Scalar, beta: Scalar) -> FloatTensor { let dtype = tensor.dtype(); let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32); let tensor_tmp = B::float_clamp( B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta), 0.0.into(), 1.0.into(), ); B::float_cast(tensor_tmp, dtype.into()) } /// Applies the LogSigmoid activation function. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The output tensor. fn log_sigmoid(tensor: FloatTensor) -> FloatTensor { // To avoid overflow, we use the log-sum-exp trick. // // ```ignore // log(sigmoid(x)) = log(1/(1 + exp(-x))) // = log(1) - log(1 + exp(-x)) // = -log(1 + exp(-x)) // = -log(exp(0) + exp(-x)) // ``` // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the // following equivalence: // ```ignore // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) // ``` // // This extends the range of values for which we obtain accurate results. // max(-x, 0) let tensor_neg = B::float_neg(tensor); let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into()); let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into()); let max_elem_neg = B::float_neg(max_elem.clone()); // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) let z = B::float_add( B::float_exp(max_elem_neg.clone()), B::float_exp(B::float_sub(tensor_neg, max_elem.clone())), ); // -max(-x, 0) - log(-z) B::float_sub(max_elem_neg, B::float_log(z)) } /// Applies the LogSigmoid activation function backward. /// /// # Arguments /// /// * `x` - The input tensor. /// * `grad` - The gradient. /// /// # Returns /// /// The output gradient. fn log_sigmoid_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor { // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) // // This simplifies to: // -max_derive - (z-1)/z if x is >= 0 // -max_derive + (z-1)/z if x is < 0 let shape = x.shape(); let dtype = x.dtype(); let device = B::float_device(&x); // max(-x, 0) let x_neg = B::float_neg(x); let mask = B::float_lower_elem(x_neg.clone(), 0f32.into()); // -x < 0 or x >= 0 let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into()); // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) let z = B::float_add( B::float_exp(B::float_neg(max_elem.clone())), B::float_exp(B::float_sub(x_neg, max_elem)), ); // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0 let ones = B::float_ones(shape, &device, dtype.into()); let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into()); let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into()); // grad * (max_derive - sign * (1 - (1 / z))) B::float_mul( grad, B::float_sub( max_derive, B::float_mul(sign, B::float_sub(ones, B::float_recip(z))), ), ) } } ================================================ FILE: crates/burn-backend/src/backend/ops/argwhere.rs ================================================ use crate::tensor::{Device, IntTensor}; use crate::{Backend, TensorData, element::ElementConversion}; use alloc::vec::Vec; use burn_std::Shape; /// Compute the indices of the elements that are non-zero, grouped by element. /// /// # Arguments /// /// * `data` - The input tensor data. /// /// # Returns /// /// A 2D tensor containing the indices of all non-zero elements of the given tensor. /// Each row contains the indices of a non-zero element. /// /// # Remarks /// /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. pub fn argwhere_data(data: TensorData, device: &Device) -> IntTensor { let dims = &data.shape; let ndims = dims.len(); let count_nonzero = data.iter::().filter(|&v| v).count(); /// Converts a flat index into a vector of indices for the specified tensor shape fn unravel_index(index: usize, shape: &[usize]) -> Vec { shape .iter() .rev() .scan(index, |i, size| { let dim_idx = *i % size; *i /= size; Some((dim_idx as i64).elem()) }) .collect::>() .into_iter() .rev() .collect() } let indices = data .iter::() .enumerate() .filter_map(|(index, v)| if v { Some(index) } else { None }) .map(|index| unravel_index::(index, dims)) .collect::>() .concat(); B::int_from_data( TensorData::new(indices, Shape::new([count_nonzero, ndims])), device, ) } ================================================ FILE: crates/burn-backend/src/backend/ops/bool_tensor.rs ================================================ use super::{ argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, }; use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor}; use crate::{Backend, TensorData, TensorMetadata}; use crate::{ExecutionError, Scalar}; use alloc::vec::Vec; use burn_std::{Shape, Slice}; use core::future::Future; /// Bool Tensor API for basic operations, see #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// for documentation on each function. pub trait BoolTensorOps { /// Creates a new bool tensor. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The boolean tensor with the given shape. fn bool_empty(shape: Shape, device: &Device) -> BoolTensor; /// Creates a new bool tensor filled false. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The boolean tensor filled with false. fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor; /// Creates a new bool tensor filled true. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The boolean tensor filled with true. fn bool_ones(shape: Shape, device: &Device) -> BoolTensor; /// Converts the tensor to a data structure. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The data structure with the tensor's data. fn bool_into_data( tensor: BoolTensor, ) -> impl Future> + Send; /// Creates a tensor from the data structure. /// /// # Arguments /// /// * `data` - The data structure. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the data. fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor; /// Converts bool tensor to int tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The int tensor with the same data as the bool tensor. fn bool_into_int(tensor: BoolTensor) -> IntTensor; /// Converts bool tensor to float tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The float tensor with the same data as the bool tensor. fn bool_into_float(tensor: BoolTensor) -> FloatTensor; /// Gets the device of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The device of the tensor. fn bool_device(tensor: &BoolTensor) -> Device; /// Moves the tensor to the device. fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor; /// Reshapes the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `shape` - The new shape. /// /// # Returns /// /// The tensor with the new shape. fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor; /// Gets the values from the tensor for the given ranges. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `slices` - The slices specifying ranges and steps for each dimension. /// /// # Returns /// /// The tensor with the values for the given slices. /// /// # Note /// /// Empty slices (where start >= end) are handled at the high-level tensor API and will not /// be passed to this method. Backend implementations do not need to handle empty slices. fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor; /// Sets the values in the tensor for the given ranges. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `ranges` - The ranges to set the values for. /// * `value` - The values to set. /// /// # Returns /// /// The tensor with the values set for the given ranges. /// /// # Note /// /// Empty slice assignments (where any slice range produces 0 elements) are handled at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty slice assignments. fn bool_slice_assign( tensor: BoolTensor, slices: &[Slice], value: BoolTensor, ) -> BoolTensor; /// Fills the tensor with values from the value tensor if the mask is true at the given /// indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `mask` - The mask. /// * `value` - The value tensor. /// /// # Returns /// /// The tensor with the values filled. fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor; /// Fills the tensor with the given value if the mask is true at the given indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `mask` - The mask. /// * `value` - The value. /// /// # Returns /// /// The tensor with the values filled. fn bool_mask_fill(tensor: BoolTensor, mask: BoolTensor, value: Scalar) -> BoolTensor; /// Gather elements from the tensor at the given indices. /// /// # Arguments /// /// * `dim` - The dimension to gather from. /// * `tensor` - The tensor. /// * `indices` - The indices. fn bool_gather(dim: usize, tensor: BoolTensor, indices: IntTensor) -> BoolTensor; /// Scatter a given value to the tensor at the given indices using boolean or reduction. /// /// # Arguments /// /// * `dim` - The dimension to scatter to. /// * `tensor` - The tensor. /// * `indices` - The indices. /// * `value` - The value. /// /// # Returns /// /// The tensor with the values scattered. fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor; /// Select tensor elements along the given dimension corresponding to the given indices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `dim` - The dimension to select from. /// * `indices` - The indices of the elements to select. /// /// # Returns /// /// The tensor with the selected elements. fn bool_select(tensor: BoolTensor, dim: usize, indices: IntTensor) -> BoolTensor { // Default implementation: convert to int, select, then convert back to bool let int_tensor = B::bool_into_int(tensor); let selected = B::int_select(int_tensor, dim, indices); B::int_equal_elem(selected, 1.into()) } /// Assign the selected elements along the given dimension corresponding to the given indices /// to the given value using sum reduction. /// /// # Arguments /// /// * `tensor` - The tensor to assign the values to. /// * `dim` - The dimension to select from. /// * `indices` - The indices of the elements to assign. /// * `value` - The values to assign. /// /// # Returns /// /// The tensor with the assigned values. fn bool_select_or( tensor: BoolTensor, dim: usize, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { // Default implementation: convert to int, select_assign, then convert back to bool let int_tensor = B::bool_into_int(tensor); let int_values = B::bool_into_int(value); let assigned = B::int_select_add(int_tensor, dim, indices, int_values); // After select_assign with sum reduction, any non-zero value should be true B::int_greater_elem(assigned, 0.into()) } /// Repeats one dimension of the tensor a given number of times along that dimension. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to repeat. /// * `times` - The number of times to repeat the dimension. /// /// # Returns /// /// The tensor with the dimension repeated. fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { repeat_with_slice_assign::(tensor, dim, times) } /// Concatenates the tensors along the given dimension. /// /// # Arguments /// /// * `tensors` - The tensors to concatenate. /// * `dim` - The dimension to concatenate along. /// /// # Returns /// /// The tensor with the tensors concatenated along the given dimension. /// /// # Note /// /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty tensors. fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { cat_with_slice_assign::(tensors, dim) } /// Equates the two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor with the result of the equate. fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor; /// Element-wise non-equality comparison. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor with the result of the comparison. fn bool_not_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let equal_tensor = B::bool_equal(lhs, rhs); B::bool_not(equal_tensor) } /// Element-wise equality comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor; /// Element-wise non-equality comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn bool_not_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { let equal_tensor = B::bool_equal_elem(lhs, rhs); B::bool_not(equal_tensor) } /// Inverses boolean values. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The tensor with the result of the negation. fn bool_not(tensor: BoolTensor) -> BoolTensor; /// Executes the logical and (`&&`) operation on two boolean tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor with the result of the logical and. fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor; /// Executes the logical or (`||`) operation on two boolean tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor with the result of the logical or. fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor; /// Element-wise exclusive or. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor with the result of the comparison. fn bool_xor(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { Self::bool_not_equal(lhs, rhs) } /// Transposes a bool tensor. /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. fn bool_transpose(tensor: BoolTensor) -> BoolTensor { let ndims = tensor.shape().num_dims(); Self::bool_swap_dims(tensor, ndims - 2, ndims - 1) } /// Swaps two dimensions of a bool tensor. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap. /// * `dim2` - The second dimension to swap. /// /// # Returns /// /// The tensor with the dimensions swapped. fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor; /// Permutes the dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to permute the dimensions of. /// * `axes` - The new order of the dimensions. /// # Returns /// /// The tensor with the dimensions permuted. fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor; /// Reverse the order of elements in a tensor along the given axes. /// /// # Arguments /// /// * `tensor` - The tensor to reverse. /// * `axes` - The axes to reverse. /// /// The tensor with the elements reversed. fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor; /// Tests if any element in the boolean `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. fn bool_any(tensor: BoolTensor) -> BoolTensor { let sum = B::int_sum(B::bool_into_int(tensor)); B::int_greater_elem(sum, 0.into()) } /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. fn bool_any_dim(tensor: BoolTensor, dim: usize) -> BoolTensor { let sum = B::int_sum_dim(B::bool_into_int(tensor), dim); B::int_greater_elem(sum, 0.into()) } /// Tests if all elements in the boolean `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn bool_all(tensor: BoolTensor) -> BoolTensor { let num_elems = tensor.shape().num_elements() as i64; let sum = B::int_sum(B::bool_into_int(tensor)); B::int_equal_elem(sum, num_elems.into()) } /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn bool_all_dim(tensor: BoolTensor, dim: usize) -> BoolTensor { let num_elems = tensor.shape()[dim] as i64; let sum = B::int_sum_dim(B::bool_into_int(tensor), dim); B::int_equal_elem(sum, num_elems.into()) } /// Compute the indices of the elements that are non-zero, grouped by element. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A 2D tensor containing the indices of all non-zero elements of the given tensor. /// Each row contains the indices of a non-zero element. fn bool_argwhere(tensor: BoolTensor) -> impl Future> + 'static + Send { async { // Size of each output tensor is variable (= number of nonzero elements in the tensor). // Reading the data to count the number of truth values might cause sync but is required. let device = B::bool_device(&tensor); let data = B::bool_into_data(tensor) .await .expect("Can read the data without error"); argwhere_data::(data, &device) } } /// Broadcasts the bool `tensor` to the given `shape`. fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor; /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the selected dim. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, size, post=...]``. fn bool_unfold(tensor: BoolTensor, dim: usize, size: usize, step: usize) -> BoolTensor; } ================================================ FILE: crates/burn-backend/src/backend/ops/cat.rs ================================================ use crate::{ Backend, TensorMetadata, tensor::{BasicOps, TensorKind}, }; use alloc::vec::Vec; use burn_std::Slice; pub(crate) fn cat_with_slice_assign + BasicOps>( tensors: Vec, dim: usize, ) -> K::Primitive { let first_tensor = tensors.first().expect("Tensors should not be empty"); let mut shape = first_tensor.shape(); let device = K::device(first_tensor); let dtype = first_tensor.dtype(); let output_dim_length: usize = tensors.iter().map(|tensor| tensor.shape()[dim]).sum(); shape[dim] = output_dim_length; let mut tensor_output = K::empty(shape.clone(), &device, dtype); let indices_select_all = shape.iter().map(|d| 0..*d).collect::>(); let mut output_index = 0; for tensor in tensors { let mut indices = indices_select_all.clone(); let tensor_dim_length = tensor.shape()[dim]; indices[dim] = output_index..output_index + tensor_dim_length; output_index += tensor_dim_length; // Convert ranges to Slice let slices: Vec = indices .iter() .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1)) .collect(); tensor_output = K::slice_assign(tensor_output, &slices, tensor); } tensor_output } ================================================ FILE: crates/burn-backend/src/backend/ops/int_tensor.rs ================================================ use super::cat::cat_with_slice_assign; use super::repeat_dim::repeat_with_slice_assign; use super::sort::{argsort, sort, sort_with_indices}; use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor}; use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion}; use crate::{ExecutionError, Scalar}; use alloc::vec::Vec; use burn_std::{IntDType, Shape, Slice}; use core::ops::Range; /// Int Tensor API for basic and numeric operations, see #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// for documentation on each function. pub trait IntTensorOps { /// Creates a new int tensor. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The integer tensor with the given shape. fn int_empty(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor; /// Converts the tensor to a data structure. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The data structure with the tensor's data. fn int_into_data( tensor: IntTensor, ) -> impl Future> + Send; /// Creates a tensor from the data structure. /// /// # Arguments /// /// * `data` - The data structure. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the data. fn int_from_data(data: TensorData, device: &Device) -> IntTensor; /// Gets the device of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The device of the tensor. fn int_device(tensor: &IntTensor) -> Device; /// Moves the tensor to the given device. fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor; /// Reshapes the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `shape` - The new shape. /// /// # Returns /// /// The tensor with the new shape. fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor; /// Gets the element at the given indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `slices` - The slices specifying ranges and steps for each dimension. /// /// # Returns /// /// The elements at the given indices. /// /// # Note /// /// Empty slices (where start >= end) are handled at the high-level tensor API and will not /// be passed to this method. Backend implementations do not need to handle empty slices. fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor; /// Sets the values in the tensor for the given ranges. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `ranges` - The ranges to set the values for. /// /// # Returns /// /// The tensor with the values set for the given ranges. /// /// # Note /// /// Empty slice assignments (where any slice range produces 0 elements) are handled at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty slice assignments. fn int_slice_assign( tensor: IntTensor, slices: &[Slice], value: IntTensor, ) -> IntTensor; /// Converts int tensor to float tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The int tensor with the same data as the float tensor. fn int_into_float(tensor: IntTensor) -> FloatTensor; /// Fills the tensor with values from the value tensor if the mask is true at the given /// indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `mask` - The mask. /// * `value` - The value tensor. /// /// # Returns /// /// The tensor with the values filled. fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> IntTensor; /// Fills the tensor with the given value if the mask is true at the given indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `mask` - The mask. /// * `value` - The value. /// /// # Returns /// /// The tensor with the values filled. fn int_mask_fill(tensor: IntTensor, mask: BoolTensor, value: Scalar) -> IntTensor; /// Gather elements from the tensor at the given indices. /// /// # Arguments /// /// * `dim` - The dimension to gather from. /// * `tensor` - The tensor. /// * `indices` - The indices. fn int_gather(dim: usize, tensor: IntTensor, indices: IntTensor) -> IntTensor; /// Scatter a given value to the tensor at the given indices using sum reduction. /// /// # Arguments /// /// * `dim` - The dimension to scatter to. /// * `tensor` - The tensor. /// * `indices` - The indices. /// * `value` - The value. /// /// # Returns /// /// The tensor with the values scattered. fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor; /// Select tensor elements along the given dimension corresponding to the given indices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to select from. /// * `indices` - The indices. /// /// # Returns /// /// The tensor with the selected elements. fn int_select(tensor: IntTensor, dim: usize, indices: IntTensor) -> IntTensor; /// Assign the selected elements along the given dimension corresponding to the given indices /// to the given value using sum reduction. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to select from. /// * `indices` - The indices. /// * `value` - The value. /// /// # Returns /// /// The tensor with the selected elements assigned to the given value. fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor; /// Repeats the tensor along the given dimension the given number of times. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to repeat. /// * `times` - The number of times to repeat. /// /// # Returns /// /// The tensor with the given dimension repeated the given number of times. fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { repeat_with_slice_assign::(tensor, dim, times) } /// Concatenates the given tensors along the given dimension. /// /// # Arguments /// /// * `tensors` - The tensors. /// * `dim` - The dimension to concatenate along. /// /// # Returns /// /// The concatenated tensor. /// /// # Note /// /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty tensors. fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { cat_with_slice_assign::(tensors, dim) } /// Element-wise equality comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; /// Element-wise non-equality comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_not_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let equal_tensor = B::int_equal(lhs, rhs); B::bool_not(equal_tensor) } /// Element-wise equality comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor; /// Element-wise non-equality comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_not_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let equal_tensor = B::int_equal_elem(lhs, rhs); B::bool_not(equal_tensor) } /// Element-wise greater than comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; /// Element-wise greater than comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor; /// Element-wise greater than or equal comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; /// Element-wise greater than or equal comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor; /// Element-wise less than comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; /// Element-wise less than comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor; /// Element-wise less than or equal comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor; /// Element-wise less than or equal comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The boolean tensor with the result of the comparison. fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor; // ==== NUMERIC ==== // /// Element-wise addition. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of the addition. fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise addition with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of the addition. fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Element-wise power with a IntTensor. /// /// # Arguments /// /// * `lhs` - The left-hand side IntTensor. /// * `rhs` - The right-hand side IntTensor. /// /// # Returns /// /// The elements of `lhs` raised to the power of the elements of `rhs`. fn int_powi(lhs: IntTensor, rhs: IntTensor) -> IntTensor { B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs)) } /// Element-wise power with a scalar. /// /// # Backend Implementors Note /// /// A number of common exponent cases can be implemented with operations /// which are much cheaper than generic exponentiation. /// /// This (`Backend` impl overridable) operation handles generic optimizations /// for several common integer exponent cases; and then dispatches to /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`] /// operation to handle the generic case. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. fn int_powi_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let exp = rhs.elem::(); match exp { 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()), 1 => lhs, 2 => Self::int_mul(lhs.clone(), lhs), _ => Self::int_powi_scalar_impl(lhs, rhs), } } /// Element-wise power with a scalar. /// /// # Backend Implementors Note /// /// This is the generic implementation of integer exponentiation /// called by [`Self::int_powi_scalar`] in the fallback case. /// /// By default, this performs a relatively expensive conversion to float, /// exponentiation in float, and conversion back to int. /// This reduces the minimal operation set for `Backend`s, /// at the cost of performance. /// /// This is a good target for specialized optimizations in `Backend` implementations. /// /// As a general rule, this should not be called directly. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. fn int_powi_scalar_impl(lhs: IntTensor, rhs: Scalar) -> IntTensor { B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs)) } /// Clamps a tensor under a minimum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// /// # Returns /// /// The clamped tensor. fn int_clamp_min(tensor: IntTensor, min: Scalar) -> IntTensor { let mask = Self::int_lower_elem(tensor.clone(), min); Self::int_mask_fill(tensor, mask, min) } /// Clamps a tensor over a maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn int_clamp_max(tensor: IntTensor, max: Scalar) -> IntTensor { let mask = Self::int_greater_elem(tensor.clone(), max); Self::int_mask_fill(tensor, mask, max) } /// Clamps a tensor between a minimum and maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) } /// Element-wise subtraction. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of the subtraction. fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise subtraction with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of the subtraction. fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Element-wise multiplication. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of the multiplication. fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise multiplication with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of the multiplication. fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Element-wise division. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of the division. fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise division with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of the division. fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Element-wise modulus. /// /// # Arguments /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of applying the modulus of the scalar to the tensor. fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise modulus with a scalar. /// /// # Arguments /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of applying the modulus of the scalar to the tensor. fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Multiplies two tensors together using matrix multiplication. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of multiplying the two tensors together using matrix multiplication. fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Element-wise negation. /// /// # Arguments /// /// * `tensor` - The tensor to negate. /// /// # Returns /// /// The negated tensor. fn int_neg(tensor: IntTensor) -> IntTensor { Self::int_mul_scalar(tensor, (-1).into()) } /// Creates a tensor of zeros. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor of zeros. fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device) } /// Creates a tensor of ones. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor of ones. fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device) } /// Creates a tensor filled with given value. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `fill_value` - The value with which to fill the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor filled with given value fn int_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: IntDType, ) -> IntTensor { Self::int_from_data( TensorData::full_dtype(shape, fill_value, dtype.into()), device, ) } /// Sums all elements in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// /// # Returns /// /// The sum of all elements in the tensor. fn int_sum(tensor: IntTensor) -> IntTensor; /// Sums all elements in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// * `dim` - The dimension to sum along. /// /// # Returns /// /// The sum of all elements in the tensor along the dimension. fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the product of all elements in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to compute the product of. /// /// # Returns /// /// The product of all elements in the tensor. fn int_prod(tensor: IntTensor) -> IntTensor; /// Computes the product of all elements in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the product of. /// * `dim` - The dimension to compute the product along. /// /// # Returns /// /// The product of all elements in the tensor along the dimension. fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the mean of all elements in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to compute the mean of. /// /// # Returns /// /// The mean of all elements in the tensor. fn int_mean(tensor: IntTensor) -> IntTensor { let num_elems = tensor.shape().num_elements() as i64; B::int_div_scalar(B::int_sum(tensor), num_elems.into()) } /// Computes the mean of all elements in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the mean of. /// /// # Returns /// /// The mean of all elements in the tensor along the dimension. fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the cumulative sum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative sum of. /// * `dim` - The dimension along which to compute the cumulative sum. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative sum /// of all elements up to and including that position along the dimension. fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the cumulative product of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative product of. /// * `dim` - The dimension along which to compute the cumulative product. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative product /// of all elements up to and including that position along the dimension. fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the cumulative minimum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative minimum of. /// * `dim` - The dimension along which to compute the cumulative minimum. /// /// # Returns /// /// A tensor with the same shape where each element is the minimum /// of all elements up to and including that position along the dimension. fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor; /// Computes the cumulative maximum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative maximum of. /// * `dim` - The dimension along which to compute the cumulative maximum. /// /// # Returns /// /// A tensor with the same shape where each element is the maximum /// of all elements up to and including that position along the dimension. fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor; /// Gets the indices of the maximum elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum indices of. /// * `dim` - The dimension to get the maximum indices along. /// /// # Returns /// /// The indices of the maximum elements along the dimension. fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor; /// Gets the indices of the minimum elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum indices of. /// * `dim` - The dimension to get the minimum indices along. /// /// # Returns /// /// The indices of the minimum elements along the dimension. fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor; /// Gets the maximum element in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum element of. /// /// # Returns /// /// The maximum element in the tensor. fn int_max(tensor: IntTensor) -> IntTensor { let shape = tensor.shape(); let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); B::int_max_dim(tensor, 0) } /// Gets the maximum element in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum element of. /// * `dim` - The dimension to get the maximum element along. /// /// # Returns /// /// The maximum element in the tensor along the dimension. fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { let index = B::int_argmax(tensor.clone(), dim); B::int_gather(dim, tensor, index) } /// Gets the maximum elements and corresponding indices along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements and indices of. /// * `dim` - The dimension to get the maximum elements and indices along. /// /// # Returns /// /// The maximum elements and corresponding indices along the dimension. fn int_max_dim_with_indices(tensor: IntTensor, dim: usize) -> (IntTensor, IntTensor) { let index = B::int_argmax(tensor.clone(), dim); let values = B::int_gather(dim, tensor, index.clone()); (values, index) } /// Gets the maximum absolute element in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum element of. /// /// # Returns /// /// The maximum element in the tensor. fn int_max_abs(tensor: IntTensor) -> IntTensor { let shape = tensor.shape(); let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); B::int_max_abs_dim(tensor, 0) } /// Gets the maximum absolute element in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum element of. /// * `dim` - The dimension to get the maximum element along. /// /// # Returns /// /// The maximum element in the tensor along the dimension. fn int_max_abs_dim(tensor: IntTensor, dim: usize) -> IntTensor { B::int_max_dim(B::int_abs(tensor), dim) } /// Gets the minimum element in the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum element of. /// /// # Returns /// /// The minimum element in the tensor. fn int_min(tensor: IntTensor) -> IntTensor { let shape = tensor.shape(); let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); B::int_min_dim(tensor, 0) } /// Gets the minimum elements in the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum element of. /// * `dim` - The dimension to get the minimum element along. /// /// # Returns /// /// The minimum element in the tensor along the dimension. fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { let index = B::int_argmin(tensor.clone(), dim); B::int_gather(dim, tensor, index) } /// Gets the minimum elements and corresponding indices along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements and indices of. /// * `dim` - The dimension to get the minimum elements and indices along. /// /// # Returns /// /// The minimum elements and corresponding indices along the dimension. fn int_min_dim_with_indices(tensor: IntTensor, dim: usize) -> (IntTensor, IntTensor) { let indices = B::int_argmin(tensor.clone(), dim); let values = B::int_gather(dim, tensor, indices.clone()); (values, indices) } /// Returns a new tensor with absolute values. /// /// # Arguments /// /// * `tensor` - The tensor to take absolute value of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with absolute values. fn int_abs(tensor: IntTensor) -> IntTensor; /// Transposes an int tensor. /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. fn int_transpose(tensor: IntTensor) -> IntTensor { let ndims = tensor.shape().num_dims(); Self::int_swap_dims(tensor, ndims - 2, ndims - 1) } /// Swaps two dimensions of an int tensor. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap. /// * `dim2` - The second dimension to swap. /// /// # Returns /// /// The tensor with the dimensions swapped. fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor; /// Permutes the dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to permute the dimensions of. /// * `axes` - The new order of the dimensions. /// # Returns /// /// The tensor with the dimensions permuted. fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor; /// Reverse the order of elements in a tensor along the given axes. /// /// # Arguments /// /// * `tensor` - The tensor to reverse. /// * `axes` - The axes to reverse. /// /// The tensor with the elements reversed. fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor; /// Creates a new int tensor with random values. /// /// # Arguments /// * `shape` - The shape of the tensor. /// * `distribution` - The distribution to sample from. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given shape and random values. fn int_random(shape: Shape, distribution: Distribution, device: &Device) -> IntTensor; /// Creates a new tensor with values from the given range with the given step size. /// /// # Arguments /// /// * `range` - The range of values. /// * `step` - The step size. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given values. fn int_arange_step(range: Range, step: usize, device: &Device) -> IntTensor { let value = range .step_by(step) .map(|i| i.elem()) .collect::>>(); let shape = Shape::new([value.len()]); let data = TensorData::new(value, shape); B::int_from_data(data, device) } /// Creates a new tensor with values from the given range. /// /// # Arguments /// /// * `range` - The range of values. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given values. /// /// # Remarks /// /// Uses `arange_step` with a step size of 1 under the hood. fn int_arange(range: Range, device: &Device) -> IntTensor { Self::int_arange_step(range, 1, device) } /// Tests if any element in the int `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. fn int_any(tensor: IntTensor) -> BoolTensor { let bool_tensor = B::int_equal_elem(tensor, 0.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum(B::bool_into_int(bool_tensor)); B::int_greater_elem(sum, 0.into()) } /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. fn int_any_dim(tensor: IntTensor, dim: usize) -> BoolTensor { let bool_tensor = B::int_equal_elem(tensor, 0.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim); B::int_greater_elem(sum, 0.into()) } /// Tests if all elements in the int `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn int_all(tensor: IntTensor) -> BoolTensor { let num_elems = tensor.shape().num_elements() as i64; let bool_tensor = B::int_equal_elem(tensor, 0.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum(B::bool_into_int(bool_tensor)); B::int_equal_elem(sum, num_elems.into()) } /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn int_all_dim(tensor: IntTensor, dim: usize) -> BoolTensor { let num_elems = tensor.shape()[dim] as i64; let bool_tensor = B::int_equal_elem(tensor, 0.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim); B::int_equal_elem(sum, num_elems.into()) } /// Returns the signs of the int `tensor`. /// /// # Arguments /// /// * `tensor` - The tensor to extract the signs from. /// /// # Returns /// /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. fn int_sign(tensor: IntTensor) -> IntTensor { let dtype = tensor.dtype(); let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into()); let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into()); let greater_than_zero = B::int_greater_elem(tensor, 0.into()); let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into()); result = B::int_mask_fill(result, greater_than_zero, 1.into()); result } /// Broadcasts the int `tensor` to the given `shape`. fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor; /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. fn int_sort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { sort::(tensor, dim, descending) } /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. fn int_sort_with_indices( tensor: IntTensor, dim: usize, descending: bool, ) -> (IntTensor, IntTensor) { sort_with_indices::(tensor, dim, descending) } /// Returns the indices that sort the elements of the input `tensor` by value /// along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { argsort::(tensor, dim, descending) } /// Bitwise AND operation for Int Tensors fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Bitwise AND operation for Int Tensors with a scalar fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Bitwise OR operation for Int Tensors fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Bitwise OR operation for Int Tensors with a scalar fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Bitwise XOR operation for Int Tensors fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Bitwise XOR operation for Int Tensors with a scalar fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Bitwise NOT operation for Int Tensors fn bitwise_not(tensor: IntTensor) -> IntTensor; /// Bitwise left shift operation for Int Tensors fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Bitwise left shift operation for Int Tensors with a scalar fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Bitwise right shift operation for Int Tensors fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; /// Bitwise right shift operation for Int Tensors with a scalar fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor; /// Converts a tensor to another integer data type. /// /// # Arguments /// /// * `tensor` - The tensor to convert. /// * `dtype` - The target data type. /// /// # Returns /// /// A tensor with the same values as `tensor` but in the target integer data type. fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor; /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the selected dim. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, size, post=...]``. fn int_unfold(tensor: IntTensor, dim: usize, size: usize, step: usize) -> IntTensor; } ================================================ FILE: crates/burn-backend/src/backend/ops/mod.rs ================================================ mod activation; mod bool_tensor; mod int_tensor; mod modules; mod qtensor; mod tensor; mod transaction; pub(crate) mod argwhere; pub(crate) mod cat; pub(crate) mod repeat_dim; pub(crate) mod sort; pub use activation::*; pub use bool_tensor::*; pub use int_tensor::*; pub use modules::*; pub use qtensor::*; pub use tensor::*; pub use transaction::*; ================================================ FILE: crates/burn-backend/src/backend/ops/modules/attention.rs ================================================ use core::f32; #[allow(unused_imports)] use num_traits::Float as _; use burn_std::Shape; use crate::{ Backend, TensorMetadata, ops::AttentionModuleOptions, tensor::{BoolTensor, FloatTensor}, }; /// Computes softmax(QKᵗ * scale) · V using separate kernels. /// Serves as a fallback when FlashAttention is not used. pub fn attention_fallback( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> FloatTensor { if let Some(softcap) = options.softcap { assert!(softcap > 0.0, "softcap must be positive, got {softcap}"); } // Attention scores: A = QKᵗ * scale let query_shape = query.shape().dims::<4>(); let scale = options .scale .unwrap_or_else(|| 1.0 / (*query_shape.last().unwrap() as f64).sqrt()); let transposed_key = B::float_transpose(key); let qk = B::float_matmul(query, transposed_key); let attention_scores = B::float_mul_scalar(qk, scale.into()); // Softcap: softcap * tanh(scores / softcap) // Applied to raw logits before any -inf masking, so that tanh does not // map -inf to a finite value (which would break masking semantics). let attention_scores = if let Some(softcap) = options.softcap { let scaled = B::float_div_scalar(attention_scores, softcap.into()); let tanh = B::float_tanh(scaled); B::float_mul_scalar(tanh, softcap.into()) } else { attention_scores }; // Bool masking let attention_scores = if let Some(mask) = mask { B::float_mask_fill(attention_scores, mask, f32::NEG_INFINITY.into()) } else { attention_scores }; // Causal masking: mask positions where col > row (future positions) let attention_scores = if options.is_causal { let causal_mask = build_causal_mask::(&attention_scores); B::float_mask_fill(attention_scores, causal_mask, f32::NEG_INFINITY.into()) } else { attention_scores }; // Additive bias (ALiBi, relative position biases, etc.) let attention_scores = if let Some(bias) = attn_bias { B::float_add(attention_scores, bias) } else { attention_scores }; // Softmax: S = softmax(A) let max_per_dim = B::float_max_dim(attention_scores.clone(), 3); let minus_max = B::float_sub(attention_scores, max_per_dim); let numerator = B::float_exp(minus_max); let sum_exp = B::float_sum_dim(numerator.clone(), 3); let softmax = B::float_div(numerator, sum_exp); // Context: S · V B::float_matmul(softmax, value) } /// Builds a causal (upper-triangular) bool mask where `true` means "mask this position". /// Shape: [batch_size, num_heads, seq_q, seq_k], masking positions where col > row. fn build_causal_mask(attention_scores: &FloatTensor) -> BoolTensor { let device = B::float_device(attention_scores); let scores_shape = attention_scores.shape().dims::<4>(); let [batch_size, num_heads, seq_q, seq_k] = scores_shape; // row indices [seq_q, 1] and col indices [1, seq_k] // Offset col indices so that the causal boundary aligns at the bottom-right corner, // which handles cross-attention (seq_k > seq_q) correctly. let offset = seq_k as i64 - seq_q as i64; let rows = B::int_reshape( B::int_arange(0..seq_q as i64, &device), Shape::new([seq_q, 1]), ); let cols = B::int_reshape( B::int_arange(0..seq_k as i64, &device), Shape::new([1, seq_k]), ); // mask where col > row + offset (upper triangle) let rows_shifted = B::int_add_scalar(rows, offset.into()); let mask_2d = B::int_lower(rows_shifted, cols); // Reshape to [1, 1, seq_q, seq_k] then expand to [batch_size, num_heads, seq_q, seq_k] let mask_4d = B::bool_reshape(mask_2d, Shape::new([1, 1, seq_q, seq_k])); B::bool_expand(mask_4d, Shape::new([batch_size, num_heads, seq_q, seq_k])) } ================================================ FILE: crates/burn-backend/src/backend/ops/modules/base.rs ================================================ use super::{conv, pool}; use crate::ops::unfold::unfold4d_using_conv2d; use crate::tensor::{BoolTensor, FloatTensor, IntTensor}; use crate::{Backend, ElementConversion, TensorMetadata}; use burn_std::Shape; use core::num::NonZeroUsize; /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). #[derive(new)] pub struct Conv2dBackward { /// Gradient. pub x_grad: FloatTensor, /// Weights gradient. pub weights_grad: FloatTensor, /// Bias gradient. pub bias_grad: Option>, } /// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d). #[derive(new)] pub struct DeformConv2dBackward { /// Gradient. pub x_grad: FloatTensor, /// Offset gradient. pub offset_grad: FloatTensor, /// Weights gradient. pub weight_grad: FloatTensor, /// Mask gradient. pub mask_grad: Option>, /// Bias gradient. pub bias_grad: Option>, } /// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d). #[derive(new)] pub struct Conv3dBackward { /// Gradient. pub x_grad: FloatTensor, /// Weights gradient. pub weights_grad: FloatTensor, /// Bias gradient. pub bias_grad: Option>, } /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). #[derive(new)] pub struct MaxPool1dBackward { /// Gradient. pub x_grad: FloatTensor, } /// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices). #[derive(new)] pub struct MaxPool1dWithIndices { /// The output tensor. pub output: FloatTensor, /// The indices tensor. pub indices: IntTensor, } /// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). #[derive(new)] pub struct MaxPool2dBackward { /// Gradient. pub x_grad: FloatTensor, } /// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices). #[derive(new)] pub struct MaxPool2dWithIndices { /// The output tensor. pub output: FloatTensor, /// The indices tensor. pub indices: IntTensor, } /// Check that the parameter value is non-zero. // NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`. pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize { NonZeroUsize::new(value).expect(msg); value } /// Convolution options. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConvOptions { /// Stride (non-zero). pub stride: [usize; N], /// Padding. pub padding: [usize; N], /// Dilation (non-zero). pub dilation: [usize; N], /// Groups (non-zero). pub groups: usize, } impl ConvOptions { /// Constructs a new `ConvOptions`. pub fn new( stride: [usize; N], padding: [usize; N], dilation: [usize; N], groups: usize, ) -> Self { Self { stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), padding, dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), groups: check_nonzero(groups, "groups must be non-zero"), } } } /// Convolution options with support for asymmetric padding. /// /// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op) /// and adds optional asymmetric padding. When asymmetric padding is specified, /// the functional convolution layer applies an explicit pad operation before /// dispatching to the backend. /// /// Implements `From>` for backward compatibility. #[derive(Debug, Clone)] pub struct PaddedConvOptions { /// The underlying convolution options for the backend. pub options: ConvOptions, /// Padding at the end of each dimension (e.g., bottom/right for 2D). /// If `None`, padding is symmetric (same as `options.padding`). /// If `Some`, specifies different end-padding per dimension. pub padding_end: Option<[usize; N]>, } impl PaddedConvOptions { /// Creates options with asymmetric padding. /// /// `padding_start` is stored in `ConvOptions::padding`. /// `padding_end` specifies the end padding per dimension. pub fn asymmetric( stride: [usize; N], padding_start: [usize; N], padding_end: [usize; N], dilation: [usize; N], groups: usize, ) -> Self { let options = ConvOptions::new(stride, padding_start, dilation, groups); if padding_start == padding_end { Self { options, padding_end: None, } } else { Self { options, padding_end: Some(padding_end), } } } /// Returns true if padding is asymmetric. pub fn is_asymmetric(&self) -> bool { self.padding_end.is_some() } } impl From> for PaddedConvOptions { fn from(options: ConvOptions) -> Self { Self { options, padding_end: None, } } } /// Convolution options. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct DeformConvOptions { /// Stride (non-zero). pub stride: [usize; N], /// Padding. pub padding: [usize; N], /// Dilation (non-zero). pub dilation: [usize; N], /// Weight Groups (non-zero). pub weight_groups: usize, /// Offset Groups (non-zero). pub offset_groups: usize, } impl DeformConvOptions { /// Constructs a new `DeformConvOptions`. pub fn new( stride: [usize; N], padding: [usize; N], dilation: [usize; N], weight_groups: usize, offset_groups: usize, ) -> Self { Self { stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), padding, dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"), offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"), } } } /// Transposed convolution options. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConvTransposeOptions { /// Stride (non-zero). pub stride: [usize; N], /// Padding. pub padding: [usize; N], /// Padding out. pub padding_out: [usize; N], /// Dilation (non-zero). pub dilation: [usize; N], /// Groups (non-zero). pub groups: usize, } impl ConvTransposeOptions { /// Constructs a new `ConvTransposeOptions`. pub fn new( stride: [usize; N], padding: [usize; N], padding_out: [usize; N], dilation: [usize; N], groups: usize, ) -> Self { Self { stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), padding, padding_out, dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), groups: check_nonzero(groups, "groups must be non-zero"), } } } /// Unfold operation options. #[derive(Debug, Clone)] pub struct UnfoldOptions { /// The number of positions to slide over the input tensor in each dimension. /// A stride of `[1, 1]` will slide the kernel one pixel at a time. pub stride: [usize; 2], /// The number of zero-padding pixels added to each side of the input tensor in each dimension. pub padding: [usize; 2], /// The spacing between the blocks (patches) in the original input tensor. pub dilation: [usize; 2], } impl UnfoldOptions { /// Constructs a new `UnfoldOptions`. pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self { Self { stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), padding, dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), } } } /// Algorithm used for upsampling. #[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)] pub enum InterpolateMode { /// Nearest-neighbor interpolation. /// Nearest, /// Bilinear interpolation. /// Bilinear, /// Bicubic interpolation. /// Bicubic, /// Lanczos3 interpolation (6-tap sinc-based filter). /// Lanczos3, } /// Interpolation options. #[derive(Debug, Clone)] pub struct InterpolateOptions { /// Algorithm used for upsampling. pub mode: InterpolateMode, /// If `true`, the input and output tensors are aligned by their corner pixels. /// If `false`, half-pixel coordinate mapping is used instead. pub align_corners: bool, } impl InterpolateOptions { /// Create new interpolate options with the given mode. /// Defaults to `align_corners = true`. pub fn new(mode: InterpolateMode) -> Self { Self { mode, align_corners: true, } } /// Set align_corners. pub fn with_align_corners(mut self, align_corners: bool) -> Self { self.align_corners = align_corners; self } } /// Padding mode for grid sampling when coordinates are out of bounds. /// /// Matches PyTorch's `padding_mode` parameter in `grid_sample`. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)] pub enum GridSamplePaddingMode { /// Fill with zeros for out-of-bounds coordinates. #[default] Zeros, /// Clamp coordinates to the border (use nearest edge value). Border, /// Reflect coordinates at the boundary. Reflection, } /// Options for grid sampling operations. #[derive(Debug, Clone)] pub struct GridSampleOptions { /// Interpolation mode (bilinear, nearest, or bicubic). pub mode: InterpolateMode, /// Padding mode for out-of-bounds coordinates. pub padding_mode: GridSamplePaddingMode, /// If `true`, grid values of -1 and 1 correspond to the corner pixels. /// If `false`, they correspond to the corner points of the corner pixels /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates). pub align_corners: bool, } impl Default for GridSampleOptions { fn default() -> Self { Self { mode: InterpolateMode::Bilinear, padding_mode: GridSamplePaddingMode::Zeros, align_corners: false, } } } impl From for GridSampleOptions { fn from(value: InterpolateMode) -> Self { GridSampleOptions::new(value) } } impl GridSampleOptions { /// Create new grid sample options with the given interpolation mode. /// /// Uses default values for padding_mode (Zeros) and align_corners (false). pub fn new(mode: InterpolateMode) -> Self { Self { mode, ..Default::default() } } /// Set the padding mode. pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self { self.padding_mode = padding_mode; self } /// Set align_corners. pub fn with_align_corners(mut self, align_corners: bool) -> Self { self.align_corners = align_corners; self } } /// Padding mode for tensor pad operations. /// /// Defines how values are filled when padding a tensor beyond its original boundaries. /// Padding can be applied to any dimension of a tensor. /// /// # Modes /// /// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0) /// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size) /// - [`Edge`](PadMode::Edge): Replicate boundary values #[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)] pub enum PadMode { /// Fill padded regions with a constant value. /// /// # Example /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0: /// Result: `[0, 0, 1, 2, 3]` Constant(f32), /// Reflect values at the boundary, excluding the edge value. /// /// Padding must be less than the dimension size (i.e., `padding < dim_size`). /// /// # Example /// For tensor `[1, 2, 3, 4]` with padding 2 on the left: /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0) Reflect, /// Replicate the edge values. /// /// # Example /// For tensor `[1, 2, 3, 4]` with padding 2 on the left: /// Result: `[1, 1, 1, 2, 3, 4]` Edge, } impl Default for PadMode { fn default() -> Self { PadMode::Constant(0.0) } } impl From for PadMode { fn from(value: E) -> Self { PadMode::Constant(value.elem()) } } /// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate). #[derive(new)] pub struct InterpolateBackward { /// Gradient. pub x_grad: FloatTensor, } /// Options for [attention](ModuleOps::attention). #[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)] pub struct AttentionModuleOptions { /// Custom scale factor applied to QK^T. When `None`, defaults to `1/sqrt(head_dim)`. pub scale: Option, /// Soft capping applied before softmax: `softcap * tanh(scores / softcap)`. /// Used by Gemma-2 and similar models. Must be positive when set. pub softcap: Option, /// When `true`, applies causal (autoregressive) masking so that each query position /// can only attend to key positions at or before it. This is more efficient than /// passing an explicit lower-triangular bool mask because backends can use optimized /// kernel paths (e.g. flash attention with causal mode). pub is_causal: bool, } /// Module operations trait. pub trait ModuleOps { /// Embedding operation. /// /// # Arguments /// /// * `weights` - The embedding weights. /// * `indices` - The indices tensor. /// /// # Returns /// /// The output tensor. fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { let [batch_size, seq_length] = indices.shape().dims(); let [_, d_model] = weights.shape().dims(); let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); let output = B::float_select(weights, 0, indices); B::float_reshape(output, Shape::new([batch_size, seq_length, d_model])) } /// Embedding backward operation. /// /// # Arguments /// /// * `weights` - The embedding weights. /// * `output_grad` - The output gradient. /// * `indices` - The indices tensor. /// /// # Returns /// /// The gradient. fn embedding_backward( weights: FloatTensor, output_grad: FloatTensor, indices: IntTensor, ) -> FloatTensor { let [batch_size, seq_length] = indices.shape().dims(); let [n_embeddings, d_model] = weights.shape().dims(); let device = B::float_device(&weights); let dtype = output_grad.dtype(); let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); let output_grad = B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into()); B::float_select_add(grad, 0, indices, output_grad) } /// One dimensional convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, length]`, /// weight: `[channels_out, channels_in, kernel_size]`, /// bias: `[channels_out]`, fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { conv::conv1d_from_conv2d::(x, weight, bias, options) } /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`. fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { conv::conv1d_x_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`. fn conv1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { conv::conv1d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`. fn conv1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv1d_bias_backward::(x, bias, output_grad) } /// Two dimensional convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, height, width]`, /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, /// bias: `[channels_out]`, fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<2>, ) -> FloatTensor; /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`. fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { conv::conv2d_x_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`. fn conv2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { conv::conv2d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`. fn conv2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv2d_bias_backward::(x, bias, output_grad) } /// Two dimensional deformable convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, height, width]`, /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, /// bias: `[channels_out]`, fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor; /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation. fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward; /// Three dimensional convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, depth, height, width]`, /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`, /// bias: `[channels_out]`, fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor; /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`. fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { conv::conv3d_x_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`. fn conv3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { conv::conv3d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`. fn conv3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv3d_bias_backward::(x, bias, output_grad) } /// One dimensional transposed convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, length]`, /// weight: `[channels_in, channels_out, length]`, /// bias: `[channels_out]`, fn conv_transpose1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { conv::conv_transpose1d_from_conv_transpose2d::(x, weight, bias, options) } /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`. fn conv_transpose1d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { conv::conv_transpose1d_x_backward::(weight, output_grad, options) } /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`. fn conv_transpose1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { conv::conv_transpose1d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`. fn conv_transpose1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv_transpose1d_bias_backward::(x, bias, output_grad) } /// Two dimensional transposed convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, height, width]`, /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, /// bias: `[channels_out]`, fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor; /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`. fn conv_transpose2d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { conv::conv_transpose2d_x_backward::(weight, output_grad, options) } /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`. fn conv_transpose2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { conv::conv_transpose2d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`. fn conv_transpose2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv_transpose2d_bias_backward::(x, bias, output_grad) } /// Three dimensional transposed convolution. /// /// # Shapes /// /// x: `[batch_size, channels_in, height, width]`, /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`, /// bias: `[channels_out]`, fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor; /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`. fn conv_transpose3d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { conv::conv_transpose3d_x_backward::(weight, output_grad, options) } /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`. fn conv_transpose3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { conv::conv_transpose3d_weight_backward::(x, weight, output_grad, options) } /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`. fn conv_transpose3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { conv::conv_transpose3d_bias_backward::(x, bias, output_grad) } /// Four-dimensional unfolding. /// /// # Shapes /// /// * x: ``[batch_size, channels_in, height, width]``, /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``, fn unfold4d( x: FloatTensor, kernel_size: [usize; 2], options: UnfoldOptions, ) -> FloatTensor { if options.padding == [0, 0] && options.dilation == [1, 1] { let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]); let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]); // batch, channels, h_blocks, w_blocks, h_kern, w_kern let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]); let shape = blocks.shape(); // batch, channels, h_kern, w_kern, h_blocks, w_blocks B::float_reshape( blocks, [ shape[0], shape[1] * shape[2] * shape[3], shape[4] * shape[5], ] .into(), ) } else { unfold4d_using_conv2d::(x, kernel_size, options) } } /// One dimensional avg pooling. /// /// # Shapes /// /// x: [batch_size, channels, length], fn avg_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { pool::avg_pool1d_from_2d::( x, kernel_size, stride, padding, count_include_pad, ceil_mode, ) } /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. fn avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { pool::avg_pool1d_backward_from_2d::( x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode, ) } /// Two dimensional avg pooling. /// /// # Shapes /// /// x: [batch_size, channels, height, width], fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor; /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor; /// Two dimensional adaptive avg pooling. /// /// # Shapes /// /// x: [batch_size, channels, height, width], fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor; /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. fn adaptive_avg_pool2d_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor; /// One dimensional adaptive avg pooling. /// /// # Shapes /// /// x: [batch_size, channels, length], fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { pool::adaptive_avg_pool1d_from_2d::(x, output_size) } /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. fn adaptive_avg_pool1d_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor { pool::adaptive_avg_pool1d_backward_from_2d::(x, grad) } /// One dimensional max pooling. /// /// # Shapes /// /// x: [batch_size, channels, length], fn max_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> FloatTensor { pool::max_pool1d_from_2d::(x, kernel_size, stride, padding, dilation, ceil_mode) } /// One dimensional max pooling with indices. /// /// # Shapes /// /// x: [batch_size, channels, height, width], fn max_pool1d_with_indices( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { pool::max_pool1d_with_indices_from_2d::( x, kernel_size, stride, padding, dilation, ceil_mode, ) } /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. #[allow(clippy::too_many_arguments)] fn max_pool1d_with_indices_backward( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { pool::max_pool1d_with_indices_backward_from_2d::( x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices, ) } /// Two dimensional max pooling. /// /// # Shapes /// /// x: [batch_size, channels, height, width], fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor; /// Two dimensional max pooling with indices. /// /// # Shapes /// /// x: [batch_size, channels, height, width], fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices; /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. #[allow(clippy::too_many_arguments)] fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward; /// Down/up samples the input. /// /// # Shapes /// /// x: `[batch_size, channels, height, width]`, fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor; /// Backward pass for the [interpolate](ModuleOps::interpolate) operation. fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor; /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V, /// where scale defaults to 1/sqrt(head_dim). Optionally applies masking, /// additive bias, causal masking, and softcap to the attention scores. /// /// # Arguments /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]` /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]` /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]` /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`, /// where `true` indicates positions to mask (i.e. set to -inf before softmax). /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]` /// added to the attention scores before softmax (e.g. ALiBi, relative position biases). /// - `options`: Additional attention options (custom scale, softcap, causal masking). /// /// # Returns /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]` /// representing the attended context per head. /// /// # Note /// This implementation does not support dropout and is intended for inference or /// use cases where dropout is not needed. fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> FloatTensor; } #[cfg(test)] mod tests { use super::*; #[test] #[should_panic = "stride must be non-zero"] fn conv_options_stride_zero() { let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1); } #[test] #[should_panic = "dilation must be non-zero"] fn conv_options_dilation_zero() { let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1); } #[test] #[should_panic = "groups must be non-zero"] fn conv_options_groups_zero() { let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0); } #[test] #[should_panic = "stride must be non-zero"] fn conv_transpose_options_stride_zero() { let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1); } #[test] #[should_panic = "dilation must be non-zero"] fn conv_transpose_options_dilation_zero() { let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1); } #[test] #[should_panic = "groups must be non-zero"] fn conv_transpose_options_groups_zero() { let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0); } #[test] #[should_panic = "stride must be non-zero"] fn deform_conv_options_stride_zero() { let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1); } #[test] #[should_panic = "dilation must be non-zero"] fn deform_conv_options_dilation_zero() { let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1); } #[test] #[should_panic = "weight groups must be non-zero"] fn deform_conv_options_weights_groups_zero() { let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1); } #[test] #[should_panic = "offset groups must be non-zero"] fn deform_conv_options_offset_groups_zero() { let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0); } #[test] #[should_panic = "stride must be non-zero"] fn unfold_options_stride_zero() { let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]); } #[test] #[should_panic = "dilation must be non-zero"] fn unfold_options_dilation_zero() { let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]); } } ================================================ FILE: crates/burn-backend/src/backend/ops/modules/conv.rs ================================================ #![allow(clippy::single_range_in_vec_init)] use super::{ConvOptions, ConvTransposeOptions}; use crate::{Backend, TensorMetadata, tensor::FloatTensor}; use burn_std::{MetadataError, Shape, Slice}; use alloc::{vec, vec::Vec}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a pooling operation. pub fn calculate_pool_output_shape( in_shape: &Shape, kernel_size: &[usize; N], stride: &[usize; N], padding: &[usize; N], dilation: &[usize; N], ceil_mode: bool, ) -> Result { if in_shape.rank() != N + 2 { return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); } let mut out_shape = in_shape.clone(); // Spatial dims for (i, size_i) in out_shape[2..].iter_mut().enumerate() { *size_i = calculate_pool_output_size( kernel_size[i], stride[i], padding[i], dilation[i], *size_i, ceil_mode, ); } Ok(out_shape) } /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution. pub fn calculate_conv_output_shape( in_shape: &Shape, weight_shape: &Shape, stride: &[usize; N], padding: &[usize; N], dilation: &[usize; N], ) -> Result { if weight_shape.rank() != N + 2 { return Err(MetadataError::RankMismatch { left: weight_shape.rank(), right: N + 2, }); } if in_shape.rank() != N + 2 { return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); } let kernel_size = &weight_shape[2..]; let mut out_shape = in_shape.clone(); // Spatial dims for (i, size_i) in out_shape[2..].iter_mut().enumerate() { *size_i = calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i); } // Output channels out_shape[1] = weight_shape[0]; Ok(out_shape) } /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a transposed convolution. pub fn calculate_conv_transpose_output_shape( in_shape: &Shape, weight_shape: &Shape, stride: &[usize; N], padding: &[usize; N], padding_out: &[usize; N], dilation: &[usize; N], groups: usize, ) -> Result { if weight_shape.rank() != N + 2 { return Err(MetadataError::RankMismatch { left: weight_shape.rank(), right: N + 2, }); } if in_shape.rank() != N + 2 { return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); } let kernel_size = &weight_shape[2..]; let mut out_shape = in_shape.clone(); // Spatial dims for (i, size_i) in out_shape[2..].iter_mut().enumerate() { *size_i = calculate_conv_transpose_output_size( kernel_size[i], stride[i], padding[i], padding_out[i], dilation[i], *size_i, ); } // Output channels out_shape[1] = weight_shape[1] * groups; Ok(out_shape) } /// Calculate the expected padding size required when applying a convolution. pub fn calculate_conv_padding( kernel_size: usize, stride: usize, size_in: usize, size_out: usize, ) -> usize { let kernel_size = kernel_size as f32; let stride = stride as f32; let size_in = size_in as f32; let size_out = size_out as f32; let padding = stride * (size_out - 1.) - size_in + kernel_size; let padding = (padding / 2.).ceil(); padding as usize } /// Calculate the expected output size when doing a convolution operation. pub fn calculate_conv_output_size( kernel_size: usize, stride: usize, padding: usize, dilation: usize, size_in: usize, ) -> usize { (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 } /// Calculate the expected output sizes when doing a convolution operation. pub fn calculate_conv_output_sizes( kernel_size: &[usize], stride: &[usize], padding: &[usize], dilation: &[usize], size_in: &[usize], ) -> Vec { size_in .iter() .enumerate() .map(|(i, size_in)| { calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_in) }) .collect() } /// Calculate the expected output size when doing a transposed convolution operation. pub fn calculate_conv_transpose_output_size( kernel_size: usize, stride: usize, padding: usize, padding_out: usize, dilation: usize, size_in: usize, ) -> usize { (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding } /// Calculate the expected output size when doing a pooling operation. /// /// # Arguments /// /// * `kernel_size` - Size of the pooling kernel /// * `stride` - Stride of the pooling operation /// * `padding` - Padding applied to input /// * `dilation` - Dilation of the pooling kernel /// * `size_in` - Input size (height or width) /// * `ceil_mode` - If true, use ceiling instead of floor for output size calculation. /// This allows the last pooling window to go out-of-bounds if needed. pub fn calculate_pool_output_size( kernel_size: usize, stride: usize, padding: usize, dilation: usize, size_in: usize, ceil_mode: bool, ) -> usize { let numerator = size_in + 2 * padding - dilation * (kernel_size - 1) - 1; if ceil_mode { // Ceiling division: (a + b - 1) / b numerator.div_ceil(stride) + 1 } else { // Floor division (default) numerator / stride + 1 } } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`. pub(crate) fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { let weight_shape = weight.shape(); let [_batch_size, _, length_in] = x.shape().dims(); let [_batch_size, _channels_out, length_out] = output_grad.shape().dims(); let [_, _, kernel_size] = weight_shape.dims(); let padding_out = calculate_padding_out( kernel_size, options.stride[0], options.padding[0], options.dilation[0], length_in, length_out, ); B::conv_transpose1d( output_grad, weight, None, ConvTransposeOptions::new( options.stride, options.padding, [padding_out], options.dilation, options.groups, ), ) } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`. pub(crate) fn conv1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv1d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv1d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`. pub(crate) fn conv1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _, _length_in] = x.shape().dims(); let [_batch_size, channels_out, length_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`. pub(crate) fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { let weight_shape = weight.shape(); let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims(); let [_, _, height_out, width_out] = output_grad.shape().dims(); let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims(); let padding_1_out = calculate_padding_out( kernel_size_1, options.stride[0], options.padding[0], options.dilation[0], height_in, height_out, ); let padding_2_out = calculate_padding_out( kernel_size_2, options.stride[1], options.padding[1], options.dilation[1], width_in, width_out, ); B::conv_transpose2d( output_grad, weight, None, ConvTransposeOptions::new( options.stride, options.padding, [padding_1_out, padding_2_out], options.dilation, options.groups, ), ) } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`. pub(crate) fn conv2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv2d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv2d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`. pub(crate) fn conv2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _, _, _] = x.shape().dims(); let [_, channels_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( grad, Shape::new([channels_out, batch_size * height_out * width_out]), ); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`. pub(crate) fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { let weight_shape = weight.shape(); let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims(); let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims(); let [ _channels_out, _, kernel_size_1, kernel_size_2, kernel_size_3, ] = weight_shape.dims(); let padding_1_out = calculate_padding_out( kernel_size_1, options.stride[0], options.padding[0], options.dilation[0], depth_in, depth_out, ); let padding_2_out = calculate_padding_out( kernel_size_2, options.stride[1], options.padding[1], options.dilation[1], height_in, height_out, ); let padding_3_out = calculate_padding_out( kernel_size_3, options.stride[2], options.padding[2], options.dilation[2], width_in, width_out, ); B::conv_transpose3d( output_grad, weight, None, ConvTransposeOptions::new( options.stride, options.padding, [padding_1_out, padding_2_out, padding_3_out], options.dilation, options.groups, ), ) } /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`. pub(crate) fn conv3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv3d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv3d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`. pub(crate) fn conv3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims(); let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( grad, Shape::new([ channels_out, batch_size * depth_out * height_out * width_out, ]), ); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`. pub(crate) fn conv_transpose1d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { B::conv1d( output_grad, weight, None, ConvOptions::new( options.stride, options.padding, options.dilation, options.groups, ), ) } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`. pub(crate) fn conv_transpose1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv_transpose1d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose1d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`. pub(crate) fn conv_transpose1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _channels_in, _] = x.shape().dims(); let [_, channels_out, length_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`. pub(crate) fn conv_transpose2d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { B::conv2d( output_grad, weight, None, ConvOptions::new( options.stride, options.padding, options.dilation, options.groups, ), ) } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`. pub(crate) fn conv_transpose2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv_transpose2d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose2d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`. pub(crate) fn conv_transpose2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _channels_in, _, _] = x.shape().dims(); let [_, channels_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( grad, Shape::new([channels_out, batch_size * height_out * width_out]), ); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`. pub(crate) fn conv_transpose3d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { B::conv3d( output_grad, weight, None, ConvOptions::new( options.stride, options.padding, options.dilation, options.groups, ), ) } /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`. pub(crate) fn conv_transpose3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { let weight_dtype = weight.dtype(); let weight_shape = weight.shape(); let weight_device = B::float_device(&weight); match options.groups == 1 { true => conv_transpose3d_weight_grad_no_groups::(x, output_grad, weight_shape, options), false => conv_transpose3d_weight_grad_groups::( x, B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), output_grad, options, ), } } /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`. pub(crate) fn conv_transpose3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let [batch_size, _channels_in, _, _, _] = x.shape().dims(); let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims(); let grad = B::float_swap_dims(output_grad, 0, 1); let grad = B::float_reshape( grad, Shape::new([ channels_out, batch_size * depth_out * height_out * width_out, ]), ); let grad = B::float_sum_dim(grad, 1); B::float_reshape(grad, bias.shape()) } /// Execute a 1D convolution using a 2D convolution. pub(crate) fn conv1d_from_conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { let [channels_out, _channels_in, kernel_size] = weight.shape().dims(); let [batch_size, channels_in, length_in] = x.shape().dims(); let weight = B::float_reshape( weight, Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), ); let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); let tensor = B::conv2d( x, weight, bias, ConvOptions::new( [options.stride[0], 1], [options.padding[0], 0], [options.dilation[0], 1], options.groups, ), ); let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } /// Execute a 1D transposed convolution using a 2D transposed convolution. pub(crate) fn conv_transpose1d_from_conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { let [channels_in, channels_out, kernel_size] = weight.shape().dims(); let [batch_size, _channels_in, length_in] = x.shape().dims(); let weight = B::float_reshape( weight, Shape::new([channels_in, channels_out, kernel_size, 1]), ); let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); let tensor = B::conv_transpose2d( x, weight, bias, ConvTransposeOptions::new( [options.stride[0], 1], [options.padding[0], 0], [options.padding_out[0], 0], [options.dilation[0], 1], options.groups, ), ); let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) } fn conv1d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvOptions<1>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv1d( x_swapped, output_grad_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); if weight_grad.shape() != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv2d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvOptions<2>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv2d( x_swapped, output_grad_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); if weight_grad.shape() != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), Slice::from(0..weight_shape[3]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv3d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvOptions<3>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv3d( x_swapped, output_grad_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); if weight_grad.shape() != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), Slice::from(0..weight_shape[3]), Slice::from(0..weight_shape[4]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv1d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv1d( x, grad, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_co..end_idx_co), Slice::from(0..increment_ci), Slice::from(0..kernel_size), ], weight_grad_tmp, ); } weight_grad } fn conv2d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv2d( x, grad, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { let slices = vec![ Slice::from(0..increment_co), Slice::from(0..increment_ci), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), ]; weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); } weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_co..end_idx_co), Slice::from(0..increment_ci), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), ], weight_grad_tmp, ); } weight_grad } fn conv3d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { let [ channels_out, increment_ci, kernel_size_1, kernel_size_2, kernel_size_3, ] = weight_grad.shape().dims(); let increment_co = channels_out / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv3d( x, grad, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [ _, _, kernel_size_1_tmp, kernel_size_2_tmp, kernel_size_3_tmp, ] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 || kernel_size_3_tmp != kernel_size_3 { let slices = vec![ Slice::from(0..increment_co), Slice::from(0..increment_ci), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), Slice::from(0..kernel_size_3), ]; weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); } weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_co..end_idx_co), Slice::from(0..increment_ci), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), Slice::from(0..kernel_size_3), ], weight_grad_tmp, ); } weight_grad } fn conv_transpose1d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvTransposeOptions<1>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv1d( output_grad_swapped, x_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv_transpose2d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvTransposeOptions<2>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv2d( output_grad_swapped, x_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), Slice::from(0..weight_shape[3]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv_transpose3d_weight_grad_no_groups( x: FloatTensor, output_grad: FloatTensor, weight_shape: Shape, options: ConvTransposeOptions<3>, ) -> FloatTensor { let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); let weight_grad_swapped = B::conv3d( output_grad_swapped, x_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); let grad_shape = weight_grad.shape(); if grad_shape != weight_shape { let slices = vec![ Slice::from(0..weight_shape[0]), Slice::from(0..weight_shape[1]), Slice::from(0..weight_shape[2]), Slice::from(0..weight_shape[3]), Slice::from(0..weight_shape[4]), ]; weight_grad = B::float_slice(weight_grad, &slices); } weight_grad } fn conv_transpose1d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<1>, ) -> FloatTensor { let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv1d( grad, x, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_tmp != kernel_size { let slices = vec![ Slice::from(0..increment_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size), ]; weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); } weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_ci..end_idx_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size), ], weight_grad_tmp, ); } weight_grad } fn conv_transpose2d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<2>, ) -> FloatTensor { let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv2d( grad, x, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { let slices = vec![ Slice::from(0..increment_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), ]; weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); } weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_ci..end_idx_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), ], weight_grad_tmp, ); } weight_grad } fn conv_transpose3d_weight_grad_groups( x: FloatTensor, mut weight_grad: FloatTensor, output_grad: FloatTensor, options: ConvTransposeOptions<3>, ) -> FloatTensor { let [ channels_in, increment_co, kernel_size_1, kernel_size_2, kernel_size_3, ] = weight_grad.shape().dims(); let increment_ci = channels_in / options.groups; let x_swapped = B::float_swap_dims(x, 0, 1); let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let x_slice = vec![Slice::new( start_idx_ci as isize, Some(end_idx_ci as isize), 1, )]; let x = B::float_slice(x_swapped.clone(), &x_slice); let grad_slice = vec![Slice::new( start_idx_co as isize, Some(end_idx_co as isize), 1, )]; let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); let mut weight_grad_tmp = B::conv3d( grad, x, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), ); weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); let [ _, _, kernel_size_1_tmp, kernel_size_2_tmp, kernel_size_3_tmp, ] = weight_grad_tmp.shape().dims(); if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 || kernel_size_3_tmp != kernel_size_3 { let slices = vec![ Slice::from(0..increment_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), Slice::from(0..kernel_size_3), ]; weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); } weight_grad = B::float_slice_assign( weight_grad, &[ Slice::from(start_idx_ci..end_idx_ci), Slice::from(0..increment_co), Slice::from(0..kernel_size_1), Slice::from(0..kernel_size_2), Slice::from(0..kernel_size_3), ], weight_grad_tmp, ); } weight_grad } fn calculate_padding_out( kernel_size: usize, stride: usize, padding: usize, dilation: usize, size_in: usize, size_out: usize, ) -> usize { if stride <= 1 { return 0; } let out = 1 + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil() as usize; i64::max(0, out as i64 - size_out as i64) as usize } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculate_output_size_1() { let kernel_size = 3; let stride = 1; let padding = 1; let size_in = 3; let dilation = 1; let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, 3); } #[test] fn test_calculate_output_size_2() { let kernel_size = 5; let stride = 2; let padding = 3; let size_in = 27; let dilation = 1; let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, 15); } #[test] fn test_calculate_output_size_3() { let kernel_size = 5; let stride = 2; let padding = 3; let size_in = 27; let dilation = 2; let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, 13); } #[test] fn test_calculate_same_padding_1() { let kernel_size = 3; let stride = 1; let size_in = 3; let dilation = 1; let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_in, size_out, "Expected size"); } #[test] fn test_calculate_same_padding_2() { let kernel_size = 3; let stride = 2; let size_in = 7; let dilation = 1; let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_in, size_out, "Expected size"); } #[test] fn test_calculate_output_padding_1() { let kernel_size = 3; let stride = 2; let size_in = 7; let size_out = 10; let dilation = 1; let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); let size_out_expected = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); assert_eq!(size_out, size_out_expected, "Expected size"); } #[test] fn test_expect_conv2d_output_shape() { // in channels: 3 // out channels: 8 // size in: [27, 3] // kernel size: [5, 3] let stride = [2, 1]; let padding = [3, 1]; let dilation = [2, 1]; let shape = calculate_conv_output_shape( &Shape::new([12, 3, 27, 3]), &Shape::new([8, 3, 5, 3]), &stride, &padding, &dilation, ) .unwrap(); assert_eq!(shape, Shape::new([12, 8, 13, 3])) } } ================================================ FILE: crates/burn-backend/src/backend/ops/modules/grid_sample.rs ================================================ use crate::{ Backend, TensorMetadata, ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}, tensor::FloatTensor, }; use alloc::vec; use burn_std::{Shape, Slice}; /// Reference implementation of grid_sample_2d that supports all options. /// /// # Arguments /// /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in) /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right /// * `options` - Grid sampling options /// /// # Returns /// /// A tensor with shape (N, C, H_out, W_out) pub fn float_grid_sample_2d_ref( tensor: FloatTensor, grid: FloatTensor, options: GridSampleOptions, ) -> FloatTensor { match options.mode { InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::( tensor, grid, options.padding_mode, options.align_corners, ), _ => todo!( "Default implementation for grid_sample_2d with {:?} unimplemented", options.mode ), } } /// Bilinear grid sampling implementation. fn float_grid_sample_2d_bilinear( tensor: FloatTensor, grid: FloatTensor, padding_mode: GridSamplePaddingMode, align_corners: bool, ) -> FloatTensor { let n = tensor.shape()[0]; let c = tensor.shape()[1]; let h_in = tensor.shape()[2]; let w_in = tensor.shape()[3]; let h_out = grid.shape()[1]; let w_out = grid.shape()[2]; let spatial_in = h_in * w_in; let spatial_out = h_out * w_out; // Separate x and y coordinates from grid // shape: (N, H_out, W_out, 1) let grid_x_slice = vec![ Slice::new(0, Some(n as isize), 1), Slice::new(0, Some(h_out as isize), 1), Slice::new(0, Some(w_out as isize), 1), Slice::new(0, Some(1), 1), ]; let grid_y_slice = vec![ Slice::new(0, Some(n as isize), 1), Slice::new(0, Some(h_out as isize), 1), Slice::new(0, Some(w_out as isize), 1), Slice::new(1, Some(2), 1), ]; let grid_x = B::float_slice(grid.clone(), &grid_x_slice); let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out])); let grid_y = B::float_slice(grid.clone(), &grid_y_slice); let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out])); // Convert normalized grid coordinates [-1, 1] to pixel coordinates let w_in_f = w_in as f64; let h_in_f = h_in as f64; let (grid_x, grid_y) = if align_corners { // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 // Maps -1 to 0 and 1 to width - 1 let grid_x = B::float_add_scalar(grid_x, 1f32.into()); let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).into()); let grid_y = B::float_add_scalar(grid_y, 1f32.into()); let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).into()); (grid_x, grid_y) } else { // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 // Maps -1 to -0.5 and 1 to width - 0.5 let grid_x = B::float_add_scalar(grid_x, 1f32.into()); let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).into()); let grid_x = B::float_sub_scalar(grid_x, 0.5f32.into()); let grid_y = B::float_add_scalar(grid_y, 1f32.into()); let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).into()); let grid_y = B::float_sub_scalar(grid_y, 0.5f32.into()); (grid_x, grid_y) }; // Apply padding mode to coordinates let (grid_x, grid_y) = match padding_mode { GridSamplePaddingMode::Border => { // Clamp coordinates to valid range [0, size-1] let grid_x = B::float_clamp(grid_x, 0f32.into(), ((w_in - 1) as f32).into()); let grid_y = B::float_clamp(grid_y, 0f32.into(), ((h_in - 1) as f32).into()); (grid_x, grid_y) } GridSamplePaddingMode::Reflection => { // Reflect coordinates at boundaries let grid_x = reflect_coordinates::(grid_x, w_in_f, align_corners); let grid_y = reflect_coordinates::(grid_y, h_in_f, align_corners); (grid_x, grid_y) } GridSamplePaddingMode::Zeros => { // Keep coordinates as-is, we'll mask out-of-bounds later (grid_x, grid_y) } }; // Get floor indices for the four corners let grid_x_floored = B::float_floor(grid_x.clone()); let grid_y_floored = B::float_floor(grid_y.clone()); // Compute interpolation weights (fractional part) let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone()); let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone()); // Convert to integer indices let x0 = B::float_into_int(grid_x_floored.clone()); let y0 = B::float_into_int(grid_y_floored.clone()); let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1f32.into())); let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1f32.into())); // Create masks for out-of-bounds coordinates (only used for zeros padding) let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros { let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.into()); let x0_valid = B::bool_and( x0_valid, B::int_lower_elem(x0.clone(), (w_in as i32).into()), ); let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.into()); let x1_valid = B::bool_and( x1_valid, B::int_lower_elem(x1.clone(), (w_in as i32).into()), ); let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.into()); let y0_valid = B::bool_and( y0_valid, B::int_lower_elem(y0.clone(), (h_in as i32).into()), ); let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.into()); let y1_valid = B::bool_and( y1_valid, B::int_lower_elem(y1.clone(), (h_in as i32).into()), ); ( Some(B::bool_and(x0_valid.clone(), y0_valid.clone())), Some(B::bool_and(x0_valid.clone(), y1_valid.clone())), Some(B::bool_and(x1_valid.clone(), y0_valid)), Some(B::bool_and(x1_valid, y1_valid)), ) } else { (None, None, None, None) }; // Clamp indices to valid range for gather let x0_clamped = B::int_clamp(x0, 0.into(), ((w_in - 1) as i32).into()); let x1_clamped = B::int_clamp(x1, 0.into(), ((w_in - 1) as i32).into()); let y0_clamped = B::int_clamp(y0, 0.into(), ((h_in - 1) as i32).into()); let y1_clamped = B::int_clamp(y1, 0.into(), ((h_in - 1) as i32).into()); // Linear indices: idx = y * W_in + x let w_in_scalar: i32 = w_in as i32; let idx_00 = B::int_add( B::int_mul_scalar(y0_clamped.clone(), w_in_scalar.into()), x0_clamped.clone(), ); let idx_01 = B::int_add( B::int_mul_scalar(y1_clamped.clone(), w_in_scalar.into()), x0_clamped, ); let idx_10 = B::int_add( B::int_mul_scalar(y0_clamped, w_in_scalar.into()), x1_clamped.clone(), ); let idx_11 = B::int_add( B::int_mul_scalar(y1_clamped, w_in_scalar.into()), x1_clamped, ); // [N, 1, H_out, W_out] -> [N, 1, H_out * W_out] let idx_00 = B::int_reshape(idx_00, Shape::new([n, 1, spatial_out])); let idx_01 = B::int_reshape(idx_01, Shape::new([n, 1, spatial_out])); let idx_10 = B::int_reshape(idx_10, Shape::new([n, 1, spatial_out])); let idx_11 = B::int_reshape(idx_11, Shape::new([n, 1, spatial_out])); // [N, 1, spatial] -> [N, C, spatial] let idx_00 = B::int_expand(idx_00, Shape::new([n, c, spatial_out])); let idx_01 = B::int_expand(idx_01, Shape::new([n, c, spatial_out])); let idx_10 = B::int_expand(idx_10, Shape::new([n, c, spatial_out])); let idx_11 = B::int_expand(idx_11, Shape::new([n, c, spatial_out])); let tensor_flat = B::float_reshape(tensor, Shape::new([n, c, spatial_in])); let sample_00 = B::float_gather(2, tensor_flat.clone(), idx_00); let sample_01 = B::float_gather(2, tensor_flat.clone(), idx_01); let sample_10 = B::float_gather(2, tensor_flat.clone(), idx_10); let sample_11 = B::float_gather(2, tensor_flat, idx_11); // Reshape samples to (N, C, H_out, W_out) let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out])); let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out])); let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out])); let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out])); // Apply masks for zeros padding (set out-of-bounds samples to 0) let (sample_00, sample_01, sample_10, sample_11) = if padding_mode == GridSamplePaddingMode::Zeros { let mask_00 = mask_00.unwrap(); let mask_01 = mask_01.unwrap(); let mask_10 = mask_10.unwrap(); let mask_11 = mask_11.unwrap(); let mask_00_inv = B::bool_not(mask_00); let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out])); let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out])); let mask_01_inv = B::bool_not(mask_01); let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out])); let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out])); let mask_10_inv = B::bool_not(mask_10); let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out])); let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out])); let mask_11_inv = B::bool_not(mask_11); let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out])); let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out])); ( B::float_mask_fill(sample_00, mask_00_inv, 0f32.into()), B::float_mask_fill(sample_01, mask_01_inv, 0f32.into()), B::float_mask_fill(sample_10, mask_10_inv, 0f32.into()), B::float_mask_fill(sample_11, mask_11_inv, 0f32.into()), ) } else { (sample_00, sample_01, sample_10, sample_11) }; // Compute bilinear interpolation weights let one_minus_x = B::float_neg(x_frac.clone()); let one_minus_x = B::float_add_scalar(one_minus_x, 1f32.into()); let one_minus_y = B::float_neg(y_frac.clone()); let one_minus_y = B::float_add_scalar(one_minus_y, 1f32.into()); let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone()); let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone()); let weight_10 = B::float_mul(x_frac.clone(), one_minus_y); let weight_11 = B::float_mul(x_frac, y_frac); // Bilinear interpolation let result = B::float_mul(sample_00, weight_00); let result = B::float_add(result, B::float_mul(sample_01, weight_01)); let result = B::float_add(result, B::float_mul(sample_10, weight_10)); B::float_add(result, B::float_mul(sample_11, weight_11)) } /// Reflect coordinates at boundaries using a triangle wave pattern. /// /// For align_corners=true: reflects within [0, size-1] /// For align_corners=false: reflects within [-0.5, size-0.5] fn reflect_coordinates( coords: FloatTensor, size: f64, align_corners: bool, ) -> FloatTensor { let (min_val, max_val) = if align_corners { (0.0f32, (size - 1.0) as f32) } else { (-0.5f32, (size - 0.5) as f32) }; let span = max_val - min_val; if span <= 0.0 { // Edge case: size is 1, just return min_val everywhere let zeros = B::float_mul_scalar(coords, 0f32.into()); return B::float_add_scalar(zeros, min_val.into()); } // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val let period = 2.0 * span; // x = abs(coord - min_val) let x = B::float_sub_scalar(coords, min_val.into()); let x = B::float_abs(x); // x_mod = x - floor(x / period) * period let x_div = B::float_div_scalar(x.clone(), period.into()); let x_div_floor = B::float_floor(x_div); let x_mod = B::float_sub(x, B::float_mul_scalar(x_div_floor, period.into())); // result = span - abs(x_mod - span) + min_val let diff = B::float_sub_scalar(x_mod, span.into()); let abs_diff = B::float_abs(diff); let reflected = B::float_sub_scalar(abs_diff, span.into()); let reflected = B::float_neg(reflected); B::float_add_scalar(reflected, min_val.into()) } ================================================ FILE: crates/burn-backend/src/backend/ops/modules/mod.rs ================================================ /// Module with convolution operations. pub mod conv; /// Module with attention operations. pub mod attention; /// Module with unfold operations. pub mod unfold; /// Module with pooling operations. pub mod pool; /// Module for grid_sample operations pub mod grid_sample; mod base; pub use base::*; ================================================ FILE: crates/burn-backend/src/backend/ops/modules/pool.rs ================================================ use crate::tensor::{FloatTensor, IntTensor}; use crate::{Backend, TensorMetadata}; use burn_std::Shape; use super::{MaxPool1dBackward, MaxPool1dWithIndices}; pub(crate) fn avg_pool1d_from_2d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::avg_pool2d( x, [kernel_size, 1], [stride, 1], [padding, 0], count_include_pad, ceil_mode, ); let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn avg_pool1d_backward_from_2d( x: FloatTensor, grad: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let [batch_size, channels, length_in] = x.shape().dims(); let [_, _, length_out] = grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); let grad_x = B::avg_pool2d_backward( x, grad_x, [kernel_size, 1], [stride, 1], [padding, 0], count_include_pad, ceil_mode, ); B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn adaptive_avg_pool1d_from_2d( x: FloatTensor, output_size: usize, ) -> FloatTensor { let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::adaptive_avg_pool2d(x, [output_size, 1]); let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn adaptive_avg_pool1d_backward_from_2d( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { let [batch_size, channels, length_in] = x.shape().dims(); let [_, _, length_out] = grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in])) } pub(crate) fn max_pool1d_from_2d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> FloatTensor { let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); let x = B::max_pool2d( x, [kernel_size, 1], [stride, 1], [padding, 0], [dilation, 1], ceil_mode, ); let [batch_size, channels, length, _] = x.shape().dims(); B::float_reshape(x, Shape::from([batch_size, channels, length])) } pub(crate) fn max_pool1d_with_indices_from_2d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { let [batch_size, channels, length] = x.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length])); let x = B::max_pool2d_with_indices( x, [1, kernel_size], [1, stride], [0, padding], [1, dilation], ceil_mode, ); let [batch_size, channels, _, length] = x.output.shape().dims(); let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length])); let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); MaxPool1dWithIndices::new(output, indices) } #[allow(clippy::too_many_arguments)] pub(crate) fn max_pool1d_with_indices_backward_from_2d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { let [batch_size, channels, length_in] = x.shape().dims(); let [_, _, length_out] = output_grad.shape().dims(); let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); let grad_x = B::float_reshape( output_grad, Shape::from([batch_size, channels, length_out, 1]), ); let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); let grad_x = B::max_pool2d_with_indices_backward( x, [kernel_size, 1], [stride, 1], [padding, 0], [dilation, 1], ceil_mode, grad_x, indices, ) .x_grad; MaxPool1dBackward::new(B::float_reshape( grad_x, Shape::from([batch_size, channels, length_in]), )) } ================================================ FILE: crates/burn-backend/src/backend/ops/modules/unfold.rs ================================================ use super::{ConvOptions, UnfoldOptions}; use crate::tensor::FloatTensor; use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion}; use alloc::vec; use alloc::vec::Vec; use burn_std::Shape; /// Constructs a special weight tensor used for unfolding. /// /// # Notes /// /// The idea behind using convolution for unfolding is to leverage the sliding window mechanism of /// convolution. By creating a weight tensor with ones in a particular pattern, we are able to borrow /// the convolution operation's mechanism as it moves across the input tensor, picking up the desired /// values in the pattern of the unfolding operation. pub(crate) fn create_unfolding_weight( in_channels: usize, kernel_size: [usize; 2], device: &B::Device, ) -> FloatTensor { let shape = Shape::new([ in_channels * kernel_size[0] * kernel_size[1], in_channels, kernel_size[0], kernel_size[1], ]); let mut strides = [0; 4]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { strides[index] = current; current *= val; }); let num_elements = shape.num_elements(); let mut weight: Vec = vec![0.0.elem(); num_elements]; for k in 0..in_channels { for i in 0..kernel_size[0] { for j in 0..kernel_size[1] { let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; let index = output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; weight[index] = 1.elem(); } } } B::float_from_data(TensorData::new(weight, shape), device) } /// Compute the unfold4d operation using the conv2d operations. pub(crate) fn unfold4d_using_conv2d( x: FloatTensor, kernel_size: [usize; 2], options: UnfoldOptions, ) -> FloatTensor { let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims(); let weight = create_unfolding_weight::(in_channels, kernel_size, &B::float_device(&x)); let unfolded = B::conv2d( x, weight, None, ConvOptions::new(options.stride, options.padding, options.dilation, 1), ); let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims(); B::float_reshape( unfolded, Shape::new([batch_size, channels_out, out_height * out_width]), ) } /// Calculate the number of unfolding windows that can be extracted from a dimension of given size. pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize { assert!(step_size > 0); let x = dim_size + step_size; if x < window_size { 0 } else { (x - window_size) / step_size } } /// Calculate the output shape for an unfold operation. /// /// The operation yields a view with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `shape` - The input shape to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A shape with ``[pre=..., windows, post=..., size]``. pub fn calculate_unfold_shape>( shape: S, dim: usize, size: usize, step: usize, ) -> Shape { let mut shape = shape.into(); let d_shape = shape[dim]; let windows = calculate_unfold_windows(d_shape, size, step); shape[dim] = windows; shape.push(size); shape } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculate_unfold_windows() { assert_eq!(calculate_unfold_windows(2, 5, 1), 0); assert_eq!(calculate_unfold_windows(2, 3, 1), 0); assert_eq!(calculate_unfold_windows(3, 3, 1), 1); assert_eq!(calculate_unfold_windows(4, 3, 1), 2); assert_eq!(calculate_unfold_windows(5, 3, 1), 3); assert_eq!(calculate_unfold_windows(2, 3, 2), 0); assert_eq!(calculate_unfold_windows(3, 3, 2), 1); assert_eq!(calculate_unfold_windows(4, 3, 2), 1); assert_eq!(calculate_unfold_windows(5, 3, 2), 2); } #[test] fn test_calculate_unfold_shape() { assert_eq!( calculate_unfold_shape([2, 6, 6], 1, 3, 2), Shape::new([2, 2, 6, 3]) ); } } ================================================ FILE: crates/burn-backend/src/backend/ops/qtensor.rs ================================================ use alloc::vec::Vec; use burn_std::{ Shape, Slice, quantization::{QuantPropagation, QuantScheme}, }; use crate::{ Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive, }; use crate::{ Scalar, tensor::{ BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor, quantization::{ Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range, }, }, }; /// Automatically applies `dequantization -> float operation -> quantization`. /// /// Used for tensor ops that should always return a quantized output. #[macro_export] macro_rules! dequant_op_quant { // Binary tensor float op w/ lhs & rhs ( ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr ) => {{ // Heuristic: prioritize lhs scheme let scheme = $t1.scheme().clone(); let t1_f = <$ty>::dequantize($t1); let t2_f = <$ty>::dequantize($t2); #[allow(clippy::redundant_closure_call)] let out_f = $float_op(t1_f, t2_f); <$ty>::quantize_dynamic(out_f, &scheme) }}; // Unary tensor float op ( ty $ty:ty, float_op $float_op:expr, $tensor:expr ) => {{ let scheme = $tensor.scheme().clone(); let tensor_f = <$ty>::dequantize($tensor); #[allow(clippy::redundant_closure_call)] let out_f = $float_op(tensor_f); <$ty>::quantize_dynamic(out_f, &scheme) }}; } /// Automatically applies `dequantization -> float operation [-> quantization]`. /// /// The output quantization step is optional. /// It is only performed when the input quantization scheme is propagated. #[macro_export] macro_rules! dequant_op_flow { // Binary tensor float op w/ lhs & rhs ( ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr ) => {{ // Heuristic: prioritize lhs scheme let scheme = $t1.scheme().clone(); let propagation = $t1.propagation(); let t1_f = <$ty>::dequantize($t1); let t2_f = <$ty>::dequantize($t2); #[allow(clippy::redundant_closure_call)] let out_f = $float_op(t1_f, t2_f); match propagation { QuantPropagation::Propagate => { TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) } QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), } }}; // Unary tensor float op ( ty $ty:ty, float_op $float_op:expr, $tensor:expr ) => {{ let scheme = $tensor.scheme().clone(); let propagation = $tensor.propagation(); let tensor_f = <$ty>::dequantize($tensor); #[allow(clippy::redundant_closure_call)] let out_f = $float_op(tensor_f); match propagation { QuantPropagation::Propagate => { TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) } QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), } }}; } /// Operations on quantized tensors. /// /// # Return Type Semantics /// /// The return type of each operation indicates how quantization is handled: /// /// ## [`QuantizedTensor`] /// If the method returns a `QuantizedTensor`, the operation is expected to preserve the quantized /// representation. Implementations should avoid dequantizing when possible to maintain performance. /// For example, shape or layout changes such as expand or transpose preserve quantization. /// /// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is /// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering /// the quantization parameters to match the new layout.* /// /// /// ## [`TensorPrimitive`] /// If the method returns a `TensorPrimitive` enum, the return type should align with propagation /// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`]) /// returned in floating-point form ([`TensorPrimitive::Float`]). /// /// This distinction allows for fine-grained control over mixed-precision flows while still operating /// through a unified API. pub trait QTensorOps { /// Creates a new tensor from the data structure. /// /// # Arguments /// /// * `data` - The data structure. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given data. fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor; /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. fn quantize( tensor: FloatTensor, scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor; /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. fn quantize_dynamic(tensor: FloatTensor, scheme: &QuantScheme) -> QuantizedTensor { // Dynamically compute min/max tensor range and qparams before quantizing let (min, max) = compute_range::(scheme, tensor.clone(), &Calibration::MinMax); let qparams = compute_q_params(scheme, min, max); Self::quantize(tensor, scheme, qparams) } /// Convert the tensor back to a higher precision data type. fn dequantize(tensor: QuantizedTensor) -> FloatTensor; /// Gets the device of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The device of the tensor. fn q_device(tensor: &QuantizedTensor) -> Device; /// Moves the tensor to the given device. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `device` - The device to move the tensor to. /// /// # Returns /// /// The tensor on the given device. fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor; /// Reshapes a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to reshape. /// * `shape` - The new shape of the tensor. /// /// # Returns /// /// The tensor with the new shape. fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor; /// Converts the tensor to a data structure. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The data structure with the tensor's data. fn q_into_data( tensor: QuantizedTensor, ) -> impl Future> + Send; /// Detaches a tensor from the computation graph. fn q_detach(tensor: QuantizedTensor) -> QuantizedTensor { // Should only be overridden by autodiff backends. tensor } /// Sets the `require_grad` flag of a tensor. fn q_set_require_grad(tensor: QuantizedTensor, _require_grad: bool) -> QuantizedTensor { // Should only be overridden by autodiff backends. tensor } /// Returns the `require_grad` flag of a tensor. fn q_is_require_grad(_tensor: &QuantizedTensor) -> bool { // Should only be overridden by autodiff backends. false } /// Broadcasts the `tensor` to the given `shape`. fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor; /// Transposes a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { let ndims = tensor.shape().num_dims(); Self::q_swap_dims(tensor, ndims - 2, ndims - 1) } /// Swaps two dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap. /// * `dim2` - The second dimension to swap. /// /// # Returns /// /// The tensor with the dimensions swapped. fn q_swap_dims(tensor: QuantizedTensor, dim1: usize, dim2: usize) -> QuantizedTensor; /// Permutes the dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to permute the dimensions of. /// * `axes` - The new order of the dimensions. /// # Returns /// /// The tensor with the dimensions permuted. fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; /// Reverse the order of elements in a tensor along the given axes. /// /// # Arguments /// /// * `tensor` - The tensor to reverse. /// * `axes` - The axes to reverse. /// /// The tensor with the elements reversed. fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; /// Select tensor elements along the given dimension corresponding for the given indices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `dim` - The dimension to select from. /// * `indices` - The indices to select. /// /// # Returns /// /// The selected elements. fn q_select( tensor: QuantizedTensor, dim: usize, indices: IntTensor, ) -> QuantizedTensor; /// Select tensor elements corresponding to the given slices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `slices` - The slices specifying ranges and steps for each dimension. /// /// # Returns /// /// The selected elements in a new tensor. fn q_slice(tensor: QuantizedTensor, slices: &[Slice]) -> QuantizedTensor; /// Gather elements from a tensor. /// /// # Arguments /// /// * `dim` - The dimension to gather from. /// * `tensor` - The tensor to gather from. /// * `indices` - The indices to gather. /// /// # Returns /// /// The gathered elements. fn q_gather( dim: usize, tensor: QuantizedTensor, indices: IntTensor, ) -> QuantizedTensor { // Default implementation. Backends can gather on the quantized values when supported. dequant_op_quant!( ty Self, float_op |tensor| B::float_gather(dim, tensor, indices), tensor ) } /// Repeat the tensor along the given dimension. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to repeat. /// * `times` - The number of times to repeat the dimension. /// /// # Returns /// /// The tensor with the given dimension repeated. fn q_repeat_dim(tensor: QuantizedTensor, dim: usize, times: usize) -> QuantizedTensor { dequant_op_quant!( ty Self, float_op |tensor| B::float_repeat_dim(tensor, dim, times), tensor ) } /// Adds two tensors together. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The result of adding the two tensors together. fn q_add(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_add(lhs, rhs), lhs, rhs ) } /// Adds a scalar to a tensor. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The result of adding the scalar to the tensor. fn q_add_scalar(lhs: QuantizedTensor, rhs: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_add_scalar(tensor, rhs), lhs ) } /// Clamps a tensor under a minimum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// /// # Returns /// /// The clamped tensor. fn q_clamp_min(tensor: QuantizedTensor, min: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_clamp_min(tensor, min), tensor ) } /// Clamps a tensor over a maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn q_clamp_max(tensor: QuantizedTensor, max: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_clamp_max(tensor, max), tensor ) } /// Clamps a tensor between a minimum and maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn q_clamp(tensor: QuantizedTensor, min: Scalar, max: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_clamp(tensor, min, max), tensor ) } /// Subtracts two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The result of subtracting the two tensors. fn q_sub(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_sub(lhs, rhs), lhs, rhs ) } /// Subtracts a scalar from a tensor. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The result of subtracting the scalar from the tensor. fn q_sub_scalar(lhs: QuantizedTensor, rhs: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sub_scalar(tensor, rhs), lhs ) } /// Multiplies two tensors together element-wise. fn q_mul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_mul(lhs, rhs), lhs, rhs ) } /// Multiplies a tensor by a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The result of multiplying the tensor by the scalar. fn q_mul_scalar(lhs: QuantizedTensor, rhs: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_mul_scalar(tensor, rhs), lhs ) } /// Divides two tensors element-wise. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The result of dividing the two tensors. fn q_div(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_div(lhs, rhs), lhs, rhs ) } /// Divides a tensor by a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The result of dividing the tensor by the scalar. fn q_div_scalar(lhs: QuantizedTensor, rhs: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_div_scalar(tensor, rhs), lhs ) } /// Multiplies two tensors together using matrix multiplication. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The result of multiplying the two tensors together using matrix multiplication. fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { let mut propagation = QuantPropagation::Inhibit; let mut scheme = QuantScheme::default(); let lhs = match lhs { TensorPrimitive::Float(lhs) => lhs, TensorPrimitive::QFloat(lhs) => { propagation = lhs.propagation(); scheme = *lhs.scheme(); Self::dequantize(lhs) } }; let rhs = match rhs { TensorPrimitive::Float(rhs) => rhs, TensorPrimitive::QFloat(rhs) => { propagation = rhs.propagation(); scheme = *rhs.scheme(); Self::dequantize(rhs) } }; let out_f = B::float_matmul(lhs, rhs); match propagation { QuantPropagation::Propagate => { TensorPrimitive::QFloat(::quantize_dynamic(out_f, &scheme)) } QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), } } /// Negates a tensor element-wise. fn q_neg(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_neg(tensor), tensor ) } /// Calculates the reciprocals element-wise fn q_recip(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_recip(tensor), tensor ) } /// Sum of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// /// # Returns /// /// A scalar tensor with the sum of all elements in `tensor`. fn q_sum(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sum(tensor), tensor ) } /// Sum of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// * `dim` - The dimension along which to sum. /// /// # Returns /// /// A tensor with the sum of all elements in `tensor` along `dim`. fn q_sum_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sum_dim(tensor, dim), tensor ) } /// Product of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to product. /// /// # Returns /// /// A scalar tensor with the product of all elements in `tensor`. fn q_prod(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_prod(tensor), tensor ) } /// Product of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to product. /// /// # Returns /// /// A tensor with the product of all elements in `tensor` along `dim`. fn q_prod_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_prod_dim(tensor, dim), tensor ) } /// Mean of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to mean. /// /// # Returns /// /// A scalar tensor with the mean of all elements in `tensor`. fn q_mean(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_mean(tensor), tensor ) } /// Mean of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to mean. /// * `dim` - The dimension along which to mean. /// /// # Returns /// /// A tensor with the mean of all elements in `tensor` along `dim`. fn q_mean_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_mean_dim(tensor, dim), tensor ) } /// Computes the cumulative sum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative sum of. /// * `dim` - The dimension along which to compute the cumulative sum. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative sum /// of all elements up to and including that position along the dimension. fn q_cumsum(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cumsum(tensor, dim), tensor ) } /// Computes the cumulative product of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative product of. /// * `dim` - The dimension along which to compute the cumulative product. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative product /// of all elements up to and including that position along the dimension. fn q_cumprod(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cumprod(tensor, dim), tensor ) } /// Computes the cumulative minimum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative minimum of. /// * `dim` - The dimension along which to compute the cumulative minimum. /// /// # Returns /// /// A tensor with the same shape where each element is the minimum /// of all elements up to and including that position along the dimension. fn q_cummin(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cummin(tensor, dim), tensor ) } /// Computes the cumulative maximum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative maximum of. /// * `dim` - The dimension along which to compute the cumulative maximum. /// /// # Returns /// /// A tensor with the same shape where each element is the maximum /// of all elements up to and including that position along the dimension. fn q_cummax(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cummax(tensor, dim), tensor ) } /// Returns a new tensor with exponential values. /// /// # Arguments /// /// * `tensor` - The tensor to exponentiate. /// /// # Returns /// /// A tensor with the same shape as `tensor` with exponential values. fn q_exp(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_exp(tensor), tensor ) } /// Returns a new tensor with natural logarithm values. /// /// # Arguments /// /// * `tensor` - The tensor to take the logarithm of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with natural logarithm values. fn q_log(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_log(tensor), tensor ) } /// Returns a new tensor with logarithm values of (1 + Xi). /// /// # Arguments /// /// * `tensor` - The tensor to take the logarithm of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). fn q_log1p(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_log1p(tensor), tensor ) } /// Element-wise power with another tensor. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The elements of `lhs` raised to the power of the elements of `rhs`. fn q_powf(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_powf(lhs, rhs), lhs, rhs ) } /// Element-wise power with an IntTensor. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side floatTensor. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. fn q_powi(lhs: QuantizedTensor, rhs: IntTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_powi(tensor, rhs), lhs ) } /// Element-wise power with an int scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. fn q_powi_scalar(lhs: QuantizedTensor, rhs: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_powi_scalar(tensor, rhs), lhs ) } /// Element-wise power with a float scalar. /// /// # Arguments /// /// * `tensor` - The tensor to exponentiate. /// * `value` - The exponent. /// /// # Returns /// /// A tensor with the same shape as `tensor` with values raised to the power of `value`. fn q_powf_scalar(tensor: QuantizedTensor, value: Scalar) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_powf_scalar(tensor, value), tensor ) } /// Returns a new tensor with square root values. /// /// # Arguments /// /// * `tensor` - The tensor to take the square root of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with square root values. fn q_sqrt(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sqrt(tensor), tensor ) } /// Returns a new tensor with absolute values. /// /// # Arguments /// /// * `tensor` - The tensor to take absolute value of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with absolute values. fn q_abs(tensor: QuantizedTensor) -> QuantizedTensor { dequant_op_quant!( ty Self, float_op |tensor| B::float_abs(tensor), tensor ) } /// Returns a new tensor with cosine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the cosine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with cosine values. fn q_cos(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cos(tensor), tensor ) } /// Returns a new tensor with sine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the sine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with sine values. fn q_sin(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sin(tensor), tensor ) } /// Returns a new tensor with tangent values. /// /// # Arguments /// /// * `tensor` - The tensor to take the tangent of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with tangent values. fn q_tan(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_tan(tensor), tensor ) } /// Returns a new tensor with hyperbolic cosine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic cosine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic cosine values. fn q_cosh(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_cosh(tensor), tensor ) } /// Returns a new tensor with hyperbolic sine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic sine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic sine values. fn q_sinh(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_sinh(tensor), tensor ) } /// Returns a new tensor with hyperbolic tangent values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic tangent of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic tangent values. fn q_tanh(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_tanh(tensor), tensor ) } /// Returns a new tensor with the error function values. /// /// # Arguments /// /// * `tensor` - The tensor to take the error function of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with error function values. fn q_erf(tensor: QuantizedTensor) -> TensorPrimitive { dequant_op_flow!( ty Self, float_op |tensor| B::float_erf(tensor), tensor ) } /// Concatenates tensors along a dimension. /// /// # Arguments /// /// * `tensors` - The tensors to concatenate. /// * `dim` - The dimension along which to concatenate. /// /// # Returns /// /// A tensor with the concatenated tensors along `dim`. fn q_cat(tensors: Vec>, dim: usize) -> QuantizedTensor { // Heuristic: prioritize first tensor scheme let scheme = *tensors.first().unwrap().scheme(); let tensor_f = tensors .into_iter() .map(|tensor| Self::dequantize(tensor)) .collect(); let out_f = B::float_cat(tensor_f, dim); Self::quantize_dynamic(out_f, &scheme) } /// Gets the indices of the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the indices of the maximum elements of `tensor` along `dim`. fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor { // Default implementation. Backends can sort on the int values since qparams remain the same. let tensor_f = Self::dequantize(tensor); B::float_argmax(tensor_f, dim) } /// Gets the indices of the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tensor with the indices of the minimum elements of `tensor` along `dim`. fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor { // Default implementation. Backends can sort on the int values since qparams remain the same. let tensor_f = Self::dequantize(tensor); B::float_argmin(tensor_f, dim) } /// Gets the maximum element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// /// # Returns /// /// A tensor with the maximum element of `tensor`. fn q_max(tensor: QuantizedTensor) -> QuantizedTensor { let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_max_dim(tensor, 0) } /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the maximum elements of `tensor` along `dim`. fn q_max_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { let index = B::q_argmax(tensor.clone(), dim); B::q_gather(dim, tensor, index) } /// Gets the maximum elements of a tensor along an axis and their indices. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tuple with the maximum elements of `tensor` along `dim` and their indices. fn q_max_dim_with_indices( tensor: QuantizedTensor, dim: usize, ) -> (QuantizedTensor, IntTensor) { let index = B::q_argmax(tensor.clone(), dim); let values = B::q_gather(dim, tensor, index.clone()); (values, index) } /// Gets the minimum element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// /// # Returns /// /// A tensor with the minimum element of `tensor`. fn q_min(tensor: QuantizedTensor) -> QuantizedTensor { let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_min_dim(tensor, 0) } /// Gets the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tensor with the minimum elements of `tensor` along `dim`. fn q_min_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { let index = B::q_argmin(tensor.clone(), dim); B::q_gather(dim, tensor, index) } /// Gets the minimum elements of a tensor along an axis and their indices. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tuple with the minimum elements of `tensor` along `dim` and their indices. fn q_min_dim_with_indices( tensor: QuantizedTensor, dim: usize, ) -> (QuantizedTensor, IntTensor) { let index = B::q_argmin(tensor.clone(), dim); let values = B::q_gather(dim, tensor, index.clone()); (values, index) } /// Gets the maximum element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// /// # Returns /// /// A tensor with the maximum element of `tensor`. fn q_max_abs(tensor: QuantizedTensor) -> QuantizedTensor { let shape = tensor.shape(); let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); B::q_max_abs_dim(tensor, 0) } /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the maximum elements of `tensor` along `dim`. fn q_max_abs_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { let index = B::q_argmax(B::q_abs(tensor.clone()), dim); B::q_gather(dim, tensor, index) } /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. fn q_any(tensor: QuantizedTensor) -> BoolTensor { let tensor_f = Self::dequantize(tensor); B::float_any(tensor_f) } /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the /// input evaluates to True, False otherwise. fn q_any_dim(tensor: QuantizedTensor, dim: usize) -> BoolTensor { let tensor_f = Self::dequantize(tensor); B::float_any_dim(tensor_f, dim) } /// Tests if all elements in the `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn q_all(tensor: QuantizedTensor) -> BoolTensor { let tensor_f = Self::dequantize(tensor); B::float_all(tensor_f) } /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn q_all_dim(tensor: QuantizedTensor, dim: usize) -> BoolTensor { let tensor_f = Self::dequantize(tensor); B::float_all_dim(tensor_f, dim) } /// Sort the elements of the input `tensor` by value in along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. fn q_sort(tensor: QuantizedTensor, dim: usize, descending: bool) -> QuantizedTensor { // Default implementation. Backends can sort on the int values since qparams remain the same. dequant_op_quant!( ty Self, float_op |tensor| B::float_sort(tensor, dim, descending), tensor ) } /// Sort the elements of the input `tensor` by value in along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. fn q_sort_with_indices( tensor: QuantizedTensor, dim: usize, descending: bool, ) -> (QuantizedTensor, IntTensor) { // Default implementation. Backends can sort on the int values since qparams remain the same. let scheme = *tensor.scheme(); let tensor_f = Self::dequantize(tensor); let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending); (Self::quantize_dynamic(out_f, &scheme), indices) } /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. fn q_argsort(tensor: QuantizedTensor, dim: usize, descending: bool) -> IntTensor { // Default implementation. Backends can sort on the int values since qparams remain the same. let tensor_f = Self::dequantize(tensor); B::float_argsort(tensor_f, dim, descending) } } ================================================ FILE: crates/burn-backend/src/backend/ops/repeat_dim.rs ================================================ use crate::{ Backend, TensorMetadata, tensor::{BasicOps, TensorKind}, }; use alloc::vec::Vec; use burn_std::Slice; pub(crate) fn repeat_with_slice_assign + BasicOps>( tensor: K::Primitive, dim: usize, times: usize, ) -> K::Primitive { let shape = tensor.shape(); let device = K::device(&tensor); let dtype = tensor.dtype(); let original_dim_length = shape[dim]; let shape = shape.repeat(dim, times).unwrap(); let mut tensor_output = K::empty(shape.clone(), &device, dtype); let indices_select_all = shape.iter().map(|d| 0..*d).collect::>(); let mut output_index = 0; for _ in 0..times { let mut indices = indices_select_all.clone(); indices[dim] = output_index..output_index + original_dim_length; output_index += original_dim_length; // Convert ranges to Slice let slices: Vec = indices .iter() .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1)) .collect(); tensor_output = K::slice_assign(tensor_output, &slices, tensor.clone()); } tensor_output } ================================================ FILE: crates/burn-backend/src/backend/ops/sort.rs ================================================ use core::cmp::Ordering; use crate::{ Backend, DType, TensorData, element::{ElementConversion, ElementOrdered}, tensor::{BasicOps, IntElem, IntTensor}, }; use alloc::{vec, vec::Vec}; use burn_std::reader::try_read_sync; use burn_std::{bf16, f16}; /// Macro used to dispatch sort operations based on dtype. macro_rules! sort_dispatch_dtype { ($fn:ident, $data:ident, $($args:expr),*) => { match $data.dtype { DType::F64 => $fn::($data, $($args),*), DType::F32 | DType::Flex32 => $fn::($data, $($args),*), DType::F16 => $fn::($data, $($args),*), DType::BF16 => $fn::($data, $($args),*), DType::I64 => $fn::($data, $($args),*), DType::I32 => $fn::($data, $($args),*), DType::I16 => $fn::($data, $($args),*), DType::I8 => $fn::($data, $($args),*), DType::U64 => $fn::($data, $($args),*), DType::U32 => $fn::($data, $($args),*), DType::U16 => $fn::($data, $($args),*), DType::U8 => $fn::($data, $($args),*), DType::Bool(_) | DType::QFloat(_) => unimplemented!("not supported for sorting operations"), } }; } /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. /// /// # Remarks /// /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. pub fn sort>( tensor: K::Primitive, dim: usize, descending: bool, ) -> K::Primitive { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; let data = try_read_sync(K::into_data_async(tensor)) .expect(msg) .expect(msg); let data = sort_dispatch_dtype!(sort_data, data, dim, descending); K::from_data(data, &device) } pub fn sort_data( mut data: TensorData, dim: usize, descending: bool, ) -> TensorData { let dims = data.shape.clone(); let data_slice = data.as_mut_slice().unwrap(); if dims.len() == 1 { // 1D sort data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending)); } else { sort_slice::(data_slice, &dims, dim, None, false, descending); } data } /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. /// /// # Remarks /// /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. pub fn sort_with_indices>( tensor: K::Primitive, dim: usize, descending: bool, ) -> (K::Primitive, IntTensor) { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; let data = try_read_sync(K::into_data_async(tensor)) .expect(msg) .expect(msg); let (values, indices) = sort_dispatch_dtype!(sort_data_with_indices, data, dim, descending); ( K::from_data(values, &device), B::int_from_data(indices, &device), ) } fn sort_data_with_indices( mut data: TensorData, dim: usize, descending: bool, ) -> (TensorData, TensorData) { let dims = data.shape.clone(); let mut indices_data = dim_indices::(&dims, dim); let data_slice = data.as_mut_slice().unwrap(); if dims.len() == 1 { // 1D sort indices_data.sort_unstable_by(|&a, &b| { compare( &data_slice[a.elem::() as usize], &data_slice[b.elem::() as usize], descending, ) }); // Permute data in-place by the sorted indices let mut indices = indices_data .clone() .iter() .map(|i| i.elem::() as usize) .collect::>(); for idx in 0..indices.len() { if indices[idx] != idx { let mut current_idx = idx; loop { let target_idx = indices[current_idx]; indices[current_idx] = current_idx; if indices[target_idx] == target_idx { // correct position break; } // Permute data by indices data_slice.swap(current_idx, target_idx); current_idx = target_idx; } } } } else { sort_slice::( data_slice, &dims, dim, Some(&mut indices_data), true, descending, ); } (data, TensorData::new(indices_data, dims)) } /// Returns the indices that sort the elements of the input `tensor` along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. /// /// # Remarks /// /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. pub fn argsort>( tensor: K::Primitive, dim: usize, descending: bool, ) -> IntTensor { let device = K::device(&tensor); let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; let data = try_read_sync(K::into_data_async(tensor)) .expect(msg) .expect(msg); let data = sort_dispatch_dtype!(argsort_data, data, dim, descending); B::int_from_data(data, &device) } fn argsort_data( mut data: TensorData, dim: usize, descending: bool, ) -> TensorData { let dims = data.shape.clone(); let mut indices_data = dim_indices::(&dims, dim); if dims.len() == 1 { // 1D sort let slice = data.as_slice::().unwrap(); indices_data.sort_unstable_by(|&a, &b| { compare( &slice[a.elem::() as usize], &slice[b.elem::() as usize], descending, ) }); } else { sort_slice::( data.as_mut_slice().unwrap(), &dims, dim, Some(&mut indices_data), false, descending, ); } TensorData::new(indices_data, dims) } /// Sort the elements by value along a given dimension. /// /// When `indices` are not provided, the `data` is sorted. /// Otherwise, the `indices` are sorted based on the value of the elements in `data`, /// and if `permute_both` is enabled then the data is also sorted. /// /// This sort is unstable (i.e., may reorder equal elements). fn sort_slice( data: &mut [E], dims: &[usize], dim: usize, mut indices: Option<&mut [IntElem]>, permute_both: bool, descending: bool, ) { let ndims = dims.len(); let strides = compute_strides(dims); // Dimensions to access elements to sort let mut sort_dims = dims.to_vec(); sort_dims[dim] = 1; let strides_out = compute_strides(&sort_dims); // Number of groups to sort let num_sorts: usize = dims .iter() .enumerate() .filter(|&(i, _)| i != dim) .map(|(_, d)| d) .product(); // TODO: run each sort in parallel // run_par!(|| { // iter_range_par!(0, num_sorts).for_each(|id| {...}) for id in 0..num_sorts { let mut index_offset = 0; let mut stride_dim = 0; let mut shape_dim = 0; for d in 0..ndims { let stride_input = strides[d]; let stride_output = strides_out[d]; let shape_output = sort_dims[d]; let num_block = id / stride_output % shape_output; if d != dim { index_offset += num_block * stride_input; } else { let shape_input = dims[d]; stride_dim = stride_input; shape_dim = shape_input; index_offset += num_block; } } // For each group, sort the indices based on the element values // NOTE: Sorting methods like `sort_unstable_by` are in-place but we need to sort // different views/groups of the underlying data, so the swap is performed on the elements // of the (flat index, element value) collection. let mut elements = (0..shape_dim) .map(|d| { let flat_index = d * stride_dim + index_offset; let elem = data[flat_index]; (d, flat_index, elem) }) .collect::>(); elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending)); // Permute data in-place by the sorted indices for idx in 0..elements.len() { if elements[idx].0 != idx { let mut current_idx = idx; loop { let target_idx = elements[current_idx].0; elements[current_idx].0 = current_idx; if elements[target_idx].0 == target_idx { // correct position break; } if indices.is_none() || permute_both { // Permute data by indices data.swap(elements[current_idx].1, elements[target_idx].1); } if let Some(ref mut indices_data) = indices { // Permute data element indices indices_data.swap(elements[current_idx].1, elements[target_idx].1); } current_idx = target_idx; } } } } } /// Computes the steps for each dimension when traversing an array. fn compute_strides(dims: &[usize]) -> Vec { let mut strides = vec![0; dims.len()]; let mut current = 1; dims.iter().enumerate().rev().for_each(|(index, val)| { strides[index] = current; current *= val; }); strides } /// Generates the indices for each element along the specified dimension. fn dim_indices(dims: &[usize], dim: usize) -> Vec> { if dims.len() == 1 { (0..dims[dim]) .map(|i| (i as i64).elem::>()) .collect::>() } else { // Dimension indices tensor let numel_leading_dims: usize = dims[..dim].iter().product(); let numel_trailing_dims: usize = dims[dim + 1..].iter().product(); (0..dims[dim]) .map(|i| [(i as i64).elem::>()].repeat(numel_trailing_dims)) .collect::>() .concat() .repeat(numel_leading_dims) } } /// Compare two elements fn compare(a: &E, b: &E, descending: bool) -> Ordering { if descending { b.cmp(a) } else { a.cmp(b) } } ================================================ FILE: crates/burn-backend/src/backend/ops/tensor.rs ================================================ use super::cat::cat_with_slice_assign; use super::grid_sample::float_grid_sample_2d_ref; use super::repeat_dim::repeat_with_slice_assign; use super::sort::{argsort, sort, sort_with_indices}; use crate::ops::GridSampleOptions; use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor}; use crate::{Backend, Distribution, TensorData}; use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive}; use alloc::vec::Vec; use burn_std::{FloatDType, Shape, Slice}; /// Operations on float tensors. pub trait FloatTensorOps { /// Creates a new tensor from the data structure. /// /// # Arguments /// /// * `data` - The data structure. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given data. fn float_from_data(data: TensorData, device: &Device) -> FloatTensor; /// Creates a new tensor with random values. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `distribution` - The distribution to sample from. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// The tensor with the given shape and random values. fn float_random(shape: Shape, distribution: Distribution, device: &Device) -> FloatTensor; /// Creates a new tensor with zeros. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor with the given shape and zeros. fn float_zeros(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device) } /// Creates a new tensor with ones. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor with the given shape and ones. fn float_ones(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device) } /// Creates a tensor filled with given value. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `fill_value` - The value with which to fill the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor filled with given value fn float_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: FloatDType, ) -> FloatTensor { Self::float_from_data( TensorData::full_dtype(shape, fill_value, dtype.into()), device, ) } /// Converts the tensor to a data structure. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The data structure with the tensor's data. fn float_into_data( tensor: FloatTensor, ) -> impl Future> + Send; /// Gets the device of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The device of the tensor. fn float_device(tensor: &FloatTensor) -> Device; /// Moves the tensor to the given device. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `device` - The device to move the tensor to. /// /// # Returns /// /// The tensor on the given device. fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor; /// Converts float tensor to int tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The int tensor with the same data as the float tensor. fn float_into_int(tensor: FloatTensor) -> IntTensor; /// Creates an empty tensor with the given shape. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// * `dtype` - The target data type. /// /// # Returns /// /// The empty tensor with the given shape. fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor; /// Repeat the tensor along the given dimension. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension to repeat. /// * `times` - The number of times to repeat the dimension. /// /// # Returns /// /// The tensor with the given dimension repeated. fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { repeat_with_slice_assign::(TensorPrimitive::Float(tensor), dim, times).tensor() } /// Adds two tensors together. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of adding the two tensors together. fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Adds a scalar to a tensor. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of adding the scalar to the tensor. fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor; /// Clamps a tensor under a minimum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// /// # Returns /// /// The clamped tensor. fn float_clamp_min(tensor: FloatTensor, min: Scalar) -> FloatTensor { // Default implementation let mask = Self::float_lower_elem(tensor.clone(), min); B::float_mask_fill(tensor, mask, min) } /// Clamps a tensor over a maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn float_clamp_max(tensor: FloatTensor, max: Scalar) -> FloatTensor { // Default implementation let mask = Self::float_greater_elem(tensor.clone(), max); B::float_mask_fill(tensor, mask, max) } /// Clamps a tensor between a minimum and maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// * `max` - The maximum value. /// /// # Returns /// /// The clamped tensor. fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { // Default implementation Self::float_clamp_min(Self::float_clamp_max(tensor, max), min) } /// Subtracts two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of subtracting the two tensors. fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Subtracts a scalar from a tensor. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of subtracting the scalar from the tensor. fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor; /// Multiplies two tensors together element-wise. fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Multiplies a tensor by a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of multiplying the tensor by the scalar. fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor; /// Divides two tensors element-wise. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of dividing the two tensors. fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Divides a tensor by a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of dividing the tensor by the scalar. fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor; /// Computes the remainder of division between two tensors element-wise. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The element-wise remainder when dividing `lhs` by `rhs`. fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Computes the modulus of a tensor given a scalar. /// /// # Arguments /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The result of applying the modulus of the scalar to the tensor. fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor; /// Multiplies two tensors together using matrix multiplication. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The result of multiplying the two tensors together using matrix multiplication. fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Computes the cross product of two tensors along a given dimension. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// * `dim` - The dimension to compute the cross product along. /// /// # Returns /// /// The cross product of the two tensors. fn float_cross(lhs: FloatTensor, rhs: FloatTensor, dim: usize) -> FloatTensor; /// Negates a tensor element-wise. fn float_neg(tensor: FloatTensor) -> FloatTensor { Self::float_mul_scalar(tensor, (-1f32).into()) } /// Calculates the reciprocals element-wise fn float_recip(tensor: FloatTensor) -> FloatTensor; /// Transposes a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. fn float_transpose(tensor: FloatTensor) -> FloatTensor { let ndims = tensor.shape().num_dims(); Self::float_swap_dims(tensor, ndims - 2, ndims - 1) } /// Swaps two dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap. /// * `dim2` - The second dimension to swap. /// /// # Returns /// /// The tensor with the dimensions swapped. fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor; /// Permutes the dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to permute the dimensions of. /// * `axes` - The new order of the dimensions. /// # Returns /// /// The tensor with the dimensions permuted. fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor; /// Reverse the order of elements in a tensor along the given axes. /// /// # Arguments /// /// * `tensor` - The tensor to reverse. /// * `axes` - The axes to reverse. /// /// The tensor with the elements reversed. fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor; /// Reshapes a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to reshape. /// * `shape` - The new shape of the tensor. /// /// # Returns /// /// The tensor with the new shape. fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor; /// Gather elements from a tensor. /// /// # Arguments /// /// * `dim` - The dimension to gather from. /// * `tensor` - The tensor to gather from. /// * `indices` - The indices to gather. /// /// # Returns /// /// The gathered elements. fn float_gather(dim: usize, tensor: FloatTensor, indices: IntTensor) -> FloatTensor; /// Scatter elements into a tensor using sum reduction. /// /// # Arguments /// /// * `dim` - The dimension to scatter into. /// * `tensor` - The tensor to scatter into. /// * `indices` - The indices to scatter into. /// * `value` - The value to scatter. /// /// # Returns /// /// The tensor with the scattered elements. fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor; /// Select tensor elements along the given dimension corresponding for the given indices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `dim` - The dimension to select from. /// * `indices` - The indices to select. /// /// # Returns /// /// The selected elements. fn float_select(tensor: FloatTensor, dim: usize, indices: IntTensor) -> FloatTensor; /// Assign the selected elements along the given dimension corresponding for the given indices /// to the given value using sum reduction. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `dim` - The dimension to select from. /// * `indices` - The indices to select. /// * `value` - The value to assign. /// /// # Returns /// /// The tensor with the selected elements assigned to the given value. fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor; /// Select tensor elements corresponding to the given slices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `slices` - The slices specifying ranges and steps for each dimension. /// /// # Returns /// /// The selected elements in a new tensor. /// /// # Note /// /// Empty slices (where start >= end) are handled at the high-level tensor API and will not /// be passed to this method. Backend implementations do not need to handle empty slices. fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor; /// Assign the selected elements corresponding to the given slices to the given value. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `ranges` - The ranges to select. /// * `value` - The value to assign. /// /// # Returns /// /// The tensor with the selected elements assigned to the given value. /// /// # Note /// /// Empty slice assignments (where any slice range produces 0 elements) are handled at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty slice assignments. fn float_slice_assign( tensor: FloatTensor, slices: &[Slice], value: FloatTensor, ) -> FloatTensor; /// Update the given tensor with the value tensor where the mask is true. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `mask` - The boolean mask to select with. /// * `value` - The value to assign to the selected elements from the value tensor. /// /// # Returns /// /// The tensor with the selected elements assigned to the given value. fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor; /// Update the given tensor with the value where the mask is true. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `mask` - The boolean mask to select with. /// * `value` - The value to assign to the selected elements. /// /// # Returns /// /// The tensor with the selected elements assigned to the given value. fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor; /// Equal comparison of two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; /// Element-wise non-equality comparison. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_not_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let equal_tensor = B::float_equal(lhs, rhs); B::bool_not(equal_tensor) } /// Equal comparison of a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor; /// Element-wise non-equality comparison with a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_not_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let equal_tensor = B::float_equal_elem(lhs, rhs); B::bool_not(equal_tensor) } /// Greater than comparison of two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; /// Greater than comparison of a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor; /// Greater than or equal comparison of two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; /// Greater than or equal comparison of a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor; /// Less than comparison of two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; /// Less than comparison of a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor; /// Less than or equal comparison of two tensors. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor; /// Less than or equal comparison of a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// A boolean tensor with the result of the comparison. fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor; /// Detaches a tensor from the computation graph. fn float_detach(tensor: FloatTensor) -> FloatTensor { // Should only be overridden by autodiff backends. tensor } /// Sets the `require_grad` flag of a tensor. fn float_set_require_grad(tensor: FloatTensor, _require_grad: bool) -> FloatTensor { // Should only be overridden by autodiff backends. tensor } /// Returns the `require_grad` flag of a tensor. fn float_is_require_grad(_tensor: &FloatTensor) -> bool { // Should only be overridden by autodiff backends. false } /// Sum of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// /// # Returns /// /// A scalar tensor with the sum of all elements in `tensor`. fn float_sum(tensor: FloatTensor) -> FloatTensor; /// Sum of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// * `dim` - The dimension along which to sum. /// /// # Returns /// /// A tensor with the sum of all elements in `tensor` along `dim`. fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Product of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to product. /// /// # Returns /// /// A scalar tensor with the product of all elements in `tensor`. fn float_prod(tensor: FloatTensor) -> FloatTensor { // Product of all elements in a tensor B::float_exp(B::float_sum(B::float_log(tensor))) } /// Product of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to product. /// /// # Returns /// /// A tensor with the product of all elements in `tensor` along `dim`. fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { // Product of all elements in a tensor along a dimension B::float_exp(B::float_sum_dim(B::float_log(tensor), dim)) } /// Mean of all elements in a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to mean. /// /// # Returns /// /// A scalar tensor with the mean of all elements in `tensor`. fn float_mean(tensor: FloatTensor) -> FloatTensor { let num_elems = tensor.shape().num_elements() as f32; B::float_div_scalar(B::float_sum(tensor), num_elems.into()) } /// Mean of all elements in a tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to mean. /// * `dim` - The dimension along which to mean. /// /// # Returns /// /// A tensor with the mean of all elements in `tensor` along `dim`. fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Computes the cumulative sum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative sum of. /// * `dim` - The dimension along which to compute the cumulative sum. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative sum /// of all elements up to and including that position along the dimension. fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Computes the cumulative product of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative product of. /// * `dim` - The dimension along which to compute the cumulative product. /// /// # Returns /// /// A tensor with the same shape where each element is the cumulative product /// of all elements up to and including that position along the dimension. fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Computes the cumulative minimum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative minimum of. /// * `dim` - The dimension along which to compute the cumulative minimum. /// /// # Returns /// /// A tensor with the same shape where each element is the minimum /// of all elements up to and including that position along the dimension. fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Computes the cumulative maximum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative maximum of. /// * `dim` - The dimension along which to compute the cumulative maximum. /// /// # Returns /// /// A tensor with the same shape where each element is the maximum /// of all elements up to and including that position along the dimension. fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor; /// Converts a tensor to another floating point data type. /// /// # Arguments /// /// * `tensor` - The tensor to convert. /// * `dtype` - The target data type. /// /// # Returns /// /// A tensor with the same values as `tensor` but in the target floating point data type. fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor; /// Returns a new tensor with exponential values. /// /// # Arguments /// /// * `tensor` - The tensor to exponentiate. /// /// # Returns /// /// A tensor with the same shape as `tensor` with exponential values. fn float_exp(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with natural logarithm values. /// /// # Arguments /// /// * `tensor` - The tensor to take the logarithm of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with natural logarithm values. fn float_log(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with logarithm values of (1 + Xi). /// /// # Arguments /// /// * `tensor` - The tensor to take the logarithm of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). fn float_log1p(tensor: FloatTensor) -> FloatTensor; /// Element-wise power with a FloatTensor. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side tensor. /// /// # Returns /// /// The elements of `lhs` raised to the power of the elements of `rhs`. fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Element-wise power with an IntTensor. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side floatTensor. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. fn float_powi(lhs: FloatTensor, rhs: IntTensor) -> FloatTensor { Self::float_powf(lhs, B::int_into_float(rhs)) } /// Raises a tensor to the power of an int scalar. /// /// # Backend Implementors Note /// /// A number of common exponent cases can be implemented with operations /// which are much cheaper than generic exponentiation. /// /// This (`Backend` impl overridable) operation handles generic optimizations /// for several common integer exponent cases; and then dispatches to /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`] /// operation to handle the generic case. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. fn float_powi_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { match rhs.elem::() { 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()), 1 => lhs, 2 => B::float_mul(lhs.clone(), lhs), -1 => Self::float_recip(lhs), -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)), _ => Self::float_powi_scalar_impl(lhs, rhs), } } /// Raises a tensor to the power of an int scalar. /// /// # Backend Implementors Note /// /// This is the generic implementation of integer exponentiation /// called by [`Self::float_powi_scalar`] in the fallback case. /// /// As a general rule, this should not be called directly. /// /// # Arguments /// /// * `lhs` - The left-hand side tensor. /// * `rhs` - The right-hand side scalar. /// /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. fn float_powi_scalar_impl(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { // Avoid a recursive loop by deferring directly to float_powf_scalar_impl. Self::float_powf_scalar_impl(lhs, rhs) } /// Returns a new tensor with values raised to the power of float `value`. /// /// # Backend Implementors Note /// /// This (`Backend` impl overridable) operation dispatches integer exponentiation /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`] /// operation to handle the generic case. /// /// # Arguments /// /// * `tensor` - The tensor to exponentiate. /// * `value` - The exponent. /// /// # Returns /// /// A tensor with the same shape as `tensor` with values raised to the power of `value`. fn float_powf_scalar(tensor: FloatTensor, value: Scalar) -> FloatTensor { if let Some(exp) = value.try_as_integer() { Self::float_powi_scalar(tensor, exp) } else { Self::float_powf_scalar_impl(tensor, value) } } /// Returns a new tensor with values raised to the power of float `value`. /// /// # Backend Implementors Note /// /// This is the generic implementation of integer exponentiation /// called by [`Self::float_powf_scalar`] in the fallback case. /// /// This is the minimal required support a `Backend` must implement /// for exponentiation. /// /// As a general rule, this should not be called directly. /// /// # Arguments /// /// * `tensor` - The tensor to exponentiate. /// * `value` - The exponent. /// /// # Returns /// /// A tensor with the same shape as `tensor` with values raised to the power of `value`. fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor; /// Returns a new tensor with square root values. /// /// # Arguments /// /// * `tensor` - The tensor to take the square root of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with square root values. fn float_sqrt(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with absolute values. /// /// # Arguments /// /// * `tensor` - The tensor to take absolute value of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with absolute values. fn float_abs(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with cosine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the cosine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with cosine values. fn float_cos(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with sine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the sine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with sine values. fn float_sin(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with tangent values. /// /// # Arguments /// /// * `tensor` - The tensor to take the tangent of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with tangent values. fn float_tan(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with hyperbolic cosine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic cosine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic cosine values. fn float_cosh(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with hyperbolic sine values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic sine of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic sine values. fn float_sinh(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with hyperbolic tangent values. /// /// # Arguments /// /// * `tensor` - The tensor to take the hyperbolic tangent of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic tangent values. fn float_tanh(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with inverse cosine values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with inverse cosine values. fn float_acos(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with inverse hyperbolic cosine values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values. fn float_acosh(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with inverse sine values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with inverse sine values. fn float_asin(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with inverse hyperbolic sine values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values. fn float_asinh(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with the inverse tangent values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with the inverse tangent values. fn float_atan(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with the inverse hyperbolic tangent values. /// /// # Arguments /// /// * `tensor` - The input tensor. /// /// # Returns /// /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values. fn float_atanh(tensor: FloatTensor) -> FloatTensor; /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`. /// /// # Arguments /// /// * `lhs` - The tensor with y coordinates. /// * `rhs` - The tensor with x coordinates. /// /// # Returns /// /// A tensor with the four-quadrant inverse tangent values. fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor; /// Returns a new tensor with rounded values. /// /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even) /// strategy, with halfway cases rounded to the nearest even integer value. /// /// # Arguments /// /// * `tensor` - The tensor to be rounded. /// /// # Returns /// /// A tensor with the same shape as `tensor` with rounded values. fn float_round(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with floored values. /// /// # Arguments /// /// * `tensor` - The tensor to be floored. /// /// # Returns /// /// A tensor with the same shape as `tensor` with floored values. fn float_floor(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with ceiled values. /// /// # Arguments /// /// * `tensor` - The tensor to be ceiled. /// /// # Returns /// /// A tensor with the same shape as `tensor` with ceiled values. fn float_ceil(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with truncated values. /// /// # Arguments /// /// * `tensor` - The tensor to be truncated. /// /// # Returns /// /// A tensor with the same shape as `tensor` with truncated values. fn float_trunc(tensor: FloatTensor) -> FloatTensor; /// Returns a new tensor with the error function values. /// /// # Arguments /// /// * `tensor` - The tensor to take the error function of. /// /// # Returns /// /// A tensor with the same shape as `tensor` with error function values. fn float_erf(tensor: FloatTensor) -> FloatTensor; /// Concatenates tensors along a dimension. /// /// # Arguments /// /// * `tensors` - The tensors to concatenate. /// * `dim` - The dimension along which to concatenate. /// /// # Returns /// /// A tensor with the concatenated tensors along `dim`. /// /// # Note /// /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the /// high-level tensor API and will not be passed to this method. Backend implementations do /// not need to handle empty tensors. fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { cat_with_slice_assign::( tensors.into_iter().map(TensorPrimitive::Float).collect(), dim, ) .tensor() } /// Gets the indices of the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the indices of the maximum elements of `tensor` along `dim`. fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor; /// Gets the indices of the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tensor with the indices of the minimum elements of `tensor` along `dim`. fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor; /// Gets the maximum element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// /// # Returns /// /// A tensor with the maximum element of `tensor`. fn float_max(tensor: FloatTensor) -> FloatTensor { let shape = tensor.shape(); let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); B::float_max_dim(tensor, 0) } /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the maximum elements of `tensor` along `dim`. fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { let index = B::float_argmax(tensor.clone(), dim); B::float_gather(dim, tensor, index) } /// Gets the maximum elements of a tensor along an axis and their indices. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tuple with the maximum elements of `tensor` along `dim` and their indices. fn float_max_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { let index = B::float_argmax(tensor.clone(), dim); let values = B::float_gather(dim, tensor, index.clone()); (values, index) } /// Gets the minimum element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// /// # Returns /// /// A tensor with the minimum element of `tensor`. fn float_min(tensor: FloatTensor) -> FloatTensor { let shape = tensor.shape(); let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); B::float_min_dim(tensor, 0) } /// Gets the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tensor with the minimum elements of `tensor` along `dim`. fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { let index = B::float_argmin(tensor.clone(), dim); B::float_gather(dim, tensor, index) } /// Gets the minimum elements of a tensor along an axis and their indices. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements of. /// * `dim` - The dimension along which to get the minimum elements. /// /// # Returns /// /// A tuple with the minimum elements of `tensor` along `dim` and their indices. fn float_min_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { let index = B::float_argmin(tensor.clone(), dim); let values = B::float_gather(dim, tensor, index.clone()); (values, index) } /// Gets the maximum absolute element of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// /// # Returns /// /// A tensor with the maximum element of `tensor`. fn float_max_abs(tensor: FloatTensor) -> FloatTensor { let shape = tensor.shape(); let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); B::float_max_abs_dim(tensor, 0) } /// Gets the maximum absolute elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements of. /// * `dim` - The dimension along which to get the maximum elements. /// /// # Returns /// /// A tensor with the maximum elements of `tensor` along `dim`. fn float_max_abs_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { B::float_max_dim(B::float_abs(tensor), dim) } /// Tests if any element in the float `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. fn float_any(tensor: FloatTensor) -> BoolTensor { let bool_tensor = B::float_equal_elem(tensor, 0f32.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum(B::bool_into_float(bool_tensor)); B::float_greater_elem(sum, 0f32.into()) } /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the /// input evaluates to True, False otherwise. fn float_any_dim(tensor: FloatTensor, dim: usize) -> BoolTensor { let bool_tensor = B::float_equal_elem(tensor, 0f32.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim); B::float_greater_elem(sum, 0f32.into()) } /// Tests if all elements in the float `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. fn float_all(tensor: FloatTensor) -> BoolTensor { let num_elems = tensor.shape().num_elements() as f32; let bool_tensor = B::float_equal_elem(tensor, 0f32.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum(B::bool_into_float(bool_tensor)); B::float_equal_elem(sum, num_elems.into()) } /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. fn float_all_dim(tensor: FloatTensor, dim: usize) -> BoolTensor { let num_elems = tensor.shape()[dim] as f32; let bool_tensor = B::float_equal_elem(tensor, 0f32.into()); let bool_tensor = B::bool_not(bool_tensor); let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim); B::float_equal_elem(sum, num_elems.into()) } /// Returns the signs of the float `tensor`. /// /// # Arguments /// /// * `tensor` - The tensor to extract the signs from. /// /// # Returns /// /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. fn float_sign(tensor: FloatTensor) -> FloatTensor { let zeros = B::float_zeros( tensor.shape(), &B::float_device(&tensor), tensor.dtype().into(), ); let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into()); let greater_than_zero = B::float_greater_elem(tensor, 0f32.into()); let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into()); result = B::float_mask_fill(result, greater_than_zero, 1f32.into()); result } /// Broadcasts the float `tensor` to the given `shape`. fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor; /// Sort the elements of the input `tensor` by value in along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. fn float_sort(tensor: FloatTensor, dim: usize, descending: bool) -> FloatTensor { sort::(TensorPrimitive::Float(tensor), dim, descending).tensor() } /// Sort the elements of the input `tensor` by value in along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. fn float_sort_with_indices( tensor: FloatTensor, dim: usize, descending: bool, ) -> (FloatTensor, IntTensor) { let (values, indices) = sort_with_indices::(TensorPrimitive::Float(tensor), dim, descending); (values.tensor(), indices) } /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. fn float_argsort(tensor: FloatTensor, dim: usize, descending: bool) -> IntTensor { argsort::(TensorPrimitive::Float(tensor), dim, descending) } /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values, /// using the given locations in [-1, 1]. /// /// # Arguments /// /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in) /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right /// * `options` - Grid sampling options (mode, padding_mode, align_corners) /// /// # Returns /// /// A tensor with shape (N, C, H_out, W_out) fn float_grid_sample_2d( tensor: FloatTensor, grid: FloatTensor, options: GridSampleOptions, ) -> FloatTensor { float_grid_sample_2d_ref::(tensor, grid, options) } /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the selected dim. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, size, post=...]``. fn float_unfold(tensor: FloatTensor, dim: usize, size: usize, step: usize) -> FloatTensor; /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// /// # Returns /// /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value. fn float_is_nan(tensor: FloatTensor) -> BoolTensor { // Check if the input tensor is NaN by comparing it to itself // NaN is the only value that is not equal to itself B::float_not_equal(tensor.clone(), tensor) } /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF). /// /// # Returns /// /// A boolean tensor where `true` indicates that the value is infinite fn float_is_inf(tensor: FloatTensor) -> BoolTensor { B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into()) } } ================================================ FILE: crates/burn-backend/src/backend/ops/transaction.rs ================================================ use alloc::vec::Vec; use core::future::Future; use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; use crate::{Backend, ExecutionError, TensorData, TensorPrimitive}; enum Order { Float(usize), QFloat(usize), Int(usize), Bool(usize), } #[derive(Default)] /// Contains all tensor primitives that are going to be read. pub struct TransactionPrimitive { /// Float tensors. pub read_floats: Vec>, /// Quantized tensors. pub read_qfloats: Vec>, /// Int tensors. pub read_ints: Vec>, /// Bool tensors. pub read_bools: Vec>, orders: Vec, } #[derive(Default)] /// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive). pub struct TransactionPrimitiveData { /// Float tensor data. pub read_floats: Vec, /// Quantized tensor data. pub read_qfloats: Vec, /// Int tensor data. pub read_ints: Vec, /// Bool tensor data. pub read_bools: Vec, } /// Operations that are sync by nature and that can be batch together in transactions to improve /// compute utilization with efficient laziness. pub trait TransactionOps { /// Executes a [transaction](TransactionPrimitive) and return its /// [data](TransactionPrimitiveData). fn tr_execute( transaction: TransactionPrimitive, ) -> impl Future> + Send { async move { let mut floats = Vec::new(); let mut qfloats = Vec::new(); let mut ints = Vec::new(); let mut bools = Vec::new(); for t in transaction.read_floats { floats.push(B::float_into_data(t).await?); } for t in transaction.read_qfloats { qfloats.push(B::q_into_data(t).await?); } for t in transaction.read_ints { ints.push(B::int_into_data(t).await?); } for t in transaction.read_bools { bools.push(B::bool_into_data(t).await?); } Ok(TransactionPrimitiveData { read_floats: floats, read_qfloats: qfloats, read_ints: ints, read_bools: bools, }) } } } impl TransactionPrimitive { /// Creates a new transaction. pub fn new( read_floats: Vec>, read_qfloats: Vec>, read_ints: Vec>, read_bools: Vec>, ) -> Self { Self { read_floats, read_qfloats, read_ints, read_bools, orders: Vec::default(), } } /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order /// in which they were [registered](crate::tensor::BasicOps::register_transaction). pub async fn execute_async(mut self) -> Result, ExecutionError> { let mut orders = Vec::new(); core::mem::swap(&mut orders, &mut self.orders); let result = B::tr_execute(self).await?; let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect(); let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect(); let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect(); let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect(); Ok(orders .into_iter() .map(|order| match order { Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(), Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(), Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(), Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(), }) .collect::>()) } pub(crate) fn register_float(&mut self, tensor: TensorPrimitive) { match tensor { TensorPrimitive::Float(tensor) => { self.orders.push(Order::Float(self.read_floats.len())); self.read_floats.push(tensor); } TensorPrimitive::QFloat(tensor) => { self.orders.push(Order::QFloat(self.read_qfloats.len())); self.read_qfloats.push(tensor); } } } pub(crate) fn register_int(&mut self, tensor: IntTensor) { self.orders.push(Order::Int(self.read_ints.len())); self.read_ints.push(tensor); } pub(crate) fn register_bool(&mut self, tensor: BoolTensor) { self.orders.push(Order::Bool(self.read_bools.len())); self.read_bools.push(tensor); } } ================================================ FILE: crates/burn-backend/src/backend/primitive.rs ================================================ use crate::Backend; use burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme}; use burn_std::{DType, Shape}; #[derive(Debug, Clone)] /// A primitive tensor representation. pub enum TensorPrimitive { /// Float tensor primitive. Float(B::FloatTensorPrimitive), /// Quantized float tensor primitive. QFloat(B::QuantizedTensorPrimitive), } impl TensorPrimitive { /// Returns the full tensor representation. pub fn tensor(self) -> B::FloatTensorPrimitive { match self { Self::QFloat(tensor) => B::dequantize(tensor), Self::Float(tensor) => tensor, } } } impl TensorMetadata for TensorPrimitive { fn dtype(&self) -> DType { match self { TensorPrimitive::Float(tensor) => tensor.dtype(), TensorPrimitive::QFloat(tensor) => tensor.dtype(), } } fn shape(&self) -> Shape { match self { TensorPrimitive::Float(tensor) => tensor.shape(), TensorPrimitive::QFloat(tensor) => tensor.shape(), } } fn rank(&self) -> usize { match self { TensorPrimitive::Float(tensor) => tensor.rank(), TensorPrimitive::QFloat(tensor) => tensor.rank(), } } } /// Tensor metadata trait for tensor primitive. pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug { /// The dtype of the tensor. fn dtype(&self) -> DType; /// The shape of the tensor. fn shape(&self) -> Shape; /// The number of dimensions of the tensor. fn rank(&self) -> usize { self.shape().num_dims() } } /// Quantized tensor primitive. pub trait QTensorPrimitive { /// Returns the quantization settings for the given tensor. fn scheme(&self) -> &QuantScheme; /// The precision used for the accumulation in various kernels. fn acc_precision(&self) -> QuantAcc { QuantAcc::F32 } /// How quantization is propagated during computation. fn propagation(&self) -> QuantPropagation { QuantPropagation::Inhibit } /// Returns the default tensor quantization scheme. fn default_scheme() -> QuantScheme { QuantScheme::default() } } ================================================ FILE: crates/burn-backend/src/data/compare.rs ================================================ use alloc::format; use alloc::string::String; use burn_std::{BoolStore, DType, bf16, f16}; use num_traits::{Float, ToPrimitive}; use super::TensorData; use crate::{Element, ElementOrdered}; /// The tolerance used to compare to floating point numbers. /// /// Generally, two numbers `x` and `y` are approximately equal if /// /// ```text /// |x - y| < max(R * (|x + y|), A) /// ``` /// /// where `R` is the relative tolerance and `A` is the absolute tolerance. /// /// /// The most common way to initialize this struct is to use `Tolerance::::default()`. /// In that case, the relative and absolute tolerances are computed using an heuristic based /// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`. /// /// Another common initialization is `Tolerance::::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`. /// This will use a sane default to manage values too close to 0.0 and /// use different relative tolerances depending on the floating point precision. #[derive(Debug, Clone, Copy)] pub struct Tolerance { relative: F, absolute: F, } impl Default for Tolerance { fn default() -> Self { Self::balanced() } } impl Tolerance { /// Create a tolerance with strict precision setting. pub fn strict() -> Self { Self { relative: F::from(0.00).unwrap(), absolute: F::from(64).unwrap() * F::min_positive_value(), } } /// Create a tolerance with balanced precision setting. pub fn balanced() -> Self { Self { relative: F::from(0.005).unwrap(), // 0.5% absolute: F::from(1e-5).unwrap(), } } /// Create a tolerance with permissive precision setting. pub fn permissive() -> Self { Self { relative: F::from(0.01).unwrap(), // 1.0% absolute: F::from(0.01).unwrap(), } } /// When comparing two numbers, this uses both the relative and absolute differences. /// /// That is, `x` and `y` are approximately equal if /// /// ```text /// |x - y| < max(R * (|x + y|), A) /// ``` /// /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance. pub fn rel_abs(relative: FF, absolute: FF) -> Self { let relative = Self::check_relative(relative); let absolute = Self::check_absolute(absolute); Self { relative, absolute } } /// When comparing two numbers, this uses only the relative difference. /// /// That is, `x` and `y` are approximately equal if /// /// ```text /// |x - y| < R * max(|x|, |y|) /// ``` /// /// where `R` is the relative `tolerance`. pub fn relative(tolerance: FF) -> Self { let relative = Self::check_relative(tolerance); Self { relative, absolute: F::from(0.0).unwrap(), } } /// When comparing two numbers, this uses only the absolute difference. /// /// That is, `x` and `y` are approximately equal if /// /// ```text /// |x - y| < A /// ``` /// /// where `A` is the absolute `tolerance`. pub fn absolute(tolerance: FF) -> Self { let absolute = Self::check_absolute(tolerance); Self { relative: F::from(0.0).unwrap(), absolute, } } /// Change the relative tolerance to the given one. pub fn set_relative(mut self, tolerance: FF) -> Self { self.relative = Self::check_relative(tolerance); self } /// Change the relative tolerance to the given one only if `F` is half precision. pub fn set_half_precision_relative(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 2 { self.relative = Self::check_relative(tolerance); } self } /// Change the relative tolerance to the given one only if `F` is single precision. pub fn set_single_precision_relative(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 4 { self.relative = Self::check_relative(tolerance); } self } /// Change the relative tolerance to the given one only if `F` is double precision. pub fn set_double_precision_relative(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 8 { self.relative = Self::check_relative(tolerance); } self } /// Change the absolute tolerance to the given one. pub fn set_absolute(mut self, tolerance: FF) -> Self { self.absolute = Self::check_absolute(tolerance); self } /// Change the absolute tolerance to the given one only if `F` is half precision. pub fn set_half_precision_absolute(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 2 { self.absolute = Self::check_absolute(tolerance); } self } /// Change the absolute tolerance to the given one only if `F` is single precision. pub fn set_single_precision_absolute(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 4 { self.absolute = Self::check_absolute(tolerance); } self } /// Change the absolute tolerance to the given one only if `F` is double precision. pub fn set_double_precision_absolute(mut self, tolerance: FF) -> Self { if core::mem::size_of::() == 8 { self.absolute = Self::check_absolute(tolerance); } self } /// Checks if `x` and `y` are approximately equal given the tolerance. pub fn approx_eq(&self, x: F, y: F) -> bool { // See the accepted answer here // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison // This also handles the case where both a and b are infinity so that we don't need // to manage it in the rest of the function. if x == y { return true; } let diff = (x - y).abs(); let max = F::max(x.abs(), y.abs()); diff < self.absolute.max(self.relative * max) } fn check_relative(tolerance: FF) -> F { let tolerance = F::from(tolerance).unwrap(); assert!(tolerance <= F::one()); tolerance } fn check_absolute(tolerance: FF) -> F { let tolerance = F::from(tolerance).unwrap(); assert!(tolerance >= F::zero()); tolerance } } impl TensorData { /// Asserts the data is equal to another data. /// /// # Arguments /// /// * `other` - The other data. /// * `strict` - If true, the data types must the be same. /// Otherwise, the comparison is done in the current data type. /// /// # Panics /// /// Panics if the data is not equal. #[track_caller] pub fn assert_eq(&self, other: &Self, strict: bool) { if strict { assert_eq!( self.dtype, other.dtype, "Data types differ ({:?} != {:?})", self.dtype, other.dtype ); } match self.dtype { DType::F64 => self.assert_eq_elem::(other), DType::F32 | DType::Flex32 => self.assert_eq_elem::(other), DType::F16 => self.assert_eq_elem::(other), DType::BF16 => self.assert_eq_elem::(other), DType::I64 => self.assert_eq_elem::(other), DType::I32 => self.assert_eq_elem::(other), DType::I16 => self.assert_eq_elem::(other), DType::I8 => self.assert_eq_elem::(other), DType::U64 => self.assert_eq_elem::(other), DType::U32 => self.assert_eq_elem::(other), DType::U16 => self.assert_eq_elem::(other), DType::U8 => self.assert_eq_elem::(other), DType::Bool(BoolStore::Native) => self.assert_eq_elem::(other), DType::Bool(BoolStore::U8) => self.assert_eq_elem::(other), DType::Bool(BoolStore::U32) => self.assert_eq_elem::(other), DType::QFloat(q) => { // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality let q_other = if let DType::QFloat(q_other) = other.dtype { q_other } else { panic!("Quantized data differs from other not quantized data") }; // Data equality mostly depends on input quantization type, but we also check level if q.value == q_other.value && q.level == q_other.level { self.assert_eq_elem::(other) } else { panic!("Quantization schemes differ ({q:?} != {q_other:?})") } } } } #[track_caller] fn assert_eq_elem(&self, other: &Self) { let mut message = String::new(); if self.shape != other.shape { message += format!( "\n => Shape is different: {:?} != {:?}", self.shape, other.shape ) .as_str(); } let mut num_diff = 0; let max_num_diff = 5; for (i, (a, b)) in self.iter::().zip(other.iter::()).enumerate() { if !a.eq(&b) { // Only print the first 5 different values. if num_diff < max_num_diff { message += format!("\n => Position {i}: {a} != {b}").as_str(); } num_diff += 1; } } if num_diff >= max_num_diff { message += format!("\n{} more errors...", num_diff - max_num_diff).as_str(); } if !message.is_empty() { panic!("Tensors are not eq:{message}"); } } /// Asserts the data is approximately equal to another data. /// /// # Arguments /// /// * `other` - The other data. /// * `tolerance` - The tolerance of the comparison. /// /// # Panics /// /// Panics if the data is not approximately equal. #[track_caller] pub fn assert_approx_eq(&self, other: &Self, tolerance: Tolerance) { let mut message = String::new(); if self.shape != other.shape { message += format!( "\n => Shape is different: {:?} != {:?}", self.shape, other.shape ) .as_str(); } let iter = self.iter::().zip(other.iter::()); let mut num_diff = 0; let max_num_diff = 5; for (i, (a, b)) in iter.enumerate() { //if they are both nan, then they are equally nan let both_nan = a.is_nan() && b.is_nan(); //this works for both infinities let both_inf = a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero())); if both_nan || both_inf { continue; } if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) { // Only print the first 5 different values. if num_diff < max_num_diff { let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap(); let max = F::max(a.abs(), b.abs()); let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap(); let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap(); let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap(); message += format!( "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})" ) .as_str(); } num_diff += 1; } } if num_diff >= max_num_diff { message += format!("\n{} more errors...", num_diff - 5).as_str(); } if !message.is_empty() { panic!("Tensors are not approx eq:{message}"); } } /// Asserts each value is within a given range. /// /// # Arguments /// /// * `range` - The range. /// /// # Panics /// /// If any value is not within the half-open range bounded inclusively below /// and exclusively above (`start..end`). pub fn assert_within_range(&self, range: core::ops::Range) { for elem in self.iter::() { if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() { panic!("Element ({elem:?}) is not within range {range:?}"); } } } /// Asserts each value is within a given inclusive range. /// /// # Arguments /// /// * `range` - The range. /// /// # Panics /// /// If any value is not within the half-open range bounded inclusively (`start..=end`). pub fn assert_within_range_inclusive( &self, range: core::ops::RangeInclusive, ) { let start = range.start(); let end = range.end(); for elem in self.iter::() { if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() { panic!("Element ({elem:?}) is not within range {range:?}"); } } } } #[cfg(test)] mod tests { use super::*; #[test] fn should_assert_appox_eq_limit() { let data1 = TensorData::from([[3.0, 5.0, 6.0]]); let data2 = TensorData::from([[3.03, 5.0, 6.0]]); data1.assert_approx_eq::(&data2, Tolerance::absolute(3e-2)); data1.assert_approx_eq::(&data2, Tolerance::absolute(3e-2)); } #[test] #[should_panic] fn should_assert_approx_eq_above_limit() { let data1 = TensorData::from([[3.0, 5.0, 6.0]]); let data2 = TensorData::from([[3.031, 5.0, 6.0]]); data1.assert_approx_eq::(&data2, Tolerance::absolute(1e-2)); } #[test] #[should_panic] fn should_assert_approx_eq_check_shape() { let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]); let data2 = TensorData::from([[3.0, 5.0, 6.0]]); data1.assert_approx_eq::(&data2, Tolerance::absolute(1e-2)); } } ================================================ FILE: crates/burn-backend/src/data/mod.rs ================================================ mod compare; mod tensor; pub use compare::*; pub use tensor::*; ================================================ FILE: crates/burn-backend/src/data/tensor.rs ================================================ use core::f32; use alloc::boxed::Box; use alloc::format; use alloc::string::String; use alloc::vec::Vec; use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError}; use rand::Rng; use thiserror::Error; use crate::Scalar; use crate::distribution::Distribution; use crate::element::{Element, ElementConversion}; use burn_std::tensor::DType; use burn_std::{ BoolStore, Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, Shape, bf16, f16, }; use serde::{Deserialize, Serialize}; /// Data structure for tensors. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TensorData { /// The values of the tensor (as bytes). pub bytes: Bytes, /// The shape of the tensor. #[serde(with = "shape_inner")] pub shape: Shape, /// The data type of the tensor. pub dtype: DType, } // For backward compatibility with shape `Vec` mod shape_inner { use burn_std::SmallVec; use super::*; pub fn serialize( shape: &Shape, serializer: S, ) -> Result { shape.as_slice().serialize(serializer) } pub fn deserialize<'de, D: serde::Deserializer<'de>>( deserializer: D, ) -> Result { let dims = SmallVec::<[usize; _]>::deserialize(deserializer)?; Ok(Shape::new_raw(dims)) } } impl TensorData { /// Creates a new tensor data structure. pub fn new>(value: Vec, shape: S) -> Self { // Ensure shape is valid let shape = shape.into(); Self::check_data_len(&value, &shape); Self { bytes: Bytes::from_elems(value), shape, dtype: E::dtype(), } } /// Creates a new quantized tensor data structure. pub fn quantized>( value: Vec, shape: S, scheme: QuantScheme, qparams: &[f32], ) -> Self { let shape = shape.into(); Self::check_data_len(&value, &shape); let q_bytes = QuantizedBytes::new(value, scheme, qparams); Self { bytes: q_bytes.bytes, shape, dtype: DType::QFloat(q_bytes.scheme), } } /// Creates a new tensor data structure from raw bytes. pub fn from_bytes>(bytes: Bytes, shape: S, dtype: DType) -> Self { Self { bytes, shape: shape.into(), dtype, } } /// Creates a new tensor data structure from raw bytes stored in a vector. /// /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are /// certain that the bytes representation is valid. pub fn from_bytes_vec>(bytes: Vec, shape: S, dtype: DType) -> Self { Self { bytes: Bytes::from_bytes_vec(bytes), shape: shape.into(), dtype, } } // Check that the input vector contains a correct number of elements fn check_data_len(data: &[E], shape: &Shape) { let expected_data_len = Self::numel(shape); let num_data = data.len(); assert_eq!( expected_data_len, num_data, "Shape {shape:?} is invalid for input of size {num_data:?}", ); } /// Returns the immutable slice view of the tensor data. pub fn as_slice(&self) -> Result<&[E], DataError> { if self.matches_target_dtype::() { match E::dtype() { // The only way to create a bool `TensorData` with invalid values is by unsafely modifying // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. DType::Bool(BoolStore::Native) => { let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes) .map_err(DataError::CastError)?; Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) }) } _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError), } } else { Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", self.dtype, E::dtype() ))) } } /// Returns the mutable slice view of the tensor data. /// /// # Panics /// If the target element type is different from the stored element type. pub fn as_mut_slice(&mut self) -> Result<&mut [E], DataError> { if self.matches_target_dtype::() { match E::dtype() { // The only way to create a bool `TensorData` with invalid values is by unsafely modifying // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. DType::Bool(BoolStore::Native) => { let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes) .map_err(DataError::CastError)?; Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) }) } _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes) .map_err(DataError::CastError), } } else { Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", self.dtype, E::dtype() ))) } } /// Returns the tensor data as a vector of scalar values. pub fn to_vec(&self) -> Result, DataError> { Ok(self.as_slice()?.to_vec()) } /// Returns the tensor data as a vector of scalar values. pub fn into_vec(self) -> Result, DataError> { // This means we cannot call `into_vec` for QFloat if !self.matches_target_dtype::() { return Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", self.dtype, E::dtype() ))); } match E::dtype() { // The only way to create a bool `TensorData` with invalid values is by unsafely modifying // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. DType::Bool(BoolStore::Native) => { let vec = self.into_vec_unchecked::()?; Ok(unsafe { core::mem::transmute::, Vec>(vec) }) } _ => self.into_vec_unchecked(), } } /// Returns the tensor data as a vector of scalar values. Does not check dtype. fn into_vec_unchecked(self) -> Result, DataError> { let mut me = self; me.bytes = match me.bytes.try_into_vec::() { Ok(elems) => return Ok(elems), Err(bytes) => bytes, }; // The bytes might have been deserialized and allocated with a different align. // In that case, we have to memcopy the data into a new vector, more suitably allocated Ok(bytemuck::checked::try_cast_slice(me.as_bytes()) .map_err(DataError::CastError)? .to_vec()) } fn matches_target_dtype(&self) -> bool { let target_dtype = E::dtype(); match self.dtype { DType::Bool(BoolStore::U8) => { matches!(target_dtype, DType::U8 | DType::Bool(BoolStore::U8)) } DType::Bool(BoolStore::U32) => { matches!(target_dtype, DType::U32 | DType::Bool(BoolStore::U32)) } dtype => dtype == target_dtype, } } /// Returns an iterator over the values of the tensor data. pub fn iter(&self) -> Box + '_> { if E::dtype() == self.dtype { Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied()) } else { match self.dtype { DType::I8 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &i8| e.elem::()), ), DType::I16 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &i16| e.elem::()), ), DType::I32 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &i32| e.elem::()), ), DType::I64 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &i64| e.elem::()), ), DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::())), DType::U16 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &u16| e.elem::()), ), DType::U32 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &u32| e.elem::()), ), DType::U64 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &u64| e.elem::()), ), DType::BF16 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &bf16| e.elem::()), ), DType::F16 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &f16| e.elem::()), ), DType::F32 | DType::Flex32 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &f32| e.elem::()), ), DType::F64 => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &f64| e.elem::()), ), // bool is a byte value equal to either 0 or 1 DType::Bool(BoolStore::Native) | DType::Bool(BoolStore::U8) => { Box::new(self.bytes.iter().map(|e| e.elem::())) } DType::Bool(BoolStore::U32) => Box::new( bytemuck::checked::cast_slice(&self.bytes) .iter() .map(|e: &u32| e.elem::()), ), DType::QFloat(scheme) => match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S // Represent sub-byte values as i8 | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S, .. } => { // Quantized int8 values let q_bytes = QuantizedBytes { bytes: self.bytes.clone(), scheme, num_elements: self.num_elements(), }; let (values, _) = q_bytes.into_vec_i8(); Box::new( values .iter() .map(|e: &i8| e.elem::()) .collect::>() .into_iter(), ) } QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1, .. } => { unimplemented!("Not yet implemented for iteration"); } }, } } } /// Returns the rank (the number of dimensions). pub fn rank(&self) -> usize { self.shape.len() } /// Returns the total number of elements of the tensor data. pub fn num_elements(&self) -> usize { Self::numel(&self.shape) } fn numel(shape: &[usize]) -> usize { shape.iter().product() } /// Populates the data with random values. pub fn random>( shape: S, distribution: Distribution, rng: &mut R, ) -> Self { let shape = shape.into(); let num_elements = Self::numel(&shape); let mut data = Vec::with_capacity(num_elements); for _ in 0..num_elements { data.push(E::random(distribution, rng)); } TensorData::new(data, shape) } /// Populates the data with zeros. pub fn zeros>(shape: S) -> TensorData { let shape = shape.into(); let num_elements = Self::numel(&shape); let mut data = Vec::::with_capacity(num_elements); for _ in 0..num_elements { data.push(0.elem()); } TensorData::new(data, shape) } /// Populates the data with ones. pub fn ones>(shape: S) -> TensorData { let shape = shape.into(); let num_elements = Self::numel(&shape); let mut data = Vec::::with_capacity(num_elements); for _ in 0..num_elements { data.push(1.elem()); } TensorData::new(data, shape) } /// Populates the data with the given value pub fn full>(shape: S, fill_value: E) -> TensorData { let shape = shape.into(); let num_elements = Self::numel(&shape); let mut data = Vec::::with_capacity(num_elements); for _ in 0..num_elements { data.push(fill_value) } TensorData::new(data, shape) } /// Populates the data with the given value pub fn full_dtype, S: Into>( shape: S, fill_value: E, dtype: DType, ) -> TensorData { let fill_value = fill_value.into(); match dtype { DType::F64 => Self::full::(shape, fill_value.elem()), DType::F32 | DType::Flex32 => Self::full::(shape, fill_value.elem()), DType::F16 => Self::full::(shape, fill_value.elem()), DType::BF16 => Self::full::(shape, fill_value.elem()), DType::I64 => Self::full::(shape, fill_value.elem()), DType::I32 => Self::full::(shape, fill_value.elem()), DType::I16 => Self::full::(shape, fill_value.elem()), DType::I8 => Self::full::(shape, fill_value.elem()), DType::U64 => Self::full::(shape, fill_value.elem()), DType::U32 => Self::full::(shape, fill_value.elem()), DType::U16 => Self::full::(shape, fill_value.elem()), DType::U8 => Self::full::(shape, fill_value.elem()), DType::Bool(BoolStore::Native) => Self::full::(shape, fill_value.elem()), DType::Bool(BoolStore::U8) => { Self::full::(shape, fill_value.elem()).into_bool_u8() } DType::Bool(BoolStore::U32) => { Self::full::(shape, fill_value.elem()).into_bool_u32() } DType::QFloat(_) => unreachable!(), } } // Unchecked, used to overwrite the dtype fn into_bool_u8(mut self) -> Self { self.dtype = DType::Bool(BoolStore::U8); self } // Unchecked, used to overwrite the dtype fn into_bool_u32(mut self) -> Self { self.dtype = DType::Bool(BoolStore::U32); self } /// Converts the data to a different element type. pub fn convert(self) -> Self { self.convert_dtype(E::dtype()) } /// Converts the data to a different element type. pub fn convert_dtype(self, dtype: DType) -> Self { if dtype == self.dtype { self } else if dtype.size() == self.dtype.size() && !matches!( self.dtype, DType::Bool(BoolStore::Native) | DType::QFloat(_) ) && !matches!(dtype, DType::Bool(BoolStore::Native) | DType::QFloat(_)) { match self.dtype { DType::F64 => self.convert_inplace_dtype::(dtype), DType::F32 | DType::Flex32 => self.convert_inplace_dtype::(dtype), DType::F16 => self.convert_inplace_dtype::(dtype), DType::BF16 => self.convert_inplace_dtype::(dtype), DType::I64 => self.convert_inplace_dtype::(dtype), DType::I32 => self.convert_inplace_dtype::(dtype), DType::I16 => self.convert_inplace_dtype::(dtype), DType::I8 => self.convert_inplace_dtype::(dtype), DType::U64 => self.convert_inplace_dtype::(dtype), DType::U32 => self.convert_inplace_dtype::(dtype), DType::U16 => self.convert_inplace_dtype::(dtype), DType::U8 => self.convert_inplace_dtype::(dtype), DType::Bool(BoolStore::U8) => self.convert_inplace_dtype::(dtype), DType::Bool(BoolStore::U32) => self.convert_inplace_dtype::(dtype), DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(), } } else { match self.dtype { DType::F64 => self.convert_clone_dtype::(dtype), DType::F32 | DType::Flex32 => self.convert_clone_dtype::(dtype), DType::F16 => self.convert_clone_dtype::(dtype), DType::BF16 => self.convert_clone_dtype::(dtype), DType::I64 => self.convert_clone_dtype::(dtype), DType::I32 => self.convert_clone_dtype::(dtype), DType::I16 => self.convert_clone_dtype::(dtype), DType::I8 => self.convert_clone_dtype::(dtype), DType::U64 => self.convert_clone_dtype::(dtype), DType::U32 => self.convert_clone_dtype::(dtype), DType::U16 => self.convert_clone_dtype::(dtype), DType::U8 => self.convert_clone_dtype::(dtype), DType::Bool(BoolStore::Native) => self.convert_clone_dtype::(dtype), DType::Bool(BoolStore::U8) => self.convert_clone_dtype::(dtype), DType::Bool(BoolStore::U32) => self.convert_clone_dtype::(dtype), DType::QFloat(_) => unreachable!(), } } } fn convert_inplace_dtype(self, dtype: DType) -> Self { match dtype { DType::F64 => self.convert_inplace::(), DType::F32 | DType::Flex32 => self.convert_inplace::(), DType::F16 => self.convert_inplace::(), DType::BF16 => self.convert_inplace::(), DType::I64 => self.convert_inplace::(), DType::I32 => self.convert_inplace::(), DType::I16 => self.convert_inplace::(), DType::I8 => self.convert_inplace::(), DType::U64 => self.convert_inplace::(), DType::U32 => self.convert_inplace::(), DType::U16 => self.convert_inplace::(), DType::U8 => self.convert_inplace::(), DType::Bool(BoolStore::U8) => self.convert_inplace::().into_bool_u8(), DType::Bool(BoolStore::U32) => self.convert_inplace::().into_bool_u32(), DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(), } } fn convert_inplace( mut self, ) -> Self { for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) { let t: Target = x.elem(); let x = cast_mut::<_, Target>(x); *x = t; } self.dtype = Target::dtype(); self } fn convert_clone_dtype(self, dtype: DType) -> Self { match dtype { DType::F64 => self.convert_clone::(), DType::F32 | DType::Flex32 => self.convert_clone::(), DType::F16 => self.convert_clone::(), DType::BF16 => self.convert_clone::(), DType::I64 => self.convert_clone::(), DType::I32 => self.convert_clone::(), DType::I16 => self.convert_clone::(), DType::I8 => self.convert_clone::(), DType::U64 => self.convert_clone::(), DType::U32 => self.convert_clone::(), DType::U16 => self.convert_clone::(), DType::U8 => self.convert_clone::(), DType::Bool(BoolStore::Native) => self.convert_clone::(), DType::Bool(BoolStore::U8) => self.convert_clone::().into_bool_u8(), DType::Bool(BoolStore::U32) => self.convert_clone::().into_bool_u32(), DType::QFloat(_) => unreachable!(), } } fn convert_clone( self, ) -> Self { let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes); let mut out: Vec = ::alloc::vec![Zeroable::zeroed(); self.num_elements()]; for (x, out) in this.iter().zip(&mut out) { *out = x.elem(); } Self::new(out, self.shape) } /// Returns the data as a slice of bytes. pub fn as_bytes(&self) -> &[u8] { &self.bytes } /// Returns the bytes representation of the data. pub fn into_bytes(self) -> Bytes { self.bytes } } impl From<[E; A]> for TensorData { fn from(elems: [E; A]) -> Self { TensorData::new(elems.to_vec(), [A]) } } impl From<[usize; A]> for TensorData { fn from(elems: [usize; A]) -> Self { TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A]) } } impl From<&[usize]> for TensorData { fn from(elems: &[usize]) -> Self { let mut data = Vec::with_capacity(elems.len()); for elem in elems.iter() { data.push(*elem as i64); } TensorData::new(data, [elems.len()]) } } impl From<&[E]> for TensorData { fn from(elems: &[E]) -> Self { let mut data = Vec::with_capacity(elems.len()); for elem in elems.iter() { data.push(*elem); } TensorData::new(data, [elems.len()]) } } impl From<[[E; B]; A]> for TensorData { fn from(elems: [[E; B]; A]) -> Self { let mut data = Vec::with_capacity(A * B); for elem in elems.into_iter().take(A) { for elem in elem.into_iter().take(B) { data.push(elem); } } TensorData::new(data, [A, B]) } } impl From<[[[E; C]; B]; A]> for TensorData { fn from(elems: [[[E; C]; B]; A]) -> Self { let mut data = Vec::with_capacity(A * B * C); for elem in elems.into_iter().take(A) { for elem in elem.into_iter().take(B) { for elem in elem.into_iter().take(C) { data.push(elem); } } } TensorData::new(data, [A, B, C]) } } impl From<[[[[E; D]; C]; B]; A]> for TensorData { fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { let mut data = Vec::with_capacity(A * B * C * D); for elem in elems.into_iter().take(A) { for elem in elem.into_iter().take(B) { for elem in elem.into_iter().take(C) { for elem in elem.into_iter().take(D) { data.push(elem); } } } } TensorData::new(data, [A, B, C, D]) } } impl From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData { fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self { let mut data = Vec::with_capacity(A * B * C * D * E); for elem in elems.into_iter().take(A) { for elem in elem.into_iter().take(B) { for elem in elem.into_iter().take(C) { for elem in elem.into_iter().take(D) { for elem in elem.into_iter().take(E) { data.push(elem); } } } } } TensorData::new(data, [A, B, C, D, E]) } } impl core::fmt::Display for TensorData { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let fmt = match self.dtype { DType::F64 => format!("{:?}", self.as_slice::().unwrap()), DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::().unwrap()), DType::F16 => format!("{:?}", self.as_slice::().unwrap()), DType::BF16 => format!("{:?}", self.as_slice::().unwrap()), DType::I64 => format!("{:?}", self.as_slice::().unwrap()), DType::I32 => format!("{:?}", self.as_slice::().unwrap()), DType::I16 => format!("{:?}", self.as_slice::().unwrap()), DType::I8 => format!("{:?}", self.as_slice::().unwrap()), DType::U64 => format!("{:?}", self.as_slice::().unwrap()), DType::U32 => format!("{:?}", self.as_slice::().unwrap()), DType::U16 => format!("{:?}", self.as_slice::().unwrap()), DType::U8 => format!("{:?}", self.as_slice::().unwrap()), DType::Bool(BoolStore::Native) => format!("{:?}", self.as_slice::().unwrap()), DType::Bool(BoolStore::U8) => format!("{:?}", self.as_slice::().unwrap()), DType::Bool(BoolStore::U32) => format!("{:?}", self.as_slice::().unwrap()), DType::QFloat(scheme) => match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S // Display sub-byte values as i8 | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S, .. } => { format!("{:?} {scheme:?}", self.iter::().collect::>()) }, QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1, .. } => { unimplemented!("Can't format yet"); } }, }; f.write_str(fmt.as_str()) } } /// The things that can go wrong when manipulating tensor data. #[derive(Debug, Error)] pub enum DataError { /// Failed to cast the values to a specified element type. #[error("Failed to cast values to the specified element type.\nError:\n {0}")] CastError(CheckedCastError), /// Invalid target element type. #[error("{0}")] TypeMismatch(String), } #[cfg(test)] mod tests { use super::*; use alloc::vec; use burn_std::shape; use rand::{ SeedableRng, rngs::{StdRng, SysRng}, }; #[test] fn should_have_rank() { let shape = [3, 5, 6]; let data = TensorData::random::( shape, Distribution::Default, &mut StdRng::try_from_rng(&mut SysRng).unwrap(), ); assert_eq!(data.rank(), 3); } #[test] fn into_vec_should_yield_same_value_as_iter() { let shape = [3, 5, 6]; let data = TensorData::random::( shape, Distribution::Default, &mut StdRng::try_from_rng(&mut SysRng).unwrap(), ); let expected = data.iter::().collect::>(); let actual = data.into_vec::().unwrap(); assert_eq!(expected, actual); } #[test] #[should_panic] fn into_vec_should_assert_wrong_dtype() { let shape = [3, 5, 6]; let data = TensorData::random::( shape, Distribution::Default, &mut StdRng::try_from_rng(&mut SysRng).unwrap(), ); data.into_vec::().unwrap(); } #[test] fn should_have_right_num_elements() { let shape = [3, 5, 6]; let num_elements: usize = shape.iter().product(); let data = TensorData::random::( shape, Distribution::Default, &mut StdRng::try_from_rng(&mut SysRng).unwrap(), ); assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s assert_eq!(num_elements, data.as_slice::().unwrap().len()); } #[test] fn should_have_right_shape() { let data = TensorData::from([[3.0, 5.0, 6.0]]); assert_eq!(data.shape, shape![1, 3]); let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); assert_eq!(data.shape, shape![2, 3]); let data = TensorData::from([3.0, 5.0, 6.0]); assert_eq!(data.shape, shape![3]); } #[test] fn should_convert_bytes_correctly() { let mut vector: Vec = Vec::with_capacity(5); vector.push(2.0); vector.push(3.0); let data1 = TensorData::new(vector, vec![2]); let factor = core::mem::size_of::() / core::mem::size_of::(); assert_eq!(data1.bytes.len(), 2 * factor); assert_eq!(data1.bytes.capacity(), 5 * factor); } #[test] fn should_convert_bytes_correctly_inplace() { fn test_precision() { let data = TensorData::new((0..32).collect(), [32]); for (i, val) in data .clone() .convert::() .into_vec::() .unwrap() .into_iter() .enumerate() { assert_eq!(i as u32, val.elem::()) } } test_precision::(); test_precision::(); test_precision::(); test_precision::(); } macro_rules! test_dtypes { ($test_name:ident, $($dtype:ty),*) => { $( paste::paste! { #[test] fn [<$test_name _ $dtype:snake>]() { let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype()); let full = TensorData::full::<$dtype, _>([2, 16], 4.elem()); assert_eq!(full_dtype, full); } } )* }; } test_dtypes!( should_create_with_dtype, bool, i8, i16, i32, i64, u8, u16, u32, u64, f16, bf16, f32, f64 ); #[test] fn should_serialize_deserialize_tensor_data() { let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]); assert_eq!( data.as_bytes(), [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64 ] ); let serialized = serde_json::to_string(&data).unwrap(); let deserialized: TensorData = serde_json::from_str(&serialized).unwrap(); assert_eq!(data, deserialized); } #[test] fn should_deserialize_tensor_data_with_shape_inner() { // TensorData `shape` was previously a Vec. let serialized = r#"{ "bytes": [0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64], "shape": [2, 3], "dtype": "F32" }"#; let data: TensorData = serde_json::from_str(serialized).unwrap(); assert_eq!(data.shape, shape![2, 3]); assert_eq!( data.as_slice::().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0] ); } #[test] fn should_serialize_shape_as_flat_array() { // Ensure the new Shape serializes identically to how Vec used to, // i.e. as a flat JSON array, not as an object like `{"dims": [2, 3]}`. let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]); let serialized = serde_json::to_string(&data).unwrap(); let json: serde_json::Value = serde_json::from_str(&serialized).unwrap(); assert_eq!(json["shape"], serde_json::json!([2, 3])); } } ================================================ FILE: crates/burn-backend/src/distribution.rs ================================================ //! Random value distributions used to initialize and populate tensor data. use rand::{Rng, RngExt, distr::StandardUniform}; use super::element::{Element, ElementConversion}; /// Distribution for random value of a tensor. #[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] pub enum Distribution { /// Uniform distribution from 0 (inclusive) to 1 (exclusive). #[default] Default, /// Bernoulli distribution with the given probability. Bernoulli(f64), /// Uniform distribution `[low, high)`. Uniform(f64, f64), /// Normal distribution with the given mean and standard deviation. Normal(f64, f64), } /// Distribution sampler for random value of a tensor. #[derive(new)] pub struct DistributionSampler<'a, E, R> where StandardUniform: rand::distr::Distribution, E: rand::distr::uniform::SampleUniform, R: Rng, { kind: DistributionSamplerKind, rng: &'a mut R, } /// Distribution sampler kind for random value of a tensor. pub enum DistributionSamplerKind where StandardUniform: rand::distr::Distribution, E: rand::distr::uniform::SampleUniform, { /// Standard distribution. Standard(rand::distr::StandardUniform), /// Uniform distribution. Uniform(rand::distr::Uniform), /// Bernoulli distribution. Bernoulli(rand::distr::Bernoulli), /// Normal distribution. Normal(rand_distr::Normal), } impl DistributionSampler<'_, E, R> where StandardUniform: rand::distr::Distribution, E: rand::distr::uniform::SampleUniform, E: Element, R: Rng, { /// Sames a random value from the distribution. pub fn sample(&mut self) -> E { match &self.kind { DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), DistributionSamplerKind::Bernoulli(distribution) => { if self.rng.sample(distribution) { 1.elem() } else { 0.elem() } } DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), } } } impl Distribution { /// Creates a new distribution sampler. /// /// # Arguments /// /// * `rng` - The random number generator. /// /// # Returns /// /// The distribution sampler. pub fn sampler(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> where R: Rng, E: Element + rand::distr::uniform::SampleUniform, StandardUniform: rand::distr::Distribution, { let kind = match self { Distribution::Default => { DistributionSamplerKind::Standard(rand::distr::StandardUniform {}) } Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform( rand::distr::Uniform::new(low.elem::(), high.elem::()).unwrap(), ), Distribution::Bernoulli(prob) => { DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap()) } Distribution::Normal(mean, std) => { DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) } }; DistributionSampler::new(kind, rng) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_distribution_default() { let dist: Distribution = Default::default(); assert_eq!(dist, Distribution::Default); assert_eq!(Distribution::default(), Distribution::Default); } } ================================================ FILE: crates/burn-backend/src/element/base.rs ================================================ use core::cmp::Ordering; use rand::Rng; use crate::distribution::Distribution; use burn_std::{BoolStore, DType, bf16, f16}; #[cfg(feature = "cubecl")] use burn_std::flex32; use super::cast::ToElement; /// Core element trait for tensor values. /// /// This trait defines the minimal set of capabilities required for a type to be /// stored and manipulated as a tensor element across all backends. pub trait Element: ToElement + ElementRandom + ElementConversion + ElementEq + ElementLimits + bytemuck::CheckedBitPattern + bytemuck::NoUninit + bytemuck::Zeroable + core::fmt::Debug + core::fmt::Display + Default + Send + Sync + Copy + 'static { /// The dtype of the element. fn dtype() -> DType; } /// Ordered element trait for tensor values. /// /// This trait extends [`Element`] with ordering semantics, enabling comparison /// and order-dependent operations in generic Rust implementations. /// /// Backends that implement these operations entirely at the device level do /// not rely on this trait. It only constrains the scalar type for generic Rust code. pub trait ElementOrdered: Element + ElementComparison {} /// Element conversion trait for tensor. pub trait ElementConversion { /// Converts an element to another element. /// /// # Arguments /// /// * `elem` - The element to convert. /// /// # Returns /// /// The converted element. fn from_elem(elem: E) -> Self; /// Converts and returns the converted element. fn elem(self) -> E; } /// Element trait for random value of a tensor. pub trait ElementRandom { /// Returns a random value for the given distribution. /// /// # Arguments /// /// * `distribution` - The distribution to sample from. /// * `rng` - The random number generator. /// /// # Returns /// /// The random value. fn random(distribution: Distribution, rng: &mut R) -> Self; } /// Element trait for equality of a tensor. pub trait ElementEq { /// Returns whether `self` and `other` are equal. fn eq(&self, other: &Self) -> bool; } /// Element ordering trait. pub trait ElementComparison { /// Returns and [Ordering] between `self` and `other`. fn cmp(&self, other: &Self) -> Ordering; } /// Element limits trait. pub trait ElementLimits { /// The minimum representable value const MIN: Self; /// The maximum representable value const MAX: Self; } /// Macro to implement the element trait for a type. #[macro_export] macro_rules! make_element { ( ty $type:ident, convert $convert:expr, random $random:expr, cmp $cmp:expr, dtype $dtype:expr ) => { make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX); }; ( ty $type:ident, convert $convert:expr, random $random:expr, cmp $cmp:expr, dtype $dtype:expr, min $min:expr, max $max:expr ) => { impl Element for $type { #[inline(always)] fn dtype() -> burn_std::DType { $dtype } } impl ElementEq for $type { fn eq(&self, other: &Self) -> bool { self == other } } impl ElementConversion for $type { #[inline(always)] fn from_elem(elem: E) -> Self { #[allow(clippy::redundant_closure_call)] $convert(&elem) } #[inline(always)] fn elem(self) -> E { E::from_elem(self) } } impl ElementRandom for $type { fn random(distribution: Distribution, rng: &mut R) -> Self { #[allow(clippy::redundant_closure_call)] $random(distribution, rng) } } impl ElementComparison for $type { fn cmp(&self, other: &Self) -> Ordering { let a = self.elem::<$type>(); let b = other.elem::<$type>(); #[allow(clippy::redundant_closure_call)] $cmp(&a, &b) } } impl ElementLimits for $type { const MIN: Self = $min; const MAX: Self = $max; } impl ElementOrdered for $type {} }; } make_element!( ty f64, convert ToElement::to_f64, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &f64, b: &f64| a.total_cmp(b), dtype DType::F64 ); make_element!( ty f32, convert ToElement::to_f32, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &f32, b: &f32| a.total_cmp(b), dtype DType::F32 ); make_element!( ty i64, convert ToElement::to_i64, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i64, b: &i64| Ord::cmp(a, b), dtype DType::I64 ); make_element!( ty u64, convert ToElement::to_u64, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u64, b: &u64| Ord::cmp(a, b), dtype DType::U64 ); make_element!( ty i32, convert ToElement::to_i32, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i32, b: &i32| Ord::cmp(a, b), dtype DType::I32 ); make_element!( ty u32, convert ToElement::to_u32, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u32, b: &u32| Ord::cmp(a, b), dtype DType::U32 ); make_element!( ty i16, convert ToElement::to_i16, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i16, b: &i16| Ord::cmp(a, b), dtype DType::I16 ); make_element!( ty u16, convert ToElement::to_u16, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u16, b: &u16| Ord::cmp(a, b), dtype DType::U16 ); make_element!( ty i8, convert ToElement::to_i8, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i8, b: &i8| Ord::cmp(a, b), dtype DType::I8 ); make_element!( ty u8, convert ToElement::to_u8, random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u8, b: &u8| Ord::cmp(a, b), dtype DType::U8 ); make_element!( ty f16, convert ToElement::to_f16, random |distribution: Distribution, rng: &mut R| { let sample: f32 = distribution.sampler(rng).sample(); f16::from_elem(sample) }, cmp |a: &f16, b: &f16| a.total_cmp(b), dtype DType::F16 ); make_element!( ty bf16, convert ToElement::to_bf16, random |distribution: Distribution, rng: &mut R| { let sample: f32 = distribution.sampler(rng).sample(); bf16::from_elem(sample) }, cmp |a: &bf16, b: &bf16| a.total_cmp(b), dtype DType::BF16 ); #[cfg(feature = "cubecl")] make_element!( ty flex32, convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()), random |distribution: Distribution, rng: &mut R| { let sample: f32 = distribution.sampler(rng).sample(); flex32::from_elem(sample) }, cmp |a: &flex32, b: &flex32| a.total_cmp(b), dtype DType::Flex32, min flex32::from_f32(f16::MIN.to_f32_const()), max flex32::from_f32(f16::MAX.to_f32_const()) ); make_element!( ty bool, convert ToElement::to_bool, random |distribution: Distribution, rng: &mut R| { let sample: u8 = distribution.sampler(rng).sample(); bool::from_elem(sample) }, cmp |a: &bool, b: &bool| Ord::cmp(a, b), dtype DType::Bool(BoolStore::Native), min false, max true ); ================================================ FILE: crates/burn-backend/src/element/cast.rs ================================================ use core::mem::size_of; use burn_std::{bf16, f16}; /// A generic trait for converting a value to a number. /// Adapted from num_traits::ToPrimitive to support [bool]. /// /// A value can be represented by the target type when it lies within /// the range of scalars supported by the target type. /// For example, a negative integer cannot be represented by an unsigned /// integer type, and an `i64` with a very high magnitude might not be /// convertible to an `i32`. /// On the other hand, conversions with possible precision loss or truncation /// are admitted, like an `f32` with a decimal part to an integer type, or /// even a large `f64` saturating to `f32` infinity. /// /// The methods *panic* when the value cannot be represented by the target type. pub trait ToElement { /// Converts the value of `self` to an `isize`. #[inline] fn to_isize(&self) -> isize { ToElement::to_isize(&self.to_i64()) } /// Converts the value of `self` to an `i8`. #[inline] fn to_i8(&self) -> i8 { ToElement::to_i8(&self.to_i64()) } /// Converts the value of `self` to an `i16`. #[inline] fn to_i16(&self) -> i16 { ToElement::to_i16(&self.to_i64()) } /// Converts the value of `self` to an `i32`. #[inline] fn to_i32(&self) -> i32 { ToElement::to_i32(&self.to_i64()) } /// Converts the value of `self` to an `i64`. fn to_i64(&self) -> i64; /// Converts the value of `self` to an `i128`. /// /// The default implementation converts through `to_i64()`. Types implementing /// this trait should override this method if they can represent a greater range. #[inline] fn to_i128(&self) -> i128 { i128::from(self.to_i64()) } /// Converts the value of `self` to a `usize`. #[inline] fn to_usize(&self) -> usize { ToElement::to_usize(&self.to_u64()) } /// Converts the value of `self` to a `u8`. #[inline] fn to_u8(&self) -> u8 { ToElement::to_u8(&self.to_u64()) } /// Converts the value of `self` to a `u16`. #[inline] fn to_u16(&self) -> u16 { ToElement::to_u16(&self.to_u64()) } /// Converts the value of `self` to a `u32`. #[inline] fn to_u32(&self) -> u32 { ToElement::to_u32(&self.to_u64()) } /// Converts the value of `self` to a `u64`. fn to_u64(&self) -> u64; /// Converts the value of `self` to a `u128`. /// /// The default implementation converts through `to_u64()`. Types implementing /// this trait should override this method if they can represent a greater range. #[inline] fn to_u128(&self) -> u128 { u128::from(self.to_u64()) } /// Converts the value of `self` to an `f16`. Overflows may map to positive /// or negative infinity. #[inline] fn to_f16(&self) -> f16 { f16::from_f32(self.to_f32()) } /// Converts the value of `self` to an `bf16`. Overflows may map to positive /// or negative infinity. #[inline] fn to_bf16(&self) -> bf16 { bf16::from_f32(self.to_f32()) } /// Converts the value of `self` to an `f32`. Overflows may map to positive /// or negative infinity. #[inline] fn to_f32(&self) -> f32 { ToElement::to_f32(&self.to_f64()) } /// Converts the value of `self` to an `f64`. Overflows may map to positive /// or negative infinity. /// /// The default implementation tries to convert through `to_i64()`, and /// failing that through `to_u64()`. Types implementing this trait should /// override this method if they can represent a greater range. #[inline] fn to_f64(&self) -> f64 { ToElement::to_f64(&self.to_u64()) } /// Converts the value of `self` to a bool. /// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are /// adopted (anything that's not 0 is true). /// /// The default implementation tries to convert through `to_i64()`, and /// failing that through `to_u64()`. Types implementing this trait should /// override this method if they can represent a greater range. #[inline] fn to_bool(&self) -> bool { ToElement::to_bool(&self.to_u64()) } } macro_rules! impl_to_element_int_to_int { ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $DstT { let min = $DstT::MIN as $SrcT; let max = $DstT::MAX as $SrcT; if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) { *self as $DstT } else { panic!( "Element cannot be represented in the target type: {:?}({:?}) => {:?}", core::any::type_name::<$SrcT>(), self, core::any::type_name::<$DstT>(), ) } } )*} } macro_rules! impl_to_element_int_to_uint { ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $DstT { let max = $DstT::MAX as $SrcT; if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) { *self as $DstT } else { panic!( "Element cannot be represented in the target type: {:?}({:?}) => {:?}", core::any::type_name::<$SrcT>(), self, core::any::type_name::<$DstT>(), ) } } )*} } macro_rules! impl_to_element_int { ($T:ident) => { impl ToElement for $T { impl_to_element_int_to_int! { $T: fn to_isize -> isize; fn to_i8 -> i8; fn to_i16 -> i16; fn to_i32 -> i32; fn to_i64 -> i64; fn to_i128 -> i128; } impl_to_element_int_to_uint! { $T: fn to_usize -> usize; fn to_u8 -> u8; fn to_u16 -> u16; fn to_u32 -> u32; fn to_u64 -> u64; fn to_u128 -> u128; } #[inline] fn to_f32(&self) -> f32 { *self as f32 } #[inline] fn to_f64(&self) -> f64 { *self as f64 } #[inline] fn to_bool(&self) -> bool { *self != 0 } } }; } impl_to_element_int!(isize); impl_to_element_int!(i8); impl_to_element_int!(i16); impl_to_element_int!(i32); impl_to_element_int!(i64); impl_to_element_int!(i128); macro_rules! impl_to_element_uint_to_int { ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $DstT { let max = $DstT::MAX as $SrcT; if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max { *self as $DstT } else { panic!( "Element cannot be represented in the target type: {:?}({:?}) => {:?}", core::any::type_name::<$SrcT>(), self, core::any::type_name::<$DstT>(), ) } } )*} } macro_rules! impl_to_element_uint_to_uint { ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $DstT { let max = $DstT::MAX as $SrcT; if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max { *self as $DstT } else { panic!( "Element cannot be represented in the target type: {:?}({:?}) => {:?}", core::any::type_name::<$SrcT>(), self, core::any::type_name::<$DstT>(), ) } } )*} } macro_rules! impl_to_element_uint { ($T:ident) => { impl ToElement for $T { impl_to_element_uint_to_int! { $T: fn to_isize -> isize; fn to_i8 -> i8; fn to_i16 -> i16; fn to_i32 -> i32; fn to_i64 -> i64; fn to_i128 -> i128; } impl_to_element_uint_to_uint! { $T: fn to_usize -> usize; fn to_u8 -> u8; fn to_u16 -> u16; fn to_u32 -> u32; fn to_u64 -> u64; fn to_u128 -> u128; } #[inline] fn to_f32(&self) -> f32 { *self as f32 } #[inline] fn to_f64(&self) -> f64 { *self as f64 } #[inline] fn to_bool(&self) -> bool { *self != 0 } } }; } impl_to_element_uint!(usize); impl_to_element_uint!(u8); impl_to_element_uint!(u16); impl_to_element_uint!(u32); impl_to_element_uint!(u64); impl_to_element_uint!(u128); macro_rules! impl_to_element_float_to_float { ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$( #[inline] fn $method(&self) -> $DstT { // We can safely cast all values, whether NaN, +-inf, or finite. // Finite values that are reducing size may saturate to +-inf. *self as $DstT } )*} } macro_rules! float_to_int_unchecked { // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating. // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`. ($float:expr => $int:ty) => { unsafe { $float.to_int_unchecked::<$int>() } }; } macro_rules! impl_to_element_float_to_signed_int { ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $i { // Float as int truncates toward zero, so we want to allow values // in the exclusive range `(MIN-1, MAX+1)`. if size_of::<$f>() > size_of::<$i>() { // With a larger size, we can represent the range exactly. const MIN_M1: $f = $i::MIN as $f - 1.0; const MAX_P1: $f = $i::MAX as $f + 1.0; if *self > MIN_M1 && *self < MAX_P1 { return float_to_int_unchecked!(*self => $i); } } else { // We can't represent `MIN-1` exactly, but there's no fractional part // at this magnitude, so we can just use a `MIN` inclusive boundary. const MIN: $f = $i::MIN as $f; // We can't represent `MAX` exactly, but it will round up to exactly // `MAX+1` (a power of two) when we cast it. const MAX_P1: $f = $i::MAX as $f; if *self >= MIN && *self < MAX_P1 { return float_to_int_unchecked!(*self => $i); } } panic!("Float cannot be represented in the target signed int type") } )*} } macro_rules! impl_to_element_float_to_unsigned_int { ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$( #[inline] $(#[$cfg])* fn $method(&self) -> $u { // Float as int truncates toward zero, so we want to allow values // in the exclusive range `(-1, MAX+1)`. if size_of::<$f>() > size_of::<$u>() { // With a larger size, we can represent the range exactly. const MAX_P1: $f = $u::MAX as $f + 1.0; if *self > -1.0 && *self < MAX_P1 { return float_to_int_unchecked!(*self => $u); } } else { // We can't represent `MAX` exactly, but it will round up to exactly // `MAX+1` (a power of two) when we cast it. // (`u128::MAX as f32` is infinity, but this is still ok.) const MAX_P1: $f = $u::MAX as $f; if *self > -1.0 && *self < MAX_P1 { return float_to_int_unchecked!(*self => $u); } } panic!("Float cannot be represented in the target unsigned int type") } )*} } macro_rules! impl_to_element_float { ($T:ident) => { impl ToElement for $T { impl_to_element_float_to_signed_int! { $T: fn to_isize -> isize; fn to_i8 -> i8; fn to_i16 -> i16; fn to_i32 -> i32; fn to_i64 -> i64; fn to_i128 -> i128; } impl_to_element_float_to_unsigned_int! { $T: fn to_usize -> usize; fn to_u8 -> u8; fn to_u16 -> u16; fn to_u32 -> u32; fn to_u64 -> u64; fn to_u128 -> u128; } impl_to_element_float_to_float! { $T: fn to_f32 -> f32; fn to_f64 -> f64; } #[inline] fn to_bool(&self) -> bool { *self != 0.0 } } }; } impl_to_element_float!(f32); impl_to_element_float!(f64); impl ToElement for f16 { #[inline] fn to_i64(&self) -> i64 { Self::to_f32(*self).to_i64() } #[inline] fn to_u64(&self) -> u64 { Self::to_f32(*self).to_u64() } #[inline] fn to_i8(&self) -> i8 { Self::to_f32(*self).to_i8() } #[inline] fn to_u8(&self) -> u8 { Self::to_f32(*self).to_u8() } #[inline] fn to_i16(&self) -> i16 { Self::to_f32(*self).to_i16() } #[inline] fn to_u16(&self) -> u16 { Self::to_f32(*self).to_u16() } #[inline] fn to_i32(&self) -> i32 { Self::to_f32(*self).to_i32() } #[inline] fn to_u32(&self) -> u32 { Self::to_f32(*self).to_u32() } #[inline] fn to_f16(&self) -> f16 { *self } #[inline] fn to_f32(&self) -> f32 { Self::to_f32(*self) } #[inline] fn to_f64(&self) -> f64 { Self::to_f64(*self) } #[inline] fn to_bool(&self) -> bool { *self != f16::from_f32_const(0.0) } } impl ToElement for bf16 { #[inline] fn to_i64(&self) -> i64 { Self::to_f32(*self).to_i64() } #[inline] fn to_u64(&self) -> u64 { Self::to_f32(*self).to_u64() } #[inline] fn to_i8(&self) -> i8 { Self::to_f32(*self).to_i8() } #[inline] fn to_u8(&self) -> u8 { Self::to_f32(*self).to_u8() } #[inline] fn to_i16(&self) -> i16 { Self::to_f32(*self).to_i16() } #[inline] fn to_u16(&self) -> u16 { Self::to_f32(*self).to_u16() } #[inline] fn to_i32(&self) -> i32 { Self::to_f32(*self).to_i32() } #[inline] fn to_u32(&self) -> u32 { Self::to_f32(*self).to_u32() } #[inline] fn to_bf16(&self) -> bf16 { *self } #[inline] fn to_f32(&self) -> f32 { Self::to_f32(*self) } #[inline] fn to_f64(&self) -> f64 { Self::to_f64(*self) } #[inline] fn to_bool(&self) -> bool { *self != bf16::from_f32_const(0.0) } } #[cfg(feature = "cubecl")] impl ToElement for burn_std::flex32 { #[inline] fn to_i64(&self) -> i64 { Self::to_f32(*self).to_i64() } #[inline] fn to_u64(&self) -> u64 { Self::to_f32(*self).to_u64() } #[inline] fn to_i8(&self) -> i8 { Self::to_f32(*self).to_i8() } #[inline] fn to_u8(&self) -> u8 { Self::to_f32(*self).to_u8() } #[inline] fn to_i16(&self) -> i16 { Self::to_f32(*self).to_i16() } #[inline] fn to_u16(&self) -> u16 { Self::to_f32(*self).to_u16() } #[inline] fn to_i32(&self) -> i32 { Self::to_f32(*self).to_i32() } #[inline] fn to_u32(&self) -> u32 { Self::to_f32(*self).to_u32() } #[inline] fn to_f32(&self) -> f32 { Self::to_f32(*self) } #[inline] fn to_f64(&self) -> f64 { Self::to_f64(*self) } #[inline] fn to_bool(&self) -> bool { *self != burn_std::flex32::from_f32(0.0) } } impl ToElement for bool { #[inline] fn to_i64(&self) -> i64 { *self as i64 } #[inline] fn to_u64(&self) -> u64 { *self as u64 } #[inline] fn to_i8(&self) -> i8 { *self as i8 } #[inline] fn to_u8(&self) -> u8 { *self as u8 } #[inline] fn to_i16(&self) -> i16 { *self as i16 } #[inline] fn to_u16(&self) -> u16 { *self as u16 } #[inline] fn to_i32(&self) -> i32 { *self as i32 } #[inline] fn to_u32(&self) -> u32 { *self as u32 } #[inline] fn to_f32(&self) -> f32 { self.to_u8() as f32 } #[inline] fn to_f64(&self) -> f64 { self.to_u8() as f64 } #[inline] fn to_bool(&self) -> bool { *self } } mod tests { #[allow(unused_imports)] use super::*; #[test] fn to_element_float() { let f32_toolarge = 1e39f64; assert_eq!(f32_toolarge.to_f32(), f32::INFINITY); assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY); assert_eq!((f32::MAX as f64).to_f32(), f32::MAX); assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX); assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY); assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY); assert!((f64::NAN).to_f32().is_nan()); } #[test] #[should_panic] fn to_element_signed_to_u8_underflow() { let _x = (-1i8).to_u8(); } #[test] #[should_panic] fn to_element_signed_to_u16_underflow() { let _x = (-1i8).to_u16(); } #[test] #[should_panic] fn to_element_signed_to_u32_underflow() { let _x = (-1i8).to_u32(); } #[test] #[should_panic] fn to_element_signed_to_u64_underflow() { let _x = (-1i8).to_u64(); } #[test] #[should_panic] fn to_element_signed_to_u128_underflow() { let _x = (-1i8).to_u128(); } #[test] #[should_panic] fn to_element_signed_to_usize_underflow() { let _x = (-1i8).to_usize(); } #[test] #[should_panic] fn to_element_unsigned_to_u8_overflow() { let _x = 256.to_u8(); } #[test] #[should_panic] fn to_element_unsigned_to_u16_overflow() { let _x = 65_536.to_u16(); } #[test] #[should_panic] fn to_element_unsigned_to_u32_overflow() { let _x = 4_294_967_296u64.to_u32(); } #[test] #[should_panic] fn to_element_unsigned_to_u64_overflow() { let _x = 18_446_744_073_709_551_616u128.to_u64(); } #[test] fn to_element_int_to_float() { assert_eq!((-1).to_f32(), -1.0); assert_eq!((-1).to_f64(), -1.0); assert_eq!(255.to_f32(), 255.0); assert_eq!(65_535.to_f64(), 65_535.0); } #[test] fn to_element_float_to_int() { assert_eq!((-1.0).to_i8(), -1); assert_eq!(1.0.to_u8(), 1); assert_eq!(1.8.to_u16(), 1); assert_eq!(123.456.to_u32(), 123); } } ================================================ FILE: crates/burn-backend/src/element/mod.rs ================================================ //! Traits and helpers for working with element types and conversions. mod base; mod scalar; /// Tensor element casting. pub mod cast; pub use base::*; pub use scalar::*; ================================================ FILE: crates/burn-backend/src/element/scalar.rs ================================================ use burn_std::{DType, bf16, f16}; use num_traits::ToPrimitive; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use crate::{Element, ElementConversion}; /// A scalar element. #[derive(Clone, Copy, Debug)] #[allow(missing_docs)] pub enum Scalar { Float(f64), Int(i64), UInt(u64), Bool(bool), } impl Scalar { /// Creates a scalar with the specified data type. /// /// # Note /// [`QFloat`](DType::QFloat) scalars are represented as float for element-wise operations. pub fn new(value: E, dtype: &DType) -> Self { if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) { Self::Float(value.elem()) } else if dtype.is_int() { Self::Int(value.elem()) } else if dtype.is_uint() { Self::UInt(value.elem()) } else if dtype.is_bool() { Self::Bool(value.elem()) } else { unimplemented!("Scalar not supported for {dtype:?}") } } /// Converts and returns the converted element. pub fn elem(self) -> E { match self { Self::Float(x) => x.elem(), Self::Int(x) => x.elem(), Self::UInt(x) => x.elem(), Self::Bool(x) => x.elem(), } } /// Returns the exact integer value, if valid. pub fn try_as_integer(&self) -> Option { match self { Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())), Scalar::Int(_) | Scalar::UInt(_) => Some(*self), Scalar::Bool(x) => Some(Scalar::Int(*x as i64)), } } } macro_rules! impl_from_scalar { ($($ty:ty => $variant:ident),+ $(,)?) => { $( impl From<$ty> for Scalar { fn from(value: $ty) -> Self { Scalar::$variant(value.elem()) } } )+ }; } impl_from_scalar! { f64 => Float, f32 => Float, f16 => Float, bf16 => Float, i64 => Int, i32 => Int, i16 => Int, i8 => Int, u64 => UInt, u32 => UInt, u16 => UInt, u8 => UInt, bool => Bool, } // CubeCL requirement impl ToPrimitive for Scalar { fn to_i64(&self) -> Option { match self { Scalar::Float(x) => x.to_i64(), Scalar::UInt(x) => x.to_i64(), Scalar::Int(x) => Some(*x), Scalar::Bool(x) => Some(*x as i64), } } fn to_u64(&self) -> Option { match self { Scalar::Float(x) => x.to_u64(), Scalar::UInt(x) => Some(*x), Scalar::Int(x) => x.to_u64(), Scalar::Bool(x) => Some(*x as u64), } } fn to_f64(&self) -> Option { match self { Scalar::Float(x) => Some(*x), Scalar::UInt(x) => x.to_f64(), Scalar::Int(x) => x.to_f64(), Scalar::Bool(x) => (*x as u8).to_f64(), } } } ================================================ FILE: crates/burn-backend/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! This library provides the core types that define how Burn tensor data is represented, stored, and interpreted. #[macro_use] extern crate derive_new; extern crate alloc; mod data; pub use data::*; pub mod distribution; pub use distribution::*; pub mod element; pub use element::*; /// [`Backend`] trait and required types. pub mod backend; pub use backend::*; /// Backend tensor primitives and operations. pub mod tensor; // Re-exported types pub use burn_std::reader::*; // Useful so that backends don't have to add `burn_std` as a dependency. pub use burn_std::{ AllocationProperty, BoolDType, BoolStore, Bytes, DType, DeviceHandle, FloatDType, IntDType, bf16, f16, stream_id::StreamId, }; /// Shape definition. pub mod shape { pub use burn_std::shape::*; } pub use shape::*; /// Slice utilities. pub mod slice { pub use burn_std::{s, slice::*}; } pub use slice::*; /// Indexing utilities. pub mod indexing { pub use burn_std::indexing::*; } pub use indexing::*; /// Quantization data representation. pub mod quantization { pub use crate::tensor::quantization::*; pub use burn_std::quantization::{ BlockSize, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantStore, QuantValue, QuantizedBytes, }; } #[cfg(feature = "cubecl-wgpu")] mod cube_wgpu { use crate::backend::DeviceOps; use cubecl::wgpu::WgpuDevice; impl DeviceOps for WgpuDevice {} } #[cfg(feature = "cubecl-cuda")] mod cube_cuda { use crate::backend::DeviceOps; use cubecl::cuda::CudaDevice; impl DeviceOps for CudaDevice {} } #[cfg(feature = "cubecl-cpu")] mod cube_cpu { use crate::backend::DeviceOps; use cubecl::cpu::CpuDevice; impl DeviceOps for CpuDevice {} } #[cfg(feature = "cubecl-hip")] mod cube_hip { use crate::backend::DeviceOps; use cubecl::hip::AmdDevice; impl DeviceOps for AmdDevice {} } /// Convenience macro to link to the `burn-tensor` docs for this crate version. /// /// Usage: /// ```rust,ignore /// # use burn_backend::doc_tensor; /// doc_tensor!(); // Links to `Tensor` struct /// doc_tensor!("zeros"); // Links to `Tensor::zeros` method /// ``` #[macro_export] macro_rules! doc_tensor { () => { concat!( "[`Tensor`](https://docs.rs/burn-tensor/", env!("CARGO_PKG_VERSION"), "/burn_tensor/struct.Tensor.html)" ) }; ($method:literal) => { concat!( "[`Tensor::", $method, "`](", "https://docs.rs/burn-tensor/", env!("CARGO_PKG_VERSION"), "/burn_tensor/struct.Tensor.html#method.", $method, ")" ) }; } ================================================ FILE: crates/burn-backend/src/tensor/alias.rs ================================================ use crate::backend::Backend; // We provide some type aliases to improve the readability of using associated types without // having to use the disambiguation syntax. /// Device type used by the backend. pub type Device = ::Device; /// Float element type used by backend. pub type FloatElem = ::FloatElem; /// Integer element type used by backend. pub type IntElem = ::IntElem; /// Boolean element type used by backend. pub type BoolElem = ::BoolElem; /// Float tensor primitive type used by the backend. pub type FloatTensor = ::FloatTensorPrimitive; /// Integer tensor primitive type used by the backend. pub type IntTensor = ::IntTensorPrimitive; /// Boolean tensor primitive type used by the backend. pub type BoolTensor = ::BoolTensorPrimitive; /// Quantized tensor primitive type used by the backend. pub type QuantizedTensor = ::QuantizedTensorPrimitive; ================================================ FILE: crates/burn-backend/src/tensor/container.rs ================================================ use alloc::boxed::Box; use core::any::Any; #[cfg(not(feature = "std"))] use alloc::vec::Vec; #[cfg(not(feature = "std"))] use hashbrown::HashMap; #[cfg(feature = "std")] use std::collections::HashMap; use crate::{TensorPrimitive, backend::Backend}; /// Contains tensor of arbitrary dimension. #[derive(Debug)] pub struct TensorContainer { tensors: HashMap>, } impl Default for TensorContainer where ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { fn default() -> Self { Self::new() } } impl TensorContainer where ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug, { /// Create an empty container. pub fn new() -> Self { Self { tensors: HashMap::new(), } } /// Get a tensor with the given ID. pub fn get(&self, id: &ID) -> Option> where B: Backend, { let grad = self.tensors.get(id)?; let tensor = grad .downcast_ref::>() // .map(|primitive| Tensor::::from_primitive(primitive.clone())) .unwrap(); Some(tensor.clone()) } /// Register a new tensor for the given ID. /// /// # Notes /// /// If a tensor is already registered for the given ID, it will be replaced. pub fn register(&mut self, id: ID, value: TensorPrimitive) where B: Backend, { self.tensors.insert(id, Box::new(value)); } /// Remove a tensor for the given ID and returns it. pub fn remove(&mut self, id: &ID) -> Option> where B: Backend, { self.tensors .remove(id) .map(|item| *item.downcast::>().unwrap()) // .map(|primitive| Tensor::from_primitive(*primitive)) } /// The number of tensors registered. pub fn len(&self) -> usize { self.tensors.len() } /// If any tensor is contained. pub fn is_empty(&self) -> bool { self.len() == 0 } /// Get id of every tensor in the container pub fn ids(&self) -> Vec<&ID> { self.tensors.keys().collect() } } ================================================ FILE: crates/burn-backend/src/tensor/kind.rs ================================================ use crate::{Backend, TensorMetadata, TensorPrimitive}; /// A type-level representation of the kind of a float tensor #[derive(Clone, Debug)] pub struct Float; /// A type-level representation of the kind of a int tensor. #[derive(Clone, Debug)] pub struct Int; /// A type-level representation of the kind of a bool tensor. #[derive(Clone, Debug)] pub struct Bool; /// A type-level representation of the kind of a tensor. /// Metadata access is lazy. pub trait TensorKind: Clone + core::fmt::Debug { /// The primitive type of the tensor. type Primitive: TensorMetadata; /// The name of the tensor kind. fn name() -> &'static str; } impl TensorKind for Float { type Primitive = TensorPrimitive; fn name() -> &'static str { "Float" } } impl TensorKind for Int { type Primitive = B::IntTensorPrimitive; fn name() -> &'static str { "Int" } } impl TensorKind for Bool { type Primitive = B::BoolTensorPrimitive; fn name() -> &'static str { "Bool" } } ================================================ FILE: crates/burn-backend/src/tensor/mod.rs ================================================ mod alias; mod container; mod kind; mod ops; pub use alias::*; pub use container::*; pub use kind::*; pub use ops::*; /// Tensor quantization module. pub mod quantization; ================================================ FILE: crates/burn-backend/src/tensor/ops/autodiff.rs ================================================ use crate::{ AutodiffBackend, tensor::{BasicOps, TensorKind}, }; /// Trait that list all operations that can be applied on all tensors on an autodiff backend. /// /// # Warnings /// /// This is an internal trait, use the public API provided by the #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// struct. pub trait BasicAutodiffOps: BasicOps + BasicOps { /// Inner primitive tensor. type InnerKind: BasicOps; /// Returns the inner tensor without the autodiff information. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("inner"))] #[cfg_attr(not(doc), doc = "`Tensor::inner`")] /// function, which is more high-level and designed for public use. fn inner( tensor: >::Primitive, ) -> >::Primitive; /// Convert a tensor to the autodiff backend. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("from_inner"))] #[cfg_attr(not(doc), doc = "`Tensor::from_inner`")] /// function, which is more high-level and designed for public use. fn from_inner( inner: >::Primitive, ) -> >::Primitive; } ================================================ FILE: crates/burn-backend/src/tensor/ops/base.rs ================================================ use alloc::vec::Vec; use burn_std::{DType, Shape, Slice}; use crate::{ Backend, ExecutionError, Scalar, TensorData, TensorMetadata, element::Element, ops::TransactionPrimitive, tensor::{IndexingUpdateOp, IntTensor, TensorKind}, }; /// Trait that list all operations that can be applied on all tensors. /// /// # Warnings /// /// This is an internal trait, use the public API provided by the #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// struct. pub trait BasicOps: TensorKind { /// The type of the tensor elements. type Elem: Element; /// Creates an empty tensor with the given shape. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device on which the tensor will be allocated. /// * `dtype` - The target data type. /// /// # Returns /// /// The empty tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating empty tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("empty"))] #[cfg_attr(not(doc), doc = "`Tensor::empty`")] /// function, which is more high-level and designed for public use. fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive; /// Creates a tensor filled with zeros. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device on which the tensor will be allocated. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor filled with zeros. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating a tensor filled with zeros, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("zeros"))] #[cfg_attr(not(doc), doc = "`Tensor::zeros`")] /// function, which is more high-level and designed for public use. fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive; /// Creates a tensor filled with ones. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device on which the tensor will be allocated. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor filled with ones. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating a tensor filled with ones, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("ones"))] #[cfg_attr(not(doc), doc = "`Tensor::ones`")] /// function, which is more high-level and designed for public use. fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive; /// Creates a tensor of the given shape where each element is equal to the provided value. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `fill_value` - The value with which to fill the tensor. /// * `device` - The device on which the tensor will be allocated. /// * `dtype` - The target data type. /// /// # Returns /// /// The tensor filled with the specified value. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating full tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("full"))] #[cfg_attr(not(doc), doc = "`Tensor::full`")] /// function, which is more high-level and designed for public use. fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive; /// Reshapes the tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `shape` - The new shape of the tensor. /// /// # Returns /// /// The reshaped tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For reshaping a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("reshape"))] #[cfg_attr(not(doc), doc = "`Tensor::reshape`")] /// function, which is more high-level and designed for public use. fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive; /// Transposes a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. fn transpose(tensor: Self::Primitive) -> Self::Primitive; /// Swaps two dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap. /// * `dim2` - The second dimension to swap. /// /// # Returns /// /// The tensor with the dimensions swapped. fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive; /// Permutes the dimensions of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to permute the dimensions of. /// * `axes` - The new order of the dimensions. /// /// # Returns /// /// The tensor with the dimensions permuted. fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive; /// Flips the tensor along the given axes. /// /// # Arguments /// /// * `tensor` - The tensor to flip. /// * `axes` - The axes to flip the tensor along. /// /// # Returns /// /// The tensor with the axes flipped. fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive; /// Select tensor elements corresponding to the given slices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `slices` - The slices specifying ranges and steps for each dimension. /// /// # Returns /// /// The selected elements. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For selecting elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("slice"))] #[cfg_attr(not(doc), doc = "`Tensor::slice`")] /// function, which is more high-level and designed for public use. fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive; /// Assigns the given value to the tensor elements corresponding to the given slices. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `slices` - The slices specifying which elements to assign, including support for steps. /// * `value` - The value to assign. /// /// # Returns /// /// The tensor with the assigned values. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For assigning values to elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("slice_assign"))] #[cfg_attr(not(doc), doc = "`Tensor::slice_assign`")] /// function, which is more high-level and designed for public use. fn slice_assign( tensor: Self::Primitive, slices: &[Slice], value: Self::Primitive, ) -> Self::Primitive; /// Select tensor elements along the given dimension corresponding to the given indices. /// /// # Arguments /// /// * `tensor` - The tensor to select from. /// * `dim` - The dimension along which to select. /// * `indices` - The indices of the elements to select. /// /// # Returns /// /// The selected tensor elements. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For selecting elements from a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("select"))] #[cfg_attr(not(doc), doc = "`Tensor::select`")] /// function, which is more high-level and designed for public use. fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor) -> Self::Primitive; /// Assign the selected elements along the given dimension corresponding to the given indices /// from the value tensor. /// /// # Arguments /// /// * `tensor` - The tensor to assign elements to. /// * `dim` - The axis along which to assign elements. /// * `indices` - The indices of the elements to assign. /// * `values` - The values to assign to the tensor. /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add). /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is taken from the /// corresponding element of the input tensor at the corresponding index along the specified axis, /// except for the elements at the specified indices, which are taken from the corresponding /// element of the values tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For assigning elements to a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("select_assign"))] #[cfg_attr(not(doc), doc = "`Tensor::select_assign`")] /// function, which is more high-level and designed for public use. fn select_assign( tensor: Self::Primitive, dim: usize, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive; /// Selects elements from a tensor based on a boolean mask. /// /// # Arguments /// /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true. /// * `mask` - The boolean mask to use for selecting elements. /// * `source` - The tensor to select elements from when the corresponding element of the mask is false. /// /// # Returns /// /// A tensor with the same shape as the input tensors, where each element is taken from the /// corresponding element of the left hand side tensor if the corresponding element of the mask /// is true, and from the corresponding element of the right hand side tensor otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For selecting elements from a tensor based on a boolean mask, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mask_where"))] #[cfg_attr(not(doc), doc = "`Tensor::mask_where`")] /// function, which is more high-level and designed for public use. fn mask_where( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, source: Self::Primitive, ) -> Self::Primitive; /// Fills elements of a tensor based on a boolean mask. /// /// # Arguments /// /// * `tensor` - The tensor where will be overwritten with the value /// when the corresponding element of the mask is true. /// * `mask` - The boolean mask to use for filling elements. /// * `value` - The value to fill elements with when the corresponding element of the mask is true. /// /// # Returns /// /// A tensor with the same shape as the input tensors, where each element is taken from the /// corresponding element unmodified if the corresponding element of the mask is false, and /// filled with the value otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For filling elements of a tensor based on a boolean mask, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mask_fill"))] #[cfg_attr(not(doc), doc = "`Tensor::mask_fill`")] /// function, which is more high-level and designed for public use. fn mask_fill( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, value: Scalar, ) -> Self::Primitive; /// Gathers elements from a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to gather elements. /// * `tensor` - The tensor to gather elements from. /// * `indices` - The indices of the elements to gather. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is taken from the /// corresponding element of the input tensor at the corresponding index along the specified axis. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For gathering elements from a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("gather"))] #[cfg_attr(not(doc), doc = "`Tensor::gather`")] /// function, which is more high-level and designed for public use. fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor) -> Self::Primitive; /// Scatters elements into a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to scatter elements. /// * `tensor` - The tensor to scatter elements into. /// * `indices` - The indices of the elements to scatter. /// * `values` - The values to scatter into the tensor. /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add). /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is taken from the /// corresponding element of the input tensor at the corresponding index along the specified axis, /// except for the elements at the specified indices, which are taken from the corresponding /// element of the values tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For scattering elements into a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("scatter"))] #[cfg_attr(not(doc), doc = "`Tensor::scatter`")] /// function, which is more high-level and designed for public use. fn scatter( dim: usize, tensor: Self::Primitive, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive; /// Returns the device on which the tensor is allocated. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The device on which the tensor is allocated. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the device of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("device"))] #[cfg_attr(not(doc), doc = "`Tensor::device`")] /// function, which is more high-level and designed for public use. fn device(tensor: &Self::Primitive) -> B::Device; /// Moves the tensor to the given device. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `device` - The device on which the tensor will be moved. /// /// # Returns /// /// The tensor on the given device. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For moving a tensor to a device, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("to_device"))] #[cfg_attr(not(doc), doc = "`Tensor::to_device`")] /// function, which is more high-level and designed for public use. #[allow(clippy::wrong_self_convention)] fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive; /// Extracts the data from the tensor asynchronously. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The data of the tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For extracting the data of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("into_data"))] #[cfg_attr(not(doc), doc = "`Tensor::into_data`")] /// function, which is more high-level and designed for public use. #[allow(clippy::wrong_self_convention)] fn into_data_async( tensor: Self::Primitive, ) -> impl Future> + Send; /// Read the data from the tensor using a transaction. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. fn register_transaction(tr: &mut TransactionPrimitive, tensor: Self::Primitive); /// Creates a tensor from the given data. /// /// # Arguments /// /// * `data` - The data of the tensor. /// * `device` - The device on which the tensor will be allocated. /// /// # Returns /// /// The tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating a tensor from data, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("from_data"))] #[cfg_attr(not(doc), doc = "`Tensor::from_data`")] /// function, which is more high-level and designed for public use. fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive; /// Creates a tensor from the given data enforcing the given data type. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For creating a tensor from data, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("from_data_dtype"))] #[cfg_attr(not(doc), doc = "`Tensor::from_data_dtype`")] /// function, which is more high-level and designed for public use. fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive; /// Repeat the tensor along the given dimension. /// /// # Arguments /// /// * `tensor` - The tensor. /// * `dim` - The dimension along which the tensor will be repeated. /// * `times` - The number of times the tensor will be repeated. /// /// # Returns /// /// The repeated tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For repeating a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("repeat_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::repeat_dim`")] /// function, which is more high-level and designed for public use. fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive; /// Concatenates the given tensors along the given dimension. /// /// # Arguments /// /// * `vectors` - The tensors to concatenate. /// * `dim` - The dimension along which the tensors will be concatenated. /// /// # Returns /// /// The concatenated tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For concatenating tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("cat"))] #[cfg_attr(not(doc), doc = "`Tensor::cat`")] /// function, which is more high-level and designed for public use. fn cat(vectors: Vec, dim: usize) -> Self::Primitive; /// Equates the given tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor of booleans indicating whether the corresponding elements are equal. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For equating tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("equal"))] #[cfg_attr(not(doc), doc = "`Tensor::equal`")] /// function, which is more high-level and designed for public use. fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise equality between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding elements of the input tensors are equal, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise equality between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("equal_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::equal_elem`")] /// function, which is more high-level and designed for public use. fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Applies element-wise non-equality comparison between the given tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The tensor of booleans indicating whether the corresponding elements are equal. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For non-equality comparison of tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal"))] #[cfg_attr(not(doc), doc = "`Tensor::not_equal`")] /// function, which is more high-level and designed for public use. fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise non-equality between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding elements of the input tensors are equal, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise non-equality between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("not_equal_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::not_equal_elem`")] /// function, which is more high-level and designed for public use. fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Returns the name of the element type. fn elem_type_name() -> &'static str { core::any::type_name::() } /// Returns the tensor data type. fn dtype(tensor: &Self::Primitive) -> DType { tensor.dtype() } /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("any"))] #[cfg_attr(not(doc), doc = "`Tensor::any`")] /// function, which is more high-level and designed for public use. fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive; /// Tests if any element in the tensor evaluates to True along a given dimension dim. /// /// # Arguments /// /// * tensor - The tensor to test. /// * dim - The axis along which to test. /// /// # Returns /// /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1. /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("any_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::any_dim`")] /// function, which is more high-level and designed for public use. fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive; /// Tests if all elements in the `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("all"))] #[cfg_attr(not(doc), doc = "`Tensor::all`")] /// function, which is more high-level and designed for public use. fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive; /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. /// /// # Returns /// /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1. /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("all_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::all_dim`")] /// function, which is more high-level and designed for public use. fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive; /// Broadcasts the given tensor to the specified shape. /// /// # Arguments /// /// * `tensor` - The tensor to broadcast. /// * `shape` - The shape to broadcast to. /// /// # Returns /// /// The broadcasted tensor. fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive; /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Warning /// /// For the `ndarray` and `candle` backends; this is not a view but a full copy. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, post=..., size]``. fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive; } ================================================ FILE: crates/burn-backend/src/tensor/ops/bool.rs ================================================ use alloc::vec::Vec; use burn_std::{DType, Shape, Slice}; use crate::{ AutodiffBackend, Backend, ExecutionError, Scalar, TensorData, element::Element, ops::TransactionPrimitive, tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind}, }; impl BasicOps for Bool { type Elem = B::BoolElem; fn empty(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { if dtype != Self::Elem::dtype() { panic!("Expected bool data type, got {dtype:?}"); } B::bool_empty(shape, device) } fn zeros(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { if dtype != Self::Elem::dtype() { panic!("Expected bool data type, got {dtype:?}"); } B::bool_zeros(shape, device) } fn ones(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { if dtype != Self::Elem::dtype() { panic!("Expected bool data type, got {dtype:?}"); } B::bool_ones(shape, device) } fn full(shape: Shape, fill_value: Scalar, device: &Device, dtype: DType) -> Self::Primitive { if dtype != Self::Elem::dtype() { panic!("Expected bool data type, got {dtype:?}"); } if fill_value.elem() { B::bool_ones(shape, device) } else { B::bool_zeros(shape, device) } } fn register_transaction(tr: &mut TransactionPrimitive, tensor: Self::Primitive) { tr.register_bool(tensor); } fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::bool_reshape(tensor, shape) } fn transpose(tensor: Self::Primitive) -> Self::Primitive { B::bool_transpose(tensor) } fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive { B::bool_swap_dims(tensor, dim1, dim2) } fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive { B::bool_slice(tensor, slices) } fn slice_assign( tensor: Self::Primitive, slices: &[Slice], value: Self::Primitive, ) -> Self::Primitive { B::bool_slice_assign(tensor, slices, value) } fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor) -> Self::Primitive { B::bool_select(tensor, dim, indices) } fn select_assign( tensor: Self::Primitive, dim: usize, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { match update { IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values), } } fn mask_where( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, source: Self::Primitive, ) -> Self::Primitive { B::bool_mask_where(tensor, mask, source) } fn mask_fill( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, value: Scalar, ) -> Self::Primitive { B::bool_mask_fill(tensor, mask, value) } fn gather( dim: usize, tensor: Self::Primitive, indices: B::IntTensorPrimitive, ) -> Self::Primitive { B::bool_gather(dim, tensor, indices) } fn scatter( dim: usize, tensor: Self::Primitive, indices: B::IntTensorPrimitive, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { match update { IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values), } } fn device(tensor: &Self::Primitive) -> Device { B::bool_device(tensor) } fn to_device(tensor: Self::Primitive, device: &Device) -> Self::Primitive { B::bool_to_device(tensor, device) } async fn into_data_async(tensor: Self::Primitive) -> Result { B::bool_into_data(tensor).await } fn from_data(data: TensorData, device: &Device) -> Self::Primitive { B::bool_from_data(data.convert::(), device) } fn from_data_dtype(data: TensorData, device: &Device, _dtype: DType) -> Self::Primitive { // Bool tensors have exactly one representation per backend, so the // requested dtype is irrelevant. Convert to `B::BoolElem` directly. B::bool_from_data(data.convert::(), device) } fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive { B::bool_repeat_dim(tensor, dim, times) } fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::bool_equal(lhs, rhs) } fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::bool_not_equal(lhs, rhs) } fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::bool_equal_elem(lhs, rhs) } fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::bool_not_equal_elem(lhs, rhs) } fn cat(vectors: Vec, dim: usize) -> Self::Primitive { B::bool_cat(vectors, dim) } fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive { B::bool_any(tensor) } fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive { B::bool_any_dim(tensor, dim) } fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive { B::bool_all(tensor) } fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive { B::bool_all_dim(tensor, dim) } fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::bool_permute(tensor, axes) } fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::bool_expand(tensor, shape) } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::bool_flip(tensor, axes) } fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive { B::bool_unfold(tensor, dim, size, step) } } impl BasicAutodiffOps for Bool { type InnerKind = Bool; fn inner( tensor: >::Primitive, ) -> ::InnerBackend>>::Primitive { B::bool_inner(tensor) } fn from_inner( inner: ::InnerBackend>>::Primitive, ) -> >::Primitive { B::bool_from_inner(inner) } } ================================================ FILE: crates/burn-backend/src/tensor/ops/float.rs ================================================ use alloc::vec::Vec; use burn_std::{DType, Shape, Slice}; use crate::{ AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorPrimitive, ops::TransactionPrimitive, tensor::{ BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered, TensorKind, }, }; macro_rules! q_bin_ops { ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => { match ($lhs, $rhs) { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::$op(lhs, rhs)) } (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs), (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs)) } (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => { TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs))) } } }; } impl BasicOps for Float { type Elem = B::FloatElem; fn empty(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { TensorPrimitive::Float(B::float_empty(shape, device, dtype.into())) } fn zeros(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into())) } fn ones(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { TensorPrimitive::Float(B::float_ones(shape, device, dtype.into())) } fn full(shape: Shape, fill_value: Scalar, device: &Device, dtype: DType) -> Self::Primitive { TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into())) } fn register_transaction(tr: &mut TransactionPrimitive, tensor: Self::Primitive) { tr.register_float(tensor); } fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_reshape(tensor, shape)) } TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)), } } fn transpose(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)), } } fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2)) } } } fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_slice(tensor, slices)) } TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)), } } fn slice_assign( tensor: Self::Primitive, slices: &[Slice], value: Self::Primitive, ) -> Self::Primitive { TensorPrimitive::Float(B::float_slice_assign( tensor.tensor(), slices, value.tensor(), )) } fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_select(tensor, dim, indices)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_select(tensor, dim, indices)) } } } fn select_assign( tensor: Self::Primitive, dim: usize, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { // Select assign is ambiguous for QFloat match update { IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add( tensor.tensor(), dim, indices, values.tensor(), )), } } fn mask_where( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, source: Self::Primitive, ) -> Self::Primitive { TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor())) } fn mask_fill( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, value: Scalar, ) -> Self::Primitive { TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value)) } fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_gather(dim, tensor, indices)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices)) } } } fn scatter( dim: usize, tensor: Self::Primitive, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { match update { IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add( dim, tensor.tensor(), indices, values.tensor(), )), } } fn device(tensor: &Self::Primitive) -> Device { match tensor { TensorPrimitive::Float(tensor) => B::float_device(tensor), TensorPrimitive::QFloat(tensor) => B::q_device(tensor), } } fn to_device(tensor: Self::Primitive, device: &Device) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_to_device(tensor, device)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_to_device(tensor, device)) } } } async fn into_data_async(tensor: Self::Primitive) -> Result { match tensor { TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await, TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await, } } fn from_data(data: TensorData, device: &Device) -> Self::Primitive { match &data.dtype { DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)), _ => TensorPrimitive::Float(B::float_from_data(data.convert::(), device)), } } fn from_data_dtype(data: TensorData, device: &Device, dtype: DType) -> Self::Primitive { match dtype { DType::QFloat(_scheme) => { TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device)) } _ if dtype.is_float() => { TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device)) } _ => panic!("Expected float dtype, got {dtype:?}"), } } fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times)) } } } fn cat(vectors: Vec, dim: usize) -> Self::Primitive { match vectors.first().unwrap() { TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat( vectors.into_iter().map(|tensor| tensor.tensor()).collect(), dim, )), TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat( vectors .into_iter() .map(|tensor| { if let TensorPrimitive::QFloat(t) = tensor { t } else { panic!("Concatenation only works with vector of QFloat") } }) .collect(), dim, )), } } fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_equal(lhs.tensor(), rhs.tensor()) } fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_not_equal(lhs.tensor(), rhs.tensor()) } fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_equal_elem(lhs.tensor(), rhs) } fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_not_equal_elem(lhs.tensor(), rhs) } fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive { B::float_any(tensor.tensor()) } fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive { B::float_any_dim(tensor.tensor(), dim) } fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive { B::float_all(tensor.tensor()) } fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive { B::float_all_dim(tensor.tensor(), dim) } fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_permute(tensor, axes)) } TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)), } } fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape)) } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)), } } fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive { TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step)) } } impl Numeric for Float { fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { q_bin_ops!(lhs, rhs, float_add, q_add) } fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { match lhs { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs), } } fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { q_bin_ops!(lhs, rhs, float_sub, q_sub) } fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { match lhs { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs), } } fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { q_bin_ops!(lhs, rhs, float_div, q_div) } fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { match lhs { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs), } } fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor())) } fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs)) } fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { q_bin_ops!(lhs, rhs, float_mul, q_mul) } fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { match lhs { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs), } } fn neg(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)), TensorPrimitive::QFloat(tensor) => B::q_neg(tensor), } } fn sum(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)), TensorPrimitive::QFloat(tensor) => B::q_sum(tensor), } } fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)), TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim), } } fn prod(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)), TensorPrimitive::QFloat(tensor) => B::q_prod(tensor), } } fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_prod_dim(tensor, dim)) } TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim), } } fn mean(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)), TensorPrimitive::QFloat(tensor) => B::q_mean(tensor), } } fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_mean_dim(tensor, dim)) } TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim), } } fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)), TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim), } } fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)), TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim), } } fn abs(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)), } } fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { q_bin_ops!(lhs, rhs, float_powf, q_powf) } fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { match lhs { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs), } } fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { TensorPrimitive::Float(B::float_random(shape, distribution, device)) } fn sign(tensor: Self::Primitive) -> Self::Primitive { TensorPrimitive::Float(B::float_sign(tensor.tensor())) } /// Applies the matrix multiplication operation. /// /// `C = AB` /// /// # Panics /// /// If the two tensors don't have a compatible shape. fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { match (lhs, rhs) { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_matmul(lhs, rhs)) } (lhs, rhs) => B::q_matmul(lhs, rhs), } } } impl Ordered for Float { fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_sort(tensor, dim, descending)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending)) } } } fn sort_with_indices( tensor: Self::Primitive, dim: usize, descending: bool, ) -> (Self::Primitive, IntTensor) { match tensor { TensorPrimitive::Float(tensor) => { let (values, indices) = B::float_sort_with_indices(tensor, dim, descending); (TensorPrimitive::Float(values), indices) } TensorPrimitive::QFloat(tensor) => { let (values, indices) = B::q_sort_with_indices(tensor, dim, descending); (TensorPrimitive::QFloat(values), indices) } } } fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { match tensor { TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending), TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending), } } fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)), TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim), } } fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)), TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim), } } fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_greater(lhs.tensor(), rhs.tensor()) } fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_greater_elem(lhs.tensor(), rhs) } fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_greater_equal(lhs.tensor(), rhs.tensor()) } fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_greater_equal_elem(lhs.tensor(), rhs) } fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_lower(lhs.tensor(), rhs.tensor()) } fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_lower_elem(lhs.tensor(), rhs) } fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::float_lower_equal(lhs.tensor(), rhs.tensor()) } fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::float_lower_equal_elem(lhs.tensor(), rhs) } fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor { match tensor { TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim), TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim), } } fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor { match tensor { TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim), TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim), } } fn max(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)), } } fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)), } } fn max_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, IntTensor) { match tensor { TensorPrimitive::Float(tensor) => { let (values, indices) = B::float_max_dim_with_indices(tensor, dim); (TensorPrimitive::Float(values), indices) } TensorPrimitive::QFloat(tensor) => { let (values, indices) = B::q_max_dim_with_indices(tensor, dim); (TensorPrimitive::QFloat(values), indices) } } } fn min(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)), } } fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)), } } fn min_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, IntTensor) { match tensor { TensorPrimitive::Float(tensor) => { let (values, indices) = B::float_min_dim_with_indices(tensor, dim); (TensorPrimitive::Float(values), indices) } TensorPrimitive::QFloat(tensor) => { let (values, indices) = B::q_min_dim_with_indices(tensor, dim); (TensorPrimitive::QFloat(values), indices) } } } fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp(tensor, min, max)) } TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max), } } fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp_min(tensor, min)) } TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min), } } fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp_max(tensor, max)) } TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max), } } fn max_abs(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)), } } fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim)) } } } } impl BasicAutodiffOps for Float { type InnerKind = Float; fn inner( tensor: >::Primitive, ) -> ::InnerBackend>>::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)), } } fn from_inner( inner: ::InnerBackend>>::Primitive, ) -> >::Primitive { match inner { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)), TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)), } } } ================================================ FILE: crates/burn-backend/src/tensor/ops/int.rs ================================================ use alloc::vec::Vec; use burn_std::{DType, Shape, Slice}; use crate::{ AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, ops::TransactionPrimitive, tensor::{ BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric, Ordered, TensorKind, }, }; impl BasicOps for Int { type Elem = B::IntElem; fn empty(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { B::int_empty(shape, device, dtype.into()) } fn zeros(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { B::int_zeros(shape, device, dtype.into()) } fn ones(shape: Shape, device: &Device, dtype: DType) -> Self::Primitive { B::int_ones(shape, device, dtype.into()) } fn full(shape: Shape, fill_value: Scalar, device: &Device, dtype: DType) -> Self::Primitive { B::int_full(shape, fill_value, device, dtype.into()) } fn register_transaction(tr: &mut TransactionPrimitive, tensor: Self::Primitive) { tr.register_int(tensor); } fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::int_reshape(tensor, shape) } fn transpose(tensor: Self::Primitive) -> Self::Primitive { B::int_transpose(tensor) } fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive { B::int_swap_dims(tensor, dim1, dim2) } fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive { B::int_slice(tensor, slices) } fn slice_assign( tensor: Self::Primitive, slices: &[Slice], value: Self::Primitive, ) -> Self::Primitive { B::int_slice_assign(tensor, slices, value) } fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor) -> Self::Primitive { B::int_select(tensor, dim, indices) } fn select_assign( tensor: Self::Primitive, dim: usize, indices: IntTensor, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { match update { IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values), } } fn mask_where( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, source: Self::Primitive, ) -> Self::Primitive { B::int_mask_where(tensor, mask, source) } fn mask_fill( tensor: Self::Primitive, mask: B::BoolTensorPrimitive, value: Scalar, ) -> Self::Primitive { B::int_mask_fill(tensor, mask, value) } fn gather( dim: usize, tensor: Self::Primitive, indices: B::IntTensorPrimitive, ) -> Self::Primitive { B::int_gather(dim, tensor, indices) } fn scatter( dim: usize, tensor: Self::Primitive, indices: B::IntTensorPrimitive, values: Self::Primitive, update: IndexingUpdateOp, ) -> Self::Primitive { match update { IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values), } } fn device(tensor: &Self::Primitive) -> Device { B::int_device(tensor) } fn to_device(tensor: Self::Primitive, device: &Device) -> Self::Primitive { B::int_to_device(tensor, device) } async fn into_data_async(tensor: Self::Primitive) -> Result { B::int_into_data(tensor).await } fn from_data(data: TensorData, device: &Device) -> Self::Primitive { B::int_from_data(data.convert::(), device) } fn from_data_dtype(data: TensorData, device: &Device, dtype: DType) -> Self::Primitive { if !dtype.is_int() { panic!("Expected int dtype, got {dtype:?}") } B::int_from_data(data.convert_dtype(dtype), device) } fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive { B::int_repeat_dim(tensor, dim, times) } fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor { B::int_equal(lhs, rhs) } fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor { B::int_not_equal(lhs, rhs) } fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_equal_elem(lhs, rhs) } fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_not_equal_elem(lhs, rhs) } fn cat(vectors: Vec, dim: usize) -> Self::Primitive { B::int_cat(vectors, dim) } fn any(tensor: Self::Primitive) -> BoolTensor { B::int_any(tensor) } fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor { B::int_any_dim(tensor, dim) } fn all(tensor: Self::Primitive) -> BoolTensor { B::int_all(tensor) } fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor { B::int_all_dim(tensor, dim) } fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::int_permute(tensor, axes) } fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive { B::int_expand(tensor, shape) } fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { B::int_flip(tensor, axes) } fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive { B::int_unfold(tensor, dim, size, step) } } impl Numeric for Int { fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_add(lhs, rhs) } fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_add_scalar(lhs, rhs) } fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_sub(lhs, rhs) } fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_sub_scalar(lhs, rhs) } fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_div(lhs, rhs) } fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_div_scalar(lhs, rhs) } fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_remainder(lhs, rhs) } fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_remainder_scalar(lhs, rhs) } fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_mul(lhs, rhs) } fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_mul_scalar(lhs, rhs) } fn neg(tensor: Self::Primitive) -> Self::Primitive { B::int_neg(tensor) } fn sum(tensor: Self::Primitive) -> Self::Primitive { B::int_sum(tensor) } fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_sum_dim(tensor, dim) } fn prod(tensor: Self::Primitive) -> Self::Primitive { B::int_prod(tensor) } fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_prod_dim(tensor, dim) } fn mean(tensor: Self::Primitive) -> Self::Primitive { B::int_mean(tensor) } fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_mean_dim(tensor, dim) } fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_cumsum(tensor, dim) } fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_cumprod(tensor, dim) } fn abs(tensor: Self::Primitive) -> Self::Primitive { B::int_abs(tensor) } fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_powi(lhs, rhs) } fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive { B::int_powi_scalar(lhs, rhs) } fn random(shape: Shape, distribution: Distribution, device: &Device) -> Self::Primitive { B::int_random(shape, distribution, device) } fn sign(tensor: Self::Primitive) -> Self::Primitive { B::int_sign(tensor) } /// Applies the matrix multiplication operation. /// /// `C = AB` /// /// # Panics /// /// If the two tensors don't have a compatible shape. fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive { B::int_matmul(lhs, rhs) } } impl Ordered for Int { fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive { B::int_sort(tensor, dim, descending) } fn sort_with_indices( tensor: Self::Primitive, dim: usize, descending: bool, ) -> (Self::Primitive, IntTensor) { B::int_sort_with_indices(tensor, dim, descending) } fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_cummin(tensor, dim) } fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_cummax(tensor, dim) } fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::int_greater(lhs, rhs) } fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_greater_elem(lhs, rhs) } fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::int_greater_equal(lhs, rhs) } fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_greater_equal_elem(lhs, rhs) } fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::int_lower(lhs, rhs) } fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_lower_elem(lhs, rhs) } fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive { B::int_lower_equal(lhs, rhs) } fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive { B::int_lower_equal_elem(lhs, rhs) } fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor { B::int_argmax(tensor, dim) } fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor { B::int_argmin(tensor, dim) } fn max(tensor: Self::Primitive) -> Self::Primitive { B::int_max(tensor) } fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_max_dim(tensor, dim) } fn max_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, IntTensor) { B::int_max_dim_with_indices(tensor, dim) } fn max_abs(tensor: Self::Primitive) -> Self::Primitive { B::int_max_abs(tensor) } fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_max_abs_dim(tensor, dim) } fn min(tensor: Self::Primitive) -> Self::Primitive { B::int_min(tensor) } fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { B::int_min_dim(tensor, dim) } fn min_dim_with_indices( tensor: Self::Primitive, dim: usize, ) -> (Self::Primitive, IntTensor) { B::int_min_dim_with_indices(tensor, dim) } fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive { B::int_clamp(tensor, min, max) } fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive { B::int_clamp_min(tensor, min) } fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive { B::int_clamp_max(tensor, max) } } impl BasicAutodiffOps for Int { type InnerKind = Int; fn inner( tensor: >::Primitive, ) -> ::InnerBackend>>::Primitive { B::int_inner(tensor) } fn from_inner( inner: ::InnerBackend>>::Primitive, ) -> >::Primitive { B::int_from_inner(inner) } } ================================================ FILE: crates/burn-backend/src/tensor/ops/mod.rs ================================================ mod autodiff; mod base; mod bool; mod float; mod int; mod numeric; mod ordered; pub use autodiff::*; pub use base::*; pub use numeric::*; pub use ordered::*; /// Computation to be used to update the existing values in indexed assignment operations (scatter/select). #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum IndexingUpdateOp { // Assign, /// Performs an addition. Add, // Mul } ================================================ FILE: crates/burn-backend/src/tensor/ops/numeric.rs ================================================ use burn_std::Shape; use crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps}; /// Trait that list all operations that can be applied on all numerical tensors. /// /// # Warnings /// /// This is an internal trait, use the public API provided by the #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// struct. pub trait Numeric: BasicOps where Self::Elem: Element, { /// Adds two tensors together. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The sum of the two tensors. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For adding tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("add"))] #[cfg_attr(not(doc), doc = "`Tensor::add`")] /// function, which is more high-level and designed for public use. fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Adds a scalar to a tensor element-wise. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The sum of the tensor and the scalar. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For adding a scalar to a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("add_scalar"))] #[cfg_attr(not(doc), doc = "`Tensor::add_scalar`")] /// function, which is more high-level and designed for public use. fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Subtracts two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The difference of the two tensors. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For subtracting tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sub"))] #[cfg_attr(not(doc), doc = "`Tensor::sub`")] /// function, which is more high-level and designed for public use. fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Subtracts a scalar from a tensor element-wise. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The difference of the tensor and the scalar. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For subtracting a scalar from a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sub_scalar"))] #[cfg_attr(not(doc), doc = "`Tensor::sub_scalar`")] /// function, which is more high-level and designed for public use. fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Divides two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The quotient of the two tensors. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For dividing tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("div"))] #[cfg_attr(not(doc), doc = "`Tensor::div`")] /// function, which is more high-level and designed for public use. fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Divides a tensor by a scalar element-wise. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The quotient of the tensor and the scalar. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For dividing a tensor by a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("div_scalar"))] #[cfg_attr(not(doc), doc = "`Tensor::div_scalar`")] /// function, which is more high-level and designed for public use. fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is /// less than that of the divisor. /// /// # Arguments /// /// * `lhs` - The dividend. /// * `rhs` - The divisor. /// /// # Returns /// /// The modulo of the input tensor with the divisor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For performing the modulo operation, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("remainder"))] #[cfg_attr(not(doc), doc = "`Tensor::remainder`")] /// function, which is more high-level and designed for public use. fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is /// less than that of the divisor. /// /// # Arguments /// /// * `lhs` - The dividend. /// * `rhs` - The divisor. /// /// # Returns /// /// The modulo of the input tensor with the divisor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For performing the modulo operation, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("remainder_scalar"))] #[cfg_attr(not(doc), doc = "`Tensor::remainder_scalar`")] /// function, which is more high-level and designed for public use. fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Multiplies two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// The product of the two tensors. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For multiplying tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mul"))] #[cfg_attr(not(doc), doc = "`Tensor::mul`")] /// function, which is more high-level and designed for public use. fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Multiplies a tensor by a scalar element-wise. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// The product of the tensor and the scalar. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For multiplying a tensor by a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mul_scalar"))] #[cfg_attr(not(doc), doc = "`Tensor::mul_scalar`")] /// function, which is more high-level and designed for public use. fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Negates a tensor. /// /// # Arguments /// /// * `tensor` - The tensor to negate. /// /// # Returns /// /// The negated tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For negating a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("neg"))] #[cfg_attr(not(doc), doc = "`Tensor::neg`")] /// function, which is more high-level and designed for public use. fn neg(tensor: Self::Primitive) -> Self::Primitive; /// Returns the signs of the elements of a tensor. /// /// # Arguments /// /// * `tensor` - The tensor. /// /// # Returns /// /// The signs of the elements of the tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the signs of the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sign"))] #[cfg_attr(not(doc), doc = "`Tensor::sign`")] /// function, which is more high-level and designed for public use. fn sign(tensor: Self::Primitive) -> Self::Primitive; /// Sums all the elements of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// /// # Returns /// /// The sum of all the elements of the tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For summing all the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sum"))] #[cfg_attr(not(doc), doc = "`Tensor::sum`")] /// function, which is more high-level and designed for public use. fn sum(tensor: Self::Primitive) -> Self::Primitive; /// Sums all the elements of the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to sum. /// * `dim` - The dimension along which to sum. /// /// # Returns /// /// The sum of all the elements of the tensor along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For summing all the elements of a tensor along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sum_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::sum_dim`")] /// function, which is more high-level and designed for public use. fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Computes the product of all the elements of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to compute the product of. /// /// # Returns /// /// The product of all the elements of the tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the product of all the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("prod"))] #[cfg_attr(not(doc), doc = "`Tensor::prod`")] /// function, which is more high-level and designed for public use. fn prod(tensor: Self::Primitive) -> Self::Primitive; /// Computes the product of all the elements of the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the product of. /// * `dim` - The dimension along which to compute the product. /// /// # Returns /// /// The product of all the elements of the tensor along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the product of all the elements of a tensor along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("prod_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::prod_dim`")] /// function, which is more high-level and designed for public use. fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Computes the mean of all the elements of the tensor. /// /// # Arguments /// /// * `tensor` - The tensor to compute the mean of. /// /// # Returns /// /// The mean of all the elements of the tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the mean of all the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mean"))] #[cfg_attr(not(doc), doc = "`Tensor::mean`")] /// function, which is more high-level and designed for public use. fn mean(tensor: Self::Primitive) -> Self::Primitive; /// Computes the mean of all the elements of the tensor along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the mean of. /// * `dim` - The dimension along which to compute the mean. /// /// # Returns /// /// The mean of all the elements of the tensor along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the mean of all the elements of a tensor along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("mean_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::mean_dim`")] /// function, which is more high-level and designed for public use. fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Computes the cumulative sum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative sum of. /// * `dim` - The dimension along which to compute the cumulative sum. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the cumulative sum /// of all elements up to and including that position along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the cumulative sum of elements along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("cumsum"))] #[cfg_attr(not(doc), doc = "`Tensor::cumsum`")] /// function, which is more high-level and designed for public use. fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Computes the cumulative product of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative product of. /// * `dim` - The dimension along which to compute the cumulative product. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the cumulative product /// of all elements up to and including that position along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the cumulative product of elements along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("cumprod"))] #[cfg_attr(not(doc), doc = "`Tensor::cumprod`")] /// function, which is more high-level and designed for public use. fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Calculate absolute value on all elements of a tensor /// /// # Arguments /// /// * `tensor` - The tensor to apply abs to. /// /// # Returns /// /// A tensor with absolute values. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For calculating abs of the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("abs"))] #[cfg_attr(not(doc), doc = "`Tensor::abs`")] /// function, which is more high-level and designed for public use. fn abs(tensor: Self::Primitive) -> Self::Primitive; /// Element-wise power of a tensor /// /// # Arguments /// * `tensor` - The tensor to apply power to. /// * `power` - The power to apply to the tensor. fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; /// Element-wise power of a tensor to a scalar int /// /// # Arguments /// * `tensor` - The tensor to apply power to. /// * `power` - The power to apply to the tensor. fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive; /// Create a random tensor. /// /// # Arguments /// /// * `shape` - The shape of the output tensor. /// * `distribution` - The distribution used to sample. /// * `device` - The device to use. /// /// # Returns /// /// A new tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("random"))] #[cfg_attr(not(doc), doc = "`Tensor::random`")] /// function, which is more high-level and designed for public use. fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive; /// Applies the matrix multiplication operation. /// /// ```math /// C = AB /// ``` fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; } ================================================ FILE: crates/burn-backend/src/tensor/ops/ordered.rs ================================================ use crate::{ Backend, Scalar, tensor::{IntTensor, Numeric}, }; /// Trait that list all operations that can be applied on all numerical tensors /// whose elements have a well-defined ordering. /// /// This includes operations such as comparisons, minimum/maximum reductions, /// and other order-dependent computations that are not strictly valid for all numerical /// types. /// /// # Warnings /// /// This is an internal trait, use the public API provided by the #[cfg_attr(doc, doc = crate::doc_tensor!())] #[cfg_attr(not(doc), doc = "`Tensor`")] /// struct. pub trait Ordered: Numeric { /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where the elements are sorted by value. /// /// # Remarks /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sort"))] #[cfg_attr(not(doc), doc = "`Tensor::sort`")] /// function, which is more high-level and designed for public use. fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive; /// Sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// the elements are sorted by value and the indices map back to the original input tensor. /// /// # Remarks /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For sorting the elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("sort_with_indices"))] #[cfg_attr(not(doc), doc = "`Tensor::sort_with_indices`")] /// function, which is more high-level and designed for public use. fn sort_with_indices( tensor: Self::Primitive, dim: usize, descending: bool, ) -> (Self::Primitive, IntTensor); /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `tensor` - The input tensor. /// * `dim` - The axis along which to sort. /// * `descending` - The sorting order. /// /// # Returns /// /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. /// /// # Remarks /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// Users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("argsort"))] #[cfg_attr(not(doc), doc = "`Tensor::argsort`")] /// function, which is more high-level and designed for public use. fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor; /// Computes the cumulative minimum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative minimum of. /// * `dim` - The dimension along which to compute the cumulative minimum. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the minimum /// of all elements up to and including that position along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the cumulative minimum of elements along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("cummin"))] #[cfg_attr(not(doc), doc = "`Tensor::cummin`")] /// function, which is more high-level and designed for public use. fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Computes the cumulative maximum of elements along a dimension. /// /// # Arguments /// /// * `tensor` - The tensor to compute the cumulative maximum of. /// * `dim` - The dimension along which to compute the cumulative maximum. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the maximum /// of all elements up to and including that position along the specified dimension. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For computing the cumulative maximum of elements along a dimension, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("cummax"))] #[cfg_attr(not(doc), doc = "`Tensor::cummax`")] /// function, which is more high-level and designed for public use. fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Element-wise greater than comparison between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding element of the left hand side tensor is greater than the corresponding element /// of the right hand side tensor, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise greater than comparison between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("greater"))] #[cfg_attr(not(doc), doc = "`Tensor::greater`")] /// function, which is more high-level and designed for public use. fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise greater than comparison between a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensor, where each element is true if the /// corresponding element of the left hand side tensor is greater than the right hand side /// scalar, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise greater than comparison between a tensor and a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("greater_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::greater_elem`")] /// function, which is more high-level and designed for public use. fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Element-wise greater than or equal comparison between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding element of the left hand side tensor is greater than or equal to the /// corresponding element of the right hand side tensor, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise greater than or equal comparison between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal"))] #[cfg_attr(not(doc), doc = "`Tensor::greater_equal`")] /// function, which is more high-level and designed for public use. fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise greater than or equal comparison between a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensor, where each element is true if the /// corresponding element of the left hand side tensor is greater than or equal to the right /// hand side scalar, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("greater_equal_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::greater_equal_elem`")] /// function, which is more high-level and designed for public use. fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Element-wise less than comparison between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding element of the left hand side tensor is less than the corresponding element of /// the right hand side tensor, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise less than comparison between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("lower"))] #[cfg_attr(not(doc), doc = "`Tensor::lower`")] /// function, which is more high-level and designed for public use. fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise less than comparison between a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensor, where each element is true if the /// corresponding element of the left hand side tensor is less than the right hand side scalar, /// and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise less than comparison between a tensor and a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("lower_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::lower_elem`")] /// function, which is more high-level and designed for public use. fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Element-wise less than or equal comparison between two tensors. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side tensor. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors, where each element is true if the /// corresponding element of the left hand side tensor is less than or equal to the corresponding /// element of the right hand side tensor, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise less than or equal comparison between two tensors, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal"))] #[cfg_attr(not(doc), doc = "`Tensor::lower_equal`")] /// function, which is more high-level and designed for public use. fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive; /// Element-wise less than or equal comparison between a tensor and a scalar. /// /// # Arguments /// /// * `lhs` - The left hand side tensor. /// * `rhs` - The right hand side scalar. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensor, where each element is true if the /// corresponding element of the left hand side tensor is less than or equal to the right hand /// side scalar, and false otherwise. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("lower_equal_elem"))] #[cfg_attr(not(doc), doc = "`Tensor::lower_equal_elem`")] /// function, which is more high-level and designed for public use. fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive; /// Gets the indices of the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to get the indices of the maximum elements. /// * `tensor` - The tensor to get the indices of the maximum elements from. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the index of the /// maximum element of the input tensor at the corresponding index along the specified axis. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("argmax"))] #[cfg_attr(not(doc), doc = "`Tensor::argmax`")] /// function, which is more high-level and designed for public use. fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor; /// Gets the indices of the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to get the indices of the minimum elements. /// * `tensor` - The tensor to get the indices of the minimum elements from. /// /// # Returns /// /// A tensor with the same shape as the input tensor, where each element is the index of the /// minimum element of the input tensor at the corresponding index along the specified axis. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("argmin"))] #[cfg_attr(not(doc), doc = "`Tensor::argmin`")] /// function, which is more high-level and designed for public use. fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor; /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to get the maximum elements. /// /// # Returns /// /// A single-element tensor containing the maximum element of the input tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the maximum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("max"))] #[cfg_attr(not(doc), doc = "`Tensor::max`")] /// function, which is more high-level and designed for public use. fn max(tensor: Self::Primitive) -> Self::Primitive; /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements from. /// * `dim` - The axis along which to get the maximum elements. /// /// # Returns /// /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. /// Each element is the maximum element of the corresponding input dim. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the maximum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::max_dim`")] /// function, which is more high-level and designed for public use. fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements from. /// * `dim` - The axis along which to get the maximum elements. /// /// # Returns /// /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape /// as the input tensor, where each element is the index of the maximum element of the input tensor /// at the corresponding index along the specified axis. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the maximum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("max_dim_with_indices"))] #[cfg_attr(not(doc), doc = "`Tensor::max_dim_with_indices`")] /// function, which is more high-level and designed for public use. fn max_dim_with_indices(tensor: Self::Primitive, dim: usize) -> (Self::Primitive, IntTensor); /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `dim` - The axis along which to get the maximum elements. /// /// # Returns /// /// A single-element tensor containing the maximum absolute element of the input tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the maximum absolute elements of a tensor, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs"))] #[cfg_attr(not(doc), doc = "`Tensor::max_abs`")] /// function, which is more high-level and designed for public use. fn max_abs(tensor: Self::Primitive) -> Self::Primitive; /// Gets the maximum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the maximum elements from. /// * `dim` - The axis along which to get the maximum elements. /// /// # Returns /// /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. /// Each element is the maximum absolute element of the corresponding input dim. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the maximum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("max_abs_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::max_abs_dim`")] /// function, which is more high-level and designed for public use. fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Gets the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements from. /// /// # Returns /// /// A single-element tensor containing the minimum element of the input tensor. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the minimum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("min"))] #[cfg_attr(not(doc), doc = "`Tensor::min`")] /// function, which is more high-level and designed for public use. fn min(tensor: Self::Primitive) -> Self::Primitive; /// Gets the minimum elements of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements from. /// * `dim` - The axis along which to get the minimum elements. /// /// # Returns /// /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1. /// Each element is the minimum element of the corresponding input dim. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the minimum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim"))] #[cfg_attr(not(doc), doc = "`Tensor::min_dim`")] /// function, which is more high-level and designed for public use. fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; /// Gets the minimum elements and indices of a tensor along an axis. /// /// # Arguments /// /// * `tensor` - The tensor to get the minimum elements from. /// /// # Returns /// /// A tensor with the same shape as the input tensor and corresponding indices, where /// each element is the minimum element of the input tensor at the corresponding index /// along the specified axis. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. /// /// For getting the minimum elements of a tensor along an axis, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("min_dim_with_indices"))] #[cfg_attr(not(doc), doc = "`Tensor::min_dim_with_indices`")] /// function, which is more high-level and designed for public use. fn min_dim_with_indices(tensor: Self::Primitive, dim: usize) -> (Self::Primitive, IntTensor); /// Clamp the tensor between the given min and max values. /// /// # Arguments /// /// * `min` - The minimum value. /// * `max` - The maximum value. /// /// # Returns /// /// A new tensor with the values clamped between the given min and max values. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users. /// /// For clamping a tensor between the given min and max values, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("clamp"))] #[cfg_attr(not(doc), doc = "`Tensor::clamp`")] /// function, which is more high-level and designed for public use. fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive; /// Clamps a tensor under a minimum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// /// # Returns /// /// A new tensor with the values clamped under the given min value. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users. /// /// For clamping a tensor under a minimum value, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_min"))] #[cfg_attr(not(doc), doc = "`Tensor::clamp_min`")] /// function, which is more high-level and designed for public use. fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive; /// Clamps a tensor over a maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `max` - The maximum value. /// /// # Returns /// /// A new tensor with the values clamped over the given max value. /// /// # Remarks /// /// This is a low-level function used internally by the library to call different backend functions /// with static dispatch. It is not designed for direct usage by users. /// /// For clamping a tensor over a maximum value, users should prefer the #[cfg_attr(doc, doc = crate::doc_tensor!("clamp_max"))] #[cfg_attr(not(doc), doc = "`Tensor::clamp_max`")] /// function, which is more high-level and designed for public use. fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive; } ================================================ FILE: crates/burn-backend/src/tensor/quantization/calibration.rs ================================================ /// Calibration method used to compute the quantization range mapping. pub enum Calibration { /// Computes quantization range mapping based on the min and max values. MinMax, } ================================================ FILE: crates/burn-backend/src/tensor/quantization/mod.rs ================================================ mod calibration; mod parameters; mod scheme; pub use calibration::*; pub use parameters::*; pub use scheme::*; ================================================ FILE: crates/burn-backend/src/tensor/quantization/parameters.rs ================================================ use crate::Backend; pub use burn_std::quantization::{QParamTensor, QParams}; /// The quantization parameters primitive. /// /// # Remarks /// /// This is a low-level struct used internally by the library to provide the quantization parameters /// to the backends. It is not designed for direct usage by users, and not recommended to import /// or use this struct directly. pub struct QuantizationParametersPrimitive { /// The scaling factor. pub scales: B::FloatTensorPrimitive, } ================================================ FILE: crates/burn-backend/src/tensor/quantization/scheme.rs ================================================ pub use burn_std::{QPARAM_ALIGN, params_shape}; use burn_std::{QuantLevel, QuantMode, QuantScheme, Shape}; use super::{Calibration, QuantizationParametersPrimitive}; use crate::{Backend, TensorMetadata}; /// Compute the quantization range mapping. pub fn compute_range( scheme: &QuantScheme, tensor: B::FloatTensorPrimitive, calibration: &Calibration, ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) { match calibration { Calibration::MinMax => match scheme.level { QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)), QuantLevel::Block(block_size) => { let block_elems = block_size.num_elements(); let shape = tensor.shape(); let numel = shape.num_elements(); assert_eq!( numel % block_elems, 0, "Tensor {shape:?} must be evenly divisible by block size {block_elems}" ); let num_blocks = numel / block_elems; let params_shape = params_shape(&shape, scheme.level); let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems])); let blocks_min = B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone()); let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape); (blocks_min, blocks_max) } }, } } /// Compute the quantization parameters. pub fn compute_q_params( scheme: &QuantScheme, min: B::FloatTensorPrimitive, max: B::FloatTensorPrimitive, ) -> QuantizationParametersPrimitive { match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, .. } => { // Quantized range `[a, b]` let (a, b) = scheme.value.range(); // Compute scale to convert an input value in range `[-alpha, alpha]` let min_abs = B::float_abs(min); let max_abs = B::float_abs(max); // `min_abs.max_pair(max_abs)` let mask = B::float_lower(min_abs.clone(), max_abs.clone()); let values_range = B::float_mul_scalar(B::float_mask_where(min_abs, mask, max_abs), 2f32.into()); QuantizationParametersPrimitive { scales: B::float_div_scalar(values_range, (b - a).into()), } } } } ================================================ FILE: crates/burn-backend-tests/.cargo/config.toml ================================================ [alias] test-cpu = "test --release --no-default-features --features cpu,std" test-cuda = "test --release --no-default-features --features cuda,std" test-ndarray = "test --release --no-default-features --features ndarray,std" test-rocm = "test --release --no-default-features --features rocm,std" test-router = "test --release --no-default-features --features router,std" test-tch = "test --release --no-default-features --features tch,std" test-wgpu = "test --release --no-default-features --features wgpu,std" test-vulkan = "test --release --no-default-features --features vulkan,std" test-metal = "test --release --no-default-features --features metal,std" ================================================ FILE: crates/burn-backend-tests/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Tensor tests for Burn backends" documentation = "https://docs.rs/burn-backend-tests" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-backend-tests" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-backend-tests" version.workspace = true [lints] workspace = true [features] default = [ "burn-tensor/default", "burn-autodiff/default", # Backends (default not enabled for CubeCL backends as it activates fusion) "burn-cpu?/default", "burn-ndarray?/default", "burn-tch?/default", # Default "ndarray", "std", ] std = [ "burn-tensor/std", "burn-autodiff/std", # Backends "burn-cpu?/std", "burn-ndarray?/std", "burn-wgpu?/std", "burn-router?/std", "burn-cuda?/std", "burn-rocm?/std", ] tracing = [ "cubecl?/tracing", "burn-tensor/tracing", "burn-autodiff/tracing", # Backends "burn-cpu?/tracing", "burn-ndarray?/tracing", "burn-wgpu?/tracing", "burn-router?/tracing", "burn-cuda?/tracing", "burn-rocm?/tracing", ] # Backends cuda = ["burn-cuda", "quantization", "cube"] rocm = ["burn-rocm", "quantization", "cube"] ndarray = ["burn-ndarray", "quantization"] tch = ["burn-tch"] vulkan = ["wgpu", "burn-wgpu/vulkan"] webgpu = ["wgpu", "burn-wgpu/webgpu"] metal = ["wgpu", "burn-wgpu/metal"] wgpu = ["burn-wgpu", "quantization", "cube"] cpu = ["burn-cpu", "cube"] router = ["burn-router", "ndarray", "burn-wgpu"] autotune = [ "burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-rocm?/autotune", "burn-cpu?/autotune", ] autotune-checks = [ "burn-wgpu?/autotune-checks", "burn-cuda?/autotune-checks", "burn-rocm?/autotune-checks", "burn-cpu?/autotune-checks", ] # CubeCL backends cube = [ "cubecl", "cubek", "autotune", "burn-fusion", "burn-cubecl", "burn-ndarray", ] # Test configs quantization = [] [dependencies] burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2" } # Backends burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", default-features = false, features = [ "export_tests", ] } burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false, features = [ "export_tests", ] } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } # To wrap `Fusion burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true, features = [ "fusion", ] } num-traits = { workspace = true } serial_test = { workspace = true } cubecl = { workspace = true, optional = true } cubek = { workspace = true, features = ["random"], optional = true } ================================================ FILE: crates/burn-backend-tests/README.md ================================================ # Burn Backend Tests This crate provides a comprehensive suite of tests for Burn backends, covering: - Tensor operations: [tests/tensor/](./tests/tensor/) - Autodiff: [tests/autodiff/](./tests/autodiff/) - (Optional) CubeCL kernels correctness: [tests/cubecl/](./tests/cubecl/) ## Running Tests The `TestBackend` is selected via feature flags. Use the provided shorthand commands for convenience: ```sh # Cpu cargo test-cpu # Cuda cargo test-cuda # Rocm cargo test-rocm # Wgpu / WebGpu cargo test-wgpu # Vulkan cargo test-vulkan # Metal cargo test-metal # Router cargo test-router # NdArray cargo test-ndarray # LibTorch cargo test-tch ``` By default, `cargo test` fail-fast across integration test binaries. When one integration test binary fails, Cargo does not run the remaining test binaries. If you want to run all test binaries regardless of failures, pass `--no-fail-fast`, for example: ```sh cargo test-cuda --no-fail-fast ``` ## Structure - `tests/tensor.rs`: Tensor tests - `tests/autodiff.rs`: Autodiff tests - `tests/fusion.rs`: Fusion backend tests wrapping tensor and autodiff tests - `tests/cubecl.rs`: CubeCL kernel tests Each test module assumes exactly one `FloatElemType`, `IntElemType`, and `TestBackend` in scope. ### Common Modules - `common/backend.rs`: Backend type definitions - `common/tensor.rs`: Reusable tensor test suite, split across float, int and bool tensor kinds - `common/autodiff.rs`: Reusable autodiff test suite, with and without checkpointing ### Test Reusability This crate uses a pattern of parameterized test modules to run the same tests with different configurations (backends, dtypes, etc.): 1. **Type aliases define the configuration**: Each test scope declares `FloatElemType`, `IntElemType`, and `TestBackend` 1. **`#[path = "..."]` references shared modules**: Points to test files outside the normal module hierarchy, e.g. `"common/tensor.rs"` 1. **`include!()` imports test code**: Test modules are included multiple times with different type configurations 1. **`use super::*;`** propagates types down the module tree: Each level re-exports parent types so deeply nested tests have access to the configured types For example, `common/tensor.rs` can be included with `FloatElemType = f32` for base tests, then included again with `FloatElemType = f16` for half-precision tests, running the same test suite twice with different dtypes. ## Adding New Tests Add test modules under `tests/tensor/`, `tests/autodiff/`, or `tests/cubecl` respectively. They will automatically run for all required configurations. For tensor tests, make sure to add the test to each relevant tensor kind: - `tensor/bool`: boolean tensor tests - `tensor/float`: float tensor tests - `tensor/int`: integer tensor tests **Guidelines:** Import types with `use super::*;` at the top of each module and use the types defined in `common/backend.rs`: ```rust /// Collection of types used across tests pub use burn_autodiff::Autodiff; pub use burn_tensor::Tensor; pub type TestBackend = ...; pub type TestTensor = Tensor; pub type TestTensorInt = Tensor; pub type TestTensorBool = Tensor; pub type FloatElem = burn_tensor::ops::FloatElem; pub type IntElem = burn_tensor::ops::IntElem; pub type TestAutodiffBackend = Autodiff; pub type TestAutodiffTensor = Tensor; ``` Tests will automatically run with default dtypes and any variants (f16, bf16, etc.) based on the backend configuration. ================================================ FILE: crates/burn-backend-tests/cubecl.toml ================================================ [profiling] logger = { file = "target/profiling.log", level = "disabled" } [autotune] logger = { file = "target/autotune.log", level = "disabled" } [compilation] logger = { file = "target/compilation.log", level = "disabled" } [memory] logger = { file = "target/memory.log", level = "disabled" } [streaming] max_streams = 4 ================================================ FILE: crates/burn-backend-tests/src/lib.rs ================================================ extern crate alloc; #[cfg(feature = "std")] pub use burn_tensor_testgen::might_panic; /// Generate a test module with custom floating element types. #[macro_export] macro_rules! test_float_elem_variant { ($modname:ident, $float:ty, $module:literal, [$($feat:literal),* $(,)?]) => { #[cfg(all(test, any($(feature = $feat),*)))] mod $modname { pub type FloatElemType = $float; #[allow(unused)] pub use super::IntElemType; mod ty { include!("backend.rs"); include!($module); } } }; } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/abs.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance, cast::ToElement}; #[test] fn should_diff_abs() { let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_abs_no_nans() { let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]); let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let contains_nan = grad_2.contains_nan(); assert!(!contains_nan.into_scalar().to_bool()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/adaptive_avgpool1d.rs ================================================ use super::*; use burn_tensor::module::adaptive_avg_pool1d; use burn_tensor::{Shape, Tolerance}; #[test] fn test_avg_pool1d_simple() { let test = AdaptiveAvgPool1dTestCase { batch_size: 1, channels: 2, length: 5, output_size: 3, }; test.assert_output(TestTensor::from_floats( [[ [0.5000, 0.83333, 0.33333, 0.83333, 0.5000], [0.5000, 0.83333, 0.33333, 0.83333, 0.5000], ]], &Default::default(), )); } struct AdaptiveAvgPool1dTestCase { batch_size: usize, channels: usize, length: usize, output_size: usize, } impl AdaptiveAvgPool1dTestCase { fn assert_output(self, x_grad: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels, self.length]); let device = Default::default(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = adaptive_avg_pool1d(x.clone(), self.output_size); let grads = output.backward(); let x_grad_actual = x.grad(&grads).unwrap(); x_grad.to_data().assert_approx_eq::( &x_grad_actual.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/adaptive_avgpool2d.rs ================================================ use super::*; use burn_tensor::module::adaptive_avg_pool2d; use burn_tensor::{Shape, Tolerance}; #[test] fn test_avg_pool2d_simple() { let test = AdaptiveAvgPool2dTestCase { batch_size: 1, channels: 2, height: 5, width: 3, output_size_1: 3, output_size_2: 2, }; test.assert_output(TestTensor::from_floats( [[ [ [0.2500, 0.5000, 0.2500], [0.41667, 0.83333, 0.41667], [0.16667, 0.33333, 0.16667], [0.41667, 0.83333, 0.41667], [0.2500, 0.5000, 0.2500], ], [ [0.2500, 0.5000, 0.2500], [0.41667, 0.83333, 0.41667], [0.16667, 0.33333, 0.16667], [0.41667, 0.83333, 0.41667], [0.2500, 0.5000, 0.2500], ], ]], &Default::default(), )); } #[test] fn test_avg_pool2d_output_1() { let test = AdaptiveAvgPool2dTestCase { batch_size: 1, channels: 1, height: 4, width: 8, output_size_1: 1, output_size_2: 1, }; test.assert_output(TestTensor::from_floats( [[[ [ 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, ], [ 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, ], [ 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, ], [ 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, ], ]]], &Default::default(), )); } struct AdaptiveAvgPool2dTestCase { batch_size: usize, channels: usize, height: usize, width: usize, output_size_1: usize, output_size_2: usize, } impl AdaptiveAvgPool2dTestCase { fn assert_output(self, x_grad: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let device = Default::default(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]); let grads = output.backward(); let x_grad_actual = x.grad(&grads).unwrap(); x_grad.to_data().assert_approx_eq::( &x_grad_actual.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/add.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_add() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad(); let tensor_3 = tensor_1.clone() + tensor_2.clone(); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([1.0, 1.0]), false); grad_2 .to_data() .assert_eq(&TensorData::from([1.0, 1.0]), false); tensor_3 .to_data() .assert_eq(&TensorData::from([6.0, 6.0]), false); } #[test] fn should_diff_add_scalar() { let data = TensorData::from([2.0, 10.0]); let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad(); let tensor_out = tensor.clone().add_scalar(5.0); let grads = tensor_out.backward(); let grad = tensor.grad(&grads).unwrap(); grad.to_data() .assert_eq(&TensorData::from([1.0, 1.0]), false); tensor_out .into_data() .assert_eq(&TensorData::from([7.0, 15.0]), false); } #[test] fn test_add_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().add(tensor_2.clone()); let tensor_5 = tensor_4 .add(tensor_3) .add_scalar(5.0) .add(tensor_1.clone()) .add(tensor_2.clone()); let tensor_6 = tensor_1.clone().add(tensor_5); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/aggregation.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_mean() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_sum_1() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_sum_2() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.clone().sum_dim(1); let tensor_5 = tensor_4.mul(tensor_3); let grads = tensor_5.sum().backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_mean_dim() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_sum_dim() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/avgpool1d.rs ================================================ use super::*; use burn_tensor::module::avg_pool1d; use burn_tensor::{Shape, Tolerance}; #[test] fn test_avg_pool1d_simple() { let test = AvgPool1dTestCase { batch_size: 1, channels: 1, kernel_size: 3, padding: 0, stride: 1, length: 6, count_include_pad: true, }; test.assert_output(TestTensor::from_floats( [[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]], &Default::default(), )); } #[test] fn test_avg_pool1d_complex() { let test = AvgPool1dTestCase { batch_size: 1, channels: 2, kernel_size: 3, padding: 1, stride: 2, length: 6, count_include_pad: true, }; test.assert_output(TestTensor::from_floats( [[ [0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333], [0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333], ]], &Default::default(), )); } #[test] fn test_avg_pool1d_complex_dont_count_pad() { let test = AvgPool1dTestCase { batch_size: 1, channels: 2, kernel_size: 3, padding: 1, stride: 2, length: 6, count_include_pad: false, }; test.assert_output(TestTensor::from_floats( [[ [0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333], [0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333], ]], &Default::default(), )); } struct AvgPool1dTestCase { batch_size: usize, channels: usize, kernel_size: usize, padding: usize, stride: usize, length: usize, count_include_pad: bool, } impl AvgPool1dTestCase { fn assert_output(self, x_grad: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels, self.length]); let device = Default::default(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = avg_pool1d( x.clone(), self.kernel_size, self.stride, self.padding, self.count_include_pad, false, ); let grads = output.backward(); let x_grad_actual = x.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); x_grad .to_data() .assert_approx_eq::(&x_grad_actual.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/avgpool2d.rs ================================================ use super::*; use burn_tensor::module::avg_pool2d; use burn_tensor::{Shape, Tolerance}; #[test] fn test_avg_pool2d_simple() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, height: 6, width: 6, count_include_pad: true, }; test.assert_output(TestTensor::from_floats( [[[ [0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111], [0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222], [0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333], [0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333], [0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222], [0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111], ]]], &Default::default(), )); } #[test] fn test_avg_pool2d_complex() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 2, height: 4, width: 6, count_include_pad: true, }; test.assert_output(TestTensor::from_floats( [[[ [0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333], [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000], [0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333], ]]], &Default::default(), )); } #[test] fn test_avg_pool2d_complex_dont_include_pad() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 2, height: 4, width: 6, count_include_pad: false, }; test.assert_output(TestTensor::from_floats( [[[ [0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250], [0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750], [0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750], [0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250], ]]], &Default::default(), )); } struct AvgPool2dTestCase { batch_size: usize, channels: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, height: usize, width: usize, count_include_pad: bool, } impl AvgPool2dTestCase { fn assert_output(self, x_grad: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let device = Default::default(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = avg_pool2d( x.clone(), [self.kernel_size_1, self.kernel_size_2], [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], self.count_include_pad, false, ); let grads = output.backward(); let x_grad_actual = x.grad(&grads).unwrap(); x_grad.to_data().assert_approx_eq::( &x_grad_actual.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/backward.rs ================================================ use super::*; use burn_tensor::{Int, Tensor, TensorData, module::embedding}; #[test] fn test_embedding_backward() { let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TensorData::from([[0, 1], [1, 1]]); let x = TensorData::from([ [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]], [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]], ]); let device = Default::default(); let weights = Tensor::::from_data(weights, &device).require_grad(); let indices = Tensor::::from_data(indices, &device); let x = Tensor::::from_data(x, &device).require_grad(); let output = embedding(weights.clone(), indices); let output = output.matmul(x); let grads = output.backward(); let grad = weights.grad(&grads).unwrap(); grad.to_data() .assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/bridge.rs ================================================ use super::*; use burn_tensor::{DType, Distribution, Tensor}; #[test] fn test_full_precision() { let device = Default::default(); let x1 = Tensor::::random([32, 32], Distribution::Default, &device) .require_grad(); let x2 = Tensor::::random([32, 32], Distribution::Default, &device) .require_grad(); let dtype = x1.dtype(); let x3 = x1.clone().cast(DType::F32); let x4 = x2.clone().cast(DType::F32); let x5 = x3.matmul(x4); let x6 = x5.cast(dtype); let x7 = x6 * x1.clone() / x2.clone(); let grads = x7.backward(); let x1_grad = x1.grad(&grads); let x2_grad = x2.grad(&grads); assert!(x1_grad.is_some()); assert!(x2_grad.is_some()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/broadcast.rs ================================================ use super::*; #[test] fn mul_broadcast() { test_ops_broadcast_backward(|x, y| x * y); } #[test] fn div_broadcast() { test_ops_broadcast_backward(|x, y| x / y); } #[test] fn sub_broadcast() { test_ops_broadcast_backward(|x, y| x - y); } #[test] fn add_broadcast() { test_ops_broadcast_backward(|x, y| x + y); } #[test] fn matmul_broadcast() { test_ops_broadcast_backward(|x, y| x.matmul(y)); } #[test] fn mask_where_broadcast() { test_ops_broadcast_backward(|x, y| { let cond = y.clone().equal_elem(4); x.mask_where(cond, y) }); } fn test_ops_broadcast_backward(func: F) where F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>, { let device = Default::default(); let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad(); let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad(); // Slice isn't a broadcastable operation, so it will fail when the previous backward pass // of an operation that support broadcast doesn't support it during the backward pass. let y = func(w.clone().slice([0..1]), x.clone()); // Will panic if broadcast isn't supported! let grads = y.backward(); let w_grad = w.grad(&grads).unwrap(); let x_grad = x.grad(&grads).unwrap(); assert_eq!(w_grad.shape(), w.shape()); assert_eq!(x_grad.shape(), x.shape()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cast.rs ================================================ // Skip on metal - F64 not supported #![cfg(all(feature = "std", not(feature = "metal")))] use super::*; use burn_backend_tests::might_panic; use burn_tensor::{DType, Tensor, TensorData}; #[might_panic(reason = "Unsupported precision for fusion")] #[test] fn cast_keeps_gradient_flow() { let device = Default::default(); let x = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ) .require_grad(); let y = x.clone().cast(DType::F64); let z = y.sum(); let grads = z.backward(); let grad_x = x.grad(&grads).unwrap(); grad_x .to_data() .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cat.rs ================================================ use super::*; use burn_tensor::Tolerance; #[test] fn should_diff_cat() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let mut tensor_1_list = Vec::new(); let mut tensor_2_list = Vec::new(); for i in 0..2 { tensor_1_list.push(tensor_1.clone().slice([i..i + 1])); tensor_2_list.push(tensor_2.clone().slice([i..i + 1])); } let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0); let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0); let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone()); let grads = tensor_3_cat.backward(); let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]); let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]); let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]); let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]); grad_1 .clone() .slice([0..1]) .to_data() .assert_approx_eq::(&grad_1_slice_1.to_data(), Tolerance::default()); grad_1 .slice([1..2]) .to_data() .assert_approx_eq::(&grad_1_slice_2.to_data(), Tolerance::default()); grad_2 .clone() .slice([0..1]) .to_data() .assert_approx_eq::(&grad_2_slice_1.to_data(), Tolerance::default()); grad_2 .slice([1..2]) .to_data() .assert_approx_eq::(&grad_2_slice_2.to_data(), Tolerance::default()); } #[test] fn should_diff_cat_more_than_1_dim() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device) .require_grad(); // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0. // The resulting tensor should be [5, 2] let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0); assert_eq!(tensor_3.dims(), [5, 2]); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); assert_eq!(tensor_1.dims(), grad_1.dims()); assert_eq!(tensor_2.dims(), grad_2.dims()); } #[test] fn should_slice_grads_correctly_when_some_inputs_not_tracked() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked let tensor_3 = TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked let cat = TestAutodiffTensor::cat( vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()], 1, ); // Make gradient per column unique so wrong slicing shows up. let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device); let loss = (cat * weights).sum(); let grads = loss.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_3 = tensor_3.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&burn_tensor::TensorData::from([[1.0]]), false); grad_3 .to_data() .assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/ceil.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_ceil() { let data = TensorData::from([ [-1.9751, 0.0714, 0.0643, 0.2406], [-1.3172, 0.1252, -0.1119, -0.0127], ]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); let tensor_2 = tensor_1.clone().ceil(); let grads = tensor_2.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/checkpoint.rs ================================================ use super::*; use burn_tensor::{Bool, Tensor, TensorData}; #[test] fn test_autodiff_checkpoint_complicated_computation() { let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]); let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]); let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]); let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]); let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad(); let tensor_5 = compute_bound_eager(tensor_0, tensor_1); let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone()); let tensor_7 = memory_bound_eager(tensor_3, tensor_4); let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone()); let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.); let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone()); let tensor_11 = memory_bound_lazy(tensor_8, tensor_9); let tensor_12 = compute_bound_lazy(tensor_10, tensor_11); assert_checkpoint(tensor_12); } #[test] fn test_autodiff_checkpoint_with_missing_requirement() { let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]); let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad let tensor_2 = memory_bound_eager(tensor_0, tensor_1); let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.); let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.); let tensor_5 = compute_bound_lazy(tensor_3, tensor_4); let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.); let tensor_7 = memory_bound_eager(tensor_5, tensor_2); let tensor_8 = memory_bound_eager(tensor_6, tensor_7); assert_checkpoint(tensor_8); } #[test] fn test_autodiff_checkpoint_with_many_duplicates() { let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad(); let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone()); let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone()); let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone()); let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone()); let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone()); let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone()); let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone()); let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone()); let tensor_9 = memory_bound_lazy(tensor_6, tensor_7); let tensor_10 = memory_bound_eager(tensor_0, tensor_9); let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.); let tensor_12 = compute_bound_lazy(tensor_8, tensor_11); assert_checkpoint(tensor_12); } #[test] fn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() { let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]); let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]); let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]); let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]); let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad(); let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone()); let tensor_6 = memory_bound_eager(tensor_5, tensor_2); let tensor_7 = memory_bound_eager(tensor_6, tensor_3); let tensor_8 = memory_bound_eager(tensor_7, tensor_4); let tensor_9 = memory_bound_lazy(tensor_8, tensor_1); assert_checkpoint(tensor_9) } #[test] fn test_autodiff_checkpoint_half_sub_graph_not_tracked() { let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]); let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]); let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]); let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]); let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]); let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad(); let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad(); let tensor_6 = memory_bound_lazy(tensor_0, tensor_1); let tensor_7 = compute_bound_eager(tensor_6, tensor_2); let tensor_8 = memory_bound_eager(tensor_3, tensor_4); let tensor_9 = compute_bound_lazy(tensor_8, tensor_5); let tensor_10 = compute_bound_lazy(tensor_7, tensor_9); assert_checkpoint(tensor_10); } #[test] fn test_autodiff_checkpoint_very_complex() { let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]); let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]); let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]); let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]); let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]); let device = Default::default(); let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad(); let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.); let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone()); let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6); let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone()); let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.); let tensor_10 = compute_bound_eager(tensor_5, tensor_8); let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9); let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2); let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11); let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.); let tensor_15 = compute_bound_lazy(tensor_4, tensor_12); let tensor_16 = memory_bound_lazy(tensor_10, tensor_7); let tensor_17 = compute_bound_lazy(tensor_13, tensor_1); let tensor_18 = memory_bound_eager(tensor_15, tensor_16); let tensor_19 = compute_bound_eager(tensor_14, tensor_17); let tensor_20 = memory_bound_lazy(tensor_18, tensor_19); let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.); assert_checkpoint(tensor_21) } fn assert_checkpoint(tensor: TestAutodiffTensor) { // Assert is not explicit here, but the test can fail // - when a tensor is actually required more than n_required, it won't be found and will panic // - when a tensor is actually required less than n_required, the backward states map won't be // empty and will fail the assertion within the backward code, same for retro_forwards tensor.backward(); } // Does not save its state and does not need its parents fn memory_bound_eager( tensor_a: TestAutodiffTensor, tensor_b: TestAutodiffTensor, ) -> TestAutodiffTensor { tensor_a.add(tensor_b) } fn memory_bound_eager_scalar( tensor_a: TestAutodiffTensor, b: f32, ) -> TestAutodiffTensor { tensor_a.add_scalar(b) } // Saves its own state and does not need its parents fn compute_bound_eager( tensor_a: TestAutodiffTensor, tensor_b: TestAutodiffTensor, ) -> TestAutodiffTensor { let mask = Tensor::::empty(tensor_a.shape(), &tensor_a.device()); tensor_a.mask_where(mask, tensor_b) } fn compute_bound_eager_scalar( tensor_a: TestAutodiffTensor, b: f32, ) -> TestAutodiffTensor { let mask = Tensor::::empty(tensor_a.shape(), &tensor_a.device()); tensor_a.mask_fill(mask, b) } // Does not save its state and needs its parents fn memory_bound_lazy( tensor_a: TestAutodiffTensor, tensor_b: TestAutodiffTensor, ) -> TestAutodiffTensor { tensor_a.mul(tensor_b) } // Saves its own state and needs its parents fn compute_bound_lazy( tensor_a: TestAutodiffTensor, tensor_b: TestAutodiffTensor, ) -> TestAutodiffTensor { tensor_a.matmul(tensor_b) } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/complex.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_full_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.matmul(tensor_1.clone()); let tensor_5 = tensor_4.mul(tensor_2.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false); } #[test] fn should_diff_full_complex_2() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.matmul(tensor_1.clone()); let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false); } #[test] fn should_diff_full_complex_3() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.matmul(tensor_1.clone()); let tensor_5 = tensor_4.clone().sub(tensor_2.clone()); let tensor_6 = tensor_5.add(tensor_4); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv1d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions}; #[test] fn test_conv1d_basic() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 1, groups: 1, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[14., 24., 24., 18.], [26., 42., 42., 30.]], [[14., 24., 24., 18.], [26., 42., 42., 30.]], ], &device, ), weight: TestTensor::from_floats( [ [[30., 44., 36.], [54., 76., 60.]], [[30., 44., 36.], [54., 76., 60.]], ], &device, ), bias: TestTensor::from_floats([8., 8.], &device), }; test.assert_grads(grads); } #[test] fn test_conv1d_different_channels() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 3, kernel_size: 3, padding: 1, stride: 1, dilation: 1, groups: 1, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[39., 63., 63., 45.], [57., 90., 90., 63.]], [[39., 63., 63., 45.], [57., 90., 90., 63.]], ], &device, ), weight: TestTensor::from_floats( [ [[30., 44., 36.], [54., 76., 60.]], [[30., 44., 36.], [54., 76., 60.]], [[30., 44., 36.], [54., 76., 60.]], ], &device, ), bias: TestTensor::from_floats([8., 8., 8.], &device), }; test.assert_grads(grads); } #[test] fn test_conv1d_with_padding() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 2, stride: 1, dilation: 1, groups: 1, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[24., 24., 24., 24.], [42., 42., 42., 42.]], [[24., 24., 24., 24.], [42., 42., 42., 42.]], ], &device, ), weight: TestTensor::from_floats( [ [[44., 44., 44.], [76., 76., 76.]], [[44., 44., 44.], [76., 76., 76.]], ], &device, ), bias: TestTensor::from_floats([12., 12.], &device), }; test.assert_grads(grads); } #[test] fn test_conv1d_with_stride() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 2, dilation: 1, groups: 1, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[8., 16., 8., 10.], [14., 28., 14., 16.]], [[8., 16., 8., 10.], [14., 28., 14., 16.]], ], &device, ), weight: TestTensor::from_floats( [ [[10., 20., 24.], [18., 36., 40.]], [[10., 20., 24.], [18., 36., 40.]], ], &device, ), bias: TestTensor::from_floats([4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_conv1d_dilation() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 2, groups: 1, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[6., 8., 8., 10.], [12., 14., 14., 16.]], [[6., 8., 8., 10.], [12., 14., 14., 16.]], ], &device, ), weight: TestTensor::from_floats( [ [[8., 22., 14.], [16., 38., 22.]], [[8., 22., 14.], [16., 38., 22.]], ], &device, ), bias: TestTensor::from_floats([4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_conv1d_groups() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 1, groups: 2, length: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[1., 3., 3., 3.], [7., 12., 12., 9.]], [[1., 3., 3., 3.], [7., 12., 12., 9.]], ], &device, ), weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device), bias: TestTensor::from_floats([8., 8.], &device), }; test.assert_grads(grads); } struct Conv1dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size: usize, padding: usize, stride: usize, dilation: usize, groups: usize, length: usize, } struct Grads { x: TestTensor<3>, weight: TestTensor<3>, bias: TestTensor<1>, } impl Conv1dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<3, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv1d( x.clone(), weight.clone(), Some(bias.clone()), ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::default(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv2d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions}; #[test] fn test_conv2d_basic() { let test = Conv2dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [88., 138., 138., 96.], [150., 234., 234., 162.], [150., 234., 234., 162.], [112., 174., 174., 120.], ], [ [160., 246., 246., 168.], [258., 396., 396., 270.], [258., 396., 396., 270.], [184., 282., 282., 192.], ], ], [ [ [88., 138., 138., 96.], [150., 234., 234., 162.], [150., 234., 234., 162.], [112., 174., 174., 120.], ], [ [160., 246., 246., 168.], [258., 396., 396., 270.], [258., 396., 396., 270.], [184., 282., 282., 192.], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], ], [ [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], ], ], &device, ), bias: TestTensor::from_floats([32., 32.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_channels() { let test = Conv2dTestCase { batch_size: 2, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [240., 369., 369., 252.], [387., 594., 594., 405.], [387., 594., 594., 405.], [276., 423., 423., 288.], ], [ [348., 531., 531., 360.], [549., 837., 837., 567.], [549., 837., 837., 567.], [384., 585., 585., 396.], ], ], [ [ [240., 369., 369., 252.], [387., 594., 594., 405.], [387., 594., 594., 405.], [276., 423., 423., 288.], ], [ [348., 531., 531., 360.], [549., 837., 837., 567.], [549., 837., 837., 567.], [384., 585., 585., 396.], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], ], [ [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], ], [ [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]], [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]], ], ], &device, ), bias: TestTensor::from_floats([32., 32., 32.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_kernel_size() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [116., 180., 192., 132.], [198., 306., 324., 222.], [198., 306., 324., 222.], [148., 228., 240., 164.], ], [ [212., 324., 336., 228.], [342., 522., 540., 366.], [342., 522., 540., 366.], [244., 372., 384., 260.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [27., 45., 54., 39.], [52., 84., 96., 68.], [51., 81., 90., 63.], ], [ [123., 189., 198., 135.], [180., 276., 288., 196.], [147., 225., 234., 159.], ], ], [ [ [27., 45., 54., 39.], [52., 84., 96., 68.], [51., 81., 90., 63.], ], [ [123., 189., 198., 135.], [180., 276., 288., 196.], [147., 225., 234., 159.], ], ], ], &device, ), bias: TestTensor::from_floats([12., 12.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_padding() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [138., 138., 138., 138.], [234., 234., 234., 234.], [234., 234., 234., 234.], [174., 174., 174., 174.], ], [ [246., 246., 246., 246.], [396., 396., 396., 396.], [396., 396., 396., 396.], [282., 282., 282., 282.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], ], [ [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]], [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]], ], ], &device, ), bias: TestTensor::from_floats([24., 24.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_width() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 5, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [88., 138., 138., 138., 96.], [150., 234., 234., 234., 162.], [150., 234., 234., 234., 162.], [112., 174., 174., 174., 120.], ], [ [160., 246., 246., 246., 168.], [258., 396., 396., 396., 270.], [258., 396., 396., 396., 270.], [184., 282., 282., 282., 192.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], ], [ [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]], [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]], ], ], &device, ), bias: TestTensor::from_floats([20., 20.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_stride_2() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 2, stride_2: 2, dilation_1: 1, dilation_2: 1, groups: 1, height: 6, width: 6, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [26., 52., 26., 52., 26., 28.], [52., 104., 52., 104., 52., 56.], [26., 52., 26., 52., 26., 28.], [52., 104., 52., 104., 52., 56.], [26., 52., 26., 52., 26., 28.], [32., 64., 32., 64., 32., 34.], ], [ [44., 88., 44., 88., 44., 46.], [88., 176., 88., 176., 88., 92.], [44., 88., 44., 88., 44., 46.], [88., 176., 88., 176., 88., 92.], [44., 88., 44., 88., 44., 46.], [50., 100., 50., 100., 50., 52.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], ], [ [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]], [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]], ], ], &device, ), bias: TestTensor::from_floats([9., 9.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_stride() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 3, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 8, width: 8, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [50., 78., 78., 78., 78., 78., 78., 54.], [62., 96., 96., 96., 96., 96., 96., 66.], [38., 60., 60., 60., 60., 60., 60., 42.], [50., 78., 78., 78., 78., 78., 78., 54.], [62., 96., 96., 96., 96., 96., 96., 66.], [38., 60., 60., 60., 60., 60., 60., 42.], [50., 78., 78., 78., 78., 78., 78., 54.], [62., 96., 96., 96., 96., 96., 96., 66.], ], [ [86., 132., 132., 132., 132., 132., 132., 90.], [98., 150., 150., 150., 150., 150., 150., 102.], [74., 114., 114., 114., 114., 114., 114., 78.], [86., 132., 132., 132., 132., 132., 132., 90.], [98., 150., 150., 150., 150., 150., 150., 102.], [74., 114., 114., 114., 114., 114., 114., 78.], [86., 132., 132., 132., 132., 132., 132., 90.], [98., 150., 150., 150., 150., 150., 150., 102.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], [ [1330., 1528., 1344.], [1911., 2196., 1932.], [2079., 2388., 2100.], ], ], [ [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]], [ [1330., 1528., 1344.], [1911., 2196., 1932.], [2079., 2388., 2100.], ], ], ], &device, ), bias: TestTensor::from_floats([24., 24.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_dilation_2() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 2, dilation_2: 2, groups: 1, height: 6, width: 6, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [18., 38., 38., 42., 42., 22.], [42., 88., 88., 96., 96., 50.], [42., 88., 88., 96., 96., 50.], [54., 112., 112., 120., 120., 62.], [54., 112., 112., 120., 120., 62.], [30., 62., 62., 66., 66., 34.], ], [ [36., 74., 74., 78., 78., 40.], [78., 160., 160., 168., 168., 86.], [78., 160., 160., 168., 168., 86.], [90., 184., 184., 192., 192., 98.], [90., 184., 184., 192., 192., 98.], [48., 98., 98., 102., 102., 52.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], ], [ [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]], [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]], ], ], &device, ), bias: TestTensor::from_floats([16., 16.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_different_dilation() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 2, dilation_2: 3, groups: 1, height: 6, width: 6, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [18., 0., 20., 20., 0., 22.], [42., 0., 46., 46., 0., 50.], [42., 0., 46., 46., 0., 50.], [54., 0., 58., 58., 0., 62.], [54., 0., 58., 58., 0., 62.], [30., 0., 32., 32., 0., 34.], ], [ [36., 0., 38., 38., 0., 40.], [78., 0., 82., 82., 0., 86.], [78., 0., 82., 82., 0., 86.], [90., 0., 94., 94., 0., 98.], [90., 0., 94., 94., 0., 98.], [48., 0., 50., 50., 0., 52.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], ], [ [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]], [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]], ], ], &device, ), bias: TestTensor::from_floats([8., 8.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_groups() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 5, width: 5, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [0., 1., 3., 3., 2.], [3., 8., 15., 12., 7.], [9., 21., 36., 27., 15.], [9., 20., 33., 24., 13.], [6., 13., 21., 15., 8.], ], [ [9., 19., 30., 21., 11.], [21., 44., 69., 48., 25.], [36., 75., 117., 81., 42.], [27., 56., 87., 60., 31.], [15., 31., 48., 33., 17.], ], ]], &device, ), weight: TestTensor::from_floats( [ [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]], [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]], ], &device, ), bias: TestTensor::from_floats([9., 9.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_groups_stride_2() { let test = Conv2dTestCase { batch_size: 1, channels_in: 4, channels_out: 4, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 2, stride_2: 2, dilation_1: 1, dilation_2: 1, groups: 4, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [4., 8., 4., 5.], [8., 16., 8., 10.], [4., 8., 4., 5.], [7., 14., 7., 8.], ], [ [13., 26., 13., 14.], [26., 52., 26., 28.], [13., 26., 13., 14.], [16., 32., 16., 17.], ], [ [22., 44., 22., 23.], [44., 88., 44., 46.], [22., 44., 22., 23.], [25., 50., 25., 26.], ], [ [31., 62., 31., 32.], [62., 124., 62., 64.], [31., 62., 31., 32.], [34., 68., 34., 35.], ], ]], &device, ), weight: TestTensor::from_floats( [ [[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]], [[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]], [[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]], [[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]], ], &device, ), bias: TestTensor::from_floats([4., 4., 4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_groups_different_channels() { let test = Conv2dTestCase { batch_size: 1, channels_in: 3, channels_out: 6, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 3, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [9., 20., 24., 13.], [24., 52., 60., 32.], [36., 76., 84., 44.], [21., 44., 48., 25.], ], [ [45., 92., 96., 49.], [96., 196., 204., 104.], [108., 220., 228., 116.], [57., 116., 120., 61.], ], [ [81., 164., 168., 85.], [168., 340., 348., 176.], [180., 364., 372., 188.], [93., 188., 192., 97.], ], ]], &device, ), weight: TestTensor::from_floats( [ [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]], [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]], [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]], ], &device, ), bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_complex() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 2, kernel_size_2: 3, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 2, dilation_1: 2, dilation_2: 3, groups: 1, height: 4, width: 5, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [36., 39., 0., 39., 42.], [81., 87., 0., 87., 93.], [81., 87., 0., 87., 93.], [45., 48., 0., 48., 51.], ], [ [54., 57., 0., 57., 60.], [117., 123., 0., 123., 129.], [117., 123., 0., 123., 129.], [63., 66., 0., 66., 69.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[15., 42., 27.], [30., 72., 42.]], [[75., 162., 87.], [90., 192., 102.]], ], [ [[15., 42., 27.], [30., 72., 42.]], [[75., 162., 87.], [90., 192., 102.]], ], [ [[15., 42., 27.], [30., 72., 42.]], [[75., 162., 87.], [90., 192., 102.]], ], ], &device, ), bias: TestTensor::from_floats([8., 8., 8.], &device), }; test.assert_grads(grads); } #[test] fn test_conv2d_groups_stride_2_no_pad() { let test = Conv2dTestCase { batch_size: 1, channels_in: 4, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 2, stride_2: 2, dilation_1: 1, dilation_2: 1, groups: 2, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [0., 1., 2., 0.], [3., 4., 5., 0.], [6., 7., 8., 0.], [0., 0., 0., 0.], ], [ [9., 10., 11., 0.], [12., 13., 14., 0.], [15., 16., 17., 0.], [0., 0., 0., 0.], ], [ [18., 19., 20., 0.], [21., 22., 23., 0.], [24., 25., 26., 0.], [0., 0., 0., 0.], ], [ [27., 28., 29., 0.], [30., 31., 32., 0.], [33., 34., 35., 0.], [0., 0., 0., 0.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]], [[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]], ], [ [[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]], [[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]], ], ], &device, ), bias: TestTensor::from_floats([1., 1.], &device), }; test.assert_grads(grads); } struct Conv2dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, dilation_1: usize, dilation_2: usize, groups: usize, height: usize, width: usize, } struct Grads { x: TestTensor<4>, weight: TestTensor<4>, bias: TestTensor<1>, } impl Conv2dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv2d( x.clone(), weight.clone(), Some(bias.clone()), ConvOptions::new( [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.dilation_1, self.dilation_2], self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::rel_abs(0.01, 0.01); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv3d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions}; #[test] fn test_conv3d_basic() { let test = Conv3dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 4, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [ [536., 816., 816., 552.], [840., 1278., 1278., 864.], [840., 1278., 1278., 864.], [584., 888., 888., 600.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [680., 1032., 1032., 696.], [1056., 1602., 1602., 1080.], [1056., 1602., 1602., 1080.], [728., 1104., 1104., 744.], ], ], [ [ [968., 1464., 1464., 984.], [1488., 2250., 2250., 1512.], [1488., 2250., 2250., 1512.], [1016., 1536., 1536., 1032.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1112., 1680., 1680., 1128.], [1704., 2574., 2574., 1728.], [1704., 2574., 2574., 1728.], [1160., 1752., 1752., 1176.], ], ], ], [ [ [ [536., 816., 816., 552.], [840., 1278., 1278., 864.], [840., 1278., 1278., 864.], [584., 888., 888., 600.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [912., 1386., 1386., 936.], [1422., 2160., 2160., 1458.], [1422., 2160., 2160., 1458.], [984., 1494., 1494., 1008.], ], [ [680., 1032., 1032., 696.], [1056., 1602., 1602., 1080.], [1056., 1602., 1602., 1080.], [728., 1104., 1104., 744.], ], ], [ [ [968., 1464., 1464., 984.], [1488., 2250., 2250., 1512.], [1488., 2250., 2250., 1512.], [1016., 1536., 1536., 1032.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1560., 2358., 2358., 1584.], [2394., 3618., 3618., 2430.], [2394., 3618., 3618., 2430.], [1632., 2466., 2466., 1656.], ], [ [1112., 1680., 1680., 1128.], [1704., 2574., 2574., 1728.], [1704., 2574., 2574., 1728.], [1160., 1752., 1752., 1176.], ], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [ [ [4590., 6156., 4644.], [6264., 8400., 6336.], [4806., 6444., 4860.], ], [ [6696., 8976., 6768.], [9120., 12224., 9216.], [6984., 9360., 7056.], ], [ [5454., 7308., 5508.], [7416., 9936., 7488.], [5670., 7596., 5724.], ], ], [ [ [8046., 10764., 8100.], [10872., 14544., 10944.], [8262., 11052., 8316.], ], [ [11304., 15120., 11376.], [15264., 20416., 15360.], [11592., 15504., 11664.], ], [ [8910., 11916., 8964.], [12024., 16080., 12096.], [9126., 12204., 9180.], ], ], ], [ [ [ [4590., 6156., 4644.], [6264., 8400., 6336.], [4806., 6444., 4860.], ], [ [6696., 8976., 6768.], [9120., 12224., 9216.], [6984., 9360., 7056.], ], [ [5454., 7308., 5508.], [7416., 9936., 7488.], [5670., 7596., 5724.], ], ], [ [ [8046., 10764., 8100.], [10872., 14544., 10944.], [8262., 11052., 8316.], ], [ [11304., 15120., 11376.], [15264., 20416., 15360.], [11592., 15504., 11664.], ], [ [8910., 11916., 8964.], [12024., 16080., 12096.], [9126., 12204., 9180.], ], ], ], ], &device, ), bias: TestTensor::from_floats([128., 128.], &device), }; test.assert_grads(grads); } #[test] fn test_conv3d_complex() { let test = Conv3dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 2, kernel_size_2: 3, kernel_size_3: 4, padding_1: 1, padding_2: 2, padding_3: 3, stride_1: 1, stride_2: 2, stride_3: 3, dilation_1: 2, dilation_2: 3, dilation_3: 4, groups: 1, depth: 5, height: 6, width: 7, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [ [0., 147., 0., 0., 0., 150., 0.], [0., 159., 0., 0., 0., 162., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 159., 0., 0., 0., 162., 0.], [0., 171., 0., 0., 0., 174., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 330., 0., 0., 0., 336., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 354., 0., 0., 0., 360., 0.], [0., 378., 0., 0., 0., 384., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 183., 0., 0., 0., 186., 0.], [0., 195., 0., 0., 0., 198., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 195., 0., 0., 0., 198., 0.], [0., 207., 0., 0., 0., 210., 0.], [0., 0., 0., 0., 0., 0., 0.], ], ], [ [ [0., 219., 0., 0., 0., 222., 0.], [0., 231., 0., 0., 0., 234., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 231., 0., 0., 0., 234., 0.], [0., 243., 0., 0., 0., 246., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 474., 0., 0., 0., 480., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 498., 0., 0., 0., 504., 0.], [0., 522., 0., 0., 0., 528., 0.], [0., 0., 0., 0., 0., 0., 0.], ], [ [0., 255., 0., 0., 0., 258., 0.], [0., 267., 0., 0., 0., 270., 0.], [0., 0., 0., 0., 0., 0., 0.], [0., 267., 0., 0., 0., 270., 0.], [0., 279., 0., 0., 0., 282., 0.], [0., 0., 0., 0., 0., 0., 0.], ], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], [ [ [ [0., 256., 272., 0.], [0., 624., 656., 0.], [0., 368., 384., 0.], ], [ [0., 424., 440., 0.], [0., 960., 992., 0.], [0., 536., 552., 0.], ], ], [ [ [0., 1096., 1112., 0.], [0., 2304., 2336., 0.], [0., 1208., 1224., 0.], ], [ [0., 1264., 1280., 0.], [0., 2640., 2672., 0.], [0., 1376., 1392., 0.], ], ], ], ], &device, ), bias: TestTensor::from_floats([10., 10., 10.], &device), }; test.assert_grads(grads); } #[test] fn test_conv3d_groups_stride_2_no_pad() { let test = Conv3dTestCase { batch_size: 1, channels_in: 4, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 0, padding_2: 0, padding_3: 0, stride_1: 2, stride_2: 2, stride_3: 2, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 2, depth: 4, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [ [0., 1., 2., 0.], [3., 4., 5., 0.], [6., 7., 8., 0.], [0., 0., 0., 0.], ], [ [9., 10., 11., 0.], [12., 13., 14., 0.], [15., 16., 17., 0.], [0., 0., 0., 0.], ], [ [18., 19., 20., 0.], [21., 22., 23., 0.], [24., 25., 26., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [27., 28., 29., 0.], [30., 31., 32., 0.], [33., 34., 35., 0.], [0., 0., 0., 0.], ], [ [36., 37., 38., 0.], [39., 40., 41., 0.], [42., 43., 44., 0.], [0., 0., 0., 0.], ], [ [45., 46., 47., 0.], [48., 49., 50., 0.], [51., 52., 53., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [54., 55., 56., 0.], [57., 58., 59., 0.], [60., 61., 62., 0.], [0., 0., 0., 0.], ], [ [63., 64., 65., 0.], [66., 67., 68., 0.], [69., 70., 71., 0.], [0., 0., 0., 0.], ], [ [72., 73., 74., 0.], [75., 76., 77., 0.], [78., 79., 80., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], [ [ [81., 82., 83., 0.], [84., 85., 86., 0.], [87., 88., 89., 0.], [0., 0., 0., 0.], ], [ [90., 91., 92., 0.], [93., 94., 95., 0.], [96., 97., 98., 0.], [0., 0., 0., 0.], ], [ [99., 100., 101., 0.], [102., 103., 104., 0.], [105., 106., 107., 0.], [0., 0., 0., 0.], ], [ [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]], [[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]], [[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]], ], [ [[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]], [[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]], [[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]], ], ], [ [ [[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]], [[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]], [[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]], ], [ [[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]], [[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]], [[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]], ], ], ], &device, ), bias: TestTensor::from_floats([1., 1.], &device), }; test.assert_grads(grads); } struct Conv3dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, kernel_size_3: usize, padding_1: usize, padding_2: usize, padding_3: usize, stride_1: usize, stride_2: usize, stride_3: usize, dilation_1: usize, dilation_2: usize, dilation_3: usize, groups: usize, depth: usize, height: usize, width: usize, } struct Grads { x: TestTensor<5>, weight: TestTensor<5>, bias: TestTensor<1>, } impl Conv3dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([ self.batch_size, self.channels_in, self.depth, self.height, self.width, ]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, self.kernel_size_3, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<5, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<5, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv3d( x.clone(), weight.clone(), Some(bias.clone()), ConvOptions::new( [self.stride_1, self.stride_2, self.stride_3], [self.padding_1, self.padding_2, self.padding_3], [self.dilation_1, self.dilation_2, self.dilation_3], self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::default(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose1d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions}; #[test] fn test_conv_transpose1d_basic() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 2], kernel_size: 3, padding: 0, padding_out: 0, stride: 1, dilation: 1, groups: 1, size: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]], ], &device, ), weight: TestTensor::from_floats( [ [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]], [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]], ], &device, ), bias: TestTensor::from_floats([12., 12.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose1d_padding() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 2], kernel_size: 3, padding: 2, padding_out: 0, stride: 1, dilation: 1, groups: 1, size: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[7., 12., 8., 3.], [19., 36., 32., 15.]], [[7., 12., 8., 3.], [19., 36., 32., 15.]], ], &device, ), weight: TestTensor::from_floats( [ [[26., 22., 18.], [26., 22., 18.]], [[42., 38., 34.], [42., 38., 34.]], ], &device, ), bias: TestTensor::from_floats([4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose1d_stride() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 2], kernel_size: 3, padding: 0, padding_out: 0, stride: 2, dilation: 1, groups: 1, size: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[15., 15., 15., 15.], [51., 51., 51., 51.]], [[15., 15., 15., 15.], [51., 51., 51., 51.]], ], &device, ), weight: TestTensor::from_floats( [ [[44., 44., 44.], [44., 44., 44.]], [[76., 76., 76.], [76., 76., 76.]], ], &device, ), bias: TestTensor::from_floats([18., 18.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose1d_stride_padding_out() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 2], kernel_size: 3, padding: 0, padding_out: 1, stride: 2, dilation: 1, groups: 1, size: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[15., 15., 15., 15.], [51., 51., 51., 51.]], [[15., 15., 15., 15.], [51., 51., 51., 51.]], ], &device, ), weight: TestTensor::from_floats( [ [[44., 44., 44.], [44., 44., 44.]], [[76., 76., 76.], [76., 76., 76.]], ], &device, ), bias: TestTensor::from_floats([20., 20.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose1d_dilation() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 2], kernel_size: 3, padding: 0, padding_out: 0, stride: 1, dilation: 2, groups: 1, size: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [[15., 15., 15., 15.], [51., 51., 51., 51.]], [[15., 15., 15., 15.], [51., 51., 51., 51.]], ], &device, ), weight: TestTensor::from_floats( [ [[44., 44., 44.], [44., 44., 44.]], [[76., 76., 76.], [76., 76., 76.]], ], &device, ), bias: TestTensor::from_floats([16., 16.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose1d_complex() { let test = ConvTranspose1dTestCase { batch_size: 2, channels: [2, 4], kernel_size: 3, padding: 1, padding_out: 1, stride: 2, dilation: 2, groups: 2, size: 8, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], ], [ [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0], [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0], ], ], &device, ), weight: TestTensor::from_floats( [ [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]], [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]], ], &device, ), bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device), }; test.assert_grads(grads); } struct ConvTranspose1dTestCase { batch_size: usize, channels: [usize; 2], kernel_size: usize, padding: usize, padding_out: usize, stride: usize, dilation: usize, groups: usize, size: usize, } struct Grads { x: TestTensor<3>, weight: TestTensor<3>, bias: TestTensor<1>, } impl ConvTranspose1dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]); let shape_weight = Shape::new([ self.channels[0], self.channels[1] / self.groups, self.kernel_size, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<3, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv_transpose1d( x.clone(), weight.clone(), Some(bias.clone()), ConvTransposeOptions::new( [self.stride], [self.padding], [self.padding_out], [self.dilation], self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), Tolerance::default()); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose2d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions}; #[test] fn test_conv_transpose2d_basic() { let test = ConvTranspose2dTestCase { batch_size: 2, channels: [2, 2], kernel_size: [3, 3], padding: [0, 0], padding_out: [0, 0], stride: [1, 1], dilation: [1, 1], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [153., 153., 153., 153.], [153., 153., 153., 153.], [153., 153., 153., 153.], [153., 153., 153., 153.], ], [ [477., 477., 477., 477.], [477., 477., 477., 477.], [477., 477., 477., 477.], [477., 477., 477., 477.], ], ], [ [ [153., 153., 153., 153.], [153., 153., 153., 153.], [153., 153., 153., 153.], [153., 153., 153., 153.], ], [ [477., 477., 477., 477.], [477., 477., 477., 477.], [477., 477., 477., 477.], [477., 477., 477., 477.], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]], ], [ [ [1264., 1264., 1264.], [1264., 1264., 1264.], [1264., 1264., 1264.], ], [ [1264., 1264., 1264.], [1264., 1264., 1264.], [1264., 1264., 1264.], ], ], ], &device, ), bias: TestTensor::from_floats([72., 72.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_padding() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [1, 1], kernel_size: [3, 3], padding: [1, 2], padding_out: [0, 0], stride: [1, 1], dilation: [1, 1], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[[ [13., 24., 20., 9.], [15., 27., 21., 9.], [15., 27., 21., 9.], [7., 12., 8., 3.], ]]], &device, ), weight: TestTensor::from_floats( [[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]], &device, ), bias: TestTensor::from_floats([8.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_stride() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [1, 1], kernel_size: [3, 3], padding: [0, 0], padding_out: [0, 0], stride: [2, 3], dilation: [1, 1], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[[ [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], ]]], &device, ), weight: TestTensor::from_floats( [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]], &device, ), bias: TestTensor::from_floats([108.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_stride_padding_out() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [1, 1], kernel_size: [3, 3], padding: [0, 0], padding_out: [1, 2], stride: [2, 3], dilation: [1, 1], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[[ [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], ]]], &device, ), weight: TestTensor::from_floats( [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]], &device, ), bias: TestTensor::from_floats([140.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_dilation() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [1, 1], kernel_size: [3, 3], padding: [0, 0], padding_out: [0, 0], stride: [1, 1], dilation: [2, 3], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[[ [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], ]]], &device, ), weight: TestTensor::from_floats( [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]], &device, ), bias: TestTensor::from_floats([80.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_channels() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [2, 3], kernel_size: [3, 3], padding: [0, 0], padding_out: [0, 0], stride: [1, 1], dilation: [1, 1], groups: 1, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [351., 351., 351., 351.], [351., 351., 351., 351.], [351., 351., 351., 351.], [351., 351., 351., 351.], ], [ [1080., 1080., 1080., 1080.], [1080., 1080., 1080., 1080.], [1080., 1080., 1080., 1080.], [1080., 1080., 1080., 1080.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]], ], [ [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]], ], ], &device, ), bias: TestTensor::from_floats([36., 36., 36.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_kernel_size() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [1, 1], kernel_size: [3, 5], padding: [0, 0], padding_out: [0, 0], stride: [1, 1], dilation: [1, 1], groups: 1, size: [6, 6], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[[ [105., 105., 105., 105., 105., 105.], [105., 105., 105., 105., 105., 105.], [105., 105., 105., 105., 105., 105.], [105., 105., 105., 105., 105., 105.], [105., 105., 105., 105., 105., 105.], [105., 105., 105., 105., 105., 105.], ]]], &device, ), weight: TestTensor::from_floats( [[[ [630., 630., 630., 630., 630.], [630., 630., 630., 630., 630.], [630., 630., 630., 630., 630.], ]]], &device, ), bias: TestTensor::from_floats([80.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_groups() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [2, 2], kernel_size: [3, 3], padding: [0, 0], padding_out: [0, 0], stride: [1, 1], dilation: [1, 1], groups: 2, size: [4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], [36., 36., 36., 36.], ], [ [117., 117., 117., 117.], [117., 117., 117., 117.], [117., 117., 117., 117.], [117., 117., 117., 117.], ], ]], &device, ), weight: TestTensor::from_floats( [ [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]], [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]], ], &device, ), bias: TestTensor::from_floats([36., 36.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_complex_no_groups() { let test = ConvTranspose2dTestCase { batch_size: 2, channels: [2, 3], kernel_size: [3, 5], padding: [1, 2], padding_out: [1, 2], stride: [2, 3], dilation: [2, 3], groups: 1, size: [6, 8], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [600., 735., 735., 735., 735., 735., 735., 735.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], ], [ [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], ], ], [ [ [600., 735., 735., 735., 735., 735., 735., 735.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], [810., 990., 990., 990., 990., 990., 990., 990.], ], [ [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [ [5320., 6040., 6040., 6040., 6040.], [6048., 6864., 6864., 6864., 6864.], [6048., 6864., 6864., 6864., 6864.], ], [ [5320., 6040., 6040., 6040., 6040.], [6048., 6864., 6864., 6864., 6864.], [6048., 6864., 6864., 6864., 6864.], ], [ [5320., 6040., 6040., 6040., 6040.], [6048., 6864., 6864., 6864., 6864.], [6048., 6864., 6864., 6864., 6864.], ], ], [ [ [8680., 9880., 9880., 9880., 9880.], [10080., 11472., 11472., 11472., 11472.], [10080., 11472., 11472., 11472., 11472.], ], [ [8680., 9880., 9880., 9880., 9880.], [10080., 11472., 11472., 11472., 11472.], [10080., 11472., 11472., 11472., 11472.], ], [ [8680., 9880., 9880., 9880., 9880.], [10080., 11472., 11472., 11472., 11472.], [10080., 11472., 11472., 11472., 11472.], ], ], ], &device, ), bias: TestTensor::from_floats([896., 896., 896.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_complex_no_groups_2() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [4, 2], kernel_size: [2, 3], padding: [1, 2], padding_out: [1, 2], stride: [2, 3], dilation: [1, 2], groups: 1, size: [10, 10], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.], ], [ [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.], ], [ [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.], ], [ [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [[4455., 4905., 4905.], [4500., 4950., 4950.]], [[4455., 4905., 4905.], [4500., 4950., 4950.]], ], [ [[12555., 13905., 13905.], [13500., 14950., 14950.]], [[12555., 13905., 13905.], [13500., 14950., 14950.]], ], [ [[20655., 22905., 22905.], [22500., 24950., 24950.]], [[20655., 22905., 22905.], [22500., 24950., 24950.]], ], [ [[28755., 31905., 31905.], [31500., 34950., 34950.]], [[28755., 31905., 31905.], [31500., 34950., 34950.]], ], ], &device, ), bias: TestTensor::from_floats([570., 570.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose2d_complex_groups() { let test = ConvTranspose2dTestCase { batch_size: 1, channels: [4, 2], kernel_size: [2, 3], padding: [1, 2], padding_out: [1, 2], stride: [2, 3], dilation: [1, 2], groups: 2, size: [10, 10], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.], ], [ [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.], ], [ [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.], ], [ [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.], ], ]], &device, ), weight: TestTensor::from_floats( [ [[[4455., 4905., 4905.], [4500., 4950., 4950.]]], [[[12555., 13905., 13905.], [13500., 14950., 14950.]]], [[[20655., 22905., 22905.], [22500., 24950., 24950.]]], [[[28755., 31905., 31905.], [31500., 34950., 34950.]]], ], &device, ), bias: TestTensor::from_floats([570., 570.], &device), }; test.assert_grads(grads); } struct ConvTranspose2dTestCase { batch_size: usize, channels: [usize; 2], kernel_size: [usize; 2], padding: [usize; 2], padding_out: [usize; 2], stride: [usize; 2], dilation: [usize; 2], groups: usize, size: [usize; 2], } struct Grads { x: TestTensor<4>, weight: TestTensor<4>, bias: TestTensor<1>, } impl ConvTranspose2dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([ self.batch_size, self.channels[0], self.size[0], self.size[1], ]); let shape_weight = Shape::new([ self.channels[0], self.channels[1] / self.groups, self.kernel_size[0], self.kernel_size[1], ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = conv_transpose2d( x.clone(), weight.clone(), Some(bias.clone()), ConvTransposeOptions::new( self.stride, self.padding, self.padding_out, self.dilation, self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::permissive(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/conv_transpose3d.rs ================================================ use super::*; use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions}; #[test] fn test_conv_transpose3d_basic() { let test = ConvTranspose3dTestCase { batch_size: 2, channels: [2, 2], kernel_size: [3, 3, 3], padding: [0, 0, 0], padding_out: [0, 0, 0], stride: [1, 1, 1], dilation: [1, 1, 1], groups: 1, size: [4, 4, 4], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], ], [ [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], ], ], [ [ [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], [ [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], [13.250001, 13.250001, 13.250001, 13.250001], ], ], [ [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], [ [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], [40.249992, 40.249992, 40.249992, 40.249992], ], ], ], ], &device, ), weight: TestTensor::from_floats( [ [ [ [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], ], [ [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], [ [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], [47.750000, 47.750000, 47.750000], ], ], ], [ [ [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], ], [ [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], [ [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], [79.750000, 79.750000, 79.750000], ], ], ], ], &device, ), bias: TestTensor::from_floats([432., 432.], &device), }; test.assert_grads(grads); } #[test] fn test_conv_transpose3d_complex_groups() { let test = ConvTranspose3dTestCase { batch_size: 1, channels: [4, 2], kernel_size: [2, 3, 4], padding: [1, 2, 3], padding_out: [1, 2, 3], stride: [2, 3, 4], dilation: [1, 2, 3], groups: 2, size: [6, 6, 6], }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [ [1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000], [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500], [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500], [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500], [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500], [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500], ], [ [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], ], [ [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], ], [ [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], ], [ [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], ], [ [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000], ], ], [ [ [2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000], [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500], [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500], [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500], [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500], [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500], ], [ [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], ], [ [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], ], [ [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], ], [ [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], ], [ [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000], ], ], [ [ [4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000], [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500], [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500], [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500], [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500], [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500], ], [ [ 7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], ], [ [ 7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], ], [ [ 7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], ], [ [ 7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], ], [ [ 7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], [ 11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000, ], ], ], [ [ [5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000], [ 8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500, ], [ 8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500, ], [ 8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500, ], [ 8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500, ], [ 8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500, ], ], [ [ 10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], ], [ [ 10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], ], [ [ 10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], ], [ [ 10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], ], [ [ 10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], [ 15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000, ], ], ], ]], &device, ), weight: TestTensor::from_floats( [ [[ [ [18.663193, 22.309027, 22.309027, 22.309027], [21.875000, 26.145834, 26.145834, 26.145834], [21.875000, 26.145834, 26.145834, 26.145834], ], [ [19.270832, 23.020834, 23.020834, 23.020834], [22.500000, 26.875002, 26.875002, 26.875002], [22.500000, 26.875002, 26.875002, 26.875002], ], ]], [[ [ [49.913193, 59.809029, 59.809029, 59.809029], [59.375000, 71.145836, 71.145836, 71.145836], [59.375000, 71.145836, 71.145836, 71.145836], ], [ [56.770836, 68.020836, 68.020836, 68.020836], [67.500000, 80.875000, 80.875000, 80.875000], [67.500000, 80.875000, 80.875000, 80.875000], ], ]], [[ [ [81.163193, 97.309029, 97.309029, 97.309029], [96.875000, 116.145828, 116.145828, 116.145828], [96.875000, 116.145828, 116.145828, 116.145828], ], [ [94.270828, 113.020828, 113.020828, 113.020828], [112.500000, 134.875000, 134.875000, 134.875000], [112.500000, 134.875000, 134.875000, 134.875000], ], ]], [[ [ [112.413200, 134.809021, 134.809021, 134.809021], [134.375000, 161.145828, 161.145828, 161.145828], [134.375000, 161.145828, 161.145828, 161.145828], ], [ [131.770844, 158.020828, 158.020828, 158.020828], [157.500000, 188.875000, 188.875000, 188.875000], [157.500000, 188.875000, 188.875000, 188.875000], ], ]], ], &device, ), bias: TestTensor::from_floats([5346., 5346.], &device), }; test.assert_grads(grads); } struct ConvTranspose3dTestCase { batch_size: usize, channels: [usize; 2], kernel_size: [usize; 3], padding: [usize; 3], padding_out: [usize; 3], stride: [usize; 3], dilation: [usize; 3], groups: usize, size: [usize; 3], } struct Grads { x: TestTensor<5>, weight: TestTensor<5>, bias: TestTensor<1>, } impl ConvTranspose3dTestCase { fn assert_grads(self, expected_grads: Grads) { let shape_x = Shape::new([ self.batch_size, self.channels[0], self.size[0], self.size[1], self.size[2], ]); let shape_weight = Shape::new([ self.channels[0], self.channels[1] / self.groups, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<5, _>(shape_weight.clone()) .into_data(), &device, ) .div_scalar(shape_weight.num_elements() as f32) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<5, _>(shape_x.clone()) .into_data(), &device, ) .div_scalar(shape_x.num_elements() as f32) .require_grad(); let output = conv_transpose3d( x.clone(), weight.clone(), Some(bias.clone()), ConvTransposeOptions::new( self.stride, self.padding, self.padding_out, self.dilation, self.groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); let tolerance = Tolerance::permissive(); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cross.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[cfg(feature = "std")] use burn_backend_tests::might_panic; #[test] fn backward_basic() { let device = Default::default(); let a = TestAutodiffTensor::<2>::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ) .require_grad(); let b = TestAutodiffTensor::<2>::from_data( TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), &device, ) .require_grad(); // Simple cross product; grad is a vector of ones. let c = a.clone().cross(b.clone(), 1); let grads = c.backward(); let a_grad = a.grad(&grads).unwrap().to_data(); let b_grad = b.grad(&grads).unwrap().to_data(); // For a: b×grad_out, where grad_out = [1,1,1] let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]); // For b: grad_out×a let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]); a_grad.assert_approx_eq::(&expected_a, Tolerance::default()); b_grad.assert_approx_eq::(&expected_b, Tolerance::default()); } #[test] fn backward_after_sum() { let device = Default::default(); let a = TestAutodiffTensor::<2>::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ) .require_grad(); let b = TestAutodiffTensor::<2>::from_data( TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), &device, ) .require_grad(); // Sum reduces to scalar, but the gradient should be the same. let c = a.clone().cross(b.clone(), 1).sum(); let grads = c.backward(); let a_grad = a.grad(&grads).unwrap().to_data(); let b_grad = b.grad(&grads).unwrap().to_data(); let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]); let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]); a_grad.assert_approx_eq::(&expected_a, Tolerance::default()); b_grad.assert_approx_eq::(&expected_b, Tolerance::default()); } #[cfg(feature = "std")] #[might_panic(reason = "not implemented: Cross product on non-last dimension")] #[test] fn different_dim() { // Also check when the cross is along a different dimension (e.g. dim 0). let device = Default::default(); let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]]; let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]]; let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device); let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device); // Cross along dim 0. Some backends (for example CubeCL) may not support // cross on non-last dimensions and will intentionally panic with a // message like "Cross product on non-last dimension not yet implemented". // In that case we treat the panic as a skipped test for that backend. let out = a.cross(b.clone(), 0); // Manually compute cross of each column vector using raw arrays let expected = [ [ a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0], a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1], a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2], ], [ a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0], a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1], a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2], ], [ a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0], a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1], a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2], ], ]; out.to_data() .assert_approx_eq::(&TensorData::from(expected), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cross_entropy.rs ================================================ use super::*; use burn_tensor::{Tensor, TensorData, Tolerance, loss}; #[test] fn test_cross_entropy_loss_grad() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data_1, &device).require_grad(); let tensor_2 = Tensor::::from_data(data_2, &device).require_grad(); let tensor_targets = Tensor::::from_data(data_targets, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::permissive(); let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cummax.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_cummax() { // Simple test to verify cummax gradients work let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device) .require_grad(); let output = tensor.clone().cummax(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 2.0, 0.0] let expected = TensorData::from([1.0, 2.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummax_2d() { // Test 2D cummax gradients let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data( TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]), &device, ) .require_grad(); let output = tensor.clone().cummax(1); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]] let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummax_duplicate_values() { // Test with duplicate maximum values - critical edge case let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device) .require_grad(); let output = tensor.clone().cummax(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // input: [1.0, 3.0, 3.0, 2.0] // cummax: [1.0, 3.0, 3.0, 3.0] // PyTorch reference: [1.0, 1.0, 2.0, 0.0] // Position 2 gets grad from itself + position 3 let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummax_all_same() { // Test with all same values let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device) .require_grad(); let output = tensor.clone().cummax(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 1.0, 1.0] // Each position matches cummax, so each gets its own gradient let expected = TensorData::from([1.0, 1.0, 1.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummax_increasing() { // Test with increasing sequence let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cummax(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 1.0, 1.0, 1.0] // Each position is a new maximum let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummax_2d_duplicates() { // Test 2D with duplicate values let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data( TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]), &device, ) .require_grad(); let output = tensor.clone().cummax(1); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]] let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cummin.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_cummin() { // Simple test to verify cummin gradients work let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cummin(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 2.0, 0.0] let expected = TensorData::from([1.0, 2.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummin_2d() { // Test 2D cummin gradients let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data( TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]), &device, ) .require_grad(); let output = tensor.clone().cummin(1); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]] let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummin_duplicate_values() { // Test with duplicate minimum values - critical edge case let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cummin(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // input: [3.0, 2.0, 2.0, 4.0] // cummin: [3.0, 2.0, 2.0, 2.0] // PyTorch reference: [1.0, 1.0, 2.0, 0.0] // Position 2 gets grad from itself + position 3 let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummin_all_same() { // Test with all same values let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device) .require_grad(); let output = tensor.clone().cummin(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 1.0, 1.0] // Each position matches cummin, so each gets its own gradient let expected = TensorData::from([1.0, 1.0, 1.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummin_decreasing() { // Test with decreasing sequence let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device) .require_grad(); let output = tensor.clone().cummin(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 1.0, 1.0, 1.0] // Each position is a new minimum let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cummin_2d_duplicates() { // Test 2D with duplicate values let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data( TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]), &device, ) .require_grad(); let output = tensor.clone().cummin(1); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]] let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cumprod.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_cumprod() { // Simple test to verify cumprod gradients work let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cumprod(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [16.0, 10.0, 6.0] let expected = TensorData::from([16.0, 10.0, 6.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cumprod_2d() { // Test 2D cumprod gradients let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ) .require_grad(); let output = tensor.clone().cumprod(1); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]] let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } // TODO: The following tests are currently ignored due to a known limitation // in the cumprod gradient implementation. The current implementation uses // division (grad / input), which produces NaN when the input contains zeros. // // A proper fix requires implementing a zero-safe algorithm using exclusive // cumulative products (similar to PyTorch's cumprod_backward or JAX's // associative_scan approach). This is a non-trivial implementation that // requires careful handling of cumulative products in both forward and // reverse directions. // // See: https://github.com/tracel-ai/burn/issues/3864 // // References: // - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward) // - JAX PR #2596: Parallel prefix scan implementation // - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros #[test] #[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"] fn should_diff_cumprod_zero_in_middle() { // Test cumprod with zero in the middle - edge case for division let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cumprod(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 32.0, 0.0, 0.0] let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"] fn should_diff_cumprod_zero_at_start() { // Test cumprod with zero at the beginning let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device) .require_grad(); let output = tensor.clone().cumprod(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [33.0, 0.0, 0.0, 0.0] let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"] fn should_diff_cumprod_zero_at_end() { // Test cumprod with zero at the end let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device) .require_grad(); let output = tensor.clone().cumprod(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [16.0, 10.0, 6.0, 24.0] let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[ignore = "cumprod gradient with zeros not yet implemented - produces NaN due to division by zero"] fn should_diff_cumprod_multiple_zeros() { // Test cumprod with multiple zeros let device = Default::default(); let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device) .require_grad(); let output = tensor.clone().cumprod(0); let grads = output.sum().backward(); let grad = tensor.grad(&grads).unwrap(); // PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0] let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/cumsum.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_cumsum_dim0() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.cumsum(0); let tensor_5 = tensor_1.clone().mul(tensor_4); let grads = tensor_5.sum().backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); // Expected gradients computed with PyTorch let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cumsum_dim1() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.cumsum(1); let tensor_5 = tensor_1.clone().mul(tensor_4); let grads = tensor_5.sum().backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); // Expected gradients computed with PyTorch let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_cumsum_complex() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.clone().cumsum(1); let tensor_5 = tensor_4.mul(tensor_3); let grads = tensor_5.sum().backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); // Expected gradients computed with PyTorch let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/deform_conv2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Shape, module::deform_conv2d, ops::DeformConvOptions}; #[test] fn test_deform_conv2d_basic() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, offset_groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [0.000, 6.0678, 14.2071, 12.2477], [11.2292, 33.7937, 50.1555, 44.0561], [17.9294, 57.2174, 85.1505, 79.1840], [18.0220, 73.6263, 126.8184, 151.6910], ], [ [0.000, 8.9783, 20.7620, 17.7888], [16.2326, 48.7386, 71.7961, 62.5845], [25.3808, 80.5195, 119.0949, 110.0938], [25.0567, 101.8461, 174.3329, 206.6013], ], ]], &device, ), offset: TestTensor::from_floats( [[ [[0.000, 15.0000], [30.000, 45.0000]], [[0.000, 3.7500], [7.5000, 11.2500]], [[62.6667, 78.3333], [94.0000, 109.6667]], [[15.6667, 19.5833], [23.5000, 27.4167]], [[130.6667, 104.1250], [163.3333, 122.2732]], [[32.6667, -492.9583], [40.8333, -787.1620]], [[204.0000, 221.0000], [238.0000, 255.0000]], [[51.0000, 55.2500], [59.5000, 63.7500]], [[282.6667, 300.3333], [318.0000, 335.6667]], [[70.6667, 75.0833], [79.5000, 83.9167]], [[366.6667, 144.3750], [403.3333, 146.4121]], [[91.6667, -1788.9860], [100.8333, -2392.7456]], [[456.0000, 475.0000], [-2718.6250, -2953.2188]], [[114.0000, 118.7500], [37.7361, 37.4063]], [[550.6667, 570.3334], [-3404.5139, -3672.5312]], [[137.6667, 142.5833], [28.6806, 27.5197]], [[650.6667, 27.9584], [-4174.3657, -59.7509]], [[162.6667, -3991.0139], [14.4028, -298.7557]], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [0.7029, 2.8356, 5.1067], [12.7492, 19.4745, 17.8345], [22.0687, 25.9156, 14.6394], ], [ [3.3696, 12.6134, 19.2671], [36.7492, 50.5856, 43.5506], [50.8774, 56.3292, 30.7470], ], ], [ [ [0.7029, 2.8356, 5.1067], [12.7492, 19.4745, 17.8345], [22.0687, 25.9156, 14.6394], ], [ [3.3696, 12.6134, 19.2671], [36.7492, 50.5856, 43.5506], [50.8774, 56.3292, 30.7470], ], ], [ [ [0.7029, 2.8356, 5.1067], [12.7492, 19.4745, 17.8345], [22.0687, 25.9156, 14.6394], ], [ [3.3696, 12.6134, 19.2671], [36.7492, 50.5856, 43.5506], [50.8774, 56.3292, 30.7470], ], ], ], &device, ), mask: TestTensor::from_floats( [[ [[1303.5000, 1447.8750], [1862.2500, 2006.6250]], [[1571.1666, 1721.9581], [2154.7500, 2305.5417]], [[1857.4999, 1396.7151], [2465.9167, 1753.2246]], [[2315.5000, 2479.1250], [2948.7502, 3112.3750]], [[2645.1665, 2815.2085], [3303.2500, 3473.2917]], [[2993.5000, 1150.0625], [3676.4165, 1300.4055]], [[3531.5000, 3714.3752], [1150.1876, 1148.4744]], [[3923.1665, 4112.4585], [794.3865, 770.0470]], [[4333.5000, 181.4101], [368.3260, 4.2679]], ]], &device, ), bias: TestTensor::from_floats([4., 4., 4.], &device), }; test.assert_grads(grads); } #[test] fn test_deform_conv2d_batched() { let test = Conv2dTestCase { batch_size: 2, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, offset_groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [ [ [ [0.000, 3.4604, 8.7539, 6.8080], [8.4661, 24.0784, 35.4610, 26.4276], [19.5988, 51.0406, 68.4389, 53.4993], [17.4698, 47.9106, 67.3808, 56.6063], ], [ [0.000, 5.1185, 12.7803, 9.8796], [12.1957, 34.5728, 50.4616, 37.3777], [27.4521, 71.1227, 94.5778, 73.4724], [24.1147, 65.8443, 91.8995, 76.7475], ], ], [ [ [6.3750, 19.3553, 26.4935, 22.5650], [17.0026, 57.8088, 85.5580, 78.0746], [20.7334, 86.5793, 139.4667, 136.4133], [16.8126, 103.0225, 186.4502, 206.9613], ], [ [9.5625, 28.8786, 39.1137, 32.9178], [25.1984, 85.0747, 124.6941, 112.5691], [30.0242, 124.2863, 198.6056, 192.4489], [23.5826, 143.4660, 257.8752, 283.2587], ], ], ], &device, ), offset: TestTensor::from_floats( [ [ [[0.000, 7.5000], [15.0000, 22.5000]], [[0.000, 1.8750], [3.7500, 5.6250]], [[31.3333, 39.1667], [47.0000, 54.8333]], [[7.8333, 9.7917], [11.7500, 13.7083]], [[65.3333, 62.7813], [81.6667, 75.4849]], [[16.3333, -237.8021], [20.4167, -381.7280]], [[102.0000, 110.5000], [119.0000, 127.5000]], [[25.5000, 27.6250], [29.7500, 31.8750]], [[141.3333, 150.1667], [159.0000, 167.8333]], [[35.3333, 37.5417], [39.7500, 41.9583]], [[183.3333, 132.3438], [201.6667, 142.0197]], [[45.8333, -839.6840], [50.4167, -1133.4155]], [[228.0000, 237.5000], [-1336.1562, -1452.1173]], [[57.0000, 59.3750], [40.3090, 41.4141]], [[275.3333, 285.1667], [-1670.5034, -1802.9244]], [[68.8333, 71.2917], [44.0451, 44.9841]], [[325.3333, 174.7396], [-2045.1747, -1090.4585]], [[81.3333, -1844.0659], [46.8090, -1150.2101]], ], [ [[270.000, 277.5000], [285.0000, 292.5000]], [[67.5000, 69.3750], [71.2500, 73.1250]], [[313.3333, 321.1667], [329.0000, 336.8333]], [[78.3333, 80.2917], [82.2500, 84.2083]], [[359.3333, 130.1563], [375.6667, 130.6099]], [[89.8333, -4312.7603], [93.9167, -4893.6035]], [[408.0000, 416.5000], [425.0000, 433.5000]], [[102.0000, 104.1250], [106.2500, 108.3750]], [[459.3333, 468.1667], [477.0000, 485.8333]], [[114.8333, 117.0417], [119.2500, 121.4583]], [[513.3334, 97.9688], [531.6667, 93.8947]], [[128.3333, -6720.3926], [132.9167, -7504.5405]], [[570.000, 579.5000], [-7971.8438, -8251.0850]], [[142.5000, 144.8750], [22.4965, 21.8203]], [[629.3333, 639.1667], [-8948.2334, -9249.6641]], [[157.3333, 159.7917], [15.7743, 14.8695]], [[691.3333, 14.6145], [-9992.9453, -70.4040]], [[172.8333, -9818.5234], [7.4132, -352.0222]], ], ], &device, ), weight: TestTensor::from_floats( [ [ [ [77.7195, 89.8692, 69.0213], [121.0760, 137.0775, 92.2989], [100.0212, 106.5561, 61.1851], ], [ [112.3862, 131.6470, 103.8793], [177.0760, 200.1887, 138.2681], [149.5922, 158.7074, 94.3991], ], ], [ [ [77.7195, 89.8692, 69.0213], [121.0760, 137.0775, 92.2989], [100.0212, 106.5561, 61.1851], ], [ [112.3862, 131.6470, 103.8793], [177.0760, 200.1887, 138.2681], [149.5922, 158.7074, 94.3991], ], ], [ [ [77.7195, 89.8692, 69.0213], [121.0760, 137.0775, 92.2989], [100.0212, 106.5561, 61.1851], ], [ [112.3862, 131.6470, 103.8793], [177.0760, 200.1887, 138.2681], [149.5922, 158.7074, 94.3991], ], ], ], &device, ), mask: TestTensor::from_floats( [ [ [[1299.7499, 1439.4375], [1849.1249, 1988.8125]], [[1528.0834, 1673.9791], [2101.8750, 2247.7708]], [[1771.7500, 1624.9811], [2369.9583, 2099.5039]], [[2183.7500, 2342.0625], [2806.3750, 2964.6875]], [[2464.0833, 2628.6042], [3111.1250, 3275.6458]], [[2759.7500, 1979.2551], [3431.2085, 2390.0286]], [[3241.7498, 3418.6873], [2415.3589, 2500.8682]], [[3574.0835, 3757.2292], [2394.3889, 2471.7510]], [[3921.7500, 2095.5293], [2345.9363, 1199.5048]], ], [ [[5957.2500, 6096.9375], [6506.6250, 6646.3125]], [[6392.5835, 6538.4790], [6966.3750, 7112.2705]], [[6843.2500, 2443.8982], [7441.4585, 2550.9199]], [[7462.2505, 7620.5625], [8084.8745, 8243.1875]], [[7949.5835, 8114.1045], [8596.6250, 8761.1465]], [[8452.2500, 1591.6719], [9123.7080, 1589.9454]], [[9141.2500, 9318.1875], [1414.3584, 1375.1803]], [[9680.5840, 9863.7285], [949.0560, 897.3544]], [[10235.2500, 213.4454], [428.2699, 2.4790]], ], ], &device, ), bias: TestTensor::from_floats([8., 8., 8.], &device), }; test.assert_grads(grads); } #[test] fn test_deform_conv2d_different_kernel_size() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, offset_groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [14.558521, 27.249609, 37.382030, 36.039406], [33.151936, 60.480656, 81.264656, 78.618156], [57.520061, 108.623283, 153.413559, 170.072998], [54.706184, 102.596664, 144.367157, 162.643570], ], [ [25.836353, 48.088451, 65.249161, 62.103317], [56.805233, 102.995605, 136.983124, 131.120911], [96.105408, 179.790192, 250.550934, 272.668793], [90.210945, 167.567917, 232.847275, 257.934692], ], ]], &device, ), offset: TestTensor::from_floats( [[ [ [0.0e+00, 5.355903e+00, 1.171528e+01], [3.124999e-01, 8.000000e+00, 1.000000e+01], [7.500000e-01, 1.400000e+01, 1.600000e+01], [1.312500e+00, 2.000000e+01, 2.200000e+01], ], [ [0.0e+00, 1.736104e-03, 6.944418e-03], [1.606250e+01, 2.000000e+00, 2.500000e+00], [4.425000e+01, 3.500000e+00, 4.000000e+00], [8.456250e+01, 5.000000e+00, 5.500000e+00], ], [ [6.745834e+01, 7.996479e+01, 9.353048e+01], [3.166667e+01, 3.377778e+01, 3.588889e+01], [3.800000e+01, 4.011111e+01, 4.222223e+01], [4.433333e+01, 4.644444e+01, 4.855556e+01], ], [ [5.277777e-01, 5.955827e-01, 6.670526e-01], [7.916667e+00, 8.444445e+00, 8.972222e+00], [9.500000e+00, 1.002778e+01, 1.055556e+01], [1.108333e+01, 1.161111e+01, 1.213889e+01], ], [ [1.547778e+02, 1.751640e+02, 1.518874e+02], [6.000000e+01, 6.222223e+01, 4.989969e+01], [6.666666e+01, 6.888889e+01, 5.432098e+01], [7.333334e+01, 7.555556e+01, 5.860340e+01], ], [ [2.222223e+00, 2.363040e+00, -3.360339e+01], [1.500000e+01, 1.555556e+01, -2.277485e+02], [1.666667e+01, 1.722222e+01, -3.231605e+02], [1.833333e+01, 1.888889e+01, -4.320448e+02], ], [ [2.641250e+02, 2.021189e+02, 0.0e+00], [9.100000e+01, 6.481482e+01, 0.0e+00], [9.800000e+01, 6.863078e+01, 0.0e+00], [1.050000e+02, 7.230093e+01, 0.0e+00], ], [ [5.250000e+00, -7.268316e+01, 0.0e+00], [2.275000e+01, -3.346296e+02, 0.0e+00], [2.450000e+01, -4.611053e+02, 0.0e+00], [2.625000e+01, -6.017269e+02, 0.0e+00], ], [ [4.400000e+01, 1.197778e+02, 1.222222e+02], [4.804860e+01, 1.271111e+02, 1.295556e+02], [5.225000e+01, 1.344444e+02, 1.368889e+02], [-3.138958e+02, -8.007446e+02, -8.507313e+02], ], [ [3.377778e+02, 2.994445e+01, 3.055556e+01], [4.848542e+02, 3.177778e+01, 3.238889e+01], [6.467500e+02, 3.361111e+01, 3.422222e+01], [4.909653e+02, 2.239892e+01, 2.265992e+01], ], [ [1.533333e+02, 1.558889e+02, 1.584444e+02], [1.610000e+02, 1.635556e+02, 1.661111e+02], [1.686667e+02, 1.712222e+02, 1.737778e+02], [-9.952491e+02, -1.054551e+03, -1.115134e+03], ], [ [3.833333e+01, 3.897222e+01, 3.961111e+01], [4.025000e+01, 4.088889e+01, 4.152778e+01], [4.216667e+01, 4.280556e+01, 4.344445e+01], [2.433767e+01, 2.453511e+01, 2.472810e+01], ], [ [1.920000e+02, 1.946667e+02, 8.907407e+01], [2.000000e+02, 2.026667e+02, 9.054632e+01], [2.080000e+02, 2.106667e+02, 9.185186e+01], [-1.272938e+03, -1.343509e+03, -5.811921e+02], ], [ [4.800000e+01, 4.866667e+01, -7.413704e+02], [5.000000e+01, 5.066667e+01, -9.788981e+02], [5.200000e+01, 5.266667e+01, -1.232593e+03], [2.531250e+01, 2.543518e+01, -6.388311e+02], ], [ [2.333333e+02, 8.772182e+01, 0.0e+00], [2.416667e+02, 8.827161e+01, 0.0e+00], [2.500000e+02, 8.864776e+01, 0.0e+00], [-1.587216e+03, -5.535372e+02, 0.0e+00], ], [ [5.833333e+01, -9.011902e+02, 0.0e+00], [6.041667e+01, -1.179988e+03, 0.0e+00], [6.250000e+01, -1.475625e+03, 0.0e+00], [2.489150e+01, -6.213175e+02, 0.0e+00], ], [ [1.964444e+02, 2.802222e+02, 2.831111e+02], [2.055625e+02, 2.888889e+02, 2.917778e+02], [-1.173472e+03, -1.679611e+03, -1.771290e+03], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [1.144889e+03, 7.005556e+01, 7.077778e+01], [1.469646e+03, 7.222223e+01, 7.294444e+01], [5.029167e+02, 2.298823e+01, 2.295062e+01], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [3.240000e+02, 3.270000e+02, 3.300000e+02], [3.330000e+02, 3.360000e+02, 3.390000e+02], [-1.931469e+03, -2.034961e+03, -2.139958e+03], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [8.100000e+01, 8.175000e+01, 8.250000e+01], [8.325000e+01, 8.400000e+01, 8.475000e+01], [1.959376e+01, 1.946614e+01, 1.933334e+01], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [3.733333e+02, 3.764445e+02, 4.480865e+01], [3.826667e+02, 3.857778e+02, 4.185955e+01], [-2.313792e+03, -2.431276e+03, -2.392101e+02], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [9.333333e+01, 9.411111e+01, -1.904932e+03], [9.566667e+01, 9.644444e+01, -2.344715e+03], [1.429166e+01, 1.406212e+01, -3.417283e+02], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [4.253333e+02, 1.636843e+01, 0.0e+00], [4.350000e+02, 1.217279e+01, 0.0e+00], [-2.738517e+03, -4.792887e+01, 0.0e+00], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [1.063333e+02, -2.178747e+03, 0.0e+00], [1.087500e+02, -2.670679e+03, 0.0e+00], [6.947917e+00, -1.629574e+02, 0.0e+00], [0.0e+00, 0.0e+00, 0.0e+00], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [1.856041, 7.203409, 12.833395, 11.969448], [24.236776, 40.125511, 41.396423, 27.642044], [43.613083, 57.508926, 46.093338, 25.174383], ], [ [6.989914, 26.580338, 42.618557, 37.501404], [75.623192, 116.925674, 113.288368, 72.567764], [112.724869, 139.826447, 107.653435, 56.799385], ], ], [ [ [1.856041, 7.203409, 12.833395, 11.969448], [24.236776, 40.125511, 41.396423, 27.642044], [43.613083, 57.508926, 46.093338, 25.174383], ], [ [6.989914, 26.580338, 42.618557, 37.501404], [75.623192, 116.925674, 113.288368, 72.567764], [112.724869, 139.826447, 107.653435, 56.799385], ], ], ], &device, ), mask: TestTensor::from_floats( [[ [ [0.0e+00, 2.677941e+00, 5.857617e+00], [4.015623e+01, 7.759999e+02, 8.492499e+02], [6.637500e+01, 1.067750e+03, 1.141000e+03], [9.865628e+01, 1.359500e+03, 1.432750e+03], ], [ [6.745831e+01, 7.688924e+01, 8.684974e+01], [8.387916e+02, 9.161111e+02, 9.934306e+02], [1.146750e+03, 1.224069e+03, 1.301389e+03], [1.454708e+03, 1.532028e+03, 1.609347e+03], ], [ [1.547778e+02, 1.716607e+02, 1.460455e+02], [9.861667e+02, 1.067556e+03, 8.756536e+02], [1.310333e+03, 1.391722e+03, 1.110864e+03], [1.634500e+03, 1.715889e+03, 1.339339e+03], ], [ [2.641250e+02, 1.993876e+02, 0.0e+00], [1.144875e+03, 8.365740e+02, 0.0e+00], [1.485250e+03, 1.056253e+03, 0.0e+00], [1.825625e+03, 1.268859e+03, 0.0e+00], ], [ [3.800000e+02, 1.047861e+03, 1.137389e+03], [5.276354e+02, 1.404444e+03, 1.493972e+03], [6.826807e+02, 1.761028e+03, 1.850555e+03], [5.038855e+02, 1.256341e+03, 1.304936e+03], ], [ [1.123500e+03, 1.217097e+03, 1.310694e+03], [1.496292e+03, 1.589889e+03, 1.683486e+03], [1.869083e+03, 1.962681e+03, 2.056278e+03], [1.146700e+03, 1.190136e+03, 1.232930e+03], ], [ [1.300000e+03, 1.397667e+03, 6.512036e+02], [1.689000e+03, 1.786667e+03, 8.072734e+02], [2.078000e+03, 2.175667e+03, 9.552593e+02], [1.060781e+03, 1.097745e+03, 4.656539e+02], ], [ [1.487833e+03, 5.672195e+02, 0.0e+00], [1.893042e+03, 6.972655e+02, 0.0e+00], [2.298250e+03, 8.188910e+02, 0.0e+00], [9.472098e+02, 3.238781e+02, 0.0e+00], ], [ [1.216444e+03, 1.792806e+03, 1.898611e+03], [1.536448e+03, 2.214222e+03, 2.320028e+03], [5.177084e+02, 7.256571e+02, 7.493920e+02], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [1.897500e+03, 2.007375e+03, 2.117250e+03], [2.335125e+03, 2.445000e+03, 2.554875e+03], [5.591096e+02, 5.750975e+02, 5.903336e+02], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [2.119333e+03, 2.233278e+03, 2.654414e+02], [2.573167e+03, 2.687111e+03, 2.907444e+02], [3.856317e+02, 3.924502e+02, 3.737657e+01], [0.0e+00, 0.0e+00, 0.0e+00], ], [ [2.352500e+03, 9.009851e+01, 0.0e+00], [2.822542e+03, 7.854909e+01, 0.0e+00], [1.785990e+02, 2.930897e+00, 0.0e+00], [0.0e+00, 0.0e+00, 0.0e+00], ], ]], &device, ), bias: TestTensor::from_floats([12., 12.], &device), }; test.assert_grads(grads); } #[test] fn test_deform_conv2d_different_padding() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 2, padding_2: 3, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, offset_groups: 1, height: 4, width: 4, }; let device = Default::default(); let grads = Grads { x: TestTensor::from_floats( [[ [ [60.633026, 60.906506, 61.179493, 61.451954], [122.557770, 123.088188, 123.618599, 124.149033], [126.801132, 127.331535, 127.861938, 128.392365], [131.044434, 131.574875, 132.105286, 132.635712], ], [ [102.000595, 102.497604, 102.993835, 103.489281], [198.932983, 199.830597, 200.728210, 201.625870], [206.113968, 207.011627, 207.909256, 208.806870], [213.294952, 214.192627, 215.090271, 215.987930], ], ]], &device, ), // => Position 788: 10.421875 != 10.0546875 // diff (rel = +1.79e-2, abs = +3.67e-1), tol (rel = +1.00e-2, abs = +9.77e-4) offset: TestTensor::from_floats( [[ [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 0.0, 0.895062, 14.760561, 17.604168, 20.698063, 22.200424, 0.0, ], [ 0.0, 0.0, 0.687500, 9.500000, 10.0, 10.500000, 10.108797, 0.0, ], [ 0.0, 0.0, 1.113426, 13.500000, 14.000000, 14.499999, 13.645835, 0.0, ], [ 0.0, 0.0, 1.613426, 17.500000, 18.000000, 18.500000, 17.108795, 0.0, ], [ 0.0, 0.0, -12.395836, -122.399445, -130.752319, -139.355469, -131.526810, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 0.0, 0.154321, 0.017506, 0.020833, 0.024450, -0.387539, 0.0, ], [ 0.0, 0.0, 24.187502, 2.375000, 2.500000, 2.625000, -37.863422, 0.0, ], [ 0.0, 0.0, 48.057869, 3.375000, 3.500000, 3.625000, -66.770836, 0.0, ], [ 0.0, 0.0, 80.02312, 4.375000, 4.500000, 4.625000, -103.752319, 0.0, ], [ 0.0, 0.0, 113.215271, 5.107495, 5.219907, 5.332031, -139.725891, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 14.206017, 83.017586, 92.379395, 102.010040, 90.356323, 0.0, 0.0, ], [ 0.0, 6.504737, 35.444443, 35.981483, 36.518517, 29.978970, 0.0, 0.0, ], [ 0.0, 7.668316, 39.740742, 40.277779, 40.814816, 33.071907, 0.0, 0.0, ], [ 0.0, 8.911458, 44.037037, 44.574074, 45.111111, 36.085281, 0.0, 0.0, ], [ 0.0, -57.523048, -274.267914, -289.547089, -305.095093, -248.578552, 0.0, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 9.749230, 0.955354, 0.980994, 1.006945, -13.930464, 0.0, 0.0, ], [ 0.0, 96.046921, 8.861111, 8.995371, 9.129629, -129.920715, 0.0, 0.0, ], [ 0.0, 147.434769, 9.935185, 10.069445, 10.203704, -186.718735, 0.0, 0.0, ], [ 0.0, 207.494781, 11.009259, 11.143518, 11.277778, -252.188889, 0.0, 0.0, ], [ 0.0, 226.050003, 10.153355, 10.252030, 10.350393, -266.255280, 0.0, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 44.224964, 159.898483, 176.651901, 193.692688, 146.270813, 0.0, 0.0, 0.0, ], [ 19.050755, 64.870377, 65.444443, 66.018517, 46.553150, 0.0, 0.0, 0.0, ], [ 21.049385, 69.462967, 70.037033, 70.611115, 49.104595, 0.0, 0.0, 0.0, ], [ 23.133059, 74.055557, 74.629631, 75.203705, 51.570988, 0.0, 0.0, 0.0, ], [ -141.200272, -445.302155, -468.381012, -491.747223, -341.553131, 0.0, 0.0, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 35.665298, 3.505739, 3.556735, 3.608062, -48.756947, 0.0, 0.0, 0.0, ], [ 181.404663, 16.217594, 16.361111, 16.504629, -238.136124, 0.0, 0.0, 0.0, ], [ 263.888885, 17.365742, 17.509258, 17.652779, -326.403656, 0.0, 0.0, 0.0, ], [ 355.643341, 18.513889, 18.657408, 18.800926, -423.941345, 0.0, 0.0, 0.0, ], [ 318.709198, 14.359658, 14.441552, 14.523109, -369.819580, 0.0, 0.0, 0.0, ], ], [ [ 0.0, 0.0, 88.846703, 237.478439, 261.731201, 286.289917, 182.508713, 0.0, ], [ 0.0, 0.0, 37.688015, 94.722221, 95.333328, 95.944450, 57.441605, 0.0, ], [ 0.0, 0.0, 40.562500, 99.611107, 100.222229, 100.833336, 59.410744, 0.0, ], [ 0.0, 0.0, 43.527519, 104.500000, 105.111107, 105.722221, 61.289349, 0.0, ], [ 0.0, 0.0, -258.324371, -618.353943, -649.340271, -680.632507, -397.101013, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 0.0, 76.229431, 7.564093, 7.641718, 7.719699, -102.792252, 0.0, ], [ 0.0, 0.0, 272.015167, 23.680555, 23.833332, 23.986113, -351.944214, 0.0, ], [ 0.0, 0.0, 386.062500, 24.902777, 25.055557, 25.208334, -472.147888, 0.0, ], [ 0.0, 0.0, 509.978149, 26.125000, 26.277777, 26.430555, -602.219971, 0.0, ], [ 0.0, 0.0, 378.410248, 17.123661, 17.187500, 17.250984, -436.000732, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 157.623291, 331.938538, 365.283356, 398.952606, 205.988480, 0.0, 0.0, ], [ 0.0, 66.495949, 130.925934, 131.574066, 132.222229, 64.435974, 0.0, 0.0, ], [ 0.0, 70.396835, 136.111115, 136.759262, 137.407410, 65.672256, 0.0, 0.0, ], [ 0.0, 74.393753, 141.296295, 141.944458, 142.592606, 66.812523, 0.0, 0.0, ], [ 0.0, -432.798035, -827.492065, -867.978455, -908.789368, -425.074158, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 140.150024, 14.043960, 14.152921, 14.262260, -187.656906, 0.0, 0.0, ], [ 0.0, 386.813873, 32.731483, 32.893517, 33.055557, -494.779602, 0.0, 0.0, ], [ 0.0, 538.926697, 34.027779, 34.189816, 34.351852, -653.421875, 0.0, 0.0, ], [ 0.0, 701.505859, 35.324074, 35.486115, 35.648151, -822.530640, 0.0, 0.0, ], [ 0.0, 416.044586, 18.903570, 18.944647, 18.985338, -476.728790, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 249.876541, 435.868500, 479.178772, 522.832031, 207.919815, 0.0, 0.0, 0.0, ], [ 105.417015, 170.611115, 171.296295, 171.981476, 64.750000, 0.0, 0.0, 0.0, ], [ 110.441696, 176.092590, 176.777771, 177.462952, 65.156044, 0.0, 0.0, 0.0, ], [ 115.567902, 181.574066, 182.259247, 182.944458, 65.460571, 0.0, 0.0, 0.0, ], [ -662.743530, -1056.641846, -1107.501953, -1158.704712, -409.510162, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 227.160507, 22.982454, 23.125793, 23.269531, -303.112030, 0.0, 0.0, 0.0, ], [ 518.495178, 42.652779, 42.824074, 42.995369, -657.157410, 0.0, 0.0, 0.0, ], [ 712.252380, 44.023148, 44.194443, 44.365738, -857.817200, 0.0, 0.0, 0.0, ], [ 917.074036, 45.393517, 45.564812, 45.736115, -1069.541626, 0.0, 0.0, 0.0, ], [ 416.581482, 18.997831, 19.013102, 19.027966, -475.031525, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 0.0, 151.750259, 210.166672, 210.888885, 211.611099, 57.506927, 0.0, ], [ 0.0, 0.0, 157.929276, 215.944443, 216.666672, 217.388901, 57.052204, 0.0, ], [ 0.0, 0.0, 164.215271, 221.722229, 222.444458, 223.166672, 56.490482, 0.0, ], [ 0.0, 0.0, -931.783752, -1285.353760, -1346.555908, -1408.119385, -346.739044, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 0.0, 655.669983, 52.541668, 52.722221, 52.902775, -824.946777, 0.0, ], [ 0.0, 0.0, 890.972473, 53.986111, 54.166668, 54.347225, -1067.525024, 0.0, ], [ 0.0, 0.0, 1137.937500, 55.430557, 55.611115, 55.791668, -1321.765625, 0.0, ], [ 0.0, 0.0, 375.580566, 17.180984, 17.169498, 17.157579, -425.993713, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 213.521454, 256.629639, 257.388885, 258.148132, 41.652927, 0.0, 0.0, ], [ 0.0, 221.015625, 262.703705, 263.462982, 264.222229, 40.176598, 0.0, 0.0, ], [ 0.0, 228.622284, 268.777802, 269.537048, 270.296295, 38.587788, 0.0, 0.0, ], [ 0.0, -1285.466797, -1554.254517, -1627.530640, -1701.186646, -228.291397, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 823.380554, 64.157410, 64.347221, 64.537033, -1028.532715, 0.0, 0.0, ], [ 0.0, 1107.296509, 65.675926, 65.865746, 66.055557, -1320.097534, 0.0, 0.0, ], [ 0.0, 1403.473022, 67.194450, 67.384262, 67.574074, -1623.922974, 0.0, 0.0, ], [ 0.0, 288.151398, 13.201796, 13.158524, 13.114797, -323.577820, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 288.790131, 306.574066, 307.370361, 308.166656, 15.734239, 0.0, 0.0, 0.0, ], [ 297.696838, 312.944427, 313.740723, 314.537048, 13.138914, 0.0, 0.0, 0.0, ], [ 306.721527, 319.314819, 320.111115, 320.907410, 10.425544, 0.0, 0.0, 0.0, ], [ -1711.543213, -1844.013062, -1930.236572, -2016.858643, -46.846100, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 1011.358093, 76.643517, 76.842590, 77.041664, -1255.045654, 0.0, 0.0, 0.0, ], [ 1347.466431, 78.236107, 78.435181, 78.634262, -1599.175903, 0.0, 0.0, 0.0, ], [ 1696.433350, 79.828705, 80.027779, 80.226852, -1956.164917, 0.0, 0.0, 0.0, ], [ 146.703568, 6.690874, 6.612756, 6.534196, -159.277222, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ]], &device, ), weight: TestTensor::from_floats( [ [ [ [10.341997, 22.988085, 35.634174], [46.920216, 59.566299, 72.212387], [80.881615, 92.591522, 104.158524], ], [ [29.213360, 68.837769, 108.462166], [143.825104, 183.449509, 223.073944], [228.029373, 256.751740, 283.807098], ], ], [ [ [10.341997, 22.988085, 35.634174], [46.920216, 59.566299, 72.212387], [80.881615, 92.591522, 104.158524], ], [ [29.213360, 68.837769, 108.462166], [143.825104, 183.449509, 223.073944], [228.029373, 256.751740, 283.807098], ], ], ], &device, ), mask: TestTensor::from_floats( [[ [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 0.0, 0.447531, 7.380288, 8.802088, 10.349031, 11.100212, 0.0, ], [ 0.0, 0.0, 44.343754, 584.937439, 639.250000, 693.562439, 683.262756, 0.0, ], [ 0.0, 0.0, 68.390068, 803.437561, 857.750000, 912.062500, 874.698059, 0.0, ], [ 0.0, 0.0, 96.473381, 1021.937500, 1076.250000, 1130.562500, 1062.095947, 0.0, ], [ 0.0, 0.0, 121.302101, 1168.487915, 1218.373779, 1268.134888, 1169.444702, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 0.0, 13.084491, 75.860909, 83.767761, 91.809029, 80.728188, 0.0, 0.0, ], [ 0.0, 118.950417, 649.486084, 707.821777, 766.157410, 658.076599, 0.0, 0.0, ], [ 0.0, 170.660782, 884.171265, 942.506958, 1000.842651, 837.809326, 0.0, 0.0, ], [ 0.0, 226.707260, 1118.856445, 1177.192261, 1235.527710, 1013.205933, 0.0, 0.0, ], [ 0.0, 234.939651, 1106.213867, 1153.415649, 1200.482666, 966.248901, 0.0, 0.0, ], ], [ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [ 42.524002, 153.045700, 168.319275, 183.736511, 138.144653, 0.0, 0.0, 0.0, ], [ 207.319611, 718.432800, 780.791626, 843.150391, 619.975037, 0.0, 0.0, 0.0, ], [ 290.277802, 969.303223, 1031.661987, 1094.020752, 784.421631, 0.0, 0.0, 0.0, ], [ 377.871063, 1220.173584, 1282.532471, 1344.891235, 944.233032, 0.0, 0.0, 0.0, ], [ 328.083038, 1025.494995, 1069.130615, 1112.622192, 766.054932, 0.0, 0.0, 0.0, ], ], [ [ 0.0, 0.0, 88.238174, 235.055206, 258.194336, 281.486389, 178.858536, 0.0, ], [ 0.0, 0.0, 305.575500, 789.868042, 856.250061, 922.631897, 572.466064, 0.0, ], [ 0.0, 0.0, 421.809021, 1056.923584, 1123.305542, 1189.687500, 719.598816, 0.0, ], [ 0.0, 0.0, 542.976746, 1323.979248, 1390.361206, 1456.743042, 861.797302, 0.0, ], [ 0.0, 0.0, 393.291565, 934.439697, 974.010376, 1013.428101, 586.924011, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 157.214920, 330.227448, 362.473419, 394.881653, 203.374420, 0.0, 0.0, ], [ 0.0, 424.340576, 867.495361, 937.900452, 1008.305542, 505.640503, 0.0, 0.0, ], [ 0.0, 578.894897, 1150.736084, 1221.141235, 1291.546265, 630.414001, 0.0, 0.0, ], [ 0.0, 738.682495, 1433.976929, 1504.381958, 1574.787109, 749.954346, 0.0, 0.0, ], [ 0.0, 429.912781, 816.507507, 850.771973, 884.873779, 411.152588, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 249.876541, 434.964233, 477.198730, 519.604675, 206.215576, 0.0, 0.0, 0.0, ], [ 560.309326, 949.520813, 1023.949097, 1098.377319, 422.458344, 0.0, 0.0, 0.0, ], [ 756.768127, 1248.946777, 1323.375000, 1397.803223, 521.289001, 0.0, 0.0, 0.0, ], [ 958.759216, 1548.372803, 1622.800903, 1697.229248, 614.587402, 0.0, 0.0, 0.0, ], [ 428.833923, 679.269775, 707.346252, 735.250916, 258.169373, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 0.0, 707.671387, 1033.687378, 1112.138916, 1190.590210, 328.295044, 0.0, ], [ 0.0, 0.0, 947.779419, 1349.298584, 1427.750000, 1506.201416, 399.438080, 0.0, ], [ 0.0, 0.0, 1193.718872, 1664.909668, 1743.361084, 1821.812500, 464.749847, 0.0, ], [ 0.0, 0.0, 388.737854, 532.503540, 553.962891, 575.241089, 140.658310, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 0.0, 880.797302, 1124.393555, 1206.868042, 1289.342651, 209.627625, 0.0, 0.0, ], [ 0.0, 1169.882812, 1456.189819, 1538.664429, 1621.138916, 247.754730, 0.0, 0.0, ], [ 0.0, 1465.098755, 1787.986084, 1870.460571, 1952.935181, 279.751526, 0.0, 0.0, ], [ 0.0, 297.330719, 356.362152, 369.893524, 383.234344, 50.974621, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ [ 1074.567993, 1219.497681, 1305.995361, 1392.493042, 71.162437, 0.0, 0.0, 0.0, ], [ 1416.214722, 1567.479126, 1653.976929, 1740.474609, 72.689949, 0.0, 0.0, 0.0, ], [ 1764.290771, 1915.460571, 2001.958496, 2088.456055, 67.787628, 0.0, 0.0, 0.0, ], [ 151.018372, 160.055023, 164.776138, 169.298447, 3.865937, 0.0, 0.0, 0.0, ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ]], &device, ), bias: TestTensor::from_floats([48., 48.], &device), }; test.assert_grads(grads); } struct Conv2dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, dilation_1: usize, dilation_2: usize, groups: usize, offset_groups: usize, height: usize, width: usize, } struct Grads { x: TestTensor<4>, offset: TestTensor<4>, weight: TestTensor<4>, mask: TestTensor<4>, bias: TestTensor<1>, } impl Conv2dTestCase { fn assert_grads(self, expected_grads: Grads) { let out_height = (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1) / self.stride_1 + 1; let out_width = (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1) / self.stride_2 + 1; let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let shape_offset = Shape::new([ self.batch_size, 2 * self.offset_groups * self.kernel_size_1 * self.kernel_size_2, out_height, out_width, ]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, ]); let shape_mask = Shape::new([ self.batch_size, self.offset_groups * self.kernel_size_1 * self.kernel_size_2, out_height, out_width, ]); let device = Default::default(); let weight = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight) .into_data(), &device, ) .require_grad(); let bias = TestAutodiffTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ) .require_grad(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let offset = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device) .reshape::<4, _>(shape_offset.clone()) .into_data(), &device, ) .div_scalar(shape_offset.num_elements() as f32) .require_grad(); let mask = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device) .reshape::<4, _>(shape_mask.clone()) .into_data(), &device, ) .div_scalar(shape_mask.num_elements() as f32) .require_grad(); let output = deform_conv2d( x.clone(), offset.clone(), weight.clone(), Some(mask.clone()), Some(bias.clone()), DeformConvOptions::new( [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.dilation_1, self.dilation_2], self.groups, self.offset_groups, ), ); let grads = output.backward(); // Assert let x_grad_actual = x.grad(&grads).unwrap(); let offset_grad_actual = offset.grad(&grads).unwrap(); let weight_grad_actual = weight.grad(&grads).unwrap(); let mask_grad_actual = mask.grad(&grads).unwrap(); let bias_grad_actual = bias.grad(&grads).unwrap(); // Relative is set to 5%, which is much higher than typical numerical test tolerances. // This is due to the complexity of the deformable convolution operation. // Unlike regular conv2d, which samples from fixed integer grid positions, // deformable conv2d samples input values at fractional offset locations (learned offsets). // These non-integer positions require bilinear interpolation to estimate the input value. // Gradients computed through all these floating-point operations can compound numerical differences. let tolerance = Tolerance::relative(0.5); println!("Testing bias"); expected_grads .bias .to_data() .assert_approx_eq::(&bias_grad_actual.to_data(), tolerance); println!("Testing input"); expected_grads .x .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), tolerance); println!("Testing offset"); expected_grads .offset .to_data() .assert_approx_eq::(&offset_grad_actual.to_data(), tolerance); println!("Testing mask"); expected_grads .mask .to_data() .assert_approx_eq::(&mask_grad_actual.to_data(), tolerance); println!("Testing weight"); expected_grads .weight .to_data() .assert_approx_eq::(&weight_grad_actual.to_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/div.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_div() { let data_1 = TensorData::from([1.0, 7.0]); let data_2 = TensorData::from([4.0, 7.0]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().div(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([0.25, 0.14285715]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([-0.0625, -0.14285715]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_div_scalar() { let data = TensorData::from([1.0, 7.0]); let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad(); let tensor_out = tensor.clone().div_scalar(4.0); let grads = tensor_out.backward(); let grad = tensor.grad(&grads).unwrap(); grad.to_data() .assert_eq(&TensorData::from([0.25, 0.25]), false); } #[test] fn test_div_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().div(tensor_2.clone()); let tensor_5 = tensor_4.div(tensor_3.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let grad_3 = tensor_3.grad(&grads).unwrap(); let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]); grad_3 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_div_complex_2() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.div(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_absolute(2e-3); let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/erf.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_erf() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/exp.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_exp() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default(); let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/expand.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_expand() { // Python code to generate the test case values // import torch // x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True) // x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True) // y = x1.expand(4, 4) // z = (x2 * y).sum() // z.backward() // print("x1", x1.grad) // print("x2", x2.grad) let device = Default::default(); let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]); let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad(); let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]); let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().expand([4, 4]); // Use unsqueeze to make tensor_2 have the same shape as tensor_3 let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum(); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([8., 18., 28., 12.]), false); grad_2 .to_data() .assert_eq(&TensorData::from([16., 28., 8., 12.]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/flip.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_flip() { let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2 let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3 let device = Default::default(); let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().flip([1, 2]); let tensor_4 = tensor_1.clone().matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); grad_1 .into_data() .assert_approx_eq::(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2 grad_2.into_data().assert_approx_eq::( &TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]), tolerance, ); // 1x2x3 } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/floor.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_floor() { let data = TensorData::from([ [-1.9751, 0.0714, 0.0643, 0.2406], [-1.3172, 0.1252, -0.1119, -0.0127], ]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); let tensor_2 = tensor_1.clone().floor(); let grads = tensor_2.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/gather_scatter.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData}; #[test] fn test_gather_grad() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data( TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), &device, ) .require_grad(); let indices = Tensor::::from_data( TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]), &device, ); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1.clone().gather(1, indices); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[94., 150., 187.], [242., 305., 304.]]), false, ); } #[test] fn test_scatter_grad() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data( TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), &device, ) .require_grad(); let values = TestAutodiffTensor::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ) .require_grad(); let indices = Tensor::::from_data( TensorData::from([[2, 1, 0], [2, 0, 1]]), &device, ); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1 .clone() .scatter(1, indices, values.clone(), IndexingUpdateOp::Add); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = values.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[127., 181., 235.], [226., 316., 406.]]), false, ); grad_2 .to_data() .assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false); } #[test] fn test_scatter_add_grad_partial_indices() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device) .require_grad(); let tensor_2 = TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device) .require_grad(); let values = TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad(); let indices = Tensor::::from_data(TensorData::from([[2, 1, 0]]), &device); let tensor_3 = tensor_1.clone().mul(tensor_2); let tensor_4 = tensor_3 .clone() .scatter(1, indices, values.clone(), IndexingUpdateOp::Add); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = values.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[1., 1., 1.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/gelu.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance, activation}; #[test] fn should_diff_gelu() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad(); let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone())); let x = tensor_1.clone().matmul(x); let grads = x.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::permissive(); let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/gradients.rs ================================================ use super::*; use burn_tensor::{Distribution, activation}; #[test] fn should_update_tensor_when_grad_replace() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad(); let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device); let x = tensor_1.clone().matmul(activation::gelu(tensor_2)); let mut grads = x.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_1_updated = TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad(); tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner()); let grad_1_new = tensor_1.grad(&grads).unwrap(); assert_ne!(grad_1_new.to_data(), grad_1.into_data()); assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/log.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_diff_log() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/log1p.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_log1p() { let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad(); let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/log_sigmoid.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn should_diff_log_sigmoid() { let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); let tensor_2 = activation::log_sigmoid(tensor_1.clone()); let grads = tensor_2.backward(); let grad = tensor_1.grad(&grads).unwrap(); let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/mask.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Bool, Tensor, TensorData}; #[test] fn should_diff_mask_fill() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let mask = TensorData::from([[true, false], [false, true]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let mask = Tensor::::from_bool(mask, &device); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.mask_fill(mask, 2.0); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false); } #[test] fn should_diff_mask_where() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad(); let mask = Tensor::::from_data([[true, false], [false, true]], &device); let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_4.clone().matmul(tensor_3.clone()); let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone()); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let grad_3 = tensor_3.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]); grad_1 .into_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]); grad_2 .into_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[15., 18.], [23., 29.]]); grad_3 .into_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/matmul.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_matmul() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false); tensor_3 .to_data() .assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false); } #[test] fn test_matmul_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_4.matmul(tensor_3); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false); } #[test] fn test_matmul_complex_2() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_4.matmul(tensor_3.clone()); let tensor_6 = tensor_1.clone().matmul(tensor_5); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/maxmin.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_max_dim() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_min_dim() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_min_dim_3d_dim1() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad(); let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); let tensor_4 = tensor_3.min_dim(1); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[[0., -7.], [2., 0.]]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[[0., 7.], [-2., -0.]]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/maxpool1d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::module::max_pool1d; #[test] fn test_max_pool1d_simple() { let kernel_size = 4; let padding = 0; let stride = 1; let dilation = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device); let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool1d_with_dilation() { let kernel_size = 4; let padding = 0; let stride = 1; let dilation = 2; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, 0.4610, 0.5365, 0.6880, ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<3>::from_floats( [[[ 0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0., 0., 0., 1., ]]], &device, ); let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool1d_complex() { let kernel_size = 4; let padding = 0; let stride = 1; let dilation = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, 0.4610, 0.5365, 0.6880, ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<3>::from_floats( [[[ 0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., 1., 1., ]]], &device, ); let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool1d_complex_with_padding() { let kernel_size = 4; let padding = 2; let stride = 1; let dilation = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ 0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073, 0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230, 0.4610, 0.5365, 0.6880, ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<3>::from_floats( [[[ 1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0., 1., 1., 3., ]]], &device, ); let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/maxpool2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::module::max_pool2d; #[test] fn test_max_pool2d_simple_1() { let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 0; let padding_2 = 0; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 1; let dilation_2 = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ [0.2479, 0.6386, 0.3166, 0.5742], [0.7065, 0.1940, 0.6305, 0.8959], [0.5416, 0.8602, 0.8129, 0.1662], [0.3358, 0.3059, 0.8293, 0.0990], ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<4>::from_floats( [[[ [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 2.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], ]]], &device, ); let output = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool2d_simple_2() { let kernel_size_1 = 2; let kernel_size_2 = 2; let padding_1 = 1; let padding_2 = 1; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 1; let dilation_2 = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ [0.2479, 0.6386, 0.3166, 0.5742], [0.7065, 0.1940, 0.6305, 0.8959], [0.5416, 0.8602, 0.8129, 0.1662], [0.3358, 0.3059, 0.8293, 0.0990], ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<4>::from_floats( [[[ [1., 3., 0., 2.], [3., 0., 0., 4.], [1., 4., 0., 1.], [2., 0., 3., 1.], ]]], &device, ); let output = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool2d_with_dilation() { let kernel_size_1 = 2; let kernel_size_2 = 2; let padding_1 = 1; let padding_2 = 1; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 2; let dilation_2 = 2; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ [0.2479, 0.6386, 0.3166, 0.5742], [0.7065, 0.1940, 0.6305, 0.8959], [0.5416, 0.8602, 0.8129, 0.1662], [0.3358, 0.3059, 0.8293, 0.0990], ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<4>::from_floats( [[[ [0., 0., 0., 0.], [1., 1., 1., 2.], [0., 4., 4., 0.], [0., 1., 2., 0.], ]]], &device, ); let output = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool2d_complex() { let kernel_size_1 = 4; let kernel_size_2 = 2; let padding_1 = 2; let padding_2 = 1; let stride_1 = 1; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; let device = Default::default(); let x = TestAutodiffTensor::from_floats( [[[ [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], ]]], &device, ) .require_grad(); let x_grad_expected = TestAutodiffTensor::<4>::from_floats( [[[ [0., 0., 0., 3., 0.], [4., 0., 2., 1., 0.], [0., 0., 0., 0., 0.], [2., 4., 0., 0., 0.], [0., 0., 0., 0., 2.], ]]], &device, ); let output = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } #[test] fn test_max_pool2d_ceil_mode() { // Test ceil_mode=true with gradient computation // Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0 // Floor mode: output 2x2 // Ceil mode: output 3x3 let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 0; let padding_2 = 0; let stride_1 = 2; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; let device = Default::default(); // Input (values 1-36): let x = TestAutodiffTensor::from_floats( [[[ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0, 17.0, 18.0], [19.0, 20.0, 21.0, 22.0, 23.0, 24.0], [25.0, 26.0, 27.0, 28.0, 29.0, 30.0], [31.0, 32.0, 33.0, 34.0, 35.0, 36.0], ]]], &device, ) .require_grad(); // Expected gradients for ceil_mode output 3x3: // Output positions and their max value positions: // (0,0): max at (2,2)=15 -> grad[2,2] += 1 // (0,1): max at (2,4)=17 -> grad[2,4] += 1 // (0,2): max at (2,5)=18 -> grad[2,5] += 1 // (1,0): max at (4,2)=27 -> grad[4,2] += 1 // (1,1): max at (4,4)=29 -> grad[4,4] += 1 // (1,2): max at (4,5)=30 -> grad[4,5] += 1 // (2,0): max at (5,2)=33 -> grad[5,2] += 1 // (2,1): max at (5,4)=35 -> grad[5,4] += 1 // (2,2): max at (5,5)=36 -> grad[5,5] += 1 let x_grad_expected = TestAutodiffTensor::<4>::from_floats( [[[ [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 1., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 1., 1.], [0., 0., 1., 0., 1., 1.], ]]], &device, ); let output = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], true, ); let grads = output.backward(); // Asserts let x_grad_actual = x.grad(&grads).unwrap(); x_grad_expected .to_data() .assert_approx_eq::(&x_grad_actual.to_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/memory_management.rs ================================================ use super::*; use burn_tensor::{Tensor, TensorData}; #[test] fn test_mm_independent_trees() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3; let tensor_6 = tensor_4 * tensor_5; // Second tree let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_11 = tensor_7.clone() * tensor_8.clone(); let tensor_12 = tensor_9.clone() * tensor_10.clone(); let tensor_13 = tensor_11 * tensor_12; let _grads = tensor_6.backward(); let grads = tensor_13.backward(); assert!(tensor_7.grad(&grads).is_some()); assert!(tensor_8.grad(&grads).is_some()); assert!(tensor_9.grad(&grads).is_some()); assert!(tensor_10.grad(&grads).is_some()); } #[test] #[should_panic] fn test_mm_crossover_trees_root_unavailable() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3; let tensor_6 = tensor_4.clone() * tensor_5; // Second tree let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_9 = tensor_7.clone() * tensor_8.clone(); let tensor_10 = tensor_4 * tensor_9; let _grads = tensor_6.backward(); let _grads = tensor_10.backward(); } #[test] fn test_mm_crossover_trees_with_referred_subtree() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3; let tensor_6 = tensor_4.clone() * tensor_5; // Second tree let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_9 = tensor_7.clone() * tensor_8.clone(); let _tensor_10 = tensor_4 * tensor_9.clone(); let _grads = tensor_6.backward(); let _grads = tensor_9.backward(); } #[test] fn test_mm_three_crossover_trees_last_still_usable() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3; let tensor_6 = tensor_4 * tensor_5.clone(); // Third tree let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_11 = tensor_7 * tensor_8; let tensor_12 = tensor_9 * tensor_10; let tensor_13 = tensor_11 * tensor_12.clone(); // Second tree (in between) let _tensor_14 = tensor_5 * tensor_12; let _grads = tensor_6.backward(); let _grads = tensor_13.backward(); } #[test] #[should_panic] fn test_mm_three_crossover_trees_middle_one_unavailable() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3; let tensor_6 = tensor_4 * tensor_5.clone(); // Third tree let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_11 = tensor_7 * tensor_8; let tensor_12 = tensor_9 * tensor_10; let _tensor_13 = tensor_11 * tensor_12.clone(); // Second tree (in between) let tensor_14 = tensor_5 * tensor_12; let _grads = tensor_6.backward(); let _grads = tensor_14.backward(); } #[test] fn test_mm_self_referencing_tree() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // First tree let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad(); let tensor_3 = tensor_0 * tensor_1; let tensor_5 = tensor_2 * tensor_3.clone(); let tensor_6 = tensor_3 * tensor_5; let _grads = tensor_6.backward(); } #[test] fn test_mm_with_non_impacting_detach() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_2 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_3 = Tensor::::from_data(data, &device).require_grad(); let tensor_4 = tensor_1.clone() * tensor_2.clone(); let tensor_5 = tensor_4.detach() * tensor_3.clone(); let grads = tensor_5.backward(); assert!(tensor_3.grad(&grads).is_some()); } #[test] fn test_mm_with_missing_require_grad_after_cleanup() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_2 = Tensor::::from_data(data.clone(), &device); let tensor_3 = Tensor::::from_data(data.clone(), &device); let tensor_4 = tensor_1.clone() * tensor_2.clone(); let tensor_5 = tensor_4 * tensor_3.clone(); // Trivial backward, just to trigger cleanup Tensor::::from_data(data, &device) .require_grad() .backward(); let grads = tensor_5.backward(); assert!(tensor_1.grad(&grads).is_some()); assert!(tensor_2.grad(&grads).is_none()); assert!(tensor_3.grad(&grads).is_none()); } #[test] fn test_mm_with_detach_after_cleanup() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_2 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_3 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_4 = tensor_1.clone() * tensor_2.clone(); let tensor_5 = tensor_4 * tensor_3.clone().detach(); // Trivial backward, just to trigger cleanup Tensor::::from_data(data, &device) .require_grad() .backward(); let grads = tensor_5.backward(); assert!(tensor_1.grad(&grads).is_some()); assert!(tensor_2.grad(&grads).is_some()); assert!(tensor_3.grad(&grads).is_none()); } #[test] #[should_panic] fn test_mm_deletables_propagate_well() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); let tensor_0 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_1 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_2 = tensor_0 * tensor_1; let tensor_3 = tensor_2.clone().exp(); let _tensor_4 = tensor_3.clone().log(); let _grads = tensor_2.backward(); // We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but // the intermediate tensor_3 as well let _grads = tensor_3.backward(); } #[test] fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() { let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); // The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative // By repeating it many times it becomes almost impossible that it passes if it shouldn't for _ in 0..12 { let tensor_0 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_1 = Tensor::::from_data(data.clone(), &device).require_grad(); let tensor_2 = tensor_1.clone().exp(); let tensor_3 = tensor_0.exp(); let _tensor_4 = tensor_3.clone() * tensor_2.clone(); let tensor_5 = tensor_2.exp(); let tensor_6 = tensor_5.exp(); let tensor_7 = tensor_6.exp(); let tensor_8 = tensor_7.exp(); // tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8 // which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search tensor_3.backward(); let grads = tensor_8.backward(); assert!(tensor_1.grad(&grads).is_some()); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/mod.rs ================================================ #[allow(unused_imports)] // required for re-included modules pub use super::*; mod abs; mod adaptive_avgpool1d; mod adaptive_avgpool2d; mod add; mod aggregation; mod avgpool1d; mod avgpool2d; mod backward; mod bridge; mod broadcast; mod cast; mod cat; mod ceil; mod checkpoint; mod complex; mod conv1d; mod conv2d; mod conv3d; mod conv_transpose1d; mod conv_transpose2d; mod conv_transpose3d; mod cross; mod cross_entropy; mod cummax; mod cummin; mod cumprod; mod cumsum; mod deform_conv2d; mod div; mod erf; mod exp; mod expand; mod flip; mod floor; mod gather_scatter; mod gelu; mod gradients; mod log; mod log1p; mod log_sigmoid; mod mask; mod matmul; mod maxmin; mod maxpool1d; mod maxpool2d; mod memory_management; mod mul; mod multithread; mod nearest_interpolate; mod neg; mod nonzero; mod permute; mod pow; mod recip; mod relu; mod remainder; mod repeat_dim; mod reshape; mod round; mod select; mod sigmoid; mod sign; mod slice; mod slice_assign; mod softmax; mod sort; mod sqrt; mod sub; mod transpose; mod trig; mod unfold; ================================================ FILE: crates/burn-backend-tests/tests/autodiff/mul.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_mul() { let data_1 = TensorData::from([1.0, 7.0]); let data_2 = TensorData::from([4.0, 7.0]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad(); let tensor_3 = tensor_1.clone().mul(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let _grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_eq(&data_2, false); tensor_3 .into_data() .assert_eq(&TensorData::from([4.0, 49.0]), false); } #[test] fn should_diff_mul_scalar() { let data = TensorData::from([2.0, 5.0]); let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad(); let tensor_out = tensor.clone().mul_scalar(4.0); let grads = tensor_out.backward(); let grad = tensor.grad(&grads).unwrap(); tensor_out .into_data() .assert_eq(&TensorData::from([8.0, 20.0]), false); grad.to_data() .assert_eq(&TensorData::from([4.0, 4.0]), false); } #[test] fn test_mul_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().mul(tensor_2.clone()); let tensor_5 = tensor_4.mul(tensor_3); let tensor_6 = tensor_1.clone().mul(tensor_5); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/multithread.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_behave_the_same_with_multithread() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let with_move = || { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_4.matmul(tensor_3); // Task 1 let tensor_1_cloned = tensor_1.clone(); let tensor_2_cloned = tensor_2.clone(); let tensor_5_cloned = tensor_5.clone(); let first_call = move || { let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned); tensor_6_1.matmul(tensor_1_cloned) }; // Task 2 let tensor_1_cloned = tensor_1.clone(); let tensor_2_cloned = tensor_2.clone(); let tensor_5_cloned = tensor_5; let second_call = move || { let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned); tensor_6_2.matmul(tensor_2_cloned) }; let tensor_7_1_handle = std::thread::spawn(first_call); let tensor_7_2_handle = std::thread::spawn(second_call); let tensor_7_1 = tensor_7_1_handle.join().unwrap(); let tensor_7_2 = tensor_7_2_handle.join().unwrap(); let tensor_8 = tensor_7_1.matmul(tensor_7_2); let grads = tensor_8.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); (grad_1, grad_2) }; let without_move = || { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_4.matmul(tensor_3); // Task 1 let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone()); let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone()); // Task 2 let tensor_6_2 = tensor_5.matmul(tensor_1.clone()); let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone()); let tensor_8 = tensor_7_1.matmul(tensor_7_2); let grads = tensor_8.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); (grad_1, grad_2) }; let (grad_1, grad_2) = without_move(); let (grad_1_moved, grad_2_moved) = with_move(); grad_1 .into_data() .assert_approx_eq::(&grad_1_moved.into_data(), Tolerance::default()); grad_2 .into_data() .assert_approx_eq::(&grad_2_moved.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/nearest_interpolate.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::interpolate; use burn_tensor::ops::{InterpolateMode, InterpolateOptions}; #[test] fn test_upsample_interpolation() { let test = InterpolateTestCase { batch_size: 2, channels: 1, height: 7, width: 5, height_out: 8, width_out: 7, }; test.assert_output(TestTensor::from([ [[ [4., 2., 4., 2., 2.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], ]], [[ [4., 2., 4., 2., 2.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], [2., 1., 2., 1., 1.], ]], ])); } #[test] fn test_downsample_interpolation() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 8, width: 8, height_out: 4, width_out: 6, }; test.assert_output(TestTensor::from([[[ [1., 1., 1., 0., 1., 1., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 0., 1., 1., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 0., 1., 1., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 0., 1., 1., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], ]]])); } struct InterpolateTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl InterpolateTestCase { fn assert_output(self, x_grad: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let device = Default::default(); let x = TestAutodiffTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device()) .reshape::<4, _>(shape_x) .into_data(), &device, ) .require_grad(); let output = interpolate( x.clone(), [self.height_out, self.width_out], InterpolateOptions::new(InterpolateMode::Nearest), ); let grads = output.backward(); let x_grad_actual = x.grad(&grads).unwrap(); x_grad .to_data() .assert_approx_eq::(&x_grad_actual.into_data(), Tolerance::permissive()); } } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/neg.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_neg() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg()); let tensor_4 = tensor_3.neg(); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/nonzero.rs ================================================ use super::*; use burn_tensor::{Bool, Tensor, TensorData}; #[test] fn should_diff_nonzero() { let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let data_2 = TensorData::from([-1.0, 1.0]); let mask = TensorData::from([[false, true], [true, false]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad(); // Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do // this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would // have dimensions different from the input, so this is somewhat equivalent. let mask = Tensor::::from_bool(mask, &device).flatten::<1>(0, 1); let indices = mask.nonzero(); let tensor_3 = tensor_1 .clone() .flatten::<1>(0, 1) .select(0, indices[0].clone()); // Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes let tensor_4 = tensor_2 .clone() .unsqueeze_dim::<2>(0) .matmul(tensor_3.unsqueeze_dim(1)); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([2.0, 3.0]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/permute.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_permute() { let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2 let data_2 = TensorData::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2 let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().permute([0, 2, 1]); let tensor_4 = tensor_1.clone().matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); grad_1 .into_data() .assert_approx_eq::(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2 grad_2.into_data().assert_approx_eq::( &TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]), tolerance, ); // 1x3x2 } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/pow.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_powf_scalar() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf_scalar(0.4)); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(2e-3); let expected = TensorData::from([[68.0, 79.0328], [68.0, 79.0328]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[23.5081, 25.2779], [26.0502, 28.6383]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn should_diff_powf() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad(); let tensor_3 = tensor_1.clone().powf(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([32.0, 14.0]); grad_1 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([11.09035, 95.34960]); grad_2 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([16.0, 49.0]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_powf_with_untracked_lhs() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device); let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad(); let tensor_3 = tensor_1.clone().powf(tensor_2.clone()); let grads = tensor_3.backward(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([11.09035, 95.34960]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_powf_with_untracked_rhs() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device); let tensor_3 = tensor_1.clone().powf(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let expected = TensorData::from([32.0, 14.0]); grad_1 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/recip.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_recip() { let data = TensorData::from([2.0, 5.0, 0.4]); let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad(); let tensor_out = tensor.clone().recip(); let grads = tensor_out.backward(); let grad = tensor.grad(&grads).unwrap(); tensor_out .into_data() .assert_eq(&TensorData::from([0.5, 0.2, 2.5]), false); grad.to_data().assert_approx_eq::( &TensorData::from([-0.25, -0.04, -6.25]), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/relu.rs ================================================ use super::*; use burn_tensor::{TensorData, activation}; #[test] fn should_diff_relu() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = activation::relu(tensor_3); let tensor_5 = tensor_4.matmul(tensor_2.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[-47.0, 9.0], [-35.0, 15.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[15.0, 13.0], [-2.0, 39.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/remainder.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_remainder() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data( TensorData::from([ 0.9742, 0.3676, 0.0905, 0.8066, 0.7072, 0.7883, 0.6987, 0.1560, 0.7179, 0.7874, 0.9032, 0.1845, ]), &device, ) .require_grad(); let tensor_2 = TestAutodiffTensor::<1>::from_data( TensorData::from([ 0.3357, 0.0285, 0.4115, 0.5511, 0.8637, 0.3593, 0.3885, 0.2569, 0.0936, 0.7172, 0.4792, 0.4898, ]), &device, ) .require_grad(); let tensor_3 = tensor_1.clone().remainder(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([ -2.0, -12.0, -0.0, -1.0, -0.0, -2.0, -1.0, -0.0, -7.0, -1.0, -1.0, -0.0, ]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/repeat_dim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_repeat() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0], [2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().repeat_dim(1, 3); let tensor_3 = tensor_1.matmul(tensor_3); let grads = tensor_3.backward(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_2 .to_data() .assert_eq(&TensorData::from([[-3.0], [12.0]]), false); } #[test] fn should_diff_repeat_multi_dim() { let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]); let data_2 = TensorData::from([[4.0, 2.0], [2.0, 4.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().repeat_dim(1, 3); let tensor_3 = tensor_1.matmul(tensor_3); let grads = tensor_3.backward(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_2 .to_data() .assert_eq(&TensorData::from([[-3.0, -3.0], [12.0, 12.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/reshape.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_reshape() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([4.0, 7.0, 2.0, 3.0]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().reshape([2, 2]); let tensor_4 = tensor_1.clone().matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([3.0, 3.0, 10.0, 10.0]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/round.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_round() { let data = TensorData::from([ [-1.9751, 0.0714, 0.0643, 0.2406], [-1.3172, 0.1252, -0.1119, -0.0127], ]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); let tensor_2 = tensor_1.clone().round(); let grads = tensor_2.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/select.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData}; #[test] fn test_select_grad() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data( TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), &device, ) .require_grad(); let indices = Tensor::::from_data(TensorData::from([1, 0]), &device); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1.clone().select(0, indices); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); grad_1.into_data().assert_eq( &TensorData::from([[109., 148., 187.], [37., 58., 79.]]), false, ); } #[test] fn test_select_add_grad() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data( TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), &device, ) .require_grad(); let values = TestAutodiffTensor::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ) .require_grad(); let indices = Tensor::::from_data(TensorData::from([1, 0]), &device); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1 .clone() .select_assign(0, indices, values.clone(), IndexingUpdateOp::Add); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = values.grad(&grads).unwrap(); grad_1.into_data().assert_eq( &TensorData::from([[127., 199., 271.], [172., 244., 316.]]), false, ); grad_2 .into_data() .assert_eq(&TensorData::from([[64., 64., 64.], [19., 19., 19.]]), false); } #[test] fn test_select_add_grad_different_shapes() { let device = Default::default(); let indices: Tensor = Tensor::from_ints([1], &device); let x: Tensor = Tensor::ones([1, 1], &device).require_grad(); let y = Tensor::ones([2, 1], &device).require_grad(); let w = y .clone() .select_assign(0, indices, x.clone(), IndexingUpdateOp::Add); let w = w.matmul(y.clone().transpose()); let grads = w.backward(); let x_grad = x.grad(&grads).unwrap(); let y_grad = y.grad(&grads).unwrap(); x_grad .into_data() .assert_eq(&TensorData::from([[2.0]]), false); y_grad .into_data() .assert_eq(&TensorData::from([[5.0], [5.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/sigmoid.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn should_diff_sigmoid() { let data = TensorData::from([0.8762]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad(); let tensor_2 = activation::sigmoid(tensor_1.clone()); let grads = tensor_2.backward(); let grad = tensor_1.grad(&grads).unwrap(); let expected = TensorData::from([0.207549]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn small_neg_val_should_not_cause_grad_overflow() { let data = TensorData::from([-90.0]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad(); let tensor_2 = activation::sigmoid(tensor_1.clone()); let grads = tensor_2.backward(); let grad = tensor_1.grad(&grads).unwrap(); let expected = TensorData::from([0.0]); grad.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/sign.rs ================================================ use super::*; use burn_tensor::TensorData; /// Example using the sign function with PyTorch: // >>> import torch // >>> # Create a tensor with requires_grad=True // >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) // >>> # Forward pass: Apply the sign function // >>> y = torch.sign(x) // >>> print("Forward pass:") // Forward pass: // >>> print("x:", x) // x: tensor([-2., -1., 0., 1., 2.], requires_grad=True) // >>> print("y:", y) // y: tensor([-1., -1., 0., 1., 1.], grad_fn=) // >>> # Compute the loss (just an example) // >>> loss = y.sum() // >>> # Backward pass: Compute the gradients // >>> loss.backward() // >>> print("\nBackward pass:") // Backward pass: // >>> print("x.grad:", x.grad) // x.grad: tensor([0., 0., 0., 0., 0.]) #[test] fn should_diff_sign() { let data = TensorData::from([-2.0, -1.0, 0.0, 1.0, 2.0]); let device = Default::default(); let x = TestAutodiffTensor::<1>::from_data(data, &device).require_grad(); let y = x.clone().sign(); let loss = y.clone().sum(); let grads = loss.backward(); let grad = x.grad(&grads).unwrap(); y.to_data() .assert_eq(&TensorData::from([-1., -1., 0., 1., 1.]), false); grad.to_data() .assert_eq(&TensorData::from([0., 0., 0., 0., 0.]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/slice.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_matmul_with_slice() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_2.clone().slice([0..2, 0..2]); let tensor_4 = tensor_1.clone().matmul(tensor_3); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false); grad_2.to_data().assert_eq( &TensorData::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]), false, ); } #[test] fn should_diff_matmul_with_slice_stepped() { use burn_tensor::s; let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]); let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]] let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]] let tensor_5 = tensor_3.clone().matmul(tensor_4); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_eq( &TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]), false, ); grad_2.to_data().assert_eq( &TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]), false, ); } #[test] fn should_panic_on_slice_with_step() { use burn_tensor::s; let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]); let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); // This should panic because step is 2 let _sliced = tensor.slice(s![.., 0..4; 2]); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/slice_assign.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_matmul_with_slice_assign() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_assigned = TensorData::from([[9.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_assigned = TestAutodiffTensor::from_data(data_assigned, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned); let tensor_5 = tensor_4.matmul(tensor_1.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[58.0, 38.0], [118.0, 82.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[16.0, 15.0], [24.0, 50.0]]), false); } #[test] fn should_diff_matmul_with_slice_assign_complex() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[9.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_5 = tensor_2.clone().slice([0..1, 0..1]); let tensor_6 = tensor_5.mul(tensor_3.clone()); let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6); let tensor_8 = tensor_7.matmul(tensor_1.clone()); let grads = tensor_8.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let grad_3 = tensor_3.grad(&grads).unwrap(); grad_3 .to_data() .assert_eq(&TensorData::from([[32.0]]), false); grad_1 .to_data() .assert_eq(&TensorData::from([[85.0, 65.0], [118.0, 82.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[88.0, 15.0], [24.0, 50.0]]), false); } #[test] fn slice_assign_diff_should_give_same_results_as_cat() { let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let data_2 = TensorData::from([[5.0, 6.0], [7.0, 8.0]]); let data_3 = TensorData::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device); let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default()); let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone()); let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone()); let slice_assign_output = slice_assign_output / tensor_3.clone(); let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1); let cat_output = cat_output / tensor_3; slice_assign_output .to_data() .assert_approx_eq::(&cat_output.to_data(), Tolerance::default()); let slice_assign_grads = slice_assign_output.backward(); let cat_grads = cat_output.backward(); let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap(); let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap(); let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap(); let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap(); slice_assign_grad_1 .to_data() .assert_approx_eq::(&cat_grad_1.to_data(), Tolerance::default()); slice_assign_grad_2 .to_data() .assert_approx_eq::(&cat_grad_2.to_data(), Tolerance::default()); } #[test] fn should_diff_slice_assign_with_step() { use burn_tensor::s; let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); let value_data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad(); // Assign with step=2 let result = tensor.clone().slice_assign(s![.., 0..4; 2], value.clone()); let result = result * 2.0; // Scale to create gradients let grads = result.backward(); let grad_tensor = tensor.grad(&grads).unwrap(); let grad_value = value.grad(&grads).unwrap(); // The gradient for tensor should be 2.0 everywhere except the assigned positions grad_tensor.to_data().assert_eq( &TensorData::from([[0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0]]), false, ); // The gradient for value should be 2.0 at all positions grad_value .to_data() .assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false); } #[test] fn should_diff_slice_assign_with_negative_step() { use burn_tensor::s; let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]); let value_data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]); let device = Default::default(); let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad(); // Assign with step=-1 (reverse order, all elements) let result = tensor.clone().slice_assign(s![.., ..;-1], value.clone()); let result = result * 2.0; // Scale to create gradients let grads = result.backward(); let grad_tensor = tensor.grad(&grads).unwrap(); let grad_value = value.grad(&grads).unwrap(); // The gradient for tensor should be 0 since all values were replaced grad_tensor.to_data().assert_eq( &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), false, ); // The gradient for value should be 2.0 at all positions grad_value.to_data().assert_eq( &TensorData::from([[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/softmax.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Tensor, TensorData, activation}; #[test] fn test_softmax_grad() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data_1, &device).require_grad(); let tensor_2 = Tensor::::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(0.05, 0.5)); let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(0.05, 0.05)); } #[test] fn test_log_softmax_grad() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data_1, &device).require_grad(); let tensor_2 = Tensor::::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]); // f16 gradients from log-softmax + matmul amplify error, so we increase the tolerance // to account for limited precision and large representable step sizes in this range. let tolerance = Tolerance::permissive().set_half_precision_relative(6e-2); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[30.5984, -47.2267], [55.9631, -56.5914]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_quiet_softmax_grad() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = Tensor::::from_data(data_1, &device).require_grad(); let tensor_2 = Tensor::::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]); // Precision is quite bad yet on softmax grad especially with half precision. let tolerance = Tolerance::rel_abs(0.5, 0.2); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/sort.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_sort() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1)); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_sort_with_indices() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let (values, _indices) = tensor_3.sort_with_indices(1); let tensor_4 = tensor_1.clone().mul(values); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_diff_sort_3d_dim1() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_floats([[[4.0, -7.0], [2.0, 3.0]]], &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone()); let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1)); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let expected = TensorData::from([[[-1., -8.], [-27., 37.]]]); grad_1 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[[-4., -17.], [-17., -42.]]]); grad_2 .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/sqrt.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_sqrt() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); let expected = TensorData::from([[82.112640, 99.083275], [82.112640, 99.083275]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[30.309311, 33.120457], [34.581974, 38.769463]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/sub.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_diff_sub() { let data_1 = TensorData::from([2.0, 5.0]); let data_2 = TensorData::from([4.0, 1.0]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().sub(tensor_2.clone()); let grads = tensor_3.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([1.0, 1.0]), false); grad_2 .to_data() .assert_eq(&TensorData::from([-1.0, -1.0]), false); tensor_3 .into_data() .assert_eq(&TensorData::from([-2.0, 4.0]), false); } #[test] fn should_diff_sub_scalar() { let data = TensorData::from([2.0, 10.0]); let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad(); let tensor_out = tensor.clone().sub_scalar(5.0); let grads = tensor_out.backward(); let grad = tensor.grad(&grads).unwrap(); grad.to_data() .assert_eq(&TensorData::from([1.0, 1.0]), false); tensor_out .into_data() .assert_eq(&TensorData::from([-3.0, 5.0]), false); } #[test] fn test_sub_complex_1() { let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1.clone().sub(tensor_2.clone()); let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0); let tensor_6 = tensor_1.clone().sub(tensor_5); let grads = tensor_6.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1 .to_data() .assert_eq(&TensorData::from([[0.0, 0.0], [0.0, 0.0]]), false); grad_2 .to_data() .assert_eq(&TensorData::from([[1.0, 1.0], [1.0, 1.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/transpose.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_transpose() { let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]); let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose()); let tensor_4 = tensor_3.transpose(); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[6.0, 10.0], [6.0, 10.0]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[3.0, 10.0], [3.0, 10.0]]), Tolerance::default(), ); } #[test] fn should_diff_swap_dims() { let device = Default::default(); let tensor_1 = TestAutodiffTensor::<3>::from_floats( [[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]], &device, ) .require_grad(); let tensor_2 = TestAutodiffTensor::from_floats( [[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]], &device, ) .require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2)); let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2)); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/trig.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_diff_cos() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); // Metal has less precise trigonometric functions let tolerance = Tolerance::default().set_half_precision_relative(1e-2); grad_1.to_data().assert_approx_eq::( &TensorData::from([[26.8063, -27.7870], [26.8063, -27.7870]]), tolerance, ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[9.222064, -39.123375], [-28.721354, 49.748356]]), tolerance, ); } #[test] fn should_diff_sin() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); // Metal has less precise trigonometric functions let tolerance = Tolerance::default().set_half_precision_relative(1e-2); let expected = TensorData::from([[8.8500, -4.9790], [8.8500, -4.9790]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[38.668987, 44.194775], [-59.97261, -80.46094]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn should_diff_tanh() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let tolerance = Tolerance::default().set_half_precision_relative(8e-3); let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]); grad_1 .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[8.00092, 8.000153], [8.000003, 7.999995]]); grad_2 .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn should_diff_cosh() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cosh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[7.092221, 16.696301], [7.092221, 16.696301]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[17.489855, 27.484539], [39.409813, 86.910278]]), Tolerance::default(), ); } #[test] fn should_diff_sinh() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sinh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[4.894847, 15.887931], [4.894847, 15.887931]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[17.284000, 28.412029], [39.302979, 87.498329]]), Tolerance::default(), ); } #[test] fn should_diff_tan() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [0.3, 0.8]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tan()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[2.532602, 1.596607], [2.532602, 1.596607]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[9.028598, 14.489801], [18.038082, 21.151270]]), Tolerance::default(), ); } #[test] fn should_diff_asin() { let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]); let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asin()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[0.435841, 0.969651], [0.435841, 0.969651]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[0.475300, 0.668141], [0.701834, 1.100658]]), Tolerance::default(), ); } #[test] fn should_diff_acos() { let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]); let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acos()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[2.077433, 1.543624], [2.077433, 1.543624]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[0.781337, 0.588496], [0.554804, 0.155979]]), Tolerance::default(), ); } #[test] fn should_diff_atan() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atan()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[3.444365, 5.349211], [3.444365, 5.349211]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[9.904911, 11.554912], [10.199631, 11.391938]]), Tolerance::default(), ); } #[test] fn should_diff_asinh() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asinh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[3.806625, 6.844869], [3.806625, 6.844869]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[11.442373, 14.842072], [14.022551, 17.688538]]), Tolerance::default(), ); } #[test] fn should_diff_acosh() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[1.5, 2.0], [2.5, 3.0]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acosh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[10.611752, 15.178907], [10.611752, 15.178907]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[20.112753, 20.247547], [20.402235, 22.487328]]), Tolerance::default(), ); } #[test] fn should_diff_atanh() { let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]); let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atanh()); let tensor_4 = tensor_3.matmul(tensor_2.clone()); let grads = tensor_4.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[0.441838, 1.037115], [0.441838, 1.037115]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[0.491723, 0.698110], [0.772763, 1.298805]]), Tolerance::default(), ); } #[test] fn should_diff_atan2() { let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]); let data_3 = TensorData::from([[1.0, 0.5], [2.0, 1.5]]); let device = Default::default(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad(); let tensor_4 = tensor_1 .clone() .matmul(tensor_2.clone().atan2(tensor_3.clone())); let tensor_5 = tensor_4.matmul(tensor_2.clone()); let grads = tensor_5.backward(); let grad_1 = tensor_1.grad(&grads).unwrap(); let grad_2 = tensor_2.grad(&grads).unwrap(); let grad_3 = tensor_3.grad(&grads).unwrap(); grad_1.to_data().assert_approx_eq::( &TensorData::from([[4.570492, 4.210785], [4.570492, 4.210785]]), Tolerance::default(), ); grad_2.to_data().assert_approx_eq::( &TensorData::from([[8.208448, 8.808449], [10.357923, 12.157923]]), Tolerance::default(), ); grad_3.to_data().assert_approx_eq::( &TensorData::from([[-1.8, -8.4], [-1.8, -5.6]]), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff/unfold.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn unfold_backward_accumulates_overlaps() { let device = Default::default(); let x = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0]], &device).require_grad(); let y = x.clone().unfold::<3, _>(1, 2, 1); let loss = y.sum(); let grads = loss.backward(); let grad_x = x.grad(&grads).unwrap(); grad_x .to_data() .assert_eq(&TensorData::from([[1., 2., 2., 1.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/autodiff.rs ================================================ //! Burn autodiff tests. #![allow( clippy::single_range_in_vec_init, clippy::duplicate_mod, reason = "false positive" )] extern crate alloc; pub type FloatElemType = f32; #[allow(unused)] pub type IntElemType = i32; #[path = "common/backend.rs"] mod backend; pub use backend::*; #[allow(clippy::module_inception)] #[path = "common/autodiff.rs"] mod autodiff; ================================================ FILE: crates/burn-backend-tests/tests/common/autodiff.rs ================================================ // Burn autodiff tests, reusable with element types. pub use super::*; #[path = "../autodiff/mod.rs"] mod base; mod checkpointing { pub use super::*; use burn_autodiff::checkpoint::strategy::BalancedCheckpointing; // Override type def pub type TestAutodiffBackend = Autodiff; pub type TestAutodiffTensor = Tensor; include!("../autodiff/mod.rs"); } use burn_backend_tests::test_float_elem_variant; // NOTE: this currently doesn't test checkpointing with different dtypes test_float_elem_variant!( f16, burn_tensor::f16, "../autodiff/mod.rs", ["vulkan", "cuda", "rocm", "metal"] ); // TODO: bf16 not yet supported on any backend for full test suite // test_float_elem_variant!( // bf16, // burn_tensor::bf16, // "../autodiff/mod.rs", // [] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul, metal/wgpu doesn't support bf16 // ); ================================================ FILE: crates/burn-backend-tests/tests/common/backend.rs ================================================ // Re-export use super::FloatElemType; // Default #[cfg(feature = "ndarray")] pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "tch")] pub type TestBackend = burn_tch::LibTorch; #[cfg(feature = "cuda")] pub type TestBackend = burn_cuda::Cuda; #[cfg(feature = "rocm")] pub type TestBackend = burn_rocm::Rocm; #[cfg(feature = "wgpu")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "cpu")] pub type TestBackend = burn_cpu::Cpu; #[cfg(feature = "router")] pub type TestBackend = burn_router::BackendRouter< burn_router::DirectByteChannel<(burn_ndarray::NdArray, burn_wgpu::Wgpu)>, >; /// Collection of types used across tests #[allow(unused)] pub mod prelude { pub use burn_autodiff::Autodiff; pub use burn_tensor::Tensor; use super::*; pub type TestTensor = Tensor; pub type TestTensorInt = Tensor; pub type TestTensorBool = Tensor; pub type FloatElem = burn_tensor::ops::FloatElem; pub type IntElem = burn_tensor::ops::IntElem; pub type TestAutodiffBackend = Autodiff; pub type TestAutodiffTensor = Tensor; } #[allow(unused)] pub use prelude::*; ================================================ FILE: crates/burn-backend-tests/tests/common/tensor.rs ================================================ // Burn backend tensor tests, reusable with element types. pub use super::*; #[path = "../tensor/clone_invariance.rs"] mod clone_invariance; #[cfg(feature = "std")] #[path = "../tensor/multi_threads.rs"] mod multi_threads; // Default float dtype #[path = "../tensor/float/mod.rs"] mod float; // Default integer dtype #[path = "../tensor/int/mod.rs"] mod int; // Default bool dtype #[path = "../tensor/bool/mod.rs"] mod bool; use burn_backend_tests::test_float_elem_variant; test_float_elem_variant!( f16, burn_tensor::f16, "../tensor/float/mod.rs", ["vulkan", "cuda", "rocm", "metal"] ); // TODO: bf16 not yet supported on any backend for full test suite // test_float_elem_variant!( // bf16, // burn_tensor::bf16, // "../tensor/float/mod.rs", // [] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul, metal/wgpu doesn't support bf16 // ); ================================================ FILE: crates/burn-backend-tests/tests/cubecl/avg_pool2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ Distribution, Tensor, TensorPrimitive, backend::Backend, module, ops::ModuleOps, }; #[test] fn avg_pool2d_should_match_reference_backend() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, &Default::default(), ); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let kernel_size = [3, 4]; let stride = [1, 2]; let padding = [1, 2]; let count_include_pad = true; let pooled = module::avg_pool2d( tensor, kernel_size, stride, padding, count_include_pad, false, ); let pooled_ref = module::avg_pool2d( tensor_ref, kernel_size, stride, padding, count_include_pad, false, ); pooled .into_data() .assert_approx_eq::(&pooled_ref.into_data(), Tolerance::default()); } #[test] fn avg_pool2d_backward_should_match_reference_backend() { let device = Default::default(); TestBackend::seed(&device, 0); ReferenceBackend::seed(&Default::default(), 0); let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default, &device); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let kernel_size = [3, 3]; let stride = [1, 1]; let padding = [1, 1]; let count_include_pad = true; let shape_out = module::avg_pool2d( tensor.clone(), kernel_size, stride, padding, count_include_pad, false, ) .shape(); let grad_output = Tensor::::random(shape_out, Distribution::Default, &Default::default()); let grad_output_ref = Tensor::::from_data(grad_output.to_data(), &Default::default()); let grad: Tensor = Tensor::from_primitive(TensorPrimitive::Float(TestBackend::avg_pool2d_backward( tensor.into_primitive().tensor(), grad_output.into_primitive().tensor(), kernel_size, stride, padding, count_include_pad, false, ))); let grad_ref: Tensor = Tensor::from_primitive(TensorPrimitive::Float( ReferenceBackend::avg_pool2d_backward( tensor_ref.into_primitive().tensor(), grad_output_ref.into_primitive().tensor(), kernel_size, stride, padding, count_include_pad, false, ), )); grad.into_data() .assert_approx_eq::(&grad_ref.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/bernoulli.rs ================================================ use super::*; use serial_test::serial; use core::f32; use burn_tensor::{Distribution, Shape, Tensor, backend::Backend}; use cubek::random::{assert_number_of_1_proportional_to_prob, assert_wald_wolfowitz_runs_test}; #[test] #[serial] fn number_of_1_proportional_to_prob() { let device = Default::default(); TestBackend::seed(&device, 0); let shape: Shape = [40, 40].into(); let prob = 0.7; let tensor = Tensor::::random(shape.clone(), Distribution::Bernoulli(prob), &device) .into_data(); let numbers = tensor .as_slice::<::FloatElem>() .unwrap(); assert_number_of_1_proportional_to_prob(numbers, prob as f32); } #[test] #[serial] fn wald_wolfowitz_runs_test() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = Shape::new([512, 512]); let device = Default::default(); let tensor = Tensor::::random(shape, Distribution::Bernoulli(0.5), &device); let data = tensor.into_data(); let numbers = data .as_slice::<::FloatElem>() .unwrap(); // High bound slightly over 1 so 1.0 is included in second bin assert_wald_wolfowitz_runs_test(numbers, 0., 1.1); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/cast.rs ================================================ use super::*; use burn_tensor::{Int, Tensor, TensorData}; #[test] fn should_cast_int_to_float() { const START: usize = 0; const END: usize = 100; let device = Default::default(); let tensor = Tensor::::arange(START as i64..END as i64, &device); let data_int = tensor.to_data(); let data_int = data_int.as_slice::().unwrap(); let data_float = tensor.float().into_data(); let data_float = data_float.as_slice::().unwrap(); for i in START..END { assert_eq!(data_int[i], i as i32); assert_eq!(data_float[i], i as f32); } } #[test] fn should_cast_bool_to_int() { let device = Default::default(); let tensor_1 = Tensor::::from_floats([[1., 0., 3.], [0., 0., 900.]], &device); let tensor_2: Tensor = tensor_1.clone().greater_elem(0.0).int(); tensor_2 .to_data() .assert_eq(&TensorData::from([[1, 0, 1], [0, 0, 1]]), false); } #[test] fn should_cast_bool_to_float() { let device = Default::default(); let tensor_1 = Tensor::::from_floats([[1., 0., 3.], [0., 0., 900.]], &device); let tensor_2: Tensor = tensor_1.clone().greater_elem(0.0).float(); tensor_2 .to_data() .assert_eq(&TensorData::from([[1., 0., 1.], [0., 0., 1.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/cat.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, backend::Backend}; #[test] fn cat_should_match_reference_backend_dim0() { test_same_as_reference([6, 256], 2, 0); } #[test] fn cat_should_match_reference_backend_dim1() { test_same_as_reference([6, 256], 2, 1); } #[test] fn cat_should_support_uneven_launch() { test_same_as_reference([1, 137], 2, 0); } fn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) { let device = Default::default(); TestBackend::seed(&device, 0); let tensors = (0..num_tensors) .map(|_| { Tensor::::random(shape, Distribution::Default, &Default::default()) }) .collect::>(); let tensors_ref = tensors .iter() .map(|tensor| { Tensor::::from_data(tensor.to_data(), &Default::default()) }) .collect::>(); let tensor = Tensor::::cat(tensors, dim); let tensor_ref = Tensor::::cat(tensors_ref, dim); tensor .into_data() .assert_approx_eq::(&tensor_ref.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/clamp.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor}; #[test] fn clamp_should_match_reference() { let input = Tensor::::random( [1, 5, 32, 32], Distribution::Default, &Default::default(), ); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); let output = input.clamp(0.3, 0.7); output.into_data().assert_approx_eq::( &input_ref.clamp(0.3, 0.7).into_data(), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/contiguous.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Int, Tensor}; #[test] pub fn into_contiguous_match_reference_backend_1() { for shape in [ [4, 4, 4, 4], [32, 42, 24, 48], [8, 3, 7, 4], [1, 4, 1, 1], [1, 32, 256, 128], ] { let num_elems = shape.iter().product::() as i64; let tensor: Tensor = Tensor::::arange(0..num_elems, &Default::default()) .reshape(shape) .float(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for (i, j) in get_combinations(shape.len()) { let view = tensor.clone().swap_dims(i, j); let view_ref = tensor_ref.clone().swap_dims(i, j); let data = view.into_data(); let data_ref = view_ref.into_data(); data_ref.assert_approx_eq::(&data, Tolerance::default()); } } } fn get_combinations(n: usize) -> impl Iterator { // Iterate from 0 up to n (0..n).flat_map(move |i| { // For each i, iterate from i + 1 up to n // This ensures no repeats (i == j) and no duplicates (j, i) (i + 1..n).map(move |j| (i, j)) }) } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/conv2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::ops::{ConvOptions, ModuleOps}; use burn_tensor::{Distribution, Tensor, TensorPrimitive, module}; #[test] fn conv2d_should_match_reference_backend() { let test_device = Default::default(); let input = Tensor::::random([6, 16, 32, 32], Distribution::Default, &test_device); let weight = Tensor::::random([12, 8, 3, 3], Distribution::Default, &test_device); let bias = Tensor::::random([12], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let options = ConvOptions::new([2, 3], [2, 3], [2, 3], 2); let output = module::conv2d(input, weight, Some(bias), options.clone()); let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); output .into_data() .assert_approx_eq::(&output_ref.into_data(), Tolerance::default()); } #[test] fn conv2d_should_match_reference_backend_implicit() { let test_device = Default::default(); let input = Tensor::::random([4, 16, 6, 6], Distribution::Default, &test_device); let weight = Tensor::::random([16, 16, 3, 3], Distribution::Default, &test_device); let bias = Tensor::::random([16], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let options = ConvOptions::new([1, 1], [2, 2], [1, 1], 1); let output = module::conv2d(input, weight, Some(bias), options.clone()); let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); let tolerance = Tolerance::default(); output .into_data() .assert_approx_eq::(&output_ref.into_data(), tolerance); } /// Regression test for bias loader in new implicit GEMM #[test] fn conv2d_should_match_reference_backend_bias_regression() { let test_device = Default::default(); let input = Tensor::::random([1, 1, 1, 1], Distribution::Default, &test_device); let weight = Tensor::::random([32, 1, 3, 3], Distribution::Default, &test_device); let bias = Tensor::::random([32], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 1); let output = module::conv2d(input, weight, Some(bias), options.clone()).permute([0, 2, 3, 1]); let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options).permute([0, 2, 3, 1]); let tolerance = Tolerance::default(); output .into_data() .assert_approx_eq::(&output_ref.into_data(), tolerance); } #[test] fn conv2d_weight_backward_should_run() { // https://github.com/tracel-ai/burn/issues/4226#issuecomment-3911335769 let device = Default::default(); let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 1); let x = Tensor::::random([1, 1, 1, 672], Distribution::Default, &device); // let x = x.permute([0, 3, 1, 2]); let output_grad = Tensor::::random([1, 168, 1, 1], Distribution::Default, &device); let weight = Tensor::::random([168, 672, 1, 1], Distribution::Default, &device); let ref_device = Default::default(); let x_ref = Tensor::::from_data(x.to_data(), &ref_device); let output_grad_ref = Tensor::::from_data(output_grad.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); // Input shape [672, 1] and strides [672, 672] should be valid let output = TestBackend::conv2d_weight_backward( x.permute([0, 3, 1, 2]).into_primitive().tensor(), weight.into_primitive().tensor(), output_grad.into_primitive().tensor(), options.clone(), ); // Input shape [672, 1] and strides [672, 672] should be valid let output_ref = ReferenceBackend::conv2d_weight_backward( x_ref.permute([0, 3, 1, 2]).into_primitive().tensor(), weight_ref.into_primitive().tensor(), output_grad_ref.into_primitive().tensor(), options, ); let tolerance = Tolerance::default(); Tensor::::from_primitive(TensorPrimitive::Float(output)) .into_data() .assert_approx_eq::( &Tensor::::from_primitive(TensorPrimitive::Float(output_ref)) .into_data(), tolerance, ); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/conv3d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, module}; #[test] fn conv3d_should_match_reference_backend() { let test_device = Default::default(); let input = Tensor::::random([6, 16, 32, 32, 32], Distribution::Default, &test_device); let weight = Tensor::::random([12, 8, 3, 3, 3], Distribution::Default, &test_device); let bias = Tensor::::random([12], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let options = burn_tensor::ops::ConvOptions::new([2, 3, 4], [2, 3, 4], [2, 3, 4], 2); let output = module::conv3d(input, weight, Some(bias), options.clone()); let output_ref = module::conv3d(input_ref, weight_ref, Some(bias_ref), options); output .into_data() .assert_approx_eq::(&output_ref.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/conv_transpose2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, backend::Backend, module}; #[test] fn conv_transpose2d_should_match_reference_backend() { let device = Default::default(); TestBackend::seed(&device, 0); let height = 8; let width = 8; let in_channels = 8; let out_channels = 8; let batch_size = 32; let kernel_size_0 = 3; let kernel_size_1 = 3; let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1); let test_device = Default::default(); let input = Tensor::::random( [batch_size, in_channels, height, width], Distribution::Default, &test_device, ); let weight = Tensor::::random( [ in_channels, out_channels / options.groups, kernel_size_0, kernel_size_1, ], Distribution::Default, &test_device, ); let bias = Tensor::::random([out_channels], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let output = module::conv_transpose2d(input, weight, Some(bias), options.clone()); let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options); output .into_data() .assert_approx_eq::(&output_ref.into_data(), Tolerance::rel_abs(0.01, 0.02)); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/conv_transpose3d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, backend::Backend, module}; #[test] fn conv_transpose3d_should_match_reference_backend() { let test_device = Default::default(); TestBackend::seed(&test_device, 0); let depth = 8; let height = 8; let width = 8; let in_channels = 8; let out_channels = 8; let batch_size = 32; let kernel_size_0 = 3; let kernel_size_1 = 3; let kernel_size_2 = 3; let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1], 1); let input = Tensor::::random( [batch_size, in_channels, depth, height, width], Distribution::Default, &test_device, ); let weight = Tensor::::random( [ in_channels, out_channels / options.groups, kernel_size_0, kernel_size_1, kernel_size_2, ], Distribution::Default, &test_device, ); let bias = Tensor::::random([out_channels], Distribution::Default, &test_device); let ref_device = Default::default(); let input_ref = Tensor::::from_data(input.to_data(), &ref_device); let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); let output = module::conv_transpose3d(input, weight, Some(bias), options.clone()); let output_ref = module::conv_transpose3d(input_ref, weight_ref, Some(bias_ref), options); output .into_data() .assert_approx_eq::(&output_ref.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/cross.rs ================================================ use super::*; use burn_tensor::Tensor; use burn_tensor::Tolerance; #[test] fn test_cross_product() { let device = Default::default(); // Test with well-known orthogonal vectors for clearer validation let a = Tensor::::from_data([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], &device); let b = Tensor::::from_data([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], &device); let result = a.cross(b, 1); // For orthogonal unit vectors: // i × j = k // j × k = i let expected = Tensor::::from_data([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], &device); // Use Tolerance for floating-point comparisons let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } #[test] fn test_cross_product_zeros() { let device = Default::default(); // Test cross product with zero vector - should always give zero vector let a = Tensor::::from_data([[2.0, 3.0, 4.0]], &device); let b = Tensor::::zeros([1, 3], &device); let result = a.cross(b, 1); let expected = Tensor::::zeros([1, 3], &device); // For zeros, we can use exact equality or a very tight tolerance let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } #[test] fn test_cross_product_batch() { let device = Default::default(); // Test typical cross product computations in batch let a = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let b = Tensor::::from_data([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device); let result = a.cross(b, 1); // Cross products: // [1,2,3] × [4,5,6] = [-3,6,-3] // [4,5,6] × [7,8,9] = [-3,6,-3] let expected = Tensor::::from_data([[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]], &device); let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } #[test] #[should_panic] fn test_cross_product_invalid_dimension() { let device = Default::default(); let a = Tensor::::zeros([1, 4], &device); let b = Tensor::::zeros([1, 4], &device); let _ = a.cross(b, 1); } #[test] fn test_cross_product_parallel_vectors() { let device = Default::default(); // Test cross product of parallel vectors (should be zero) let a = Tensor::::from_data([[1.0, 2.0, 3.0]], &device); let b = Tensor::::from_data([[2.0, 4.0, 6.0]], &device); // b = 2 * a let result = a.cross(b, 1); let expected = Tensor::::zeros([1, 3], &device); let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } #[test] fn test_cross_product_3d_tensor() { let device = Default::default(); // Test with 3D tensor (batch of matrices) let a = Tensor::::from_data( [ [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], ], &device, ); let b = Tensor::::from_data( [ [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], ], &device, ); let result = a.cross(b, 2); // Cross on last dimension let expected = Tensor::::from_data( [ [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], [[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]], ], &device, ); let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } // Test to verify that padding doesn't affect results #[test] fn test_cross_product_with_padding_awareness() { let device = Default::default(); // Create tensors that would span multiple 4-element blocks // This tests that the padding doesn't corrupt adjacent data let a = Tensor::::from_data( [ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], // Two vectors: [1,2,3] and [4,5,6] ], &device, ); let b = Tensor::::from_data( [ [7.0, 8.0, 9.0, 10.0, 11.0, 12.0], // Two vectors: [7,8,9] and [10,11,12] ], &device, ); // Reshape to have proper 3-element vectors in last dimension let a_reshaped = a.reshape([2, 3]); let b_reshaped = b.reshape([2, 3]); let result = a_reshaped.cross(b_reshaped, 1); // Expected cross products: // [1,2,3] × [7,8,9] = [-6,12,-6] // [4,5,6] × [10,11,12] = [-6,12,-6] let expected = Tensor::::from_data([[-6.0, 12.0, -6.0], [-6.0, 12.0, -6.0]], &device); let tolerance = Tolerance::::default(); result .to_data() .assert_approx_eq(&expected.to_data(), tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/gather.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Int, Shape, Tensor, backend::Backend}; #[test] fn gather_should_work_with_multiple_workgroups_dim0() { test_same_as_ref([6, 256], 0); } #[test] fn gather_should_work_with_multiple_workgroups_dim1() { test_same_as_ref([6, 256], 1); } fn test_same_as_ref(shape: [usize; D], dim: usize) { let device = Default::default(); TestBackend::seed(&device, 0); let max = shape[dim]; let shape = Shape::new(shape); let tensor = Tensor::::random(shape.clone(), Distribution::Default, &Default::default()); let indices = Tensor::::from_data( Tensor::::random( [shape.num_elements()], Distribution::Uniform(0., max as f64), &Default::default(), ) .into_data(), &Default::default(), ) .reshape(shape); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let indices_ref = Tensor::::from_data(indices.to_data(), &Default::default()); let actual = tensor.gather(dim, indices); let expected = tensor_ref.gather(dim, indices_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/mask_fill.rs ================================================ use super::*; use burn_cubecl::kernel::{MaskFillStrategy, mask_fill}; use burn_tensor::Tolerance; use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend}; use cubecl::prelude::InputScalar; #[test] fn mask_fill_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); let dtype_bool = <::BoolElem as Element>::dtype(); let dtype_ft = ::dtype(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill( tensor.into_primitive().tensor(), mask.into_primitive(), InputScalar::new(4.0, dtype_ft), MaskFillStrategy::Readonly, dtype_bool, ))); let expected = tensor_ref.mask_fill(mask_ref, 4.0); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn mask_fill_inplace_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); let dtype_bool = <::BoolElem as Element>::dtype(); let dtype_ft = ::dtype(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill::<_>( tensor.into_primitive().tensor(), mask.into_primitive(), InputScalar::new(4.0, dtype_ft), MaskFillStrategy::Inplace, dtype_bool, ))); let expected = tensor_ref.mask_fill(mask_ref, 4.0); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[allow(clippy::type_complexity)] fn inputs_mask_fill() -> ( Tensor, Tensor, Tensor, Tensor, ) { let test_device = Default::default(); let tensor = Tensor::::random([2, 6, 256], Distribution::Default, &test_device); let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.), &test_device) .lower_equal_elem(0.5); let ref_device = Default::default(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &ref_device); let mask_ref = Tensor::::from_data(mask.to_data(), &ref_device); (tensor, mask, tensor_ref, mask_ref) } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/mask_where.rs ================================================ use super::*; use burn_cubecl::kernel::{MaskWhereStrategy, mask_where}; use burn_tensor::Tolerance; use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend}; #[test] fn mask_where_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let actual = tensor.mask_where(mask, value); let expected = tensor_ref.mask_where(mask_ref, value_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn mask_where_inplace_lhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let dtype_bool = <::BoolElem as Element>::dtype(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where::<_>( tensor.into_primitive().tensor(), mask.into_primitive(), value.into_primitive().tensor(), MaskWhereStrategy::InplaceLhs, dtype_bool, ))); let expected = tensor_ref.mask_where(mask_ref, value_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn mask_where_inplace_rhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let dtype_bool = <::BoolElem as Element>::dtype(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where::<_>( tensor.into_primitive().tensor(), mask.into_primitive(), value.into_primitive().tensor(), MaskWhereStrategy::InplaceRhs, dtype_bool, ))); let expected = tensor_ref.mask_where(mask_ref, value_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[allow(clippy::type_complexity)] fn inputs_mask_where() -> ( Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, ) { let device = Default::default(); TestBackend::seed(&device, 0); let tensor = Tensor::::random([2, 6, 256], Distribution::Default, &device); let value = Tensor::::random([2, 6, 256], Distribution::Default, &device); let mask = Tensor::::random([2, 6, 256], Distribution::Uniform(0., 1.), &device) .lower_equal_elem(0.5); let device_ref = Default::default(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &device_ref); let value_ref = Tensor::::from_data(value.to_data(), &device_ref); let mask_ref = Tensor::::from_data(mask.to_data(), &device_ref); mask.to_data().assert_eq(&mask_ref.to_data(), false); (tensor, value, mask, tensor_ref, value_ref, mask_ref) } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/max_pool2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, module}; #[test] pub fn max_pool2d_should_match_reference_backends() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, &Default::default(), ); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let kernel_size = [3, 3]; let stride = [2, 2]; let padding = [1, 1]; let dilation = [1, 1]; let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation, false); let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation, false); pooled .into_data() .assert_approx_eq::(&pooled_ref.into_data(), Tolerance::default()); } #[test] pub fn max_pool2d_with_indices_should_match_reference_backend() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, &Default::default(), ); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let kernel_size = [3, 3]; let stride = [2, 2]; let padding = [1, 1]; let dilation = [1, 1]; let (pooled, indices) = module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation, false); let (pooled_ref, indices_ref) = module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation, false); pooled .into_data() .assert_approx_eq::(&pooled_ref.into_data(), Tolerance::default()); indices .into_data() .assert_eq(&indices_ref.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/max_pool2d_backward.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor, TensorPrimitive, module, ops::ModuleOps}; #[test] pub fn max_pool2d_with_indices_backward_should_match_reference_backend() { let test_device = Default::default(); let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default, &test_device); let grad_output = Tensor::::random([32, 32, 16, 16], Distribution::Default, &test_device); let ref_device = Default::default(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &ref_device); let grad_output_ref = Tensor::::from_data(grad_output.to_data(), &ref_device); let kernel_size = [3, 3]; let stride = [2, 2]; let padding = [1, 1]; let dilation = [1, 1]; let (_, indices) = module::max_pool2d_with_indices( tensor.clone(), kernel_size, stride, padding, dilation, false, ); let (_, indices_ref) = module::max_pool2d_with_indices( tensor_ref.clone(), kernel_size, stride, padding, dilation, false, ); let grad = TestBackend::max_pool2d_with_indices_backward( tensor.into_primitive().tensor(), kernel_size, stride, padding, dilation, false, grad_output.into_primitive().tensor(), indices.into_primitive(), ) .x_grad; let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward( tensor_ref.into_primitive().tensor(), kernel_size, stride, padding, dilation, false, grad_output_ref.into_primitive().tensor(), indices_ref.into_primitive(), ) .x_grad; Tensor::::from_primitive(TensorPrimitive::Float(grad)) .into_data() .assert_approx_eq::( &Tensor::::from_primitive(TensorPrimitive::Float(grad_ref)) .into_data(), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/mod.rs ================================================ // #[allow(unused_imports)] // required for re-included modules pub use super::*; mod avg_pool2d; mod bernoulli; mod cast; mod cat; mod clamp; mod contiguous; mod conv2d; mod conv3d; mod conv_transpose2d; mod conv_transpose3d; mod cross; mod gather; mod mask_fill; mod mask_where; mod max_pool2d; mod max_pool2d_backward; mod normal; mod quantization; mod reduce; mod repeat_dim; mod scatter; mod select; mod select_assign; mod slice; mod slice_assign; mod unary; mod uniform; ================================================ FILE: crates/burn-backend-tests/tests/cubecl/normal.rs ================================================ use super::*; use burn_tensor::{Distribution, Shape, Tensor, backend::Backend}; use cubek::random::{assert_mean_approx_equal, assert_normal_respects_68_95_99_rule}; use serial_test::serial; #[test] #[serial] fn empirical_mean_close_to_expectation() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = [100, 100]; let mean = 10.; let tensor = Tensor::::random(shape, Distribution::Normal(mean, 2.), &device) .into_data(); let numbers = tensor.as_slice::().unwrap(); assert_mean_approx_equal(numbers, mean as f32); } #[test] #[serial] fn normal_respects_68_95_99_rule() { // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule let shape: Shape = [1000, 1000].into(); let device = Default::default(); let mu = 0.; let s = 1.; let tensor = Tensor::::random(shape.clone(), Distribution::Normal(mu, s), &device) .into_data(); let numbers = tensor.as_slice::().unwrap(); assert_normal_respects_68_95_99_rule(numbers, mu as f32, s as f32); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/quantization.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ Shape, Tensor, backend::Backend, quantization::{QuantLevel, QuantScheme, QuantStore, QuantValue}, }; fn should_quantize_dequantize_symmetric_arange>( value: QuantValue, store: QuantStore, shape: S, ) { let shape = shape.into(); assert_eq!(shape.rank(), 2); // 2D tests let scheme = QuantScheme::default().with_value(value).with_store(store); let scheme_ref = scheme.clone().with_store(QuantStore::Native); let input: Tensor = Tensor::arange(0..shape.num_elements() as i64, &Default::default()) .float() .reshape(shape); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); let output = input.quantize_dynamic(&scheme); let output_ref = input_ref.quantize_dynamic(&scheme_ref); output.to_data().assert_eq(&output_ref.to_data(), false); let output = output.dequantize(); let output_ref = output_ref.dequantize(); output .into_data() .assert_approx_eq::(&output_ref.to_data(), Tolerance::default()); } fn should_quantize_dequantize_symmetric_per_block_arange>( value: QuantValue, block_size: usize, store: QuantStore, shape: S, ) { let scheme = QuantScheme::default() .with_value(value) .with_level(QuantLevel::block([block_size as u8])) .with_store(store); let scheme_ref = scheme.clone().with_store(QuantStore::Native); let shape = shape.into(); let input: Tensor = Tensor::arange(0..shape.num_elements() as i64, &Default::default()) .float() .reshape(shape); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); let output = input.quantize_dynamic(&scheme); let output_ref = input_ref.quantize_dynamic(&scheme_ref); output.to_data().assert_eq(&output_ref.to_data(), false); let output = output.dequantize(); let output_ref = output_ref.dequantize(); output .into_data() .assert_approx_eq::(&output_ref.to_data(), Tolerance::default()); } fn should_quantize_dequantize_symmetric_per_block( value: QuantValue, block_size: usize, store: QuantStore, ) { let scheme = QuantScheme::default() .with_value(value) .with_level(QuantLevel::block([block_size as u8])) .with_store(store); let scheme_ref = scheme.clone().with_store(QuantStore::Native); let input = Tensor::::from_floats( [ [ -1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5, 0.01, 0.025, 0.03, 0.04, 0.01, 0.025, 0.03, 0.04, ], [ 1.8, 1.0, 0.0, -0.5, 1.8, 1.0, 0.0, -0.5, -0.01, -0.025, -0.03, -0.04, -0.01, -0.025, -0.03, -0.04, ], ], &Default::default(), ); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); let output = input.quantize_dynamic(&scheme); let output_ref = input_ref.quantize_dynamic(&scheme_ref); output.to_data().assert_eq(&output_ref.to_data(), false); let output = output.dequantize(); let output_ref = output_ref.dequantize(); output .into_data() .assert_approx_eq::(&output_ref.to_data(), Tolerance::default()); } fn supports_native() -> bool { let name = ::name(&Default::default()); // TODO: Proper checks for i8 support. name.contains("cuda") || name.contains("rocm") || name.contains("hip") || name.contains("vulkan") || name.contains("spirv") || name.contains("metal") || name.contains("msl") } #[test] fn should_quantize_dequantize_symmetric_arange_q8s_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_arange_q8f_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q8F, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_arange_q4s_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q4S, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_arange_q4f_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q4F, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_arange_q2s_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q2S, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_arange_q2f_packed() { should_quantize_dequantize_symmetric_arange(QuantValue::Q2F, QuantStore::PackedU32(0), [8, 16]) } #[test] fn should_quantize_dequantize_symmetric_per_block_q8s_packed() { should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::PackedU32(0)) } #[test] fn should_quantize_dequantize_symmetric_per_block_q4s_packed() { should_quantize_dequantize_symmetric_per_block(QuantValue::Q4S, 8, QuantStore::PackedU32(0)) } #[test] #[should_panic = "Block size must be divisible by 16"] fn should_panic_when_block_size_cannot_store_num_quants() { // num_quants in u32 = 32 bits / 2 bits = 16 should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 8, QuantStore::PackedU32(0)) } #[test] fn should_quantize_dequantize_symmetric_per_block_q2s_packed() { should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 16, QuantStore::PackedU32(0)) } #[test] fn should_quantize_dequantize_symmetric_arange_q8s_native() { if supports_native() { should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::Native, [32, 32]) } } #[test] fn should_quantize_dequantize_symmetric_per_block_q8s_native() { if supports_native() { should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::Native) } } #[test] fn should_quantize_dequantize_symmetric_per_block_arange_q8s_packed() { should_quantize_dequantize_symmetric_per_block_arange( QuantValue::Q8S, 32, QuantStore::PackedU32(0), [32, 32], ) } #[test] fn should_quantize_dequantize_symmetric_per_block_arange_q8s_native() { if supports_native() { should_quantize_dequantize_symmetric_per_block_arange( QuantValue::Q8S, 32, QuantStore::Native, [32, 32], ) } } #[test] fn should_quantize_dequantize_symmetric_arange_128x256_q8s_native() { if supports_native() { should_quantize_dequantize_symmetric_per_block_arange( QuantValue::Q8S, 32, QuantStore::Native, [128, 256], ) } } #[test] fn should_quantize_dequantize_symmetric_arange_128x256_q8s_packed() { should_quantize_dequantize_symmetric_per_block_arange( QuantValue::Q8S, 32, QuantStore::PackedU32(0), [128, 256], ) } #[test] #[should_panic = "Can't store in u32"] fn should_panic_when_shape_cannot_store_quants() { let device = Default::default(); let scheme = QuantScheme::default(); let _tensor_1 = Tensor::::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device) .quantize_dynamic(&scheme); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/reduce.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor}; const RANK: usize = 4; const SHAPE: [usize; RANK] = [2, 4, 8, 16]; #[test] fn reduction_argmax_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for dim in 0..RANK { tensor .clone() .argmax(dim) .into_data() .assert_eq(&tensor_ref.clone().argmax(dim).into_data(), false); } } #[test] fn reduction_argmin_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for dim in 0..RANK { tensor .clone() .argmin(dim) .into_data() .assert_eq(&tensor_ref.clone().argmin(dim).into_data(), false); } } #[test] fn reduction_mean_dim_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for dim in 0..RANK { tensor .clone() .mean_dim(dim) .into_data() .assert_approx_eq::( &tensor_ref.clone().mean_dim(dim).into_data(), Tolerance::default(), ); } } #[test] fn reduction_mean_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); tensor .clone() .mean() .into_data() .assert_approx_eq::( &tensor_ref.clone().mean().into_data(), Tolerance::default(), ); } #[test] fn reduction_prod_dim_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for dim in 0..RANK { tensor .clone() .prod_dim(dim) .into_data() .assert_approx_eq::( &tensor_ref.clone().prod_dim(dim).into_data(), Tolerance::default(), ); } } #[test] fn reduction_prod_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); tensor .clone() .prod() .into_data() .assert_approx_eq::( &tensor_ref.clone().prod().into_data(), Tolerance::default(), ); } #[test] fn reduction_sum_dim_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); for dim in 0..RANK { tensor .clone() .sum_dim(dim) .into_data() .assert_approx_eq::( &tensor_ref.clone().sum_dim(dim).into_data(), Tolerance::default(), ); } } #[test] fn reduction_sum_should_match_reference_backend() { let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); tensor .clone() .sum() .into_data() .assert_approx_eq::(&tensor_ref.clone().sum().into_data(), Tolerance::default()); } #[test] #[ignore = "Impossible to run unless you have tons of VRAM. Also reference backend is broken."] fn reduction_sum_should_match_reference_backend_64bit() { const SHAPE: [usize; RANK] = [33, 512, 512, 512]; let tensor = Tensor::::random(SHAPE, Distribution::Default, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let data = tensor.sum().into_data(); let data_ref = tensor_ref.sum().into_data(); println!("result: {:?}", data.as_slice::()); data.assert_approx_eq::(&data_ref, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/repeat_dim.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor}; #[test] fn repeat_dim_0_few_times() { let tensor = Tensor::::random([1, 6, 6], Distribution::Default, &Default::default()); let dim = 0; let times = 4; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let actual = tensor.repeat_dim(dim, times); let expected = tensor_ref.repeat_dim(dim, times); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn repeat_dim_1_few_times() { let tensor = Tensor::::random([6, 1, 6], Distribution::Default, &Default::default()); let dim = 1; let times = 4; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let actual = tensor.repeat_dim(dim, times); let expected = tensor_ref.repeat_dim(dim, times); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn repeat_dim_2_few_times() { let tensor = Tensor::::random([6, 6, 1], Distribution::Default, &Default::default()); let dim = 2; let times = 4; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let actual = tensor.repeat_dim(dim, times); let expected = tensor_ref.repeat_dim(dim, times); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } #[test] fn repeat_dim_2_many_times() { let tensor = Tensor::::random([10, 10, 1], Distribution::Default, &Default::default()); let dim = 2; let times = 200; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let actual = tensor.repeat_dim(dim, times); let expected = tensor_ref.repeat_dim(dim, times); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/scatter.rs ================================================ use super::*; use burn_tensor::{Distribution, Int, Tensor, backend::Backend}; use burn_tensor::{IndexingUpdateOp, Tolerance}; #[test] fn scatter_should_work_with_multiple_workgroups_2d_dim0() { same_as_reference_same_shape(0, [256, 32]); } #[test] fn scatter_should_work_with_multiple_workgroups_2d_dim1() { same_as_reference_same_shape(1, [32, 256]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim0() { same_as_reference_same_shape(0, [256, 6, 6]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim1() { same_as_reference_same_shape(1, [6, 256, 6]); } #[test] fn scatter_should_work_with_multiple_workgroups_3d_dim2() { same_as_reference_same_shape(2, [6, 6, 256]); } #[test] fn scatter_should_work_with_multiple_workgroups_diff_shapes() { same_as_reference_diff_shape(1, [32, 128], [32, 1]); } fn same_as_reference_diff_shape( dim: usize, shape1: [usize; D], shape2: [usize; D], ) { let test_device = Default::default(); TestBackend::seed(&test_device, 0); let tensor = Tensor::::random(shape1, Distribution::Default, &test_device); let value = Tensor::::random(shape2, Distribution::Default, &test_device); let indices = Tensor::::random( [shape2.iter().product::()], Distribution::Uniform(0., shape2[dim] as f64), &test_device, ) .reshape(shape2); let ref_device = Default::default(); let tensor_ref = Tensor::::from_data(tensor.to_data(), &ref_device); let value_ref = Tensor::::from_data(value.to_data(), &ref_device); let indices_ref = Tensor::::from_data(indices.to_data(), &ref_device); let actual = tensor.scatter(dim, indices, value, IndexingUpdateOp::Add); let expected = tensor_ref.scatter(dim, indices_ref, value_ref, IndexingUpdateOp::Add); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } fn same_as_reference_same_shape(dim: usize, shape: [usize; D]) { same_as_reference_diff_shape(dim, shape, shape); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/select.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Int, Tensor}; #[test] fn select_should_work_with_multiple_workgroups() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let indices = Tensor::::arange(0..100, &Default::default()); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let indices_ref = Tensor::::from_data(indices.to_data(), &Default::default()); let actual = tensor.select(1, indices); let expected = tensor_ref.select(1, indices_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/select_assign.rs ================================================ use super::*; use burn_tensor::{Distribution, Int, Tensor, backend::Backend}; use burn_tensor::{IndexingUpdateOp, Tolerance}; #[test] fn select_add_should_work_with_multiple_workgroups_2d_dim0() { select_add_same_as_ref(0, [256, 6]); } #[test] fn select_add_should_work_with_multiple_workgroups_2d_dim1() { select_add_same_as_ref(1, [6, 256]); } #[test] fn select_add_should_work_with_multiple_workgroups_3d_dim0() { select_add_same_as_ref(0, [256, 6, 6]); } #[test] fn select_add_should_work_with_multiple_workgroups_3d_dim1() { select_add_same_as_ref(1, [6, 256, 6]); } #[test] fn select_add_should_work_with_multiple_workgroups_3d_dim2() { select_add_same_as_ref(2, [6, 6, 256]); } fn select_add_same_as_ref(dim: usize, shape: [usize; D]) { let device = Default::default(); TestBackend::seed(&device, 0); let tensor = Tensor::::random(shape, Distribution::Default, &Default::default()); let value = Tensor::::random(shape, Distribution::Default, &Default::default()); let indices = Tensor::::random( [shape[dim]], Distribution::Uniform(0., shape[dim] as f64), &Default::default(), ); let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let value_ref = Tensor::::from_data(value.to_data(), &Default::default()); let indices_ref = Tensor::::from_data(indices.to_data(), &Default::default()); let actual = tensor.select_assign(dim, indices, value, IndexingUpdateOp::Add); let expected = tensor_ref.select_assign(dim, indices_ref, value_ref, IndexingUpdateOp::Add); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/slice.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, Tensor}; #[test] fn slice_should_work_with_multiple_workgroups() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let indices = [3..5, 45..256]; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let actual = tensor.slice(indices.clone()); let expected = tensor_ref.slice(indices); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/slice_assign.rs ================================================ use super::*; use burn_tensor::{Distribution, Tensor, Tolerance}; #[test] fn slice_assign_should_work_with_multiple_workgroups() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let value = Tensor::::random([2, 211], Distribution::Default, &Default::default()); let indices = [3..5, 45..256]; let tensor_ref = Tensor::::from_data(tensor.to_data(), &Default::default()); let value_ref = Tensor::::from_data(value.to_data(), &Default::default()); let actual = tensor.slice_assign(indices.clone(), value); let expected = tensor_ref.slice_assign(indices, value_ref); expected .into_data() .assert_approx_eq::(&actual.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/unary.rs ================================================ use super::*; use burn_tensor::Tensor; #[test] fn tanh_should_not_have_numerical_bugs_on_macos() { fn tanh_one_value(input: f32) -> f32 { let tensor = Tensor::::ones([1], &Default::default()) * input; let output = tensor.tanh().into_primitive(); Tensor::::from_primitive(output) .into_data() .as_slice() .unwrap()[0] } let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36 let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36 let neg = tanh_one_value(-45.0); // metal works correctly here assert!(!ok.is_nan() && ok == 1.0); assert!(!zero.is_nan() && zero == 1.0); assert!(!nan.is_nan() && nan == 1.0); assert!(!neg.is_nan() && neg == -1.0); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl/uniform.rs ================================================ use super::*; use burn_tensor::{Distribution, Int, Shape, Tensor, backend::Backend}; use burn_tensor::{ElementConversion, Tolerance}; use serial_test::serial; use cubek::random::{assert_at_least_one_value_per_bin, assert_wald_wolfowitz_runs_test}; #[test] #[serial] fn values_all_within_interval_default() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = [24, 24]; let tensor = Tensor::::random(shape, Distribution::Default, &device); tensor .to_data() .assert_within_range::(0.elem()..1.elem()); } #[test] #[serial] fn values_all_within_interval_uniform() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = [24, 24]; let tensor = Tensor::::random(shape, Distribution::Uniform(5., 17.), &device); tensor .to_data() .assert_within_range::(5.elem()..17.elem()); } #[test] #[serial] fn at_least_one_value_per_bin_uniform() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = [64, 64]; let tensor = Tensor::::random(shape, Distribution::Uniform(-5., 10.), &device) .into_data(); let numbers = tensor.as_slice::().unwrap(); assert_at_least_one_value_per_bin(numbers, 3, -5., 10.); } #[test] #[serial] fn runs_test() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = Shape::new([512, 512]); let tensor = Tensor::::random(shape, Distribution::Default, &device).into_data(); let numbers = tensor.as_slice::().unwrap(); assert_wald_wolfowitz_runs_test(numbers, 0., 1.); } #[test] #[serial] fn int_values_all_within_interval_uniform() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = Shape::new([20, 20]); let tensor: Tensor = Tensor::random(shape, Distribution::Default, &device); let data_float = tensor.float().into_data(); data_float.assert_within_range(0..255); } #[test] #[serial] fn at_least_one_value_per_bin_int_uniform() { let device = Default::default(); TestBackend::seed(&device, 0); let shape = Shape::new([64, 64]); let tensor: Tensor = Tensor::random(shape, Distribution::Uniform(-10.0, 10.0), &device); let data_float = tensor.float().into_data(); let numbers = data_float.as_slice::().unwrap(); assert_at_least_one_value_per_bin(numbers, 10, -10., 10.); } #[test] fn should_not_fail_on_non_float_autotune() { let device = Default::default(); let tensor_1 = Tensor::::from_floats([[1., 2., 3.], [3., 4., 5.]], &device); // Autotune of all (reduce) on lower_equal_elem's output calls uniform distribution tensor_1.lower_equal_elem(1.0).all(); } #[test] #[serial] fn test_seed_reproducibility() { let device = Default::default(); TestBackend::seed(&device, 42); let t1 = TestTensor::<1>::random([5], Distribution::Default, &device); TestBackend::seed(&device, 42); let t2 = TestTensor::<1>::random([5], Distribution::Default, &device); t1.into_data() .assert_approx_eq::(&t2.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/cubecl.rs ================================================ //! CubeCL kernel tests. #[cfg(feature = "cube")] #[path = "."] mod cube { type FloatElemType = f32; type IntElemType = i32; mod backend { include!("common/backend.rs"); pub type ReferenceBackend = burn_ndarray::NdArray; } pub use backend::*; #[path = "cubecl/mod.rs"] mod kernel; } ================================================ FILE: crates/burn-backend-tests/tests/fused_ops/mod.rs ================================================ mod reduce_broadcasted; ================================================ FILE: crates/burn-backend-tests/tests/fused_ops/reduce_broadcasted.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance, backend::Backend}; #[test] fn test_reduce_broadcasted_1() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_read = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_write = TestTensorInt::<1>::arange(0..4, &device) .reshape([4, 1]) .float(); // Forces previous tensors to be materialized. TestBackend::sync(&device).unwrap(); let x = tensor + fused_on_read.clone(); let x = x.sum_dim(1); let x = x + fused_on_write; // Broadcast let end = x + fused_on_read; let actual = end.into_data(); let expected = TensorData::from([ [56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0], [193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0], [330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 337.0], [467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0], ]); actual.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_reduce_broadcasted_2() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_read = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_write = TestTensorInt::<1>::arange(16..48, &device) .reshape([4, 8]) .float(); // Second fuse on read let y = TestTensorInt::<1>::arange(32..64, &device) .reshape([4, 8]) .float(); // Forces previous tensors to be materialized. TestBackend::sync(&device).unwrap(); let x = tensor + fused_on_read.clone(); let x = x.sum_dim(1); let x = x + fused_on_write; let x = x.mean_dim(1); let end = x + y; TestBackend::sync(&device).unwrap(); let actual = end.into_data(); let expected = TensorData::from([ [107.5, 108.5, 109.5, 110.5, 111.5, 112.5, 113.5, 114.5], [251.5, 252.5, 253.5, 254.5, 255.5, 256.5, 257.5, 258.5], [395.5, 396.5, 397.5, 398.5, 399.5, 400.5, 401.5, 402.5], [539.5, 540.5, 541.5, 542.5, 543.5, 544.5, 545.5, 546.5], ]); actual.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_reduce_broadcasted_3() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_read = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_write = TestTensorInt::<1>::arange(0..4, &device) .reshape([4, 1]) .float(); // Second fuse on read let y = TestTensorInt::<1>::arange(32..64, &device) .reshape([4, 8]) .float(); // Forces previous tensors to be materialized. TestBackend::sync(&device).unwrap(); let x = tensor + fused_on_read.clone(); let x = x.sum_dim(1); let x = x + fused_on_write; // Broadcast let x = x + fused_on_read; // Second reduce let x = x.mean_dim(1); let end = x + y; let actual = end.into_data(); let expected = TensorData::from([ [91.5, 92.5, 93.5, 94.5, 95.5, 96.5, 97.5, 98.5], [236.5, 237.5, 238.5, 239.5, 240.5, 241.5, 242.5, 243.5], [381.5, 382.5, 383.5, 384.5, 385.5, 386.5, 387.5, 388.5], [526.5, 527.5, 528.5, 529.5, 530.5, 531.5, 532.5, 533.5], ]); actual.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_reduce_broadcasted_4_reused_partial() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_read = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let fused_on_write = TestTensorInt::<1>::arange(0..4, &device) .reshape([4, 1]) .float(); let y = TestTensorInt::<1>::arange(32..64, &device) .reshape([4, 8]) .float(); // Forces previous tensors to be materialized. TestBackend::sync(&device).unwrap(); // In fusion we have to create a global buffer to keep the intermediate data for now. let x_previous = tensor + fused_on_read; let x = x_previous.clone().sum_dim(1); let x = x * fused_on_write; // Broadcast let x = x + x_previous; // Second reduce let x = x.mean_dim(1); // Second fuse on read let end = x + y; let actual = end.into_data(); let expected = TensorData::from([ [39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0], [247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0], [711.0, 712.0, 713.0, 714.0, 715.0, 716.0, 717.0, 718.0], [ 1431.0, 1432.0, 1433.0, 1434.0, 1435.0, 1436.0, 1437.0, 1438.0, ], ]); actual.assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/fusion.rs ================================================ //! Burn tensor and autodiff tests for CubeCL backends with fusion enabled. #![allow( clippy::single_range_in_vec_init, clippy::duplicate_mod, reason = "false positive" )] extern crate alloc; #[cfg(feature = "cube")] #[path = "."] mod fusion { pub type FloatElemType = f32; pub type IntElemType = i32; #[path = "common/backend.rs"] mod backend; pub use backend::prelude::*; // NOTE: // We re-include the tensor and autodiff test suites after overriding `TestBackend` // with `Fusion`. This intentionally duplicates module names and test // logic to execute the same tests under fusion. pub type TestBackend = burn_fusion::Fusion; pub type TestTensor = Tensor; pub type TestTensorInt = Tensor; pub type TestTensorBool = Tensor; // Tensor tests mod tensor { include!("common/tensor.rs"); } // Autodiff tests mod autodiff { include!("common/autodiff.rs"); } // Fusion tests include!("fused_ops/mod.rs"); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/mod.rs ================================================ pub use super::*; // re-export test types mod ops; ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/all.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_all() { let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorBool::<2>::from([[true, true, true], [true, true, true]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_all_dim() { let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]); let data_actual = tensor.all_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_all_with_bool_from_lower_equal() { let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6; let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6; let ge = tensor1.lower_equal(tensor2); let all = ge.clone().all(); TensorData::from([true]).assert_eq(&all.clone().into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/any.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_any() { let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_any_dim() { let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/argwhere_nonzero.rs ================================================ use super::*; use alloc::vec::Vec; use burn_tensor::{Shape, TensorData}; #[test] fn test_argwhere_1d() { let tensor = TestTensorBool::<1>::from([false, true, false, true, true]); let output = tensor.argwhere(); output .into_data() .assert_eq(&TensorData::from([[1], [3], [4]]), false); } #[test] fn test_argwhere_2d() { let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]); let output = tensor.argwhere(); output .into_data() .assert_eq(&TensorData::from([[1, 1], [2, 0], [2, 1]]), false); } #[test] fn test_argwhere_3d() { let tensor = TestTensorBool::<3>::from([ [[false, false, false], [false, true, false]], [[true, false, true], [true, true, false]], ]); let output = tensor.argwhere(); output.into_data().assert_eq( &TensorData::from([[0, 1, 1], [1, 0, 0], [1, 0, 2], [1, 1, 0], [1, 1, 1]]), false, ); } #[test] fn test_nonzero_1d() { let tensor = TestTensorBool::<1>::from([false, true, false, true, true]); let data_actual = tensor .nonzero() .into_iter() .map(|t| t.into_data()) .collect::>(); assert_eq!(data_actual.len(), 1); data_actual[0].assert_eq(&TensorData::from([1, 3, 4]), false); } #[test] fn test_nonzero_2d() { // 2-D tensor let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]); let data_actual = tensor .nonzero() .into_iter() .map(|t| t.into_data()) .collect::>(); let data_expected = [TensorData::from([1, 2, 2]), TensorData::from([1, 0, 1])]; assert_eq!(data_actual.len(), 2); for (idx, actual) in data_actual.iter().enumerate() { actual.assert_eq(&data_expected[idx], false) } } #[test] fn test_nonzero_3d() { // 3-D tensor let tensor = TestTensorBool::<3>::from([ [[false, false, false], [false, true, false]], [[true, false, true], [true, true, false]], ]); let data_actual = tensor .nonzero() .into_iter() .map(|t| t.into_data()) .collect::>(); let data_expected = [ TensorData::from([0, 1, 1, 1, 1]), TensorData::from([1, 0, 0, 1, 1]), TensorData::from([1, 0, 2, 0, 1]), ]; assert_eq!(data_actual.len(), 3); for (idx, actual) in data_actual.iter().enumerate() { actual.assert_eq(&data_expected[idx], false) } } #[test] fn test_nonzero_empty() { let tensor = TestTensorBool::<1>::from([false, false, false, false, false]); let output = tensor.nonzero(); assert_eq!(output.len(), 0); } #[test] fn test_argwhere_empty() { let tensor = TestTensorBool::<1>::from([false, false, false, false, false]); let output = tensor.argwhere(); assert_eq!(output.shape(), Shape::new([0, 1])); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/cat.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_cat_ops_bool() { let device = Default::default(); let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device); let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 0); output.into_data().assert_eq( &TensorData::from([[false, true, true], [true, true, false]]), false, ); } #[test] fn should_support_cat_with_empty_tensor_bool() { let device = Default::default(); let tensor_1 = TestTensorBool::<2>::from_data([[true, false, true]], &device); let tensor_2: TestTensorBool<2> = TestTensorBool::empty([1, 0], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 1); output .into_data() .assert_eq(&TensorData::from([[true, false, true]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/comparison.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_bool_equal() { let data_1 = TensorData::from([[false, true, true], [true, false, true]]); let data_2 = TensorData::from([[false, false, true], [false, true, true]]); let device = Default::default(); let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device); let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); let data_expected = TensorData::from([[true, false, true], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn should_support_bool_not_equal() { let data_1 = TensorData::from([[false, true, true], [true, false, true]]); let data_2 = TensorData::from([[false, false, true], [false, true, true]]); let device = Default::default(); let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device); let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); let data_expected = TensorData::from([[false, true, false], [true, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn should_support_bool_not() { let data_1 = TensorData::from([[false, true, true], [true, true, false]]); let tensor_1 = TestTensorBool::<2>::from_data(data_1, &Default::default()); let data_actual_cloned = tensor_1.clone().bool_not(); let data_actual_inplace = tensor_1.bool_not(); let data_expected = TensorData::from([[true, false, false], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_bool_equal_elem() { let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]); let data_actual_cloned = tensor_1.clone().equal_elem(false); let data_actual_inplace = tensor_1.equal_elem(false); let data_expected = TensorData::from([[false, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_bool_not_equal_elem() { let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]); let data_actual_cloned = tensor_1.clone().not_equal_elem(true); let data_actual_inplace = tensor_1.not_equal_elem(true); let data_expected = TensorData::from([[false, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/create_like.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_zeros_like() { let tensor = TestTensorBool::<3>::from([ [[false, true, false], [true, true, true]], [[false, false, false], [true, true, false]], ]); let tensor = tensor.zeros_like(); let expected = TensorData::from([ [[false, false, false], [false, false, false]], [[false, false, false], [false, false, false]], ]); tensor.into_data().assert_eq(&expected, false); } #[test] fn should_support_ones_like() { let tensor = TestTensorBool::<3>::from([ [[false, true, false], [true, true, true]], [[false, false, false], [true, true, false]], ]); let tensor = tensor.ones_like(); let expected = TensorData::from([ [[true, true, true], [true, true, true]], [[true, true, true], [true, true, true]], ]); tensor.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/expand.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn expand_2d_bool() { let tensor = TestTensorBool::<1>::from([false, true, false]); let expanded_tensor = tensor.expand([3, 3]); let expected_data = TensorData::from([ [false, true, false], [false, true, false], [false, true, false], ]); expanded_tensor.into_data().assert_eq(&expected_data, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/flip.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn flip_bool() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); let flipped = tensor.clone().flip([0, 2]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10) let data_expected = TensorData::from([ [ [true, true, true, true], [true, true, true, true], [true, true, true, true], ], [ [false, false, false, false], [false, false, false, false], [true, false, false, false], ], ]); flipped.into_data().assert_eq(&data_expected, false); // Test with no flip let flipped = tensor.clone().flip([]); tensor.into_data().assert_eq(&flipped.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/full.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_tensor_full() { let device = Default::default(); let bool_tensor = TestTensorBool::<2>::full([2, 2], true, &device); bool_tensor .into_data() .assert_eq(&TensorData::from([[true, true], [true, true]]), false); let bool_tensor = TestTensorBool::<2>::full([2, 2], false, &device); bool_tensor .into_data() .assert_eq(&TensorData::from([[false, false], [false, false]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/gather_scatter.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_scatter_1d_bool() { let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, false], &device); let values = TestTensorBool::from_data([false, true, true], &device); let indices = TestTensorInt::from_ints([1, 0, 2], &device); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([true, false, true]), false); } #[test] fn should_gather_1d_dim0_bool() { let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, false], &device); let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device); let output = tensor.gather(0, indices); output .into_data() .assert_eq(&TensorData::from([false, false, true, false, false]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/init.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_bool_empty() { let shape = [2, 2]; let tensor = TestTensorBool::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_bool_zeros() { let shape = [2, 2]; let tensor = TestTensorBool::<2>::zeros(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[false, false], [false, false]]), false); } #[test] fn should_support_bool_ones() { let shape = [2, 2]; let tensor = TestTensorBool::<2>::ones(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[true, true], [true, true]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/logical.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_bool_and() { let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]); let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]); let data_actual = tensor1.bool_and(tensor2).into_data(); let data_expected = TensorData::from([[false, true, false], [false, false, true]]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_bool_or() { let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]); let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]); let data_actual = tensor1.bool_or(tensor2).into_data(); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_bool_xor() { let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]); let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]); let data_actual = tensor1.bool_xor(tensor2).into_data(); let data_expected = TensorData::from([[true, false, false], [true, false, false]]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_bool_or_vec() { let device = Default::default(); let tensor1 = TestTensorBool::<1>::full([256], 0, &device); let tensor2 = TestTensorBool::<1>::full([256], 1, &device); let data_actual = tensor1.bool_or(tensor2).into_data(); let data_expected = TensorData::from([true; 256]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_bool_and_vec() { let device = Default::default(); let tensor1 = TestTensorBool::<1>::full([256], 0, &device); let tensor2 = TestTensorBool::<1>::full([256], 1, &device); let data_actual = tensor1.bool_and(tensor2).into_data(); let data_expected = TensorData::from([false; 256]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/mask.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_bool_mask_where_ops() { let device = Default::default(); let tensor = TestTensorBool::<2>::from_data([[true, false], [false, false]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let value = TestTensorBool::<2>::from_data(TensorData::from([[false, true], [true, false]]), &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[false, false], [false, false]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_bool_mask_fill_ops() { let device = Default::default(); let tensor = TestTensorBool::<2>::from_data([[false, true], [false, false]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let output = tensor.mask_fill(mask, true); let expected = TensorData::from([[true, true], [false, true]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/mod.rs ================================================ pub use super::*; // re-export test types mod all; mod any; mod argwhere_nonzero; mod cat; mod comparison; mod create_like; mod expand; mod flip; mod full; mod gather_scatter; mod init; mod logical; mod mask; mod movedim; mod permute; mod repeat; mod repeat_dim; mod reshape; mod select; mod stack; mod take; mod transpose; mod tri_mask; mod unfold; ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/movedim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn movedim_bool() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); let permuted = tensor.clone().movedim(0, 2); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).gt(10) let expected = TensorData::from([ [[false, true], [false, true], [false, true], [false, true]], [[false, true], [false, true], [false, true], [false, true]], [[false, true], [false, true], [false, true], [true, true]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().movedim(0, -1); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().movedim(0, 0); permuted.into_data().assert_eq(&tensor.into_data(), false); } #[test] fn vec_input_bool() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]); // from pytorch // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).gt(10) let expected = TensorData::from([ [[false, false, false, false], [true, true, true, true]], [[false, false, false, false], [true, true, true, true]], [[false, false, false, true], [true, true, true, true]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axes let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]); permuted.into_data().assert_eq(&expected, false); // Test with the same axes let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]); permuted.into_data().assert_eq(&tensor.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/permute.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn permute_bool() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .greater_elem(10); let permuted = tensor.clone().permute([2, 1, 0]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0).gt(10) let expected = TensorData::from([ [[false, true], [false, true], [false, true]], [[false, true], [false, true], [false, true]], [[false, true], [false, true], [false, true]], [[false, true], [false, true], [true, true]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().permute([-1, 1, 0]); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().permute([0, 1, 2]); permuted.into_data().assert_eq(&tensor.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/repeat.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_bool_repeat_ops_one_dimension() { let data = TensorData::from([[true, false, false]]); let tensor = TestTensorBool::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ [true, false, false], [true, false, false], [true, false, false], [true, false, false], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_bool_repeat_on_many_dimension() { let data = TensorData::from([ [[false, true], [true, false]], [[true, true], [false, false]], ]); let tensor = TestTensorBool::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); let expected = TensorData::from([ [ [false, true, false, true], [true, false, true, false], [false, true, false, true], [true, false, true, false], [false, true, false, true], [true, false, true, false], ], [ [true, true, true, true], [false, false, false, false], [true, true, true, true], [false, false, false, false], [true, true, true, true], [false, false, false, false], ], [ [false, true, false, true], [true, false, true, false], [false, true, false, true], [true, false, true, false], [false, true, false, true], [true, false, true, false], ], [ [true, true, true, true], [false, false, false, false], [true, true, true, true], [false, false, false, false], [true, true, true, true], [false, false, false, false], ], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/repeat_dim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_bool_repeat_ops() { let data = TensorData::from([[true, false, false]]); let tensor = TestTensorBool::<2>::from_data(data, &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([ [true, false, false], [true, false, false], [true, false, false], [true, false, false], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_bool_repeat_on_dims_larger_than_1() { let data = TensorData::from([ [[false, true], [true, false]], [[true, true], [false, false]], ]); let tensor = TestTensorBool::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(1, 2); let expected = TensorData::from([ [[false, true], [true, false], [false, true], [true, false]], [[true, true], [false, false], [true, true], [false, false]], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/reshape.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_reshape_bool() { let data = TensorData::from([false, true, false]); let tensor = TestTensorBool::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[false, true, false]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/select.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_select_bool_tensor_1d() { // Test that select works for boolean tensors let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); let indices = TestTensorInt::from_data([0, 2, 1, 0], &device); let output = tensor.select(0, indices); let expected = TensorData::from([true, true, false, true]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_bool_tensor_2d() { // Test that select works for boolean 2D tensors let device = Default::default(); let tensor = TestTensorBool::<2>::from_data([[true, false, true], [false, true, false]], &device); let indices = TestTensorInt::from_data([1, 0], &device); let output = tensor.select(0, indices); let expected = TensorData::from([[false, true, false], [true, false, true]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_bool_tensor() { // Test that select_add works for boolean tensors let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); let values = TestTensorBool::<1>::from_data([false, true], &device); let indices = TestTensorInt::from_data([0, 2], &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); // Note: select_add uses sum reduction, so: // index 0: true OR false = true // index 2: true OR true = true // index 1: false (unchanged) let expected = TensorData::from([true, false, true]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_bool_overlapping_indices() { // Test accumulation behavior with overlapping indices let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([false, true], &device); let indices = TestTensorInt::from_data([0, 0], &device); let values = TestTensorBool::<1>::from_data([true, false], &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); // Index 0: false OR true OR false = true let expected = TensorData::from([true, true]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_bool_false_to_true_case() { // Test false OR true = true let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([false], &device); let indices = TestTensorInt::from_data([0], &device); let values = TestTensorBool::<1>::from_data([true], &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([true]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_bool_true_or_true_accumulation() { // Test multiple true accumulations let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false], &device); let indices = TestTensorInt::from_data([0, 0, 0], &device); let values = TestTensorBool::<1>::from_data([true, true, true], &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([true, false]); output.into_data().assert_eq(&expected, false); } #[test] fn should_match_default_implementation_behavior() { // Verify optimized implementation matches original default logic let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); let indices = TestTensorInt::from_data([0, 1, 0], &device); let values = TestTensorBool::<1>::from_data([false, true, true], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); // Manual default implementation logic let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] fn should_select_add_bool_overlapping_indices_vs_default() { // Test overlapping indices against default implementation let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([false, true], &device); let indices = TestTensorInt::from_data([0, 0], &device); let values = TestTensorBool::<1>::from_data([true, false], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] fn should_select_add_bool_true_or_true_accumulation_vs_default() { // Test multiple true accumulations against default implementation let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false], &device); let indices = TestTensorInt::from_data([0, 0, 0], &device); let values = TestTensorBool::<1>::from_data([true, true, true], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] fn should_select_add_bool_false_to_true_case_vs_default() { // Test false OR true case against default implementation let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([false], &device); let indices = TestTensorInt::from_data([0], &device); let values = TestTensorBool::<1>::from_data([true], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] fn should_select_add_bool_tensor_vs_default() { // Test existing basic case against default implementation let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); let indices = TestTensorInt::from_data([0, 2], &device); let values = TestTensorBool::<1>::from_data([false, false], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] #[should_panic(expected = "Tensors are not eq")] fn should_fail_if_replacement_semantics_were_used() { // Test that framework uses accumulation, not replacement let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true], &device); let indices = TestTensorInt::from_data([0], &device); let values = TestTensorBool::<1>::from_data([false], &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let replacement_expected = TensorData::from([false]); output.into_data().assert_eq(&replacement_expected, false); } #[test] #[should_panic(expected = "Tensors are not eq")] fn should_fail_if_replacement_semantics_were_used_vs_default() { // Test that default implementation also uses accumulation, not replacement let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true], &device); let indices = TestTensorInt::from_data([0], &device); let values = TestTensorBool::<1>::from_data([false], &device); let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); let replacement_expected = TensorData::from([false]); default_result .into_data() .assert_eq(&replacement_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/stack.rs ================================================ use super::*; use alloc::vec; use burn_tensor::{Tensor, TensorData}; #[test] fn should_support_stack_ops_bool() { let device = Default::default(); let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device); let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[false, true, true]], [[true, true, false]]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/take.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_take_bool_tensor() { // Test take with boolean tensors let device = Default::default(); let tensor = TestTensorBool::<2>::from_data([[true, false], [false, true]], &device); let indices = TestTensorInt::<1>::from_data([1, 0], &device); let output = tensor.take::<1, 2>(0, indices); let expected = TensorData::from([[false, true], [true, false]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_bool_tensor_with_2d_indices() { // Test take with boolean tensors - output will be 3D let device = Default::default(); let tensor = TestTensorBool::<2>::from_data( [ [true, false, true], [false, true, false], [true, true, false], ], &device, ); // 2D indices - shape [2, 2] let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device); let output = tensor.take::<2, 3>(0, indices); // Expected: shape [2, 2, 3] let expected = TensorData::from([ [[true, false, true], [true, true, false]], [[false, true, false], [true, false, true]], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/transpose.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_transpose_bool() { let tensor = TestTensorBool::<3>::from_data( [ [[false, true, false], [false, false, false]], [[false, false, true], [false, false, true]], ], &Default::default(), ); let output = tensor.transpose(); let expected = TensorData::from([ [[false, false], [true, false], [false, false]], [[false, false], [false, false], [true, true]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_swap_dims_bool() { let tensor = TestTensorBool::<3>::from_data( [ [[false, true, false], [false, false, false]], [[false, false, true], [false, false, true]], ], &Default::default(), ); let output = tensor.swap_dims(0, 2); let expected = TensorData::from([ [[false, false], [false, false]], [[true, false], [false, false]], [[false, true], [false, true]], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/tri_mask.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn square_diag() { let device = Default::default(); let data_expected = TensorData::from([ [false, true, true], [true, false, true], [true, true, false], ]); let tensor = TestTensorBool::<2>::diag_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn square_diag_offset() { let device = Default::default(); let data_expected = TensorData::from([[true, false, true], [true, true, false], [true, true, true]]); let tensor = TestTensorBool::<2>::diag_mask([3, 3], 1, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn square_tri_upper() { let device = Default::default(); let data_expected = TensorData::from([ [false, false, false], [true, false, false], [true, true, false], ]); let tensor = TestTensorBool::<2>::triu_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn square_tri_upper_offset() { let device = Default::default(); let data_expected = TensorData::from([ [true, false, false], [true, true, false], [true, true, true], ]); let tensor = TestTensorBool::<2>::triu_mask([3, 3], 1, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn square_tri_lower() { let device = Default::default(); let data_expected = TensorData::from([ [false, true, true], [false, false, true], [false, false, false], ]); let tensor = TestTensorBool::<2>::tril_mask([3, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn square_tri_lower_offset() { let device = Default::default(); let data_expected = TensorData::from([ [true, true, true], [false, true, true], [false, false, true], ]); let tensor = TestTensorBool::<2>::tril_mask([3, 3], -1, &device); tensor.into_data().assert_eq(&data_expected, false); } #[test] fn rect_diag() { let device = Default::default(); let data_expected = TensorData::from([ [false, true, true, true], [true, false, true, true], [true, true, false, true], ]); let tensor = TestTensorBool::<2>::diag_mask([3, 4], 0, &device); tensor.into_data().assert_eq(&data_expected, false); let data_expected = TensorData::from([ [false, true, true], [true, false, true], [true, true, false], [true, true, true], ]); let tensor = TestTensorBool::<2>::diag_mask([4, 3], 0, &device); tensor.into_data().assert_eq(&data_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/bool/ops/unfold.rs ================================================ use super::*; use burn_tensor::Distribution; use burn_tensor::s; #[test] fn test_unfold_bool() { let device = Default::default(); let input = TestTensor::<3>::random([2, 6, 6], Distribution::Default, &device).greater_elem(0.5); let dim = 1; let size = 3; let step = 2; let actual: TestTensorBool<4> = input.clone().unfold(dim, size, step); let expected = TestTensorBool::<4>::empty([2, 2, 6, 3], &device) .slice_assign( s![.., 0, .., ..], input .clone() .slice(s![.., 0..3, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ) .slice_assign( s![.., 1, .., ..], input .clone() .slice(s![.., 2..5, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ); actual.to_data().assert_eq(&expected.to_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/clone_invariance.rs ================================================ /// This module tests whether basic tensor operations remain invariant when performed on clones, /// meaning that cloning input tensors won't affect the results. /// /// Those are relevant tests because backends may employ unsafe optimizations to reuse tensor data /// and use different kernels in such cases. We ensure that the results are consistent regardless /// of the approach and that the input tensors are not modified when cloned. use super::*; use burn_tensor::Tolerance; use burn_tensor::activation::{ gelu, log_sigmoid, log_softmax, mish, relu, sigmoid, silu, softmax, softplus, tanh, }; use burn_tensor::{Distribution, IndexingUpdateOp, TensorData}; pub trait CloneInvarianceTest { type Args; fn args(&self) -> Self::Args; fn run(&self, args: &Self::Args, inplace: bool) -> TensorData; fn check(&self) { let args = self.args(); let out = self.run(&args, false); let out_inplace = self.run(&args, true); out.assert_approx_eq::(&out_inplace, Tolerance::default()); } } macro_rules! clone_invariance_test { (unary: $name:ident, ops_float: $ops:expr) => { #[test] #[allow(non_snake_case)] fn $name() { struct $name; impl CloneInvarianceTest<2> for $name { type Args = TensorData; fn args(&self) -> Self::Args { TestTensor::<2>::random([32, 32], Distribution::Default, &Default::default()) .into_data() .convert::() } fn run(&self, args: &Self::Args, inplace: bool) -> TensorData { let lhs = TestTensor::from_data(args.clone(), &Default::default()); if inplace { $ops(lhs).into_data().convert::() } else { let out = $ops(lhs.clone()).into_data().convert::(); lhs.into_data() .assert_approx_eq::(args, Tolerance::default()); out } } } CloneInvarianceTest::<2>::check(&$name); } }; (binary: $name:ident, ops_float: $ops:expr) => { #[test] #[allow(non_snake_case)] fn $name() { struct $name; impl CloneInvarianceTest<2> for $name { type Args = (TensorData, TensorData); fn args(&self) -> Self::Args { let device = Default::default(); ( TestTensor::<2>::ones([32, 32], &device) .into_data() .convert::(), // Avoid div by zero. TestTensor::<2>::ones([32, 32], &device) .into_data() .convert::(), ) } fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData { let device = Default::default(); let lhs = TestTensor::from_data(lhs_arg.clone(), &device); let rhs = TestTensor::from_data(rhs_arg.clone(), &device); if inplace { $ops(lhs, rhs).into_data().convert::() } else { let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::(); lhs.into_data() .assert_approx_eq::(lhs_arg, Tolerance::default()); rhs.into_data() .assert_approx_eq::(rhs_arg, Tolerance::default()); out } } } CloneInvarianceTest::<2>::check(&$name); } }; (unary: $name:ident, ops_int: $ops:expr) => { #[test] #[allow(non_snake_case)] fn $name() { struct $name; impl CloneInvarianceTest<2> for $name { type Args = TensorData; fn args(&self) -> Self::Args { TestTensor::<2>::random( [32, 32], Distribution::Uniform(0.0, 50.0), &Default::default(), ) .into_data() .convert::() } fn run(&self, args: &Self::Args, inplace: bool) -> TensorData { let lhs = TestTensorInt::from_data(args.clone(), &Default::default()); if inplace { $ops(lhs).into_data().convert::() } else { let out = $ops(lhs.clone()).into_data().convert::(); lhs.into_data() .convert::() .assert_approx_eq::(args, Tolerance::default()); out } } } CloneInvarianceTest::<2>::check(&$name); } }; (binary: $name:ident, ops_int: $ops:expr) => { #[test] #[allow(non_snake_case)] fn $name() { struct $name; impl CloneInvarianceTest<2> for $name { type Args = (TensorData, TensorData); fn args(&self) -> Self::Args { let device = Default::default(); ( TestTensor::<2>::random([32, 32], Distribution::Uniform(0., 50.), &device) .into_data() .convert::(), // Avoid div by zero. TestTensor::<2>::random([32, 32], Distribution::Uniform(1., 51.), &device) .into_data() .convert::(), ) } fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData { let device = Default::default(); let lhs = TestTensorInt::from_data(lhs_arg.clone(), &device); let rhs = TestTensorInt::from_data(rhs_arg.clone(), &device); if inplace { $ops(lhs, rhs).into_data().convert::() } else { let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::(); lhs.into_data() .convert::() .assert_approx_eq::(lhs_arg, Tolerance::default()); rhs.into_data() .convert::() .assert_approx_eq::(rhs_arg, Tolerance::default()); out } } } CloneInvarianceTest::<2>::check(&$name); } }; } mod float { use super::*; // Unary ops clone_invariance_test!( unary: AddScalar, ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0) ); clone_invariance_test!( unary: SubScalar, ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0) ); clone_invariance_test!( unary: DivScalar, ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0) ); clone_invariance_test!( unary: MulScalar, ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0) ); clone_invariance_test!( unary: PowScalar, ops_float: |tensor: TestTensor<2>| tensor.powf_scalar(2.0) ); clone_invariance_test!( unary: Square, ops_float: |tensor: TestTensor<2>| tensor.square() ); clone_invariance_test!( unary: Sqrt, ops_float: |tensor: TestTensor<2>| tensor.sqrt() ); clone_invariance_test!( unary: Exp, ops_float: |tensor: TestTensor<2>| tensor.exp() ); clone_invariance_test!( unary: Neg, ops_float: |tensor: TestTensor<2>| tensor.neg() ); clone_invariance_test!( unary: MeanDim, ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1) ); clone_invariance_test!( unary: SumDim, ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1) ); clone_invariance_test!( unary: Sum, ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze::<2>() ); clone_invariance_test!( unary: Mean, ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze::<2>() ); clone_invariance_test!( unary: Clamp, ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.) ); clone_invariance_test!( unary: ClampMin, ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.) ); clone_invariance_test!( unary: ClampMax, ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.) ); clone_invariance_test!( unary: Abs, ops_float: |tensor: TestTensor<2>| tensor.abs() ); clone_invariance_test!( unary: Cos, ops_float: |tensor: TestTensor<2>| tensor.cos() ); clone_invariance_test!( unary: Sin, ops_float: |tensor: TestTensor<2>| tensor.sin() ); clone_invariance_test!( unary: Tan, ops_float: |tensor: TestTensor<2>| tensor.tan() ); clone_invariance_test!( unary: Log, ops_float: |tensor: TestTensor<2>| tensor.log() ); clone_invariance_test!( unary: Log1P, ops_float: |tensor: TestTensor<2>| tensor.log1p() ); clone_invariance_test!( unary: SwapDims, ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1) ); clone_invariance_test!( unary: Transpose, ops_float: |tensor: TestTensor<2>| tensor.transpose() ); clone_invariance_test!( unary: Slice, ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24]) ); clone_invariance_test!( unary: Erf, ops_float: |tensor: TestTensor<2>| tensor.erf() ); clone_invariance_test!( unary: EqualElem, ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5) ); clone_invariance_test!( unary: NotEqualElem, ops_float: |tensor: TestTensor<2>| tensor.not_equal_elem(0.5) ); clone_invariance_test!( unary: GreaterElem, ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5) ); clone_invariance_test!( unary: GreaterEqualElem, ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5) ); clone_invariance_test!( unary: LowerElem, ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5) ); clone_invariance_test!( unary: LowerEqualElem, ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5) ); clone_invariance_test!( unary: Argmax, ops_float: |tensor: TestTensor<2>| tensor.argmax(0) ); clone_invariance_test!( unary: Argmin, ops_float: |tensor: TestTensor<2>| tensor.argmin(0) ); clone_invariance_test!( unary: Max, ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze::<2>() ); clone_invariance_test!( unary: Min, ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze::<2>() ); clone_invariance_test!( unary: MaxDim, ops_float: |tensor: TestTensor<2>| tensor.max_dim(1) ); clone_invariance_test!( unary: MaxDimWithIndices, ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0 ); clone_invariance_test!( unary: MinDimWithIndices, ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0 ); clone_invariance_test!( unary: MinDim, ops_float: |tensor: TestTensor<2>| tensor.min_dim(1) ); clone_invariance_test!( unary: Repeat, ops_float: |tensor: TestTensor<2>| { tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32]) } ); clone_invariance_test!( unary: Reshape, ops_float: |tensor: TestTensor<2>| { let shape = tensor.shape(); let new_shape = [shape.num_elements(), 1]; tensor.reshape(new_shape) } ); clone_invariance_test!( unary: Gatter, ops_float: |tensor: TestTensor<2>| { let shape = tensor.shape(); let indices = TestTensorInt::ones(shape, &Default::default()); tensor.gather(0, indices) } ); clone_invariance_test!( unary: Select, ops_float: |tensor: TestTensor<2>| { let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default()); tensor.select(0, indices) } ); clone_invariance_test!( unary: MaskFill, ops_float: |tensor: TestTensor<2>| { let mask = tensor.clone().greater_elem(0.5); tensor.mask_fill(mask, 77.0) } ); // Activation clone_invariance_test!( unary: Softmax, ops_float: |tensor: TestTensor<2>| softmax(tensor, 1) ); clone_invariance_test!( unary: LogSoftmax, ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1) ); clone_invariance_test!( unary: Sigmoid, ops_float: |tensor: TestTensor<2>| sigmoid(tensor) ); clone_invariance_test!( unary: LogSigmoid, ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor) ); clone_invariance_test!( unary: Relu, ops_float: |tensor: TestTensor<2>| relu(tensor) ); clone_invariance_test!( unary: Gelu, ops_float: |tensor: TestTensor<2>| gelu(tensor) ); clone_invariance_test!( unary: Mish, ops_float: |tensor: TestTensor<2>| mish(tensor) ); clone_invariance_test!( unary: Silu, ops_float: |tensor: TestTensor<2>| silu(tensor) ); clone_invariance_test!( unary: Softplus, ops_float: |tensor: TestTensor<2>| softplus(tensor, 1.0) ); clone_invariance_test!( unary: Tanh, ops_float: |tensor: TestTensor<2>| tanh(tensor) ); // Binary ops clone_invariance_test!( binary: Add, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs) ); clone_invariance_test!( binary: Sub, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs) ); clone_invariance_test!( binary: Div, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs) ); clone_invariance_test!( binary: Mul, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs) ); clone_invariance_test!( binary: Matmul, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs) ); clone_invariance_test!( binary: Equal, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs) ); clone_invariance_test!( binary: Greater, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs) ); clone_invariance_test!( binary: GreaterEqual, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs) ); clone_invariance_test!( binary: Lower, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs) ); clone_invariance_test!( binary: LowerEqual, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs) ); clone_invariance_test!( binary: Cat, ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| { let lhs = lhs.reshape([1usize, 32, 32]); let rhs = rhs.reshape([1usize, 32, 32]); TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32]) } ); clone_invariance_test!( binary: Scatter, ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { let shape = tensor.shape(); let indices = TestTensorInt::ones(shape, &Default::default()); tensor.scatter(0, indices, values, IndexingUpdateOp::Add) } ); clone_invariance_test!( binary: SliceAssign, ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) } ); clone_invariance_test!( binary: MaskWhere, ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { let mask = tensor.clone().greater_elem(0.5); tensor.mask_where(mask, values) } ); clone_invariance_test!( binary: SelectAssign, ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| { let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default()); let values = values.select(0, indices.clone()); tensor.select_assign(0, indices, values, IndexingUpdateOp::Add) } ); } mod int { use super::*; // Unary ops clone_invariance_test!( unary: AddScalar, ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0) ); clone_invariance_test!( unary: SubScalar, ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0) ); clone_invariance_test!( unary: DivScalar, ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0) ); clone_invariance_test!( unary: MulScalar, ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0) ); clone_invariance_test!( unary: Neg, ops_int: |tensor: TestTensorInt<2>| tensor.neg() ); clone_invariance_test!( unary: MeanDim, ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1) ); clone_invariance_test!( unary: SumDim, ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1) ); clone_invariance_test!( unary: Sum, ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze::<2>() ); clone_invariance_test!( unary: Mean, ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze::<2>() ); clone_invariance_test!( unary: Clamp, ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.) ); clone_invariance_test!( unary: ClampMin, ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.) ); clone_invariance_test!( unary: ClampMax, ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.) ); clone_invariance_test!( unary: Abs, ops_int: |tensor: TestTensorInt<2>| tensor.abs() ); clone_invariance_test!( unary: SwapDims, ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1) ); clone_invariance_test!( unary: Transpose, ops_int: |tensor: TestTensorInt<2>| tensor.transpose() ); clone_invariance_test!( unary: Slice, ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24]) ); clone_invariance_test!( unary: EqualElem, ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25) ); clone_invariance_test!( unary: NotEqualElem, ops_int: |tensor: TestTensorInt<2>| tensor.not_equal_elem(25) ); clone_invariance_test!( unary: GreaterElem, ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25) ); clone_invariance_test!( unary: GreaterEqualElem, ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25) ); clone_invariance_test!( unary: LowerElem, ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25) ); clone_invariance_test!( unary: LowerEqualElem, ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25) ); clone_invariance_test!( unary: Argmax, ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0) ); clone_invariance_test!( unary: Argmin, ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0) ); clone_invariance_test!( unary: Max, ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze::<2>() ); clone_invariance_test!( unary: Min, ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze::<2>() ); clone_invariance_test!( unary: MaxDim, ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1) ); clone_invariance_test!( unary: MaxDimWithIndices, ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0 ); clone_invariance_test!( unary: MinDimWithIndices, ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0 ); clone_invariance_test!( unary: MinDim, ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1) ); clone_invariance_test!( unary: Repeat, ops_int: |tensor: TestTensorInt<2>| { tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32]) } ); clone_invariance_test!( unary: Reshape, ops_int: |tensor: TestTensorInt<2>| { let shape = tensor.shape(); let new_shape = [shape.num_elements(), 1]; tensor.reshape(new_shape) } ); clone_invariance_test!( unary: Gatter, ops_int: |tensor: TestTensorInt<2>| { let shape = tensor.shape(); let indices = TestTensorInt::ones(shape, &Default::default()); tensor.gather(0, indices) } ); clone_invariance_test!( unary: Select, ops_int: |tensor: TestTensorInt<2>| { let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default()); tensor.select(0, indices) } ); clone_invariance_test!( unary: MaskFill, ops_int: |tensor: TestTensorInt<2>| { let mask = tensor.clone().greater_elem(0.5); tensor.mask_fill(mask, 77.0) } ); // Binary ops clone_invariance_test!( binary: Add, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs) ); clone_invariance_test!( binary: Sub, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs) ); clone_invariance_test!( binary: Div, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs) ); clone_invariance_test!( binary: Mul, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs) ); clone_invariance_test!( binary: Equal, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs) ); clone_invariance_test!( binary: NotEqual, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.not_equal(rhs) ); clone_invariance_test!( binary: Greater, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs) ); clone_invariance_test!( binary: GreaterEqual, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs) ); clone_invariance_test!( binary: Lower, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs) ); clone_invariance_test!( binary: LowerEqual, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs) ); clone_invariance_test!( binary: Cat, ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| { let lhs = lhs.reshape([1usize, 32, 32]); let rhs = rhs.reshape([1usize, 32, 32]); TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32]) } ); clone_invariance_test!( binary: Scatter, ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { let shape = tensor.shape(); let indices = TestTensorInt::ones(shape, &Default::default()); tensor.scatter(0, indices, values, IndexingUpdateOp::Add) } ); clone_invariance_test!( binary: SliceAssign, ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12])) } ); clone_invariance_test!( binary: MaskWhere, ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { let mask = tensor.clone().greater_elem(0.5); tensor.mask_where(mask, values) } ); clone_invariance_test!( binary: SelectAssign, ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| { let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default()); let values = values.select(0, indices.clone()); tensor.select_assign(0, indices, values, IndexingUpdateOp::Add) } ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/celu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_celu_d2() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [-3.0, 0.5]]); let output = activation::celu(tensor, 1.0); // celu(1, 1) = 1 // celu(7, 1) = 7 // celu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213 // celu(0.5, 1) = 0.5 let expected = TensorData::from([[1.0, 7.0], [-0.950213, 0.5]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_celu_with_alpha() { let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]); let output = activation::celu(tensor, 2.0); // celu(0, 2) = 0 // celu(-1, 2) = 2 * (exp(-0.5) - 1) = -0.786939 // celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241 let expected = TensorData::from([0.0, -0.786939, -1.264241]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/elu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_elu() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::elu(tensor, 1.0); // elu(1, 1) = 1, elu(7, 1) = 7, elu(13, 1) = 13 // elu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213 let expected = TensorData::from([[1.0, 7.0], [13.0, -0.950213]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_elu_alpha() { let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]); let output = activation::elu(tensor, 2.0); // elu(0, 2) = 2*(exp(0)-1) = 0 // elu(-1, 2) = 2*(exp(-1)-1) = 2*(-0.632121) = -1.264241 // elu(-2, 2) = 2*(exp(-2)-1) = 2*(-0.864665) = -1.729329 let expected = TensorData::from([0.0, -1.264241, -1.729329]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/gelu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_gelu() { let tensor = TestTensor::<2>::from([[ 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, ]]); let output = activation::gelu(tensor); let expected = TensorData::from([[ 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, ]]); // Low precision to allow approximation implementation using tanh output.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_absolute(2e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/glu.rs ================================================ use super::*; use burn_tensor::{TensorData, activation}; #[test] fn test_glu_d3() { let tensor = TestTensor::<3>::from([[ [ -0.5710, -1.3416, 1.9128, -0.8257, -0.1331, -1.4804, -0.6281, -0.6115, ], [ 0.0267, -1.3834, 0.2752, 0.7844, -0.3549, -0.4274, 0.3290, -0.5459, ], [ -1.6347, -2.0908, 1.8801, 0.3541, 0.2237, 1.0377, 2.4850, 0.3490, ], ]]); let output = activation::glu(tensor, 2); output.into_data().assert_approx_eq::( &TensorData::from([[ [-0.2665, -0.2487, 0.6656, -0.2904], [0.0110, -0.5461, 0.1601, 0.2877], [-0.9084, -1.5439, 1.7355, 0.2077], ]]), Default::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/hard_sigmoid.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_hard_sigmoid() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::hard_sigmoid(tensor, 0.2, 0.5); let expected = TensorData::from([[0.7, 1.0], [1.0, 0.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_hard_sigmoid_overflow() { let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]); let output = activation::hard_sigmoid(tensor, 0.2, 0.5); let expected = TensorData::from([1.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/leaky_relu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_leaky_relu_d2() { let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); let output = activation::leaky_relu(tensor, 0.01); // Account for conversion errors if `FloatType != f32` output.into_data().assert_approx_eq::( &TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]), Tolerance::default(), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/log_sigmoid.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ElementConversion, TensorData, activation}; #[test] fn test_log_sigmoid() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::log_sigmoid(tensor); let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]); let tolerance = Tolerance::rel_abs(0.01, 0.0001); output .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_log_sigmoid_numerical_stability() { let tensor = TestTensor::<1>::from([300.0, -300.0]); let output = activation::log_sigmoid(tensor); // For large negative values, the previous implementation −log(1 + exp(−x)) would give -inf let expected = TensorData::from([0.0, -300.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]); let output = activation::log_sigmoid(tensor); let expected = TensorData::from([0.elem(), FloatElem::MIN]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/mish.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_mish() { let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]); let output = activation::mish(tensor); let expected = TensorData::from([ [-0.19709, -0.30056, -0.11714], [-0.24132, 0.58235, -0.08877], ]); // Metal has less precise trigonometric functions (tanh inside mish) let tolerance = Tolerance::default().set_half_precision_relative(1e-2); output .into_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/mod.rs ================================================ use super::*; mod celu; mod elu; mod gelu; mod glu; mod hard_sigmoid; mod leaky_relu; mod log_sigmoid; mod mish; mod prelu; mod quiet_softmax; mod relu; mod selu; mod sigmoid; mod silu; mod softmax; mod softmin; mod softplus; mod softsign; mod tanh_activation; mod thresholded_relu; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/prelu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_prelu_2_dimension() { let data = [ [-1.1, 0.0, 1.2, 0.25, -5.4], [-4.567, 0.56, -1.55, 99.9, 0.0], ]; let tensor = TestTensor::<2>::from(data); let output = activation::prelu(tensor, TestTensor::from([0.5, 0.25, 0.0, -0.8, -0.4])); let expected = TensorData::from([ [-0.5500, 0.0000, 1.2000, 0.2500, 2.1600], [-2.2835, 0.5600, -0.0000, 99.9000, -0.0000], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_prelu_2_dimension_scalar_weight() { let data = [ [-1.1, 0.0, 1.2, 0.25, -5.4], [-4.567, 0.56, -1.55, 99.9, 0.0], ]; let tensor = TestTensor::<2>::from(data); let output = activation::prelu(tensor, TestTensor::from([-0.8])); let expected = TensorData::from([ [0.8800, -0.0000, 1.2000, 0.2500, 4.3200], [3.6536, 0.5600, 1.2400, 99.9000, -0.0000], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_prelu_positives() { // Check that positives are untouched let data = [[ 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, ]]; let tensor = TestTensor::<2>::from(data); let output = activation::prelu(tensor, TestTensor::from([0.25])); let expected = TensorData::from(data); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_prelu_zero_weight() { // test that with weight 0 it behaves as relu let data = [-1.1, 0.0, 1.2, 0.25, -5.4]; let tensor = TestTensor::<1>::from(data); let output = activation::prelu(tensor, TestTensor::from([0.0])); let expected = TensorData::from([0.0, 0.0, 1.2, 0.25, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_prelu_some_weight() { // test that with some non zero weight it works like leaky relu let data = [-1.1, 0.0, 1.2, 0.25, -5.4]; let tensor = TestTensor::<1>::from(data); let output = activation::prelu(tensor, TestTensor::from([0.5])); let expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[should_panic] fn test_prelu_single_dim_multi_weight() { // should panic because the data has only 1 channel let data = [-1.1, 2.0, 1.2, 0.25, -5.4]; let tensor = TestTensor::<1>::from(data); let data_actual = activation::prelu(tensor, TestTensor::from([0.5, -0.25, 0.0, 0.5, -1.0])).into_data(); let data_expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]); data_expected.assert_approx_eq::(&data_actual, Tolerance::default()); } #[test] #[should_panic] fn test_prelu_multi_dim_wrong_weights() { let data = [ [-1.1, 0.0, 1.2, 0.25, -5.4], [-4.567, 0.56, -1.55, 99.9, 0.0], ]; let tensor = TestTensor::<2>::from(data); let _ = activation::prelu(tensor, TestTensor::from([-0.8, 0.1])); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/quiet_softmax.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_quiet_softmax_d2() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::quiet_softmax(tensor, 1); let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/relu.rs ================================================ use super::*; use burn_tensor::{TensorData, activation}; #[test] fn test_relu_d2() { let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]); let output = activation::relu(tensor); output .into_data() .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/selu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_selu() { // selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0 // alpha = 1.6733, gamma = 1.0507 let tensor = TestTensor::<2>::from([[0.0, 1.0, -1.0], [2.0, -2.0, 0.5]]); let output = activation::selu(tensor); // Expected values computed from the formula: // selu(0.0) = 1.0507 * 1.6733 * (exp(0) - 1) = 0.0 // selu(1.0) = 1.0507 * 1.0 = 1.0507 // selu(-1.0) = 1.0507 * 1.6733 * (exp(-1) - 1) = 1.7581 * (0.3679 - 1) = -1.1113 // selu(2.0) = 1.0507 * 2.0 = 2.1014 // selu(-2.0) = 1.0507 * 1.6733 * (exp(-2) - 1) = 1.7581 * (0.1353 - 1) = -1.5202 // selu(0.5) = 1.0507 * 0.5 = 0.5254 let expected = TensorData::from([[0.0, 1.0507, -1.1113], [2.1014, -1.5202, 0.5254]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_selu_zero() { let tensor = TestTensor::<1>::from([0.0]); let output = activation::selu(tensor); let expected = TensorData::from([0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/sigmoid.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_sigmoid() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::sigmoid(tensor); let expected = TensorData::from([[0.731059, 0.999089], [0.999998, 0.047426]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_sigmoid_overflow() { let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]); let output = activation::sigmoid(tensor); let expected = TensorData::from([1.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/silu.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_silu() { let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let output = activation::silu(tensor); let expected = TensorData::from([[0.73106, 1.76159], [2.85772, 3.92806]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/softmax.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_softmax_d2() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::softmax(tensor, 1); let expected = TensorData::from([[2.472623e-03, 9.975274e-01], [1.0, 1.125352e-07]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/softmin.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_softmin_d2() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::softmin(tensor, 1); let expected = TensorData::from([[9.975274e-01, 2.472623e-03], [1.125352e-07, 1.0000]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/softplus.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_softplus_d2() { let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]); let output = activation::softplus(tensor.clone(), 1.0); let expected = TensorData::from([ [0.503453, 0.324898, 0.588517], [0.445806, 1.117805, 0.615424], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let output = activation::softplus(tensor, 2.0); let expected = TensorData::from([ [0.178232, 0.068737, 0.247990], [0.137132, 0.827771, 0.272106], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/softsign.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_softsign() { let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); let output = activation::softsign(tensor); let expected = TensorData::from([[0.5, 0.875], [0.928571, -0.75]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_softsign_zero() { let tensor = TestTensor::<1>::from([0.0]); let output = activation::softsign(tensor); let expected = TensorData::from([0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/tanh_activation.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{TensorData, activation}; #[test] fn test_tanh() { let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]); let output = activation::tanh(tensor); let expected = TensorData::from([[0.761594, 0.964028], [0.995055, 0.999329]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/activation/thresholded_relu.rs ================================================ use super::*; use burn_tensor::{TensorData, activation}; #[test] fn test_thresholded_relu_d2() { // alpha = 1.0 (ONNX default): x if x > 1.0, else 0 let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 1.0, 0.5]]); let output = activation::thresholded_relu(tensor, 1.0); output .into_data() .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.0]]), false); } #[test] fn test_thresholded_relu_d2_alpha() { // alpha = 0.5: x if x > 0.5, else 0 let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 0.5, 0.6]]); let output = activation::thresholded_relu(tensor, 0.5); output .into_data() .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.6]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/grid/affine_grid.rs ================================================ use super::*; use burn_tensor::grid::affine_grid_2d; fn create_identity_transform(batch_size: usize) -> TestTensor<3> { // Identity affine transform (batch_size, 2, 3) TestTensor::<3>::from([[[1f32, 0., 0.], [0., 1., 0.]]]).expand([batch_size, 2, 3]) } #[test] fn test_affine_grid_identity() { let batch_size = 1; let channels = 1; let height = 2; let width = 2; let transform = create_identity_transform(batch_size); let output = affine_grid_2d(transform, [batch_size, channels, height, width]); // Expected normalized coords: // [-1, -1], [ 1,-1] // [-1, 1], [ 1, 1] let expected = TestTensor::<4>::from([[ [[-1f32, -1f32], [1f32, -1f32]], [[-1f32, 1f32], [1f32, 1f32]], ]]); output.into_data().assert_eq(&expected.into_data(), false); } #[test] fn test_affine_grid_scaling() { let batch_size = 1; let channels = 1; let height = 2; let width = 2; let scale = 2.0f32; let transform = TestTensor::<3>::from([[[scale, 0., 0.], [0., scale, 0.]]]); let output = affine_grid_2d(transform, [batch_size, channels, height, width]); // Expect scaled coordinates from normalized grid, so coords * 2 let expected = TestTensor::<4>::from([[ [[-2f32, -2f32], [2f32, -2f32]], [[-2f32, 2f32], [2f32, 2f32]], ]]); output.into_data().assert_eq(&expected.into_data(), false); } #[test] fn test_affine_grid_translation() { let batch_size = 1; let channels = 1; let height = 2; let width = 2; // Translate by 0.5 in x and -0.5 in y (normalized coords) let tx = 0.5f32; let ty = -0.5f32; let transform = TestTensor::<3>::from([[[1.0, 0.0, tx], [0.0, 1.0, ty]]]); let output = affine_grid_2d(transform, [batch_size, channels, height, width]); // Expected coordinates: // Original normalized coords are [-1,1] in x and y // After translation, each coordinate shifts by tx and ty // So points become: // [-1 + 0.5, -1 - 0.5] = [-0.5, -1.5] // [ 1 + 0.5, -1 - 0.5] = [1.5, -1.5] // [-1 + 0.5, 1 - 0.5] = [-0.5, 0.5] // [ 1 + 0.5, 1 - 0.5] = [1.5, 0.5] let expected = TestTensor::<4>::from([[ [[-0.5f32, -1.5f32], [1.5f32, -1.5f32]], [[-0.5f32, 0.5f32], [1.5f32, 0.5f32]], ]]); output.into_data().assert_eq(&expected.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/grid/meshgrid.rs ================================================ use super::*; use burn_tensor::BasicOps; use burn_tensor::Tensor; use burn_tensor::TensorData; use burn_tensor::backend::Backend; use burn_tensor::grid::{ GridIndexing, GridOptions, GridSparsity, IndexPos, meshgrid, meshgrid_stack, }; fn assert_tensors_equal( actual: &[Tensor; N], expected: &[Tensor; N], ) where K: BasicOps, { for (a, e) in actual.iter().zip(expected.iter()) { a.clone() .into_data() .assert_eq(&e.clone().into_data(), true); } } #[test] fn test_meshgrid() { let x = TestTensor::<1>::from([1, 2, 3, 4]); let y = TestTensor::<1>::from([5, 6]); let z = TestTensor::<1>::from([7, 8]); let grid_shape = [x.dims()[0], y.dims()[0], z.dims()[0]]; // 3D, Dense, Matrix assert_tensors_equal( &meshgrid(&[x.clone(), y.clone(), z.clone()], GridOptions::default()), &[ x.clone().reshape([4, 1, 1]).expand(grid_shape), y.clone().reshape([1, 2, 1]).expand(grid_shape), z.clone().reshape([1, 1, 2]).expand(grid_shape), ], ); assert_tensors_equal( &meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Dense), &[ x.clone().reshape([4, 1, 1]).expand(grid_shape), y.clone().reshape([1, 2, 1]).expand(grid_shape), z.clone().reshape([1, 1, 2]).expand(grid_shape), ], ); assert_tensors_equal( &meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Matrix), &[ x.clone().reshape([4, 1, 1]).expand(grid_shape), y.clone().reshape([1, 2, 1]).expand(grid_shape), z.clone().reshape([1, 1, 2]).expand(grid_shape), ], ); // 3D, Sparse, Matrix assert_tensors_equal( &meshgrid( &[x.clone(), y.clone(), z.clone()], GridOptions { indexing: GridIndexing::Matrix, sparsity: GridSparsity::Sparse, }, ), &[ x.clone().reshape([4, 1, 1]), y.clone().reshape([1, 2, 1]), z.clone().reshape([1, 1, 2]), ], ); assert_tensors_equal( &meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Sparse), &[ x.clone().reshape([4, 1, 1]), y.clone().reshape([1, 2, 1]), z.clone().reshape([1, 1, 2]), ], ); // 3D, Dense, Cartesian assert_tensors_equal( &meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Cartesian), &[ x.clone() .reshape([4, 1, 1]) .expand(grid_shape) .swap_dims(0, 1), y.clone() .reshape([1, 2, 1]) .expand(grid_shape) .swap_dims(0, 1), z.clone() .reshape([1, 1, 2]) .expand(grid_shape) .swap_dims(0, 1), ], ); // 3D, Sparse, Cartesian assert_tensors_equal( &meshgrid( &[x.clone(), y.clone(), z.clone()], GridOptions::new(GridIndexing::Cartesian, GridSparsity::Sparse), ), &[ x.clone().reshape([4, 1, 1]).swap_dims(0, 1), y.clone().reshape([1, 2, 1]).swap_dims(0, 1), z.clone().reshape([1, 1, 2]).swap_dims(0, 1), ], ); assert_tensors_equal( &meshgrid( &[x.clone(), y.clone(), z.clone()], GridOptions { indexing: GridIndexing::Cartesian, sparsity: GridSparsity::Sparse, }, ), &[ x.clone().reshape([4, 1, 1]).swap_dims(0, 1), y.clone().reshape([1, 2, 1]).swap_dims(0, 1), z.clone().reshape([1, 1, 2]).swap_dims(0, 1), ], ); } #[test] fn test_meshgrid_stack() { let tensors = [ TestTensor::from([0.5, 1.0, 2.5]), TestTensor::from([0.5, 1.0]), ]; let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::First); result.to_data().assert_eq( &TensorData::from([ [[0.5, 0.5], [1.0, 1.0], [2.5, 2.5]], [[0.5, 1.0], [0.5, 1.0], [0.5, 1.0]], ]), false, ); let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::Last); result.to_data().assert_eq( &TensorData::from([ [[0.5, 0.5], [0.5, 1.0]], [[1.0, 0.5], [1.0, 1.0]], [[2.5, 0.5], [2.5, 1.0]], ]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/grid/mod.rs ================================================ use super::*; pub(crate) mod affine_grid; pub(crate) mod meshgrid; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/cosine_similarity.rs ================================================ use super::*; use burn_tensor::{ElementConversion, Tolerance}; use burn_tensor::{TensorData, linalg}; #[test] fn test_cosine_similarity_basic() { // Create test tensors let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]]); let x2 = TestTensor::<2>::from([[1.5, 2.5, 3.5], [0.7, 1.7, 2.7]]); // Test cosine similarity along dimension 1 let expected = TensorData::from([[0.99983203], [0.99987257]]); linalg::cosine_similarity(x1.clone(), x2.clone(), 1, None) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); // Test with explicit epsilon linalg::cosine_similarity(x1.clone(), x2.clone(), 1, Some(1e-8.elem::())) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_cosine_similarity_orthogonal() { // Create orthogonal vectors let x1 = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]); let x2 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); // Orthogonal vectors should have cosine similarity of 0 let expected = TensorData::from([[0.0], [0.0]]); linalg::cosine_similarity(x1, x2, 1, None) .into_data() .assert_eq(&expected, false); } #[test] fn test_cosine_similarity_parallel() { // Create parallel vectors let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let x2 = TestTensor::<2>::from([[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]]); // Parallel vectors should have cosine similarity of 1 let expected = TensorData::from([[1.0], [1.0]]); linalg::cosine_similarity(x1, x2, 1, None) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_cosine_similarity_opposite() { // Create opposite direction vectors let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let x2 = TestTensor::<2>::from([[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]]); // Opposite vectors should have cosine similarity of -1 let expected = TensorData::from([[-1.0], [-1.0]]); linalg::cosine_similarity(x1, x2, 1, None) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_cosine_similarity_different_dimension() { // Test with a 3D tensor let x1 = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]); let x2 = TestTensor::<3>::from([[[2.0, 3.0], [4.0, 5.0]], [[6.0, 7.0], [8.0, 9.0]]]); // Test along dimension 2 let expected = TensorData::from([[[0.9959688], [0.9958376]], [[0.9955946], [0.9955169]]]); // sensitive to rounding in dot/norm; loosen f16 tolerance let tolerance = Tolerance::default().set_half_precision_relative(7e-3); linalg::cosine_similarity(x1.clone(), x2.clone(), 2, None) .into_data() .assert_approx_eq::(&expected, tolerance); // Test with negative dimension (-1 is the last dimension, which is 2 in this case) linalg::cosine_similarity(x1.clone(), x2.clone(), -1, None) .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_cosine_similarity_near_zero() { // Test with near-zero vectors let x1 = TestTensor::<2>::from([[1e-10, 2e-10, 3e-10], [4e-10, 5e-10, 6e-10]]); let x2 = TestTensor::<2>::from([[2e-10, 4e-10, 6e-10], [8e-10, 10e-10, 12e-10]]); // Update the expected values based on the actual implementation behavior let expected = TensorData::from([[0.0028], [0.0154]]); // Smaller values result in NaN on metal f16 let epsilon = Some(FloatElem::from_elem(1e-2)); let tolerance = Tolerance::absolute(0.2); linalg::cosine_similarity(x1, x2, 1, epsilon) .into_data() .assert_approx_eq::(&expected, tolerance); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/diag.rs ================================================ use super::*; use burn_tensor::{TensorData, linalg::diag}; #[test] fn test_diag_2d_square() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let result = diag::<_, 2, 1, _>(tensor); let expected = TensorData::from([1.0, 4.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_2d_tall() { let device = Default::default(); // 4x2 matrix (tall) - min(4,2) = 2 diagonal elements let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // Result should have shape [2] with values [1.0, 4.0] let expected = TensorData::from([1.0, 4.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_2d_wide() { let device = Default::default(); // 2x4 matrix (wide) - min(2,4) = 2 diagonal elements let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // Result should have shape [2] with values [1.0, 6.0] let expected = TensorData::from([1.0, 6.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_3d_batch_square() { let device = Default::default(); // Batch of 2 matrices, each 2x2 let tensor = TestTensor::<3>::from_data( [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], &device, ); let result = diag::<_, 3, 2, _>(tensor); // Result should have shape [2, 2] let expected = TensorData::from([[1.0, 4.0], [5.0, 8.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_3d_batch_tall() { let device = Default::default(); // Batch of 2 matrices, each 3x2 (tall) let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], ], &device, ); let result = diag::<_, 3, 2, _>(tensor); // Result should have shape [2, 2] - min(3,2) = 2 diagonal elements each let expected = TensorData::from([[1.0, 4.0], [7.0, 10.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_3d_batch_wide() { let device = Default::default(); // Batch of 2 matrices, each 2x3 (wide) let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], ], &device, ); let result = diag::<_, 3, 2, _>(tensor); // Result should have shape [2, 2] - min(2,3) = 2 diagonal elements each let expected = TensorData::from([[1.0, 5.0], [7.0, 11.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_4d_batch_channel_square() { let device = Default::default(); // [batch=2, channel=2, rows=2, cols=2] let tensor = TestTensor::<4>::from_data( [ [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], [[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]], ], &device, ); let result = diag::<_, 4, 3, _>(tensor); // Result should have shape [2, 2, 2] let expected = TensorData::from([[[1.0, 4.0], [5.0, 8.0]], [[9.0, 12.0], [13.0, 16.0]]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_4d_batch_channel_tall() { let device = Default::default(); // [batch=2, channel=1, rows=3, cols=2] let tensor = TestTensor::<4>::from_data( [ [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]], [[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]], ], &device, ); let result = diag::<_, 4, 3, _>(tensor); // Result should have shape [2, 1, 2] - min(3,2) = 2 diagonal elements each let expected = TensorData::from([[[1.0, 4.0]], [[7.0, 10.0]]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_4d_batch_channel_wide() { let device = Default::default(); // [batch=1, channel=2, rows=2, cols=4] let tensor = TestTensor::<4>::from_data( [[ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], ]], &device, ); let result = diag::<_, 4, 3, _>(tensor); // Result should have shape [1, 2, 2] - min(2,4) = 2 diagonal elements each let expected = TensorData::from([[[1.0, 6.0], [9.0, 14.0]]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_1x1() { let device = Default::default(); // Single element matrix let tensor = TestTensor::<2>::from_data([[5.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // Should return [5.0] with shape [1] let expected = TensorData::from([5.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_single_row() { let device = Default::default(); // Single row matrix let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // min(1,3) = 1, should return [1.0] with shape [1] let expected = TensorData::from([1.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_single_column() { let device = Default::default(); // Single column matrix let tensor = TestTensor::<2>::from_data([[1.0], [2.0], [3.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // min(3,1) = 1, should return [1.0] with shape [1] let expected = TensorData::from([1.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_zeros() { let device = Default::default(); // Matrix with zeros on diagonal let tensor = TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 0.0]], &device); let result = diag::<_, 2, 1, _>(tensor); // Should extract diagonal: [0.0, 0.0] let expected = TensorData::from([0.0, 0.0]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_batch_single_element() { let device = Default::default(); // Batch with single element matrices let tensor = TestTensor::<3>::from_data([[[5.0]], [[7.0]]], &device); let result = diag::<_, 3, 2, _>(tensor); // Should return [[5.0], [7.0]] with shape [2, 1] let expected = TensorData::from([[5.0], [7.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_batch_mixed_zeros() { let device = Default::default(); // Batch with mixed zero and non-zero diagonal elements let tensor = TestTensor::<3>::from_data( [[[1.0, 2.0], [3.0, 0.0]], [[0.0, 5.0], [6.0, 7.0]]], &device, ); let result = diag::<_, 3, 2, _>(tensor); // Should return [[1.0, 0.0], [0.0, 7.0]] with shape [2, 2] let expected = TensorData::from([[1.0, 0.0], [0.0, 7.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_int_tensor() { let device = Default::default(); // Test with integer tensor let tensor = TestTensorInt::<2>::from_data([[1, 2], [3, 4]], &device); let result = diag::<_, 2, 1, _>(tensor); // Result should have shape [2] with values [1, 4] let expected = TensorData::from([1, 4]); result.into_data().assert_eq(&expected, false); } #[test] fn test_diag_int_3x3() { let device = Default::default(); // Test with 3x3 integer matrix let tensor = TestTensorInt::<2>::from_data([[1, 2, 3], [4, 5, 6], [7, 8, 9]], &device); let result = diag::<_, 2, 1, _>(tensor); // Result should have shape [3] with values [1, 5, 9] let expected = TensorData::from([1, 5, 9]); result.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn test_diag_1d_should_panic() { let device = Default::default(); // 1D tensor should panic - diagonal requires at least 2 dimensions let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device); let _result = diag::<_, 1, 0, _>(tensor); } #[test] #[should_panic] fn test_diag_wrong_output_rank_should_panic() { let device = Default::default(); // Providing wrong output rank should panic let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let _result = diag::<_, 2, 2, _>(tensor); // Should be 2,1 not 2,2 } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/lu_decomposition.rs ================================================ use super::*; use burn_tensor::{ Distribution, Shape, TensorData, Tolerance, cast::ToElement, linalg::lu_decomposition, s, }; #[test] fn test_lu_2x2_decomposition() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[4.0, 3.0], [6.0, 3.0]], &device); let (result, _permutations) = lu_decomposition(tensor); let expected = TensorData::from([[6.0, 3.0], [2.0 / 3.0, 1.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_lu_3x3_decomposition() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [[0.0, 5.0, 22.0 / 3.0], [4.0, 2.0, 1.0], [2.0, 7.0, 9.0]], &device, ); let (result, permutations) = lu_decomposition(tensor); let expected = TestTensor::<2>::from_data( [ [4.0, 2.0, 1.0], [0.5, 6.0, 8.5], [0.0, 0.8333333, 0.25000048], ], &device, ); let expected_permutations = TensorData::from([1, 2, 0]); permutations .into_data() .assert_eq(&expected_permutations, false); let tolerance = Tolerance::default().set_half_precision_absolute(5e-3); result .into_data() .assert_approx_eq::(&expected.into_data(), tolerance); } #[test] #[should_panic] fn test_lu_singular_matrix() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [2.0, 4.0]], &device); let _result = lu_decomposition(tensor); } #[test] #[should_panic] fn test_lu_non_square_matrix() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let _result = lu_decomposition(tensor); } #[test] fn test_lu_1x1_element_matrix() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[5.0]], &device); let (result, _permutations) = lu_decomposition(tensor); let expected = TensorData::from([[5.0]]); result.into_data().assert_eq(&expected, false); } #[test] fn test_lu_identity_matrix() { let device = Default::default(); let tensor = TestTensor::<2>::eye(4, &device); let (result, _permutations) = lu_decomposition(tensor); let expected = TestTensor::<2>::eye(4, &device); result.into_data().assert_eq(&expected.into_data(), true); } #[test] fn test_lu_50x50_random_matrix() { let device = Default::default(); let size = 50; let distribution = Distribution::Uniform(0.0, 1.0); let tensor = TestTensor::<2>::random(Shape::new([size, size]), distribution, &device); let (result, permutations) = lu_decomposition(tensor.clone()); // Reconstruct the original matrix from L and U let mut l = TestTensor::<2>::eye(size, &device); let mut u = TestTensor::<2>::zeros(Shape::new([size, size]), &device); for i in 0..size { for j in 0..size { if i > j { l = l.slice_assign(s![i, j], result.clone().slice(s![i, j])); } else { u = u.slice_assign(s![i, j], result.clone().slice(s![i, j])); } } } // Construct the permutation matrix P from the permutation vector let mut p = TestTensor::<2>::zeros(Shape::new([size, size]), &device); for i in 0..size { let perm_index = permutations.clone().slice(s![i]).into_scalar().to_usize(); p = p.slice_assign( s![perm_index, i], TestTensor::<2>::from_data([[1.0]], &device), ); } // Verify that P * L * U reconstructs the original matrix let reconstructed = p.matmul(l).matmul(u); reconstructed .into_data() .assert_approx_eq::(&tensor.into_data(), Tolerance::permissive()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/matvec.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance, linalg}; #[test] fn test_matvec_basic_float() { let device = Default::default(); let matrix = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device); let vector = TestTensor::<1>::from_floats([5.0, 6.0], &device); let result = linalg::matvec::(matrix, vector); let expected = TensorData::from([17.0, 39.0]); result .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_matvec_basic_int() { let device = Default::default(); let matrix = TestTensorInt::<2>::from_ints([[2, 0, -1], [1, 3, 2]], &device); let vector = TestTensorInt::<1>::from_ints([3, -2, 4], &device); let result = linalg::matvec::(matrix, vector); let expected = TensorData::from([2, 5]); result.into_data().assert_eq(&expected, false); } #[test] fn test_matvec_batched() { let device = Default::default(); let matrix = TestTensor::<3>::from_floats( [ [[1.0, 0.0, 2.0], [3.0, 1.0, -1.0]], [[-2.0, 1.0, 0.0], [0.5, -1.5, 2.0]], ], &device, ); let vector = TestTensor::<2>::from_floats([[1.0, -1.0, 0.5], [2.0, 0.0, -1.0]], &device); let result = linalg::matvec::(matrix, vector); let expected = TensorData::from([[2.0, 1.5], [-4.0, -1.0]]); result .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_matvec_vector_broadcasts_over_batches() { let device = Default::default(); let matrix = TestTensor::<3>::from_floats( [ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[-1.0, 0.0, 2.0], [3.0, 1.0, -2.0]], ], &device, ); let vector = TestTensor::<2>::from_floats([[1.0, 0.0, -1.0]], &device); let result = linalg::matvec::(matrix, vector); let expected = TensorData::from([[-2.0, -2.0], [-3.0, 5.0]]); result .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_matvec_matrix_broadcasts_over_vector_batches() { let device = Default::default(); let matrix = TestTensor::<3>::from_floats([[[1.0, 0.0, 2.0], [3.0, -1.0, 1.0]]], &device); let vector = TestTensor::<2>::from_floats([[2.0, 1.0, 0.0], [1.0, -1.0, 3.0]], &device); let result = linalg::matvec::(matrix, vector); let expected = TensorData::from([[2.0, 5.0], [7.0, 7.0]]); result .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[should_panic] fn test_matvec_invalid_inner_dim_panics() { let device = Default::default(); let matrix = TestTensor::<2>::zeros([2, 3], &device); let vector = TestTensor::<1>::zeros([4], &device); let _ = linalg::matvec::(matrix, vector); } #[test] #[should_panic] fn test_matvec_mismatched_batches_panics() { let device = Default::default(); let matrix = TestTensor::<3>::zeros([2, 3, 4], &device); let vector = TestTensor::<2>::zeros([3, 4], &device); let _ = linalg::matvec::(matrix, vector); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs ================================================ use super::*; pub(crate) mod cosine_similarity; pub(crate) mod diag; pub(crate) mod lu_decomposition; pub(crate) mod matvec; pub(crate) mod outer; pub(crate) mod trace; pub(crate) mod vector_norm; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/outer.rs ================================================ use super::*; use burn_tensor::{ElementConversion, Tolerance}; use burn_tensor::{TensorData, linalg}; // ---------- Vector (D=1, R=2) tests ---------- #[test] fn test_outer_basic() { let u = TestTensor::<1>::from([1.0, 2.0, 3.0]); let v = TestTensor::<1>::from([4.0, 5.0]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::from([[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_shapes_only() { let device = Default::default(); let u = TestTensor::<1>::zeros([3], &device); let v = TestTensor::<1>::zeros([5], &device); let out = linalg::outer::(u, v); assert_eq!(out.shape().dims(), [3, 5]); } #[test] fn test_outer_asymmetry_and_shapes() { let u = TestTensor::<1>::from([1.0, 2.0]); let v = TestTensor::<1>::from([3.0, 4.0, 5.0]); let uv = linalg::outer::(u.clone(), v.clone()); let vu = linalg::outer::(v, u); assert_eq!(uv.shape().dims(), [2, 3]); assert_eq!(vu.shape().dims(), [3, 2]); } #[test] fn test_outer_zero_left() { let device = Default::default(); let u = TestTensor::<1>::zeros([3], &device); let v = TestTensor::<1>::from([7.0, 8.0]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::zeros::([3, 2]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_zero_right() { let device = Default::default(); let u = TestTensor::<1>::from([1.0, -2.0, 3.0]); let v = TestTensor::<1>::zeros([4], &device); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::zeros::([3, 4]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_signs() { let u = TestTensor::<1>::from([-1.0, 2.0]); let v = TestTensor::<1>::from([3.0, -4.0]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::from([[-3.0, 4.0], [6.0, -8.0]]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_integer_inputs() { let u = TestTensorInt::<1>::from([1, 2, 3]); let v = TestTensorInt::<1>::from([4, 5]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::from([[4, 5], [8, 10], [12, 15]]); out.assert_eq(&expected, false); } #[test] fn test_outer_equivalence_to_matmul() { let u = TestTensor::<1>::from([1.0, 2.0, 3.0]); let v = TestTensor::<1>::from([4.0, 5.0]); let out = linalg::outer::(u.clone(), v.clone()).into_data(); let u2 = u.reshape([3, 1]); let v2 = v.reshape([1, 2]); let out_matmul = u2.matmul(v2).into_data(); out.assert_approx_eq::(&out_matmul, Tolerance::default()); } #[test] fn test_outer_vector_identity_right_mult() { let u = TestTensor::<1>::from([2.0, -1.0]); let v = TestTensor::<1>::from([3.0, 4.0]); let w = TestTensor::<1>::from([5.0, 6.0]); let uv = linalg::outer::(u.clone(), v.clone()); let left = uv.matmul(w.clone().reshape([2, 1])).reshape([2]); let v_dot_w = v.dot(w); let right = u * v_dot_w; left.into_data() .assert_approx_eq::(&right.into_data(), Tolerance::default()); } #[test] fn test_outer_length_one_vectors() { let u = TestTensor::<1>::from([3.0]); let v = TestTensor::<1>::from([4.0, 5.0, 6.0]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::from([[12.0, 15.0, 18.0]]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_large_values() { let big = 1.0e10; let u = TestTensor::<1>::from([big, -big]); let v = TestTensor::<1>::from([big, big]); let out = linalg::outer::(u, v).into_data(); let expected = TensorData::from([[big * big, big * big], [-big * big, -big * big]]); let tol = Tolerance::relative(1e-6).set_half_precision_relative(1e-3); out.assert_approx_eq::(&expected, tol); } #[test] fn test_outer_nan_propagation() { let u = TestTensor::<1>::from([f32::NAN, 2.0]); let v = TestTensor::<1>::from([3.0, 4.0]); let out = linalg::outer::(u, v).into_data(); let s: &[FloatElem] = out .as_slice::() .expect("outer nan_propagation: as_slice failed"); assert!(s[0].is_nan()); assert!(s[1].is_nan()); assert_eq!(s[2], 6.0f32.elem::()); assert_eq!(s[3], 8.0f32.elem::()); } // ---------- Batched (D=2, R=3) tests ---------- #[test] fn test_outer_batched_basic() { let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]); let out = linalg::outer::(x, y).into_data(); let expected = TensorData::from([ [[5.0, 6.0, 7.0], [10.0, 12.0, 14.0]], [[24.0, 27.0, 30.0], [32.0, 36.0, 40.0]], ]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_batched_shapes() { let device = Default::default(); let x = TestTensor::<2>::zeros([3, 4], &device); let y = TestTensor::<2>::zeros([3, 5], &device); let out = linalg::outer::(x, y); assert_eq!(out.shape().dims(), [3, 4, 5]); } #[test] fn test_outer_batched_zero_left() { let device = Default::default(); let x = TestTensor::<2>::zeros([2, 3], &device); let y = TestTensor::<2>::from([[7.0, 8.0], [9.0, 10.0]]); let out = linalg::outer::(x, y).into_data(); let expected = TensorData::zeros::([2, 3, 2]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_batched_zero_right() { let device = Default::default(); let x = TestTensor::<2>::from([[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]]); let y = TestTensor::<2>::zeros([2, 4], &device); let out = linalg::outer::(x, y).into_data(); let expected = TensorData::zeros::([2, 3, 4]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_batched_signs() { let x = TestTensor::<2>::from([[-1.0, 2.0], [3.0, -4.0]]); let y = TestTensor::<2>::from([[3.0, -4.0], [-5.0, 6.0]]); let out = linalg::outer::(x, y).into_data(); let expected = TensorData::from([[[-3.0, 4.0], [6.0, -8.0]], [[-15.0, 18.0], [20.0, -24.0]]]); out.assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_outer_batched_equivalence_to_per_sample_outer() { let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]); let batched = linalg::outer::(x.clone(), y.clone()); for b in 0..2 { let idx = TestTensorInt::<1>::from([b]); let xb2d = x.clone().select(0, idx.clone()); // (1, m) let yb2d = y.clone().select(0, idx); // (1, n) let dims_x: [usize; 2] = xb2d.shape().dims(); let dims_y: [usize; 2] = yb2d.shape().dims(); let (m, n) = (dims_x[1], dims_y[1]); let per = linalg::outer::(xb2d.reshape([m]), yb2d.reshape([n])); let bat3d = batched.clone().select(0, TestTensorInt::<1>::from([b])); // (m, n) let per_len = per.shape().num_elements(); let per_flat = per.reshape([per_len]).into_data(); let bat_len = bat3d.shape().num_elements(); let bat_flat = bat3d.reshape([bat_len]).into_data(); bat_flat.assert_approx_eq::(&per_flat, Tolerance::default()); } } #[test] #[should_panic] fn test_outer_batched_mismatched_batches_panics() { let device = Default::default(); let x = TestTensor::<2>::zeros([2, 3], &device); let y = TestTensor::<2>::zeros([3, 4], &device); let _ = linalg::outer::(x, y); } #[test] fn test_outer_dim() { let u = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let v = TestTensor::<2>::from([[4.0, 5.0], [5.0, 6.0]]); let out = linalg::outer_dim::(u, v, 0).into_data(); let expected = TensorData::from([[[4.0, 10.0], [5.0, 12.0]], [[12.0, 20.0], [15.0, 24.0]]]); out.assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/trace.rs ================================================ use super::*; use burn_tensor::linalg::trace; #[test] fn test_trace_2d_square() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([15.0], &device); // 1 + 5 + 9 = 15 assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_2d_rectangular_wide() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([7.0], &device); // 1 + 6 = 7 assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_2d_rectangular_tall() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([5.0], &device); // 1 + 4 = 5 assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_3d_batch() { let device = Default::default(); let tensor = TestTensor::<3>::from_data( [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], &device, ); let result = trace::<_, 3, 2>(tensor); let expected = TestTensor::<2>::from_data([[5.0], [13.0]], &device); // [1+4=5, 5+8=13] assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_4d_batch() { let device = Default::default(); let tensor = TestTensor::<4>::from_data( [[ // Batch 0, Channel 0 [[1.0, 2.0], [3.0, 4.0]], // Batch 0, Channel 1 [[5.0, 6.0], [7.0, 8.0]], ]], &device, ); let result = trace::<_, 4, 3>(tensor); let expected = TestTensor::<3>::from_data([[[5.0], [13.0]]], &device); assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_single_element() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[42.0]], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([42.0], &device); assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_zeros() { let device = Default::default(); let tensor = TestTensor::<2>::zeros([3, 3], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([0.0], &device); assert_eq!(result.to_data(), expected.to_data()); } #[test] fn test_trace_negative_values() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[-1.0, 2.0], [3.0, -4.0]], &device); let result = trace::<_, 2, 1>(tensor); let expected = TestTensor::<1>::from_data([-5.0], &device); // -1 + (-4) = -5 assert_eq!(result.to_data(), expected.to_data()); } #[test] #[should_panic] fn test_trace_1d_should_panic() { let device = Default::default(); // 1D tensor should panic - trace requires at least 2 dimensions let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device); let _result = trace::<_, 1, 0>(tensor); } #[test] #[should_panic] fn test_trace_wrong_output_rank_should_panic() { let device = Default::default(); // Providing wrong output rank should panic let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let _result = trace::<_, 2, 2>(tensor); // Should be 2,1 not 2,2 } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/linalg/vector_norm.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use burn_tensor::backend::Backend; use burn_tensor::linalg; #[test] fn test_max_min_abs() { let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]); let expected = TestTensor::<2>::from([[3., 4.]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0) .into_data() .assert_eq(&expected, true); linalg::max_abs_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); let expected = TestTensor::<2>::from([[1., 2.]]).into_data(); linalg::vector_norm(x.clone(), -f64::INFINITY, 0) .into_data() .assert_eq(&expected, true); linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 0) .into_data() .assert_eq(&expected, true); linalg::min_abs_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); let expected = TestTensor::<2>::from([[2.], [4.]]).into_data(); linalg::vector_norm(x.clone(), f64::INFINITY, 1) .into_data() .assert_eq(&expected, true); linalg::max_abs_norm(x.clone(), 1) .into_data() .assert_eq(&expected, true); let expected = TestTensor::<2>::from([[1.], [3.]]).into_data(); linalg::vector_norm(x.clone(), -f64::INFINITY, 1) .into_data() .assert_eq(&expected, true); linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 1) .into_data() .assert_eq(&expected, true); linalg::min_abs_norm(x, 1) .into_data() .assert_eq(&expected, true); // Test with integer tensor let z = TestTensorInt::<2>::from([[1, 2], [3, 4]]); linalg::max_abs_norm(z.clone(), 0) .into_data() .assert_eq(&TestTensorInt::<2>::from([[3, 4]]).into_data(), true); linalg::max_abs_norm(z.clone(), 1) .into_data() .assert_eq(&TestTensorInt::<2>::from([[2], [4]]).into_data(), true); linalg::min_abs_norm(z.clone(), 0) .into_data() .assert_eq(&TestTensorInt::<2>::from([[1, 2]]).into_data(), true); linalg::min_abs_norm(z, 1) .into_data() .assert_eq(&TestTensorInt::<2>::from([[1], [3]]).into_data(), true); } #[test] fn test_l0_norm() { let x = TestTensor::<2>::from([[1.0, -2.0, 0.], [0.0, 0., 4.]]); let expected = TestTensor::<2>::from([[1., 1., 1.]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L0, 0) .into_data() .assert_eq(&expected, true); linalg::l0_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); let expected = TestTensor::<2>::from([[2.], [1.]]).into_data(); linalg::vector_norm(x.clone(), 0.0, 1) .into_data() .assert_eq(&expected, true); linalg::l0_norm(x.clone(), 1) .into_data() .assert_eq(&expected, true); // Test with integer tensor let z = TestTensorInt::<2>::from([[1, -2, 0], [0, 0, 4]]); linalg::l0_norm(z.clone(), 0) .into_data() .assert_eq(&TestTensor::<2>::from([[1, 1, 1]]).int().into_data(), true); linalg::l0_norm(z.clone(), 1) .into_data() .assert_eq(&TestTensor::<2>::from([[2], [1]]).int().into_data(), true); } #[test] fn test_l1_norm() { let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]); let expected = TestTensor::<2>::from([[4.0, 6.0]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L1, 0) .into_data() .assert_eq(&expected, true); linalg::l1_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); let expected = TestTensor::<2>::from([[3.0], [7.0]]).into_data(); linalg::vector_norm(x.clone(), 1.0, 1) .into_data() .assert_eq(&expected, true); linalg::l1_norm(x.clone(), 1) .into_data() .assert_eq(&expected, true); } #[test] fn test_lp_norm() { let x = TestTensor::<2>::from([[1., -2., 0.], [0., 3., 4.]]); let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3); fn lp_norm_naive( x: Tensor, p: f64, dim: usize, ) -> Tensor { x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p) } // Arbitrary P let expected = TestTensor::<2>::from([[1.0, 3.2710664, 4.0]]).into_data(); linalg::vector_norm(x.clone(), 3, 0) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::lp_norm(x.clone(), 3., 0) .into_data() .assert_approx_eq::(&expected, tolerance); // L0 let expected = TestTensor::<2>::from([[1., 2., 1.]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L0, 0) .into_data() .assert_eq(&expected, true); linalg::l0_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); linalg::lp_norm(x.clone(), 0.0, 0) .into_data() .assert_eq(&expected, true); // L1 let expected = TestTensor::<2>::from([[1.0, 5.0, 4.0]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L1, 0) .into_data() .assert_eq(&expected, true); linalg::l1_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); lp_norm_naive(x.clone(), 1.0, 0) .into_data() .assert_eq(&expected, true); linalg::lp_norm(x.clone(), 1.0, 0) .into_data() .assert_eq(&expected, true); // L2 let expected = TestTensor::<2>::from([[1.0, 3.6055512, 4.0]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L2, 0) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::l2_norm(x.clone(), 0) .into_data() .assert_approx_eq::(&expected, tolerance); lp_norm_naive(x.clone(), 2.0, 0) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::lp_norm(x.clone(), 2.0, 0) .into_data() .assert_approx_eq::(&expected, tolerance); // LInf let expected = TestTensor::<2>::from([[1.0, 3.0, 4.0]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0) .into_data() .assert_eq(&expected, true); linalg::max_abs_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); linalg::lp_norm(x.clone(), f64::INFINITY, 0) .into_data() .assert_approx_eq::(&expected, tolerance); // LNegInf let expected = TestTensor::<2>::from([[0.0, 2.0, 0.0]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::LNegInf, 0) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::min_abs_norm(x.clone(), 0) .into_data() .assert_eq(&expected, true); linalg::lp_norm(x.clone(), f64::NEG_INFINITY, 0) .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_l2_norm() { let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]); let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3); let expected = TestTensor::<2>::from([[3.16227766, 4.47213595]]).into_data(); linalg::vector_norm(x.clone(), linalg::Norm::L2, 0) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::l2_norm(x.clone(), 0) .into_data() .assert_approx_eq::(&expected, tolerance); let expected = TestTensor::<2>::from([[2.23606798], [5.0]]).into_data(); linalg::vector_norm(x.clone(), 2.0, 1) .into_data() .assert_approx_eq::(&expected, tolerance); linalg::l2_norm(x.clone(), 1) .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_normalize() { let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]); let expected = TensorData::from([[1. / 4., 2. / 6.], [3. / 4., 4. / 6.]]); let output = linalg::vector_normalize(x.clone(), 1.0, 0, 0.25).into_data(); output.assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([[1. / 5., 2. / 6.], [3. / 5., 4. / 6.]]); let output = linalg::vector_normalize(x.clone(), 1.0, 0, 5.0).into_data(); output.assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/mod.rs ================================================ #[allow(unused_imports)] pub use super::*; // re-export test types mod activation; mod grid; mod linalg; mod module; mod ops; mod primitive; mod stats; #[cfg(feature = "quantization")] mod quantization; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool1d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::adaptive_avg_pool1d; #[test] fn test_adaptive_avg_pool1d_simple() { let test = AdaptiveAvgPool1dTestCase { batch_size: 1, channels: 2, length: 8, length_out: 4, }; test.assert_output(TestTensor::from([[ [0.5, 2.5, 4.5, 6.5], [8.5, 10.5, 12.5, 14.5], ]])); } #[test] fn test_adaptive_avg_pool1d_dyn_filter_size() { let test = AdaptiveAvgPool1dTestCase { batch_size: 1, channels: 2, length: 7, length_out: 3, }; test.assert_output(TestTensor::from([[[1.0, 3.0, 5.0], [8.0, 10.0, 12.0]]])); } #[test] fn test_adaptive_avg_pool1d_bigger_output() { let test = AdaptiveAvgPool1dTestCase { batch_size: 1, channels: 2, length: 4, length_out: 8, }; test.assert_output(TestTensor::from([[ [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0], [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0], ]])); } struct AdaptiveAvgPool1dTestCase { batch_size: usize, channels: usize, length: usize, length_out: usize, } impl AdaptiveAvgPool1dTestCase { fn assert_output(self, y: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels, self.length]); let device = Default::default(); let x = TestTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ); let output = adaptive_avg_pool1d(x, self.length_out); y.into_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool2d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::adaptive_avg_pool2d; #[test] fn test_adaptive_avg_pool2d_simple() { let test = AdaptiveAvgPool2dTestCase { batch_size: 1, channels: 2, height: 8, width: 6, height_out: 4, width_out: 4, }; test.assert_output(TestTensor::from([[ [ [3.5000, 4.5000, 6.5000, 7.5000], [15.5000, 16.5000, 18.5000, 19.5000], [27.5000, 28.5000, 30.5000, 31.5000], [39.5000, 40.5000, 42.5000, 43.5000], ], [ [51.5000, 52.5000, 54.5000, 55.5000], [63.5000, 64.5000, 66.5000, 67.5000], [75.5000, 76.5000, 78.5000, 79.5000], [87.5000, 88.5000, 90.5000, 91.5000], ], ]])); } #[test] fn test_adaptive_avg_pool2d_dyn_filter_size() { let test = AdaptiveAvgPool2dTestCase { batch_size: 1, channels: 2, height: 5, width: 7, height_out: 3, width_out: 2, }; test.assert_output(TestTensor::from([[ [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]], [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]], ]])); } #[test] fn test_adaptive_avg_pool2d_bigger_output() { let test = AdaptiveAvgPool2dTestCase { batch_size: 1, channels: 2, height: 4, width: 3, height_out: 5, width_out: 4, }; test.assert_output(TestTensor::from([[ [ [0.0000, 0.5000, 1.5000, 2.0000], [1.5000, 2.0000, 3.0000, 3.5000], [4.5000, 5.0000, 6.0000, 6.5000], [7.5000, 8.0000, 9.0000, 9.5000], [9.0000, 9.5000, 10.5000, 11.0000], ], [ [12.0000, 12.5000, 13.5000, 14.0000], [13.5000, 14.0000, 15.0000, 15.5000], [16.5000, 17.0000, 18.0000, 18.5000], [19.5000, 20.0000, 21.0000, 21.5000], [21.0000, 21.5000, 22.5000, 23.0000], ], ]])); } struct AdaptiveAvgPool2dTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl AdaptiveAvgPool2dTestCase { fn assert_output(self, y: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data(), ); let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/attention.rs ================================================ use super::*; use burn_tensor::Distribution; use burn_tensor::Tolerance; use burn_tensor::module::attention; use burn_tensor::module::attention_fallback; use burn_tensor::ops::AttentionModuleOptions; #[test] fn test_attention_no_mask() { // Skip on metal with f16 - flash attention returns zeros // Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4325 #[cfg(feature = "metal")] if core::any::TypeId::of::() == core::any::TypeId::of::() { return; } let num_batches = 1; let num_heads = 1; let seq_q = 128; let seq_kv = 128; let head_dim = 64; let val_dim = 64; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_q, head_dim], Distribution::Uniform(0., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_kv, head_dim], Distribution::Uniform(0., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_kv, val_dim], Distribution::Uniform(0., 1.), &Default::default(), ); let output = attention( query.clone(), key.clone(), value.clone(), None, None, Default::default(), ); let expected = attention_fallback::(query, key, value, None, None, Default::default()); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } #[test] fn test_attention_custom_scale() { let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let options = AttentionModuleOptions { scale: Some(0.1), ..Default::default() }; let output = attention( query.clone(), key.clone(), value.clone(), None, None, options, ); let expected = attention_fallback::(query, key, value, None, None, options); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } #[test] fn test_attention_attn_bias() { let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let bias = TestTensor::<4>::random( [num_batches, num_heads, seq_len, seq_len], Distribution::Uniform(-0.5, 0.5), &Default::default(), ); let output = attention( query.clone(), key.clone(), value.clone(), None, Some(bias.clone()), Default::default(), ); let expected = attention_fallback::(query, key, value, None, Some(bias), Default::default()); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } #[test] fn test_attention_softcap() { let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let options = AttentionModuleOptions { softcap: Some(50.0), ..Default::default() }; let output = attention( query.clone(), key.clone(), value.clone(), None, None, options, ); let expected = attention_fallback::(query, key, value, None, None, options); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } #[test] fn test_attention_is_causal() { let [num_batches, num_heads, seq_len, head_dim] = [2, 4, 16, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let options = AttentionModuleOptions { is_causal: true, ..Default::default() }; let output = attention( query.clone(), key.clone(), value.clone(), None, None, options, ); let expected = attention_fallback::(query, key, value, None, None, options); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } /// Cross-attention: seq_q != seq_k, with causal masking and additive bias. #[test] fn test_attention_cross_attention_with_bias() { let [num_batches, num_heads, seq_q, seq_k, head_dim] = [2, 2, 8, 24, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_q, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_k, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_k, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let bias = TestTensor::<4>::random( [num_batches, num_heads, seq_q, seq_k], Distribution::Uniform(-0.5, 0.5), &Default::default(), ); let options = AttentionModuleOptions { is_causal: true, ..Default::default() }; let output = attention( query.clone(), key.clone(), value.clone(), None, Some(bias.clone()), options, ); let expected = attention_fallback::(query, key, value, None, Some(bias), options); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } /// Regression: softcap must be applied before -inf masking. /// With causal masking, position 0 can only attend to itself, so output[0] == value[0]. /// If softcap were applied after masking, tanh(-inf/softcap) = -softcap (finite), /// and the masked position would leak into the output. #[test] fn test_attention_softcap_preserves_causal_mask() { let [num_batches, num_heads, seq_len, head_dim] = [1, 1, 4, 8]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let options = AttentionModuleOptions { softcap: Some(20.0), is_causal: true, ..Default::default() }; let output = attention_fallback::(query, key, value.clone(), None, None, options); // With causal masking, position 0 can only attend to itself (softmax = [1, 0, 0, 0]). // So output[..., 0, :] must equal value[..., 0, :]. let output_row0 = output.slice([0..1, 0..1, 0..1, 0..head_dim]); let value_row0 = value.slice([0..1, 0..1, 0..1, 0..head_dim]); output_row0 .into_data() .assert_approx_eq::(&value_row0.into_data(), Tolerance::relative(1e-5)); } /// Combined: mask + bias + custom scale + softcap together. #[test] fn test_attention_all_options() { let [num_batches, num_heads, seq_len, head_dim] = [2, 2, 16, 32]; let query = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let key = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let value = TestTensor::<4>::random( [num_batches, num_heads, seq_len, head_dim], Distribution::Uniform(-1., 1.), &Default::default(), ); let bias = TestTensor::<4>::random( [num_batches, num_heads, seq_len, seq_len], Distribution::Uniform(-0.5, 0.5), &Default::default(), ); // Create a random bool mask by thresholding a uniform float tensor let mask = TestTensor::<4>::random( [num_batches, num_heads, seq_len, seq_len], Distribution::Uniform(0., 1.), &Default::default(), ) .greater_elem(0.7); let options = AttentionModuleOptions { scale: Some(0.05), softcap: Some(30.0), is_causal: true, }; let output = attention( query.clone(), key.clone(), value.clone(), Some(mask.clone()), Some(bias.clone()), options, ); let expected = attention_fallback::(query, key, value, Some(mask), Some(bias), options); output.into_data().assert_approx_eq::( &expected.into_data(), Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/avgpool1d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::avg_pool1d; #[test] fn test_avg_pool1d_simple() { let test = AvgPool1dTestCase { batch_size: 1, channels: 1, kernel_size: 3, padding: 0, stride: 1, length: 6, count_include_pad: true, }; test.assert_output(TestTensor::from([[[1., 2., 3., 4.]]])); } #[test] fn test_avg_pool1d_complex() { let test = AvgPool1dTestCase { batch_size: 1, channels: 2, kernel_size: 3, padding: 1, stride: 2, length: 6, count_include_pad: true, }; test.assert_output(TestTensor::from([[ [0.33333, 2.0000, 4.0000], [4.33333, 8.0000, 10.0000], ]])); } #[test] fn test_avg_pool1d_complex_dont_count_pad() { let test = AvgPool1dTestCase { batch_size: 1, channels: 2, kernel_size: 3, padding: 1, stride: 2, length: 6, count_include_pad: false, }; test.assert_output(TestTensor::from([[ [0.5000, 2.0000, 4.0000], [6.5000, 8.0000, 10.0000], ]])); } struct AvgPool1dTestCase { batch_size: usize, channels: usize, kernel_size: usize, padding: usize, stride: usize, length: usize, count_include_pad: bool, } impl AvgPool1dTestCase { fn assert_output(self, y: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels, self.length]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<3, _>(shape_x) .into_data(), ); let output = avg_pool1d( x, self.kernel_size, self.stride, self.padding, self.count_include_pad, false, ); y.to_data().assert_approx_eq::( &output.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } } #[test] fn test_avg_pool1d_ceil_mode() { // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride // Input: 1x1x6 (values 0-5), kernel: 3, stride: 2, padding: 0 // Floor mode: output = (6-3)/2+1 = 2 elements // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]); // With ceil_mode=false (floor): output is 2 elements // Window 0: avg(0,1,2) = 1 // Window 1: avg(2,3,4) = 3 let y_floor = TestTensor::<3>::from([[[1.0, 3.0]]]); let output_floor = avg_pool1d( x.clone(), 3, // kernel_size 2, // stride 0, // padding true, // count_include_pad false, ); y_floor.to_data().assert_approx_eq::( &output_floor.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); // With ceil_mode=true: output is 3 elements // Window 0: avg(0,1,2) = 1 // Window 1: avg(2,3,4) = 3 // Window 2: avg(4,5) = 4.5 (partial window, count_include_pad=false divides by 2) let y_ceil = TestTensor::<3>::from([[[1.0, 3.0, 4.5]]]); let output_ceil = avg_pool1d( x, 3, // kernel_size 2, // stride 0, // padding false, // count_include_pad=false to get correct average for partial window true, ); y_ceil.to_data().assert_approx_eq::( &output_ceil.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } #[test] fn test_avg_pool1d_ceil_mode_count_include_pad() { // Test count_include_pad=true + ceil_mode=true interaction for 1D // When ceil_mode creates windows that extend beyond the padded input: // - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions) // // Input: 1x1x6, kernel 3, stride 2, padding 1, ceil_mode=true // Output is 4 elements let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]); // Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true: // Window 0: positions -1,0,1 -> values 0,0,1 (0 is padding) / 3 = 0.333 // Window 1: positions 1,2,3 -> values 1,2,3 / 3 = 2.0 // Window 2: positions 3,4,5 -> values 3,4,5 / 3 = 4.0 // Window 3: positions 5,6,7 -> only 5 is valid, 6 is padding, 7 is ceil_mode extension // value 5 / 2 (only 2 positions within padded bounds) = 2.5 let expected = TestTensor::<3>::from([[[0.3333, 2.0, 4.0, 2.5]]]); let output = avg_pool1d( x, 3, // kernel_size 2, // stride 1, // padding true, // count_include_pad=true true, // ceil_mode=true ); expected.to_data().assert_approx_eq::( &output.into_data(), Tolerance::default().set_half_precision_relative(1e-2), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/avgpool2d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::avg_pool2d; #[test] fn test_avg_pool2d_simple() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, height: 6, width: 6, count_include_pad: true, }; test.assert_output(TestTensor::from([[[ [7., 8., 9., 10.], [13., 14., 15., 16.], [19., 20., 21., 22.], [25., 26., 27., 28.], ]]])); } #[test] fn test_avg_pool2d_complex() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 2, height: 4, width: 6, count_include_pad: true, }; test.assert_output(TestTensor::from([[[ [1.1667, 3.0000, 4.3333, 2.5000], [3.2500, 7.5000, 9.5000, 5.2500], [6.2500, 13.5000, 15.5000, 8.2500], [5.1667, 11.0000, 12.3333, 6.5000], ]]])); } #[test] fn test_avg_pool2d_complex_dont_include_pad() { let test = AvgPool2dTestCase { batch_size: 1, channels: 1, kernel_size_1: 3, kernel_size_2: 4, padding_1: 1, padding_2: 2, stride_1: 1, stride_2: 2, height: 4, width: 6, count_include_pad: false, }; test.assert_output(TestTensor::from([[[ [3.5000, 4.5000, 6.5000, 7.5000], [6.5000, 7.5000, 9.5000, 10.5000], [12.5000, 13.5000, 15.5000, 16.5000], [15.5000, 16.5000, 18.5000, 19.5000], ]]])); } struct AvgPool2dTestCase { batch_size: usize, channels: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, height: usize, width: usize, count_include_pad: bool, } impl AvgPool2dTestCase { fn assert_output(self, y: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data(), ); let output = avg_pool2d( x, [self.kernel_size_1, self.kernel_size_2], [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], self.count_include_pad, false, ); y.to_data().assert_approx_eq::( &output.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } } #[test] fn test_avg_pool2d_ceil_mode() { // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride // Input: 1x1x6x6 (values 0-35), kernel: 3x3, stride: 2x2, padding: 0x0 // Floor mode: output = (6-3)/2+1 = 2 x 2 // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 x 3 let x = TestTensor::from([[[ [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], [18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [24.0, 25.0, 26.0, 27.0, 28.0, 29.0], [30.0, 31.0, 32.0, 33.0, 34.0, 35.0], ]]]); // With ceil_mode=false (floor): output is 2x2 // Window (0,0): avg(0,1,2,6,7,8,12,13,14) = avg(63) = 7 // Window (0,1): avg(2,3,4,8,9,10,14,15,16) = avg(81) = 9 // Window (1,0): avg(12,13,14,18,19,20,24,25,26) = avg(171) = 19 // Window (1,1): avg(14,15,16,20,21,22,26,27,28) = avg(189) = 21 let y_floor = TestTensor::<4>::from([[[[7.0, 9.0], [19.0, 21.0]]]]); let output_floor = avg_pool2d( x.clone(), [3, 3], [2, 2], [0, 0], true, // count_include_pad false, ); y_floor.to_data().assert_approx_eq::( &output_floor.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); // With ceil_mode=true: output is 3x3 // The extra windows at the edge include partial/padded regions // When count_include_pad=false, only actual values are averaged // Window (0,2): positions (0:3, 4:6) -> values 4,5,10,11,16,17 -> avg = 10.5 // Window (1,2): positions (2:5, 4:6) -> values 16,17,22,23,28,29 -> avg = 22.5 // Window (2,0): positions (4:6, 0:3) -> values 24,25,26,30,31,32 -> avg = 28 // Window (2,1): positions (4:6, 2:5) -> values 26,27,28,32,33,34 -> avg = 30 // Window (2,2): positions (4:6, 4:6) -> values 28,29,34,35 -> avg = 31.5 let y_ceil = TestTensor::<4>::from([[[[7.0, 9.0, 10.5], [19.0, 21.0, 22.5], [28.0, 30.0, 31.5]]]]); let output_ceil = avg_pool2d( x, [3, 3], [2, 2], [0, 0], false, // count_include_pad=false to avoid dividing by full kernel size true, ); y_ceil.to_data().assert_approx_eq::( &output_ceil.into_data(), Tolerance::default().set_half_precision_relative(1e-3), ); } #[test] fn test_avg_pool2d_ceil_mode_count_include_pad() { // Test count_include_pad=true + ceil_mode=true interaction // When ceil_mode creates windows that extend beyond the padded input: // - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions) // // For input 6x6, kernel 3, stride 2, padding 1, ceil_mode=true: // - Output is 4x4 // - Corner (3,3) window covers positions beyond even the user padding // - Expected: 35/4 = 8.75 (divides by count of positions within padded bounds) let x = TestTensor::from([[[ [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], [18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [24.0, 25.0, 26.0, 27.0, 28.0, 29.0], [30.0, 31.0, 32.0, 33.0, 34.0, 35.0], ]]]); // Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true // Note: corner (3,3) = 8.75 = 35/4, not 35/9 let expected = TestTensor::<4>::from([[[ [1.5556, 3.3333, 4.6667, 2.6667], [8.3333, 14.0000, 16.0000, 8.5000], [16.3333, 26.0000, 28.0000, 14.5000], [10.1667, 16.0000, 17.0000, 8.7500], ]]]); let output = avg_pool2d( x, [3, 3], [2, 2], [1, 1], true, // count_include_pad=true true, // ceil_mode=true ); expected.to_data().assert_approx_eq::( &output.into_data(), Tolerance::default().set_half_precision_relative(1e-2), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/bicubic_interpolate.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::interpolate; use burn_tensor::ops::{InterpolateMode, InterpolateOptions}; #[test] fn test_upsample_interpolation() { let test = InterpolateTestCase { batch_size: 2, channels: 1, height: 7, width: 5, height_out: 8, width_out: 7, }; test.assert_output(TestTensor::from([ [[ [0.0000, 0.5741, 1.3704, 2.0000, 2.6296, 3.4259, 4.0000], [4.0015, 4.5755, 5.3718, 6.0015, 6.6311, 7.4274, 8.0015], [8.3528, 8.9268, 9.7231, 10.3528, 10.9824, 11.7787, 12.3528], [ 12.7697, 13.3438, 14.1400, 14.7697, 15.3993, 16.1956, 16.7697, ], [ 17.2303, 17.8044, 18.6007, 19.2303, 19.8600, 20.6562, 21.2303, ], [ 21.6472, 22.2213, 23.0176, 23.6472, 24.2769, 25.0731, 25.6472, ], [ 25.9986, 26.5726, 27.3689, 27.9986, 28.6282, 29.4245, 29.9986, ], [ 30.0000, 30.5741, 31.3704, 32.0000, 32.6296, 33.4259, 34.0000, ], ]], [[ [ 35.0000, 35.5741, 36.3704, 37.0000, 37.6296, 38.4259, 39.0000, ], [ 39.0015, 39.5755, 40.3718, 41.0015, 41.6311, 42.4274, 43.0015, ], [ 43.3528, 43.9269, 44.7231, 45.3528, 45.9824, 46.7787, 47.3528, ], [ 47.7697, 48.3438, 49.1400, 49.7697, 50.3993, 51.1956, 51.7697, ], [ 52.2303, 52.8044, 53.6007, 54.2303, 54.8600, 55.6562, 56.2303, ], [ 56.6472, 57.2213, 58.0176, 58.6472, 59.2769, 60.0731, 60.6472, ], [ 60.9986, 61.5726, 62.3689, 62.9986, 63.6282, 64.4245, 64.9986, ], [ 65.0000, 65.5741, 66.3704, 67.0000, 67.6296, 68.4259, 69.0000, ], ]], ])); } #[test] fn test_downsample_interpolation() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 45, width: 14, height_out: 4, width_out: 6, }; test.assert_output(TestTensor::from([[[ [0.0000, 2.5760, 5.2480, 7.7520, 10.4240, 13.0000], [204.8148, 207.3908, 210.0628, 212.5668, 215.2388, 217.8148], [411.1852, 413.7612, 416.4331, 418.9371, 421.6091, 424.1852], [616.0000, 618.576, 621.2479, 623.7519, 626.4239, 629.0000], ]]])); } #[test] fn test_1d_bicubic() { // Initialize the model without weights (because the exported file does not contain them) let device = Default::default(); // Run the model let input = TestTensor::<3>::from_floats( [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], &device, ); let input = input.unsqueeze_dim(2); let output = interpolate( input, [1, 9], InterpolateOptions::new(InterpolateMode::Bicubic), ); assert_eq!(output.dims(), [1, 1, 1, 9]); // assert output data does not contain NaN assert!( !output .clone() .to_data() .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), "interpolate output contains NaN" ); TestTensor::<4>::from([[[[ 1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794, -1.3986, ]]]]) .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } struct InterpolateTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl InterpolateTestCase { fn assert_output(self, y: TestTensor<4>) { self.assert_output_with_align_corners(y, true); } fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data(), ); let output = interpolate( x, [self.height_out, self.width_out], InterpolateOptions::new(InterpolateMode::Bicubic).with_align_corners(align_corners), ); let tolerance = Tolerance::permissive(); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } #[test] fn test_upsample_half_pixel() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 4, width: 4, height_out: 8, width_out: 8, }; test.assert_output_with_align_corners( TestTensor::from([[[ [ -0.5273, -0.2305, 0.2461, 0.875, 1.2812, 1.9102, 2.3867, 2.6836, ], [ 0.6602, 0.957, 1.4336, 2.0625, 2.4688, 3.0977, 3.5742, 3.8711, ], [ 2.5664, 2.8633, 3.3398, 3.9688, 4.375, 5.0039, 5.4805, 5.7773, ], [5.082, 5.3789, 5.8555, 6.4844, 6.8906, 7.5195, 7.9961, 8.293], [6.707, 7.0039, 7.4805, 8.1094, 8.5156, 9.1445, 9.6211, 9.918], [ 9.2227, 9.5195, 9.9961, 10.625, 11.0312, 11.6602, 12.1367, 12.4336, ], [ 11.1289, 11.4258, 11.9023, 12.5312, 12.9375, 13.5664, 14.043, 14.3398, ], [ 12.3164, 12.6133, 13.0898, 13.7188, 14.125, 14.7539, 15.2305, 15.5273, ], ]]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/bilinear_interpolate.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::module::interpolate; use burn_tensor::ops::{InterpolateMode, InterpolateOptions}; use burn_tensor::{DType, Shape}; #[test] fn test_upsample_interpolation() { let test = InterpolateTestCase { batch_size: 2, channels: 1, height: 7, width: 5, height_out: 8, width_out: 7, }; test.assert_output(TestTensor::from([ [[ [0.0000, 0.6667, 1.3333, 2.0000, 2.6667, 3.3333, 4.0000], [4.2857, 4.9524, 5.6190, 6.2857, 6.9524, 7.6190, 8.2857], [8.5714, 9.2381, 9.9048, 10.5714, 11.2381, 11.9048, 12.5714], [ 12.8571, 13.5238, 14.1905, 14.8571, 15.5238, 16.1905, 16.8571, ], [ 17.1429, 17.8095, 18.4762, 19.1429, 19.8095, 20.4762, 21.1429, ], [ 21.4286, 22.0952, 22.7619, 23.4286, 24.0952, 24.7619, 25.4286, ], [ 25.7143, 26.3810, 27.0476, 27.7143, 28.3810, 29.0476, 29.7143, ], [ 30.0000, 30.6667, 31.3333, 32.0000, 32.6667, 33.3333, 34.0000, ], ]], [[ [ 35.0000, 35.6667, 36.3333, 37.0000, 37.6667, 38.3333, 39.0000, ], [ 39.2857, 39.9524, 40.6190, 41.2857, 41.9524, 42.6190, 43.2857, ], [ 43.5714, 44.2381, 44.9048, 45.5714, 46.2381, 46.9048, 47.5714, ], [ 47.8571, 48.5238, 49.1905, 49.8571, 50.5238, 51.1905, 51.8571, ], [ 52.1429, 52.8095, 53.4762, 54.1429, 54.8095, 55.4762, 56.1429, ], [ 56.4286, 57.0952, 57.7619, 58.4286, 59.0952, 59.7619, 60.4286, ], [ 60.7143, 61.3810, 62.0476, 62.7143, 63.3810, 64.0476, 64.7143, ], [ 65.0000, 65.6667, 66.3333, 67.0000, 67.6667, 68.3333, 69.0000, ], ]], ])); } #[test] fn test_downsample_interpolation() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 45, width: 14, height_out: 4, width_out: 6, }; test.assert_output(TestTensor::from([[[ [0.0, 2.6, 5.2, 7.8, 10.4, 13.], [205.3333, 207.9333, 210.5333, 213.1333, 215.7333, 218.3333], [410.6667, 413.2667, 415.8667, 418.4667, 421.0667, 423.6667], [616., 618.6, 621.2, 623.8, 626.4, 629.], ]]])); } #[test] fn test_1d_bilinear() { // Initialize the model without weights (because the exported file does not contain them) let device = Default::default(); // Run the model let input = TestTensor::<3>::from_floats( [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], &device, ); let input = input.unsqueeze_dim(2); let output = interpolate( input, [1, 9], InterpolateOptions::new(InterpolateMode::Bilinear), ); assert_eq!(output.dims(), [1, 1, 1, 9]); // assert output data does not contain NaN assert!( !output .clone() .to_data() .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), "interpolate output contains NaN" ); TestTensor::<4>::from([[[[ 1.541f32, 0.39450002, -0.76475, -1.943125, -0.80520004, 0.36178753, -0.671275, -1.2022874, -1.3986, ]]]]) .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_interpolate_coord_float_precision_boundary() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 28, width: 4, height_out: 24, width_out: 2, }; test.assert_output(TestTensor::from([[[ [0.0, 3.0], [4.6956, 7.6956], [9.3913, 12.3913], [14.0869, 17.0869], [18.7826, 21.7826], [23.4782, 26.4782], [28.1739, 31.1739], [32.8695, 35.8695], [37.5652, 40.5652], [42.2608, 45.2608], [46.9565, 49.9565], [51.6521, 54.6521], [56.3478, 59.3478], [61.0434, 64.0434], [65.7391, 68.7391], [70.4347, 73.4347], [75.1304, 78.1304], [79.8260, 82.8260], [84.5217, 87.5217], [89.2173, 92.2173], [93.9130, 96.9130], [98.6086, 101.6086], [103.3043, 106.3043], [108.0, 111.0], ]]])); } #[test] fn should_interpolate_cast() { let device = Default::default(); let shape_x = Shape::new([1, 1, 4, 4]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), ) .cast(DType::F32); // ok for f32 backends, casts dtype for f16 tests let output = interpolate( x, [8, 8], InterpolateOptions::new(InterpolateMode::Bilinear), ); let expected = TestTensor::<4>::from([[[ [0.0, 0.42857, 0.8571, 1.2857, 1.7142, 2.1428, 2.5714, 3.0], [1.7142, 2.1428, 2.5714, 3.0, 3.4285, 3.8571, 4.2857, 4.7142], [3.4285, 3.8571, 4.2857, 4.7142, 5.1428, 5.5714, 6.0, 6.4285], [5.1428, 5.5714, 6.0, 6.4285, 6.8571, 7.2857, 7.7142, 8.1428], [6.8571, 7.2857, 7.7142, 8.1428, 8.5714, 9.0, 9.4285, 9.8571], [ 8.5714, 9.0, 9.4285, 9.8571, 10.2857, 10.7142, 11.1428, 11.5714, ], [ 10.2857, 10.7142, 11.1428, 11.5714, 12.0, 12.4285, 12.8571, 13.2857, ], [ 12.0, 12.4285, 12.8571, 13.2857, 13.7142, 14.1428, 14.5714, 15.0, ], ]]]); let tolerance = Tolerance::permissive(); output .into_data() .assert_approx_eq::(&expected.into_data(), tolerance); } struct InterpolateTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl InterpolateTestCase { fn assert_output(self, y: TestTensor<4>) { self.assert_output_with_align_corners(y, true); } fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data(), ); let output = interpolate( x, [self.height_out, self.width_out], InterpolateOptions::new(InterpolateMode::Bilinear).with_align_corners(align_corners), ); let tolerance = Tolerance::permissive(); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } #[test] fn test_upsample_half_pixel() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 4, width: 4, height_out: 8, width_out: 8, }; test.assert_output_with_align_corners( TestTensor::from([[[ [0.0, 0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.0], [1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0], [3.0, 3.25, 3.75, 4.25, 4.75, 5.25, 5.75, 6.0], [5.0, 5.25, 5.75, 6.25, 6.75, 7.25, 7.75, 8.0], [7.0, 7.25, 7.75, 8.25, 8.75, 9.25, 9.75, 10.0], [9.0, 9.25, 9.75, 10.25, 10.75, 11.25, 11.75, 12.0], [11.0, 11.25, 11.75, 12.25, 12.75, 13.25, 13.75, 14.0], [12.0, 12.25, 12.75, 13.25, 13.75, 14.25, 14.75, 15.0], ]]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv1d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::conv1d; use burn_tensor::ops::ConvOptions; #[test] fn test_conv1d_simple() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 1, groups: 1, length: 4, }; test.assert_output(TestTensor::from([ [[43., 67., 82., 49.], [104., 176., 227., 158.]], [[139., 187., 202., 113.], [392., 584., 635., 414.]], ])); } #[test] fn test_conv1d_dilation() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 2, groups: 1, length: 4, }; test.assert_output(TestTensor::from([ [[62., 38.], [159., 111.]], [[158., 102.], [447., 367.]], ])); } #[test] fn test_conv1d_groups() { let test = Conv1dTestCase { batch_size: 2, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, stride: 1, dilation: 1, groups: 2, length: 4, }; test.assert_output(TestTensor::from([ [[2., 5., 8., 3.], [42., 63., 75., 47.]], [[26., 29., 32., 11.], [114., 159., 171., 103.]], ])); } #[test] fn test_conv1d_complex() { let test = Conv1dTestCase { batch_size: 2, channels_in: 3, channels_out: 4, kernel_size: 3, padding: 1, stride: 2, dilation: 1, groups: 1, length: 4, }; test.assert_output(TestTensor::from_floats( [ [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]], [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]], ], &Default::default(), )); } struct Conv1dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size: usize, padding: usize, stride: usize, dilation: usize, groups: usize, length: usize, } impl Conv1dTestCase { fn assert_output(self, y: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size, ]); let device = Default::default(); let weight = TestTensor::from_data( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<3, _>(shape_weight) .into_data(), &device, ); let bias = TestTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ); let x = TestTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ); let output = conv1d( x, weight, Some(bias), ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups), ); let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv2d.rs ================================================ use super::*; use alloc::{vec, vec::Vec}; use burn_tensor::Shape; use burn_tensor::activation::gelu; use burn_tensor::module::conv2d; use burn_tensor::ops::ConvOptions; use burn_tensor::{TensorData, Tolerance}; #[test] fn test_conv2d_simple() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [1196., 1796., 1916., 1264.], [1881., 2793., 2946., 1923.], [2313., 3405., 3558., 2307.], [1424., 2072., 2156., 1380.], ], [ [2709., 4173., 4509., 3065.], [4582., 7006., 7483., 5056.], [5878., 8914., 9391., 6304.], [4089., 6177., 6477., 4333.], ], ]])); } #[test] fn test_conv2d_simple_implicit() { let test = Conv2dTestCase { batch_size: 1, channels_in: 1, channels_out: 16, kernel_size_1: 4, kernel_size_2: 4, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 5, width: 5, }; test.assert_output(TestTensor::from([[ [ [666., 916., 1030., 774.], [1124., 1500., 1620., 1190.], [1604., 2100., 2220., 1610.], [990., 1264., 1330., 936.], ], [ [1531., 2165., 2471., 1927.], [2757., 3805., 4181., 3207.], [4197., 5685., 6061., 4587.], [3295., 4433., 4691., 3529.], ], [ [2396., 3414., 3912., 3080.], [4390., 6110., 6742., 5224.], [6790., 9270., 9902., 7564.], [5600., 7602., 8052., 6122.], ], [ [3261., 4663., 5353., 4233.], [6023., 8415., 9303., 7241.], [9383., 12855., 13743., 10541.], [7905., 10771., 11413., 8715.], ], [ [4126., 5912., 6794., 5386.], [7656., 10720., 11864., 9258.], [11976., 16440., 17584., 13518.], [10210., 13940., 14774., 11308.], ], [ [4991., 7161., 8235., 6539.], [9289., 13025., 14425., 11275.], [14569., 20025., 21425., 16495.], [12515., 17109., 18135., 13901.], ], [ [5856., 8410., 9676., 7692.], [10922., 15330., 16986., 13292.], [17162., 23610., 25266., 19472.], [14820., 20278., 21496., 16494.], ], [ [6721., 9659., 11117., 8845.], [12555., 17635., 19547., 15309.], [19755., 27195., 29107., 22449.], [17125., 23447., 24857., 19087.], ], [ [7586., 10908., 12558., 9998.], [14188., 19940., 22108., 17326.], [22348., 30780., 32948., 25426.], [19430., 26616., 28218., 21680.], ], [ [8451., 12157., 13999., 11151.], [15821., 22245., 24669., 19343.], [24941., 34365., 36789., 28403.], [21735., 29785., 31579., 24273.], ], [ [9316., 13406., 15440., 12304.], [17454., 24550., 27230., 21360.], [27534., 37950., 40630., 31380.], [24040., 32954., 34940., 26866.], ], [ [10181., 14655., 16881., 13457.], [19087., 26855., 29791., 23377.], [30127., 41535., 44471., 34357.], [26345., 36123., 38301., 29459.], ], [ [11046., 15904., 18322., 14610.], [20720., 29160., 32352., 25394.], [32720., 45120., 48312., 37334.], [28650., 39292., 41662., 32052.], ], [ [11911., 17153., 19763., 15763.], [22353., 31465., 34913., 27411.], [35313., 48705., 52153., 40311.], [30955., 42461., 45023., 34645.], ], [ [12776., 18402., 21204., 16916.], [23986., 33770., 37474., 29428.], [37906., 52290., 55994., 43288.], [33260., 45630., 48384., 37238.], ], [ [13641., 19651., 22645., 18069.], [25619., 36075., 40035., 31445.], [40499., 55875., 59835., 46265.], [35565., 48799., 51745., 39831.], ], ]])); } #[test] fn test_conv2d_implicit_padded_in_channels() { let test = Conv2dTestCase { batch_size: 1, channels_in: 3, channels_out: 16, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [4521., 6753., 7014., 4635.], [6858., 10197., 10548., 6939.], [7830., 11601., 11952., 7839.], [5007., 7383., 7590., 4953.], ], [ [10516., 15988., 16735., 11278.], [16822., 25507., 26587., 17875.], [19738., 29827., 30907., 20719.], [13594., 20506., 21199., 14188.], ], [ [16511., 25223., 26456., 17921.], [26786., 40817., 42626., 28811.], [31646., 48053., 49862., 33599.], [22181., 33629., 34808., 23423.], ], [ [22506., 34458., 36177., 24564.], [36750., 56127., 58665., 39747.], [43554., 66279., 68817., 46479.], [30768., 46752., 48417., 32658.], ], [ [28501., 43693., 45898., 31207.], [46714., 71437., 74704., 50683.], [55462., 84505., 87772., 59359.], [39355., 59875., 62026., 41893.], ], [ [34496., 52928., 55619., 37850.], [56678., 86747., 90743., 61619.], [67370., 102731., 106727., 72239.], [47942., 72998., 75635., 51128.], ], [ [40491., 62163., 65340., 44493.], [66642., 102057., 106782., 72555.], [79278., 120957., 125682., 85119.], [56529., 86121., 89244., 60363.], ], [ [46486., 71398., 75061., 51136.], [76606., 117367., 122821., 83491.], [91186., 139183., 144637., 97999.], [65116., 99244., 102853., 69598.], ], [ [52481., 80633., 84782., 57779.], [86570., 132677., 138860., 94427.], [103094., 157409., 163592., 110879.], [73703., 112367., 116462., 78833.], ], [ [58476., 89868., 94503., 64422.], [96534., 147987., 154899., 105363.], [115002., 175635., 182547., 123759.], [82290., 125490., 130071., 88068.], ], [ [64471., 99103., 104224., 71065.], [106498., 163297., 170938., 116299.], [126910., 193861., 201502., 136639.], [90877., 138613., 143680., 97303.], ], [ [70466., 108338., 113945., 77708.], [116462., 178607., 186977., 127235.], [138818., 212087., 220457., 149519.], [99464., 151736., 157289., 106538.], ], [ [76461., 117573., 123666., 84351.], [126426., 193917., 203016., 138171.], [150726., 230313., 239412., 162399.], [108051., 164859., 170898., 115773.], ], [ [82456., 126808., 133387., 90994.], [136390., 209227., 219055., 149107.], [162634., 248539., 258367., 175279.], [116638., 177982., 184507., 125008.], ], [ [88451., 136043., 143108., 97637.], [146354., 224537., 235094., 160043.], [174542., 266765., 277322., 188159.], [125225., 191105., 198116., 134243.], ], [ [94446., 145278., 152829., 104280.], [156318., 239847., 251133., 170979.], [186450., 284991., 296277., 201039.], [133812., 204228., 211725., 143478.], ], ]])); } #[test] fn test_conv2d_groups_channels_out() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 16, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [73., 121., 154., 103.], [171., 258., 294., 186.], [279., 402., 438., 270.], [139., 187., 202., 113.], ], [ [164., 284., 371., 266.], [415., 664., 781., 538.], [739., 1132., 1249., 838.], [518., 782., 851., 564.], ], [ [255., 447., 588., 429.], [659., 1070., 1268., 890.], [1199., 1862., 2060., 1406.], [897., 1377., 1500., 1015.], ], [ [346., 610., 805., 592.], [903., 1476., 1755., 1242.], [1659., 2592., 2871., 1974.], [1276., 1972., 2149., 1466.], ], [ [437., 773., 1022., 755.], [1147., 1882., 2242., 1594.], [2119., 3322., 3682., 2542.], [1655., 2567., 2798., 1917.], ], [ [528., 936., 1239., 918.], [1391., 2288., 2729., 1946.], [2579., 4052., 4493., 3110.], [2034., 3162., 3447., 2368.], ], [ [619., 1099., 1456., 1081.], [1635., 2694., 3216., 2298.], [3039., 4782., 5304., 3678.], [2413., 3757., 4096., 2819.], ], [ [710., 1262., 1673., 1244.], [1879., 3100., 3703., 2650.], [3499., 5512., 6115., 4246.], [2792., 4352., 4745., 3270.], ], [ [5793., 8865., 9330., 6335.], [9467., 14450., 15134., 10250.], [11303., 17186., 17870., 12062.], [7971., 12099., 12546., 8457.], ], [ [6460., 9892., 10411., 7074.], [10575., 16152., 16917., 11466.], [12627., 19212., 19977., 13494.], [8926., 13558., 14059., 9484.], ], [ [7127., 10919., 11492., 7813.], [11683., 17854., 18700., 12682.], [13951., 21238., 22084., 14926.], [9881., 15017., 15572., 10511.], ], [ [7794., 11946., 12573., 8552.], [12791., 19556., 20483., 13898.], [15275., 23264., 24191., 16358.], [10836., 16476., 17085., 11538.], ], [ [8461., 12973., 13654., 9291.], [13899., 21258., 22266., 15114.], [16599., 25290., 26298., 17790.], [11791., 17935., 18598., 12565.], ], [ [9128., 14000., 14735., 10030.], [15007., 22960., 24049., 16330.], [17923., 27316., 28405., 19222.], [12746., 19394., 20111., 13592.], ], [ [9795., 15027., 15816., 10769.], [16115., 24662., 25832., 17546.], [19247., 29342., 30512., 20654.], [13701., 20853., 21624., 14619.], ], [ [10462., 16054., 16897., 11508.], [17223., 26364., 27615., 18762.], [20571., 31368., 32619., 22086.], [14656., 22312., 23137., 15646.], ], ]])); } #[test] fn test_conv2d_groups() { let test = Conv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 5, width: 5, }; test.assert_output(TestTensor::from([[ [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]], [ [3724., 3841., 3958.], [4309., 4426., 4543.], [4894., 5011., 5128.], ], ]])); } #[test] fn test_conv2d_groups_multiple_channels() { let test = Conv2dTestCase { batch_size: 1, channels_in: 4, channels_out: 4, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 5, width: 5, }; test.assert_output(TestTensor::from([[ [ [4035., 4188., 4341.], [4800., 4953., 5106.], [5565., 5718., 5871.], ], [ [10030., 10507., 10984.], [12415., 12892., 13369.], [14800., 15277., 15754.], ], [ [56075., 56876., 57677.], [60080., 60881., 61682.], [64085., 64886., 65687.], ], [ [78270., 79395., 80520.], [83895., 85020., 86145.], [89520., 90645., 91770.], ], ]])); } #[test] fn test_conv2d_complex() { let test = Conv2dTestCase { batch_size: 2, channels_in: 3, channels_out: 4, kernel_size_1: 3, kernel_size_2: 2, padding_1: 1, padding_2: 2, stride_1: 2, stride_2: 3, dilation_1: 1, dilation_2: 2, groups: 1, height: 4, width: 5, }; test.assert_output(TestTensor::from([ [ [[1845., 3789., 1926.], [3210., 6465., 3228.]], [[4276., 9082., 4789.], [8071., 16834., 8737.]], [[6707., 14375., 7652.], [12932., 27203., 14246.]], [[9138., 19668., 10515.], [17793., 37572., 19755.]], ], [ [[5445., 10629., 5166.], [8070., 15645., 7548.]], [[14356., 28882., 14509.], [22651., 45454., 22777.]], [[23267., 47135., 23852.], [37232., 75263., 38006.]], [[32178., 65388., 33195.], [51813., 105072., 53235.]], ], ])); } struct Conv2dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, dilation_1: usize, dilation_2: usize, groups: usize, height: usize, width: usize, } impl Conv2dTestCase { fn assert_output(self, y: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, ]); let device = Default::default(); let weight = TestTensor::from( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight) .into_data(), ); let bias = TestTensor::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), ); let output = conv2d( x, weight, Some(bias), ConvOptions::new( [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.dilation_1, self.dilation_2], self.groups, ), ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } #[rustfmt::skip] // param values are too long fn conv2d_weight() -> TensorData { TensorData::new( vec![0.048065186, -0.3059082, -0.10345459, -0.34643555, -0.20788574, -0.021072388, 0.13745117, -0.05102539, 0.024536133, -0.16479492, -0.19519043, 0.27270508, 0.17700195, -0.33764648, -0.08239746, -0.27929688, 0.17321777, -0.1315918, 0.04574585, -0.17980957, -0.33569336, 0.27612305, 0.30004883, -0.28979492, -0.17297363, -0.021759033, -0.27148438, 0.005657196, 0.29956055, -0.06958008, -0.29345703, -0.14440918, 0.10827637, -0.13305664, -0.20239258, 0.24890137, -0.1541748, -0.20019531, -0.2854004, 0.17016602, 0.07861328, -0.09075928, 0.30908203, -0.00013422966, 0.29589844, 0.15258789, -0.25708008, 0.20422363, -0.2529297, 0.07891846, -0.19506836, 0.23571777, 0.27124023, 0.17370605, -0.16992188, -0.23522949, 0.14648438, -0.09576416, -0.18310547, 0.21044922, -0.08911133, -0.2541504, -0.2775879, -0.2064209, -0.16271973, -0.048919678, -0.03555298, -0.11639404, 0.09661865, -0.10241699, 0.08929443, 0.2866211], [8, 1, 3, 3], ) } #[test] fn test_conv2d_binary_broadcasted() { let device = Default::default(); let x = TestTensor::<4>::full([1, 1, 28, 28], -0.42421296, &device); // conv2d -> batchnorm -> activation let weight = TestTensor::from_data(conv2d_weight(), &device); let bias = TestTensor::from([ 0.082336426, -0.049591064, 0.0031795502, 0.00095653534, 0.02357483, 0.005569458, 0.07525635, 0.056396484, ]); // channels: [1, 8], kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], groups: 1, padding: [0, 0] let opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 1); let x = conv2d(x, weight, Some(bias), opt); // simulate batchnorm binary ops with broadcasted params let gamma = TestTensor::<1>::from([ 1.0048828, 0.9902344, 1.0185547, 0.97558594, 1.0097656, 0.97802734, 1.0009766, 1.0146484, ]); let beta = TestTensor::<1>::from([ 0.026290894, 0.0007505417, 0.006134033, 0.02418518, 0.07373047, 0.020507813, 0.01902771, 0.02003479, ]); let mean = TestTensor::<1>::from([ 0.029159546, -0.08673096, -0.03894043, -0.01108551, 0.032440186, 0.03237915, 0.013839722, 0.04397583, ]) .reshape([1, 8, 1, 1]); let var = TestTensor::<1>::from([ 0.67089844, 0.29956055, 0.5209961, 0.1862793, 0.30419922, 0.21313477, 0.7504883, 0.26342773, ]) .reshape([1, 8, 1, 1]); let std = var.add_scalar(1e-5).sqrt(); let x = x.sub(mean); let x = x.div(std); let x = x.mul(gamma.reshape([1, 8, 1, 1])); let x = x.add(beta.reshape([1, 8, 1, 1])); let x = gelu(x); let expected: Vec = [ 0.36432067f32, 0.34909567, 0.30684796, 0.13217466, -0.018471397, -0.1389876, 0.39402074, 0.12394252, ] .iter() .flat_map(|&v| core::iter::repeat_n(v, 676)) .collect(); let expected = TensorData::new(expected, [1, 8, 26, 26]); x.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_absolute(1e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv3d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::conv3d; use burn_tensor::ops::ConvOptions; #[test] fn test_conv3d_simple() { let test = Conv3dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 4, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [ [29980.0, 44860.0, 45640.0, 30324.0], [45072.0, 67380.0, 68496.0, 45468.0], [48096.0, 71844.0, 72960.0, 48396.0], [31780.0, 47428.0, 48136.0, 31900.0], ], [ [47292.0, 70548.0, 71556.0, 47400.0], [70335.0, 104823.0, 106254.0, 70317.0], [74223.0, 110547.0, 111978.0, 74061.0], [48552.0, 72240.0, 73140.0, 48324.0], ], [ [58236.0, 86676.0, 87684.0, 57960.0], [85887.0, 127719.0, 129150.0, 85293.0], [89775.0, 133443.0, 134874.0, 89037.0], [58344.0, 86640.0, 87540.0, 57732.0], ], [ [36148.0, 53620.0, 54184.0, 35692.0], [52740.0, 78144.0, 78936.0, 51936.0], [54900.0, 81312.0, 82104.0, 54000.0], [35260.0, 52156.0, 52648.0, 34580.0], ], ], [ [ [66701.0, 100589.0, 102665.0, 68773.0], [102745.0, 154861.0, 157921.0, 105733.0], [110953.0, 167101.0, 170161.0, 113845.0], [75413.0, 113525.0, 115529.0, 77261.0], ], [ [112741.0, 169693.0, 172645.0, 115441.0], [172396.0, 259372.0, 263719.0, 176266.0], [184060.0, 276760.0, 281107.0, 187786.0], [124369.0, 186937.0, 189781.0, 126733.0], ], [ [144421.0, 216925.0, 219877.0, 146737.0], [219052.0, 328924.0, 333271.0, 222346.0], [230716.0, 346312.0, 350659.0, 233866.0], [154897.0, 232441.0, 235285.0, 156877.0], ], [ [100517.0, 150821.0, 152681.0, 101789.0], [151885.0, 227833.0, 230569.0, 153673.0], [159229.0, 238777.0, 241513.0, 160921.0], [106541.0, 159725.0, 161513.0, 107589.0], ], ], ]])); } #[test] fn test_conv3d_groups() { let test = Conv3dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 0, padding_2: 0, padding_3: 0, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 2, depth: 5, height: 5, width: 5, }; test.assert_output(TestTensor::from([[ [ [ [15219., 15570., 15921.], [16974., 17325., 17676.], [18729., 19080., 19431.], ], [ [23994., 24345., 24696.], [25749., 26100., 26451.], [27504., 27855., 28206.], ], [ [32769., 33120., 33471.], [34524., 34875., 35226.], [36279., 36630., 36981.], ], ], [ [ [172819., 173899., 174979.], [178219., 179299., 180379.], [183619., 184699., 185779.], ], [ [199819., 200899., 201979.], [205219., 206299., 207379.], [210619., 211699., 212779.], ], [ [226819., 227899., 228979.], [232219., 233299., 234379.], [237619., 238699., 239779.], ], ], ]])); } #[test] fn test_conv3d_complex() { let test = Conv3dTestCase { batch_size: 2, channels_in: 3, channels_out: 4, kernel_size_1: 4, kernel_size_2: 3, kernel_size_3: 2, padding_1: 1, padding_2: 2, padding_3: 3, stride_1: 2, stride_2: 3, stride_3: 4, dilation_1: 1, dilation_2: 2, dilation_3: 3, groups: 1, depth: 4, height: 5, width: 6, }; test.assert_output(TestTensor::from([ [ [ [[149148., 299070., 149850.], [147636., 295758., 148050.]], [[150660., 301014., 150282.], [147420., 294246., 146754.]], ], [ [[351325., 709903., 358507.], [357589., 722143., 364483.]], [[391717., 789607., 397819.], [396253., 798391., 402067.]], ], [ [[553502., 1120736., 567164.], [567542., 1148528., 580916.]], [[632774., 1278200., 645356.], [645086., 1302536., 657380.]], ], [ [[755679., 1531569., 775821.], [777495., 1574913., 797349.]], [[873831., 1766793., 892893.], [893919., 1806681., 912693.]], ], ], [ [ [[408348., 810990., 402570.], [393876., 781758., 387810.]], [[370980., 735174., 364122.], [354780., 702486., 347634.]], ], [ [ [1077085., 2154943., 1077787.], [1070389., 2141263., 1070803.], ], [ [1078597., 2156887., 1078219.], [1070173., 2139751., 1069507.], ], ], [ [ [1745822., 3498896., 1753004.], [1746902., 3500768., 1753796.], ], [ [1786214., 3578600., 1792316.], [1785566., 3577016., 1791380.], ], ], [ [ [2414559., 4842849., 2428221.], [2423415., 4860273., 2436789.], ], [ [2493831., 5000313., 2506413.], [2500959., 5014281., 2513253.], ], ], ], ])); } struct Conv3dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, kernel_size_3: usize, padding_1: usize, padding_2: usize, padding_3: usize, stride_1: usize, stride_2: usize, stride_3: usize, dilation_1: usize, dilation_2: usize, dilation_3: usize, groups: usize, depth: usize, height: usize, width: usize, } impl Conv3dTestCase { fn assert_output(self, y: TestTensor<5>) { let shape_x = Shape::new([ self.batch_size, self.channels_in, self.depth, self.height, self.width, ]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.groups, self.kernel_size_1, self.kernel_size_2, self.kernel_size_3, ]); let device = Default::default(); let weight = TestTensor::from( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<5, _>(shape_weight) .into_data(), ); let bias = TestTensor::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<5, _>(shape_x) .into_data(), ); let output = conv3d( x, weight, Some(bias), ConvOptions::new( [self.stride_1, self.stride_2, self.stride_3], [self.padding_1, self.padding_2, self.padding_3], [self.dilation_1, self.dilation_2, self.dilation_3], self.groups, ), ); let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose1d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::conv_transpose1d; use burn_tensor::ops::ConvTransposeOptions; #[test] fn test_conv_transpose1d_diff_channels() { let test = ConvTranspose1dTestCase { batch_size: 1, channels_in: 3, channels_out: 2, kernel_size: 3, padding: 1, padding_out: 0, stride: 1, dilation: 1, groups: 1, length: 4, }; test.assert_output(TestTensor::from([[ [270., 453., 516., 387.], [352., 589., 679., 505.], ]])); } #[test] fn test_conv_transpose1d_stride() { let test = ConvTranspose1dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, padding_out: 1, stride: 2, dilation: 1, groups: 1, length: 4, }; test.assert_output(TestTensor::from([[ [28., 62., 36., 78., 44., 94., 52., 62.], [41., 93., 55., 121., 69., 149., 83., 93.], ]])); } #[test] fn test_conv_transpose1d_dilation() { let test = ConvTranspose1dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, padding_out: 0, stride: 1, dilation: 2, groups: 1, length: 4, }; test.assert_output(TestTensor::from([[ [30., 64., 78., 76., 94., 52.], [49., 101., 127., 113., 143., 77.], ]])); } #[test] fn test_conv_transpose1d_groups() { let test = ConvTranspose1dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size: 3, padding: 1, padding_out: 0, stride: 1, dilation: 1, groups: 2, length: 4, }; test.assert_output(TestTensor::from_floats( [[[0., 1., 4., 7.], [32., 59., 71., 59.]]], &Default::default(), )); } struct ConvTranspose1dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size: usize, padding: usize, padding_out: usize, stride: usize, dilation: usize, groups: usize, length: usize, } impl ConvTranspose1dTestCase { fn assert_output(self, y: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]); let shape_weights = Shape::new([ self.channels_in, self.channels_out / self.groups, self.kernel_size, ]); let device = Default::default(); let weights = TestTensor::from_data( TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device) .reshape::<3, _>(shape_weights) .into_data(), &device, ); let bias = TestTensor::from_data( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), &device, ); let x = TestTensor::from_data( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<3, _>(shape_x) .into_data(), &device, ); let output = conv_transpose1d( x, weights, Some(bias), ConvTransposeOptions::new( [self.stride], [self.padding], [self.padding_out], [self.dilation], self.groups, ), ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose2d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::conv_transpose2d; use burn_tensor::ops::ConvTransposeOptions; #[test] fn test_conv_transpose2d_simple_1() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 1, channels_out: 1, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, padding_out_1: 0, padding_out_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 2, width: 2, }; test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]])); } #[test] fn test_conv_transpose2d_simple_2() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 3, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, padding_out_1: 0, padding_out_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [9855., 15207., 15738., 10797.], [16290., 25119., 25956., 17793.], [18486., 28467., 29304., 20061.], [13593., 20913., 21498., 14703.], ], [ [11854., 18286., 18979., 13012.], [19612., 30223., 31303., 21439.], [22456., 34543., 35623., 24355.], [16456., 25288., 26035., 17782.], ], [ [13853., 21365., 22220., 15227.], [22934., 35327., 36650., 25085.], [26426., 40619., 41942., 28649.], [19319., 29663., 30572., 20861.], ], ]])); } #[test] fn test_conv_transpose2d_simple_3() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 1, channels_out: 1, kernel_size_1: 2, kernel_size_2: 2, padding_1: 0, padding_2: 0, padding_out_1: 0, padding_out_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 1, height: 2, width: 2, }; test.assert_output(TestTensor::from([[[ [0.0, 0.0, 1.0], [0.0, 4.0, 6.0], [4.0, 12.0, 9.0], ]]])); } #[test] fn test_conv_transpose2d_stride_2() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 1, channels_out: 1, kernel_size_1: 2, kernel_size_2: 2, padding_1: 0, padding_2: 0, padding_out_1: 0, padding_out_2: 0, stride_1: 2, stride_2: 2, dilation_1: 1, dilation_2: 1, groups: 1, height: 2, width: 2, }; test.assert_output(TestTensor::from([[[ [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 2.0, 3.0], [0.0, 2.0, 0.0, 3.0], [4.0, 6.0, 6.0, 9.0], ]]])); } #[test] fn test_conv_transpose2d_dilation_2() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, padding_out_1: 1, padding_out_2: 1, stride_1: 1, stride_2: 1, dilation_1: 2, dilation_2: 2, groups: 1, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [ [126., 116., 136., 124., 146.], [108., 88., 114., 92., 120.], [156., 140., 166., 148., 176.], [126., 100., 132., 104., 138.], [186., 164., 196., 172., 206.], ], [ [217., 189., 227., 197., 237.], [163., 125., 169., 129., 175.], [247., 213., 257., 221., 267.], [181., 137., 187., 141., 193.], [277., 237., 287., 245., 297.], ], ]])); } #[test] fn test_conv_transpose2d_stride2_out_padding() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, padding_out_1: 1, padding_out_2: 1, stride_1: 2, stride_2: 2, dilation_1: 1, dilation_2: 1, groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [352., 728., 378., 780., 404., 832., 430., 452.], [784., 1616., 836., 1720., 888., 1824., 940., 992.], [456., 936., 482., 988., 508., 1040., 534., 564.], [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.], [560., 1144., 586., 1196., 612., 1248., 638., 676.], [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.], [664., 1352., 690., 1404., 716., 1456., 742., 788.], [784., 1598., 816., 1662., 848., 1726., 880., 926.], ], [ [497., 1035., 541., 1123., 585., 1211., 629., 651.], [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.], [673., 1387., 717., 1475., 761., 1563., 805., 835.], [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.], [849., 1739., 893., 1827., 937., 1915., 981., 1019.], [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.], [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.], [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.], ], ]])); } #[test] fn test_conv_transpose2d_groups_2() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, padding_1: 1, padding_2: 1, padding_out_1: 0, padding_out_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [[5., 11.], [23., 29.]], [[236., 258.], [302., 324.]], ]])); } #[test] fn test_conv_transpose2d_groups_different_channels() { let test = ConvTranspose2dTestCase { batch_size: 1, channels_in: 2, channels_out: 6, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, padding_out_1: 0, padding_out_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, groups: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [ [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00], [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01], [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01], [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01], ], [ [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01], [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01], [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01], [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01], ], [ [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01], [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01], [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01], [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01], ], [ [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02], [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02], [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02], [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02], ], [ [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02], [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02], [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02], [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02], ], [ [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02], [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02], [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02], [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02], ], ]])); } struct ConvTranspose2dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, padding_out_1: usize, padding_out_2: usize, stride_1: usize, stride_2: usize, dilation_1: usize, dilation_2: usize, groups: usize, height: usize, width: usize, } impl ConvTranspose2dTestCase { fn assert_output(self, y: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let shape_weights = Shape::new([ self.channels_in, self.channels_out / self.groups, self.kernel_size_1, self.kernel_size_2, ]); let device = Default::default(); let weights = TestTensor::from( TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device) .reshape::<4, _>(shape_weights) .into_data(), ); let bias = TestTensor::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x) .into_data(), ); let output = conv_transpose2d( x, weights, Some(bias), ConvTransposeOptions::new( [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.padding_out_1, self.padding_out_2], [self.dilation_1, self.dilation_2], self.groups, ), ); y.into_data() .assert_approx_eq::(&output.into_data(), Tolerance::rel_abs(1e-1, 0.01)); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/conv_transpose3d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::conv_transpose3d; use burn_tensor::ops::ConvTransposeOptions; #[test] fn test_conv_transpose3d_simple_1() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 1, channels_out: 1, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, padding_out_1: 0, padding_out_2: 0, padding_out_3: 0, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[[ [[96., 124.], [180., 208.]], [[348., 376.], [432., 460.]], ]]])); } #[test] fn test_conv_transpose3d_simple_2() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 3, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, padding_out_1: 0, padding_out_2: 0, padding_out_3: 0, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 4, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [ [238452., 360588., 363756., 244488.], [367929., 556353., 561186., 377163.], [380745., 575685., 580518., 390123.], [261192., 394896., 398172., 267564.], ], [ [394083., 595827., 600822., 403749.], [607635., 918648., 926262., 622404.], [627831., 949104., 956718., 642816.], [430353., 650529., 655686., 440523.], ], [ [447075., 675747., 680742., 457317.], [688419., 1040472., 1048086., 704052.], [708615., 1070928., 1078542., 724464.], [485073., 733041., 738198., 495819.], ], [ [328656., 496632., 500124., 335892.], [505611., 763983., 769302., 516645.], [519723., 785259., 790578., 530901.], [355428., 536988., 540588., 363000.], ], ], [ [ [286729., 433489., 437629., 294061.], [442288., 668620., 674911., 453466.], [458992., 693784., 700075., 470314.], [314653., 475573., 479821., 322321.], ], [ [474274., 716842., 723295., 485884.], [730837., 1104544., 1114345., 748522.], [756865., 1143748., 1153549., 774766.], [518320., 783208., 789823., 530434.], ], [ [542818., 820090., 826543., 555004.], [834949., 1261360., 1271161., 853498.], [860977., 1300564., 1310365., 879742.], [588592., 889048., 895663., 601282.], ], [ [397669., 600637., 605101., 406201.], [611074., 922906., 929683., 624052.], [629074., 950014., 956791., 642196.], [429625., 648769., 653341., 438493.], ], ], [ [ [335006., 506390., 511502., 343634.], [516647., 780887., 788636., 529769.], [537239., 811883., 819632., 550505.], [368114., 556250., 561470., 377078.], ], [ [554465., 837857., 845768., 568019.], [854039., 1290440., 1302428., 874640.], [885899., 1338392., 1350380., 906716.], [606287., 915887., 923960., 620345.], ], [ [638561., 964433., 972344., 652691.], [981479., 1482248., 1494236., 1002944.], [1013339., 1530200., 1542188., 1035020.], [692111., 1045055., 1053128., 706745.], ], [ [466682., 704642., 710078., 476510.], [716537., 1081829., 1090064., 731459.], [738425., 1114769., 1123004., 753491.], [503822., 760550., 766094., 513986.], ], ], ]])); } #[test] fn test_conv_transpose3d_stride_2() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 1, channels_out: 1, kernel_size_1: 2, kernel_size_2: 2, kernel_size_3: 2, padding_1: 0, padding_2: 0, padding_3: 0, padding_out_1: 0, padding_out_2: 0, padding_out_3: 0, stride_1: 2, stride_2: 2, stride_3: 2, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[[ [ [0., 0., 0., 1.], [0., 0., 2., 3.], [0., 2., 0., 3.], [4., 6., 6., 9.], ], [ [0., 0., 4., 5.], [0., 0., 6., 7.], [8., 10., 12., 15.], [12., 14., 18., 21.], ], [ [0., 4., 0., 5.], [8., 12., 10., 15.], [0., 6., 0., 7.], [12., 18., 14., 21.], ], [ [16., 20., 20., 25.], [24., 28., 30., 35.], [24., 30., 28., 35.], [36., 42., 42., 49.], ], ]]])); } #[test] fn test_conv_transpose3d_dilation_2() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, padding_out_1: 1, padding_out_2: 1, padding_out_3: 1, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 2, dilation_2: 2, dilation_3: 2, groups: 1, depth: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [ [ [810., 776., 832., 796., 854.], [756., 712., 774., 728., 792.], [876., 836., 898., 856., 920.], [810., 760., 828., 776., 846.], [942., 896., 964., 916., 986.], ], [ [720., 660., 734., 672., 748.], [606., 536., 616., 544., 626.], [762., 696., 776., 708., 790.], [636., 560., 646., 568., 656.], [804., 732., 818., 744., 832.], ], [ [1008., 956., 1030., 976., 1052.], [918., 856., 936., 872., 954.], [1074., 1016., 1096., 1036., 1118.], [972., 904., 990., 920., 1008.], [1140., 1076., 1162., 1096., 1184.], ], [ [846., 768., 860., 780., 874.], [696., 608., 706., 616., 716.], [888., 804., 902., 816., 916.], [726., 632., 736., 640., 746.], [930., 840., 944., 852., 958.], ], [ [1206., 1136., 1228., 1156., 1250.], [1080., 1000., 1098., 1016., 1116.], [1272., 1196., 1294., 1216., 1316.], [1134., 1048., 1152., 1064., 1170.], [1338., 1256., 1360., 1276., 1382.], ], ], [ [ [1405., 1317., 1427., 1337., 1449.], [1243., 1145., 1261., 1161., 1279.], [1471., 1377., 1493., 1397., 1515.], [1297., 1193., 1315., 1209., 1333.], [1537., 1437., 1559., 1457., 1581.], ], [ [1099., 985., 1113., 997., 1127.], [877., 753., 887., 761., 897.], [1141., 1021., 1155., 1033., 1169.], [907., 777., 917., 785., 927.], [1183., 1057., 1197., 1069., 1211.], ], [ [1603., 1497., 1625., 1517., 1647.], [1405., 1289., 1423., 1305., 1441.], [1669., 1557., 1691., 1577., 1713.], [1459., 1337., 1477., 1353., 1495.], [1735., 1617., 1757., 1637., 1779.], ], [ [1225., 1093., 1239., 1105., 1253.], [967., 825., 977., 833., 987.], [1267., 1129., 1281., 1141., 1295.], [997., 849., 1007., 857., 1017.], [1309., 1165., 1323., 1177., 1337.], ], [ [1801., 1677., 1823., 1697., 1845.], [1567., 1433., 1585., 1449., 1603.], [1867., 1737., 1889., 1757., 1911.], [1621., 1481., 1639., 1497., 1657.], [1933., 1797., 1955., 1817., 1977.], ], ], ]])); } #[test] fn test_conv_transpose3d_stride2_out_padding() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, padding_out_1: 1, padding_out_2: 1, padding_out_3: 1, stride_1: 2, stride_2: 2, stride_3: 2, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 1, depth: 2, height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [ [ [2144., 4366., 2224., 4526., 2304., 4686., 2384., 2422.], [4584., 9324., 4744., 9644., 4904., 9964., 5064., 5148.], [2464., 5006., 2544., 5166., 2624., 5326., 2704., 2750.], [5224., 10604., 5384., 10924., 5544., 11244., 5704., 5804.], [2784., 5646., 2864., 5806., 2944., 5966., 3024., 3078.], [5864., 11884., 6024., 12204., 6184., 12524., 6344., 6460.], [3104., 6286., 3184., 6446., 3264., 6606., 3344., 3406.], [3272., 6628., 3358., 6800., 3444., 6972., 3530., 3592.], ], [ [5280., 10716., 5440., 11036., 5600., 11356., 5760., 5868.], [ 11152., 22616., 11472., 23256., 11792., 23896., 12112., 12344., ], [5920., 11996., 6080., 12316., 6240., 12636., 6400., 6524.], [ 12432., 25176., 12752., 25816., 13072., 26456., 13392., 13656., ], [6560., 13276., 6720., 13596., 6880., 13916., 7040., 7180.], [ 13712., 27736., 14032., 28376., 14352., 29016., 14672., 14968., ], [7200., 14556., 7360., 14876., 7520., 15196., 7680., 7836.], [7632., 15432., 7804., 15776., 7976., 16120., 8148., 8304.], ], [ [3424., 6926., 3504., 7086., 3584., 7246., 3664., 3734.], [7144., 14444., 7304., 14764., 7464., 15084., 7624., 7772.], [3744., 7566., 3824., 7726., 3904., 7886., 3984., 4062.], [7784., 15724., 7944., 16044., 8104., 16364., 8264., 8428.], [4064., 8206., 4144., 8366., 4224., 8526., 4304., 4390.], [8424., 17004., 8584., 17324., 8744., 17644., 8904., 9084.], [4384., 8846., 4464., 9006., 4544., 9166., 4624., 4718.], [4648., 9380., 4734., 9552., 4820., 9724., 4906., 5000.], ], [ [4000., 8096., 4098., 8292., 4196., 8488., 4294., 4364.], [8368., 16928., 8564., 17320., 8760., 17712., 8956., 9104.], [4392., 8880., 4490., 9076., 4588., 9272., 4686., 4764.], [9152., 18496., 9348., 18888., 9544., 19280., 9740., 9904.], [4784., 9664., 4882., 9860., 4980., 10056., 5078., 5164.], [ 9936., 20064., 10132., 20456., 10328., 20848., 10524., 10704., ], [5176., 10448., 5274., 10644., 5372., 10840., 5470., 5564.], [5440., 10982., 5544., 11190., 5648., 11398., 5752., 5846.], ], ], [ [ [3009., 6149., 3143., 6417., 3277., 6685., 3411., 3449.], [6529., 13321., 6797., 13857., 7065., 14393., 7333., 7417.], [3545., 7221., 3679., 7489., 3813., 7757., 3947., 3993.], [7601., 15465., 7869., 16001., 8137., 16537., 8405., 8505.], [4081., 8293., 4215., 8561., 4349., 8829., 4483., 4537.], [8673., 17609., 8941., 18145., 9209., 18681., 9477., 9593.], [4617., 9365., 4751., 9633., 4885., 9901., 5019., 5081.], [4785., 9707., 4925., 9987., 5065., 10267., 5205., 5267.], ], [ [7873., 16009., 8141., 16545., 8409., 17081., 8677., 8785.], [ 16769., 34065., 17305., 35137., 17841., 36209., 18377., 18609., ], [8945., 18153., 9213., 18689., 9481., 19225., 9749., 9873.], [ 18913., 38353., 19449., 39425., 19985., 40497., 20521., 20785., ], [ 10017., 20297., 10285., 20833., 10553., 21369., 10821., 10961., ], [ 21057., 42641., 21593., 43713., 22129., 44785., 22665., 22961., ], [ 11089., 22441., 11357., 22977., 11625., 23513., 11893., 12049., ], [ 11521., 23317., 11801., 23877., 12081., 24437., 12361., 12517., ], ], [ [5153., 10437., 5287., 10705., 5421., 10973., 5555., 5625.], [ 10817., 21897., 11085., 22433., 11353., 22969., 11621., 11769., ], [5689., 11509., 5823., 11777., 5957., 12045., 6091., 6169.], [ 11889., 24041., 12157., 24577., 12425., 25113., 12693., 12857., ], [6225., 12581., 6359., 12849., 6493., 13117., 6627., 6713.], [ 12961., 26185., 13229., 26721., 13497., 27257., 13765., 13945., ], [6761., 13653., 6895., 13921., 7029., 14189., 7163., 7257.], [7025., 14187., 7165., 14467., 7305., 14747., 7445., 7539.], ], [ [5729., 11607., 5881., 11911., 6033., 12215., 6185., 6255.], [ 12041., 24381., 12345., 24989., 12649., 25597., 12953., 13101., ], [6337., 12823., 6489., 13127., 6641., 13431., 6793., 6871.], [ 13257., 26813., 13561., 27421., 13865., 28029., 14169., 14333., ], [6945., 14039., 7097., 14343., 7249., 14647., 7401., 7487.], [ 14473., 29245., 14777., 29853., 15081., 30461., 15385., 15565., ], [7553., 15255., 7705., 15559., 7857., 15863., 8009., 8103.], [7817., 15789., 7975., 16105., 8133., 16421., 8291., 8385.], ], ], ]])); } #[test] fn test_conv_transpose3d_groups_2() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 2, channels_out: 2, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 1, padding_2: 1, padding_3: 1, padding_out_1: 0, padding_out_2: 0, padding_out_3: 0, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 2, depth: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [[[96., 124.], [180., 208.]], [[348., 376.], [432., 460.]]], [ [[2997., 3089.], [3273., 3365.]], [[3825., 3917.], [4101., 4193.]], ], ]])); } #[test] fn test_conv_transpose3d_groups_different_channels() { let test = ConvTranspose3dTestCase { batch_size: 1, channels_in: 2, channels_out: 6, kernel_size_1: 3, kernel_size_2: 3, kernel_size_3: 3, padding_1: 0, padding_2: 0, padding_3: 0, padding_out_1: 0, padding_out_2: 0, padding_out_3: 0, stride_1: 1, stride_2: 1, stride_3: 1, dilation_1: 1, dilation_2: 1, dilation_3: 1, groups: 2, depth: 2, height: 2, width: 2, }; test.assert_output(TestTensor::from([[ [ [ [0., 0., 1., 2.], [0., 5., 11., 11.], [6., 23., 29., 23.], [12., 32., 37., 24.], ], [ [0., 13., 23., 21.], [30., 96., 124., 86.], [66., 180., 208., 134.], [66., 161., 179., 107.], ], [ [36., 103., 113., 75.], [138., 348., 376., 230.], [174., 432., 460., 278.], [138., 323., 341., 197.], ], [ [72., 166., 175., 100.], [192., 433., 455., 255.], [222., 499., 521., 291.], [144., 318., 331., 182.], ], ], [ [ [1., 28., 29., 30.], [55., 168., 174., 120.], [61., 186., 192., 132.], [67., 168., 173., 106.], ], [ [109., 284., 294., 184.], [355., 853., 881., 519.], [391., 937., 965., 567.], [283., 648., 666., 378.], ], [ [145., 374., 384., 238.], [463., 1105., 1133., 663.], [499., 1189., 1217., 711.], [355., 810., 828., 468.], ], [ [181., 410., 419., 236.], [463., 1028., 1050., 580.], [493., 1094., 1116., 616.], [307., 670., 683., 372.], ], ], [ [ [2., 56., 57., 58.], [110., 331., 337., 229.], [116., 349., 355., 241.], [122., 304., 309., 188.], ], [ [218., 555., 565., 347.], [680., 1610., 1638., 952.], [716., 1694., 1722., 1000.], [500., 1135., 1153., 649.], ], [ [254., 645., 655., 401.], [788., 1862., 1890., 1096.], [824., 1946., 1974., 1144.], [572., 1297., 1315., 739.], ], [ [290., 654., 663., 372.], [734., 1623., 1645., 905.], [764., 1689., 1711., 941.], [470., 1022., 1035., 562.], ], ], [ [ [651., 1388., 1405., 750.], [1485., 3150., 3188., 1690.], [1539., 3264., 3302., 1750.], [873., 1840., 1861., 982.], ], [ [1695., 3578., 3620., 1910.], [3789., 7967., 8059., 4233.], [3921., 8243., 8335., 4377.], [2181., 4566., 4616., 2416.], ], [ [1875., 3956., 3998., 2108.], [4185., 8795., 8887., 4665.], [4317., 9071., 9163., 4809.], [2397., 5016., 5066., 2650.], ], [ [1191., 2490., 2515., 1316.], [2613., 5450., 5504., 2870.], [2691., 5612., 5666., 2954.], [1473., 3062., 3091., 1608.], ], ], [ [ [868., 1848., 1865., 994.], [1972., 4177., 4215., 2231.], [2026., 4291., 4329., 2291.], [1144., 2408., 2429., 1280.], ], [ [2236., 4713., 4755., 2505.], [4978., 10452., 10544., 5530.], [5110., 10728., 10820., 5674.], [2830., 5917., 5967., 3119.], ], [ [2416., 5091., 5133., 2703.], [5374., 11280., 11372., 5962.], [5506., 11556., 11648., 6106.], [3046., 6367., 6417., 3353.], ], [ [1516., 3166., 3191., 1668.], [3316., 6909., 6963., 3627.], [3394., 7071., 7125., 3711.], [1852., 3846., 3875., 2014.], ], ], [ [ [1085., 2308., 2325., 1238.], [2459., 5204., 5242., 2772.], [2513., 5318., 5356., 2832.], [1415., 2976., 2997., 1578.], ], [ [2777., 5848., 5890., 3100.], [6167., 12937., 13029., 6827.], [6299., 13213., 13305., 6971.], [3479., 7268., 7318., 3822.], ], [ [2957., 6226., 6268., 3298.], [6563., 13765., 13857., 7259.], [6695., 14041., 14133., 7403.], [3695., 7718., 7768., 4056.], ], [ [1841., 3842., 3867., 2020.], [4019., 8368., 8422., 4384.], [4097., 8530., 8584., 4468.], [2231., 4630., 4659., 2420.], ], ], ]])); } struct ConvTranspose3dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, kernel_size_3: usize, padding_1: usize, padding_2: usize, padding_3: usize, padding_out_1: usize, padding_out_2: usize, padding_out_3: usize, stride_1: usize, stride_2: usize, stride_3: usize, dilation_1: usize, dilation_2: usize, dilation_3: usize, groups: usize, depth: usize, height: usize, width: usize, } impl ConvTranspose3dTestCase { fn assert_output(self, y: TestTensor<5>) { let shape_x = Shape::new([ self.batch_size, self.channels_in, self.depth, self.height, self.width, ]); let shape_weights = Shape::new([ self.channels_in, self.channels_out / self.groups, self.kernel_size_1, self.kernel_size_2, self.kernel_size_3, ]); let device = Default::default(); let weights = TestTensor::from( TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device) .reshape::<5, _>(shape_weights) .into_data(), ); let bias = TestTensor::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<5, _>(shape_x) .into_data(), ); let output = conv_transpose3d( x, weights, Some(bias), ConvTransposeOptions::new( [self.stride_1, self.stride_2, self.stride_3], [self.padding_1, self.padding_2, self.padding_3], [self.padding_out_1, self.padding_out_2, self.padding_out_3], [self.dilation_1, self.dilation_2, self.dilation_3], self.groups, ), ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/deform_conv2d.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::module::deform_conv2d; use burn_tensor::ops::DeformConvOptions; use burn_tensor::{Shape, Tensor}; #[test] fn test_deform_conv2d_simple() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 3, channels_out: 5, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [[0.9074, 0.6387], [0.5160, 0.4196]], [[2.4259, 1.8008], [1.5449, 1.3112]], [[3.9444, 2.9629], [2.5738, 2.2027]], [[5.4629, 4.1250], [3.6027, 3.0943]], [[6.9814, 5.2871], [4.6316, 3.9859]], ]])); } #[test] fn test_deform_conv2d_batched() { let test = DeformConv2dTestCase { batch_size: 2, channels_in: 3, channels_out: 5, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([ [ [[0.215466, 0.192846], [0.193407, 0.175496]], [[0.725073, 0.675926], [0.687746, 0.648506]], [[1.234679, 1.159006], [1.182085, 1.121516]], [[1.744286, 1.642086], [1.676423, 1.594526]], [[2.253892, 2.125167], [2.170762, 2.067536]], ], [ [[1.652976, 1.136937], [0.984030, 0.718403]], [[4.836801, 3.472453], [3.177263, 2.418021]], [[8.020626, 5.807969], [5.370497, 4.117639]], [[11.204453, 8.143486], [7.563731, 5.817256]], [[14.388277, 10.479003], [9.756965, 7.516875]], ], ])) } #[test] fn test_deform_conv2d_weight_groups() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 3, channels_out: 6, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 3, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [[0.101823, 0.065756], [0.046691, 0.036233]], [[0.412523, 0.336674], [0.306863, 0.282386]], [[1.307585, 1.024152], [0.902454, 0.800008]], [[1.840507, 1.458072], [1.299371, 1.158781]], [[3.402235, 2.634555], [2.305198, 2.014265]], [[4.157379, 3.231476], [2.838861, 2.485659]], ]])) } #[test] fn test_deform_conv2d_offset_groups() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 3, channels_out: 6, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 3, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [[1.0794, 0.7676], [0.7209, 0.5337]], [[2.7059, 2.0216], [1.9740, 1.5419]], [[4.3325, 3.2755], [3.2271, 2.5501]], [[5.9590, 4.5295], [4.4802, 3.5582]], [[7.5855, 5.7835], [5.7333, 4.5664]], [[9.2120, 7.0375], [6.9864, 5.5746]], ]])) } #[test] fn test_deform_conv2d_different_kernel_size() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 4, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [[1.0669], [0.6329]], [[2.9741], [2.0383]], [[4.8812], [3.4437]], ]])) } #[test] fn test_deform_conv2d_different_padding_size() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 2, padding_2: 3, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [ [ 0.199779, 0.376176, 0.528501, 0.605256, 0.384365, 0.198675, 0.048145, 0.000000, ], [ 0.287923, 0.551719, 0.777562, 0.890479, 0.580469, 0.304325, 0.079554, 0.000000, ], [ 0.372947, 0.721405, 1.013668, 1.151988, 0.756444, 0.393098, 0.101582, 0.000000, ], [ 0.132138, 0.324872, 0.495372, 0.584617, 0.453122, 0.250084, 0.075703, 0.000000, ], [ 0.059332, 0.160658, 0.244789, 0.297057, 0.239464, 0.132701, 0.047114, 0.000000, ], [ 0.014338, 0.051338, 0.078303, 0.094190, 0.081278, 0.041954, 0.014506, 0.000000, ], ], [ [ 0.766652, 1.164805, 1.521938, 1.711110, 1.230500, 0.807579, 0.450423, 0.333333, ], [ 0.981162, 1.601005, 2.152534, 2.440920, 1.745547, 1.091843, 0.536749, 0.333333, ], [ 1.196386, 2.044845, 2.785330, 3.152243, 2.242613, 1.351308, 0.604905, 0.333333, ], [ 0.669465, 1.178133, 1.644096, 1.902188, 1.573183, 1.033924, 0.553577, 0.333333, ], [ 0.495048, 0.786124, 1.039796, 1.204721, 1.052342, 0.743887, 0.483380, 0.333333, ], [ 0.378767, 0.498209, 0.592867, 0.654230, 0.615487, 0.488202, 0.390890, 0.333333, ], ], [ [ 1.333524, 1.953435, 2.515375, 2.816964, 2.076636, 1.416483, 0.852701, 0.666667, ], [ 1.674402, 2.650291, 3.527507, 3.991360, 2.910625, 1.879361, 0.993943, 0.666667, ], [ 2.019825, 3.368286, 4.556992, 5.152499, 3.728782, 2.309520, 1.108229, 0.666667, ], [ 1.206791, 2.031395, 2.792820, 3.219759, 2.693245, 1.817763, 1.031452, 0.666667, ], [ 0.930765, 1.411590, 1.834802, 2.112385, 1.865221, 1.355072, 0.919646, 0.666667, ], [ 0.743195, 0.945081, 1.107431, 1.214270, 1.149695, 0.934451, 0.767274, 0.666667, ], ], ]])) } #[test] fn test_deform_conv2d_different_stride() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 2, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 4, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [[1.0647], [0.5783]], [[2.9289], [1.8829]], [[4.7931], [3.1875]], ]])) } #[test] fn test_deform_conv2d_different_dilation() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 2, weight_groups: 1, offset_groups: 1, height: 5, width: 5, }; test.assert_output(TestTensor::<4>::from([[ [[0.6162], [0.7611], [0.4666]], [[1.8578], [2.2684], [1.6208]], [[3.0994], [3.7757], [2.7749]], ]])) } #[test] fn test_deform_conv2d_different_width() { let test = DeformConv2dTestCase { batch_size: 1, channels_in: 2, channels_out: 3, kernel_size_1: 3, kernel_size_2: 3, padding_1: 0, padding_2: 0, stride_1: 1, stride_2: 1, dilation_1: 1, dilation_2: 1, weight_groups: 1, offset_groups: 1, height: 6, width: 4, }; test.assert_output(TestTensor::<4>::from([[ [ [0.8909, 0.6016], [1.0697, 0.7186], [1.2618, 0.8433], [0.6424, 0.5032], ], [ [2.4670, 1.8168], [2.9529, 2.1497], [3.4805, 2.5090], [2.0925, 1.7411], ], [ [4.0432, 3.0321], [4.8362, 3.5809], [5.6992, 4.1746], [3.5425, 2.9790], ], ]])) } struct DeformConv2dTestCase { batch_size: usize, channels_in: usize, channels_out: usize, kernel_size_1: usize, kernel_size_2: usize, padding_1: usize, padding_2: usize, stride_1: usize, stride_2: usize, dilation_1: usize, dilation_2: usize, weight_groups: usize, offset_groups: usize, height: usize, width: usize, } impl DeformConv2dTestCase { fn assert_output(self, y: Tensor) { let out_height = (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1) / self.stride_1 + 1; let out_width = (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1) / self.stride_2 + 1; let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let shape_weight = Shape::new([ self.channels_out, self.channels_in / self.weight_groups, self.kernel_size_1, self.kernel_size_2, ]); let shape_offset = Shape::new([ self.batch_size, self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2, out_height, out_width, ]); let shape_mask = Shape::new([ self.batch_size, self.kernel_size_1 * self.kernel_size_2 * self.offset_groups, out_height, out_width, ]); let device = Default::default(); let weight = TestTensor::<4>::from( TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) .reshape::<4, _>(shape_weight.clone()) .into_data(), ) .div_scalar(shape_weight.num_elements() as f32); let bias = TestTensor::<1>::from( TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), ) .div_scalar(self.channels_out as f32); let x = TestTensor::<4>::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) .reshape::<4, _>(shape_x.clone()) .into_data(), ) .div_scalar(shape_x.num_elements() as f32); let offset = TestTensor::<4>::from( TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device) .reshape::<4, _>(shape_offset.clone()) .into_data(), ) .div_scalar(shape_offset.num_elements() as f32); let mask = TestTensor::<4>::from( TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device) .reshape::<4, _>(shape_mask.clone()) .into_data(), ) .div_scalar(shape_mask.num_elements() as f32); let output = deform_conv2d( x, offset, weight, Some(mask), Some(bias), DeformConvOptions::new( [self.stride_1, self.stride_2], [self.padding_1, self.padding_2], [self.dilation_1, self.dilation_2], self.weight_groups, self.offset_groups, ), ); let tolerance = Tolerance::permissive(); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/forward.rs ================================================ use super::*; use burn_tensor::{TensorData, module::embedding}; #[test] fn test_embedding_forward() { let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TensorData::from([[0, 1], [1, 1]]); let weights = TestTensor::<2>::from(weights); let indices = TestTensorInt::<2>::from(indices); let output = embedding(weights, indices); let expected = TensorData::from([ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/lanczos3_interpolate.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::interpolate; use burn_tensor::ops::{InterpolateMode, InterpolateOptions}; #[test] fn test_upsample_interpolation() { let test = InterpolateTestCase { batch_size: 2, channels: 1, height: 7, width: 5, height_out: 8, width_out: 7, }; test.assert_output(TestTensor::from([ [[ [-0.0000, 0.5685, 1.3918, 2.0000, 2.6082, 3.4315, 4.0000], [4.0822, 4.6507, 5.4740, 6.0822, 6.6904, 7.5137, 8.0822], [8.7971, 9.3656, 10.1889, 10.7971, 11.4053, 12.2286, 12.7971], [ 12.8964, 13.4649, 14.2882, 14.8964, 15.5046, 16.3279, 16.8964, ], [ 17.1036, 17.6721, 18.4954, 19.1036, 19.7118, 20.5351, 21.1036, ], [ 21.2029, 21.7715, 22.5947, 23.2029, 23.8112, 24.6344, 25.2029, ], [ 25.9178, 26.4863, 27.3096, 27.9178, 28.5260, 29.3493, 29.9178, ], [ 30.0000, 30.5685, 31.3918, 32.0000, 32.6082, 33.4315, 34.0000, ], ]], [[ [ 35.0000, 35.5685, 36.3918, 37.0000, 37.6082, 38.4315, 39.0000, ], [ 39.0822, 39.6507, 40.4740, 41.0822, 41.6904, 42.5137, 43.0822, ], [ 43.7971, 44.3656, 45.1888, 45.7971, 46.4053, 47.2286, 47.7971, ], [ 47.8964, 48.4649, 49.2882, 49.8964, 50.5046, 51.3279, 51.8964, ], [ 52.1036, 52.6721, 53.4954, 54.1036, 54.7118, 55.5351, 56.1036, ], [ 56.2029, 56.7715, 57.5947, 58.2029, 58.8112, 59.6344, 60.2029, ], [ 60.9178, 61.4863, 62.3096, 62.9178, 63.5260, 64.3493, 64.9178, ], [ 65.0000, 65.5685, 66.3918, 67.0000, 67.6082, 68.4315, 69.0000, ], ]], ])); } #[test] fn test_downsample_interpolation() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 45, width: 14, height_out: 4, width_out: 6, }; test.assert_output(TestTensor::from([[[ [-0.0000, 2.6107, 5.1803, 7.8197, 10.3893, 13.0000], [205.5606, 208.1713, 210.7408, 213.3802, 215.9498, 218.5606], [410.4395, 413.0502, 415.6198, 418.2592, 420.8287, 423.4395], [616.0000, 618.6107, 621.1803, 623.8197, 626.3893, 629.0000], ]]])); } #[test] fn test_upsample_2x() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 4, width: 4, height_out: 8, width_out: 8, }; test.assert_output(TestTensor::from([[[ [ -0.0000, 0.2972, 0.8164, 1.3131, 1.6869, 2.1836, 2.7028, 3.0000, ], [ 1.1889, 1.4861, 2.0053, 2.5020, 2.8758, 3.3725, 3.8917, 4.1889, ], [ 3.2658, 3.5630, 4.0822, 4.5789, 4.9527, 5.4493, 5.9685, 6.2658, ], [ 5.2524, 5.5496, 6.0689, 6.5655, 6.9393, 7.4360, 7.9552, 8.2524, ], [ 6.7476, 7.0448, 7.5640, 8.0607, 8.4345, 8.9311, 9.4504, 9.7476, ], [ 8.7342, 9.0315, 9.5507, 10.0473, 10.4211, 10.9178, 11.4370, 11.7342, ], [ 10.8111, 11.1083, 11.6275, 12.1242, 12.4980, 12.9947, 13.5139, 13.8111, ], [ 12.0000, 12.2972, 12.8164, 13.3131, 13.6869, 14.1836, 14.7028, 15.0000, ], ]]])); } #[test] fn test_upsample_half_pixel() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 4, width: 4, height_out: 8, width_out: 8, }; test.assert_output_with_align_corners( TestTensor::from([[[ [ -0.4626, -0.2276, 0.3055, 0.9087, 1.3512, 1.9543, 2.4875, 2.7225, ], [ 0.4773, 0.7123, 1.2454, 1.8486, 2.2911, 2.8942, 3.4274, 3.6623, ], [ 2.6099, 2.8449, 3.3780, 3.9812, 4.4237, 5.0268, 5.5600, 5.7949, ], [ 5.0224, 5.2574, 5.7906, 6.3937, 6.8362, 7.4394, 7.9725, 8.2075, ], [ 6.7925, 7.0275, 7.5606, 8.1638, 8.6063, 9.2094, 9.7426, 9.9776, ], [ 9.2051, 9.4400, 9.9732, 10.5763, 11.0188, 11.6220, 12.1551, 12.3901, ], [ 11.3377, 11.5726, 12.1058, 12.7089, 13.1514, 13.7546, 14.2877, 14.5227, ], [ 12.2775, 12.5125, 13.0457, 13.6488, 14.0913, 14.6945, 15.2276, 15.4626, ], ]]]), false, ); } #[test] fn test_1d_lanczos3() { let device = Default::default(); let input = TestTensor::<3>::from_floats( [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], &device, ); let input = input.unsqueeze_dim(2); let output = interpolate( input, [1, 9], InterpolateOptions::new(InterpolateMode::Lanczos3), ); assert_eq!(output.dims(), [1, 1, 1, 9]); assert!( !output .clone() .to_data() .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), "interpolate output contains NaN" ); TestTensor::<4>::from([[[[ 1.5410, 0.7266, -1.1387, -2.2672, -0.7894, 0.6408, -0.4967, -1.4650, -1.3986, ]]]]) .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::permissive()); } struct InterpolateTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl InterpolateTestCase { fn assert_output(self, y: TestTensor<4>) { self.assert_output_with_align_corners(y, true); } fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data(), ); let output = interpolate( x, [self.height_out, self.width_out], InterpolateOptions::new(InterpolateMode::Lanczos3).with_align_corners(align_corners), ); let tolerance = Tolerance::permissive(); y.to_data() .assert_approx_eq::(&output.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/linear.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use burn_tensor::module::linear; #[test] fn test_linear_1d() { let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let x = TestTensor::<1>::from([1.0, 2.0]); let output = linear(x, weight, None); let expected = TensorData::from([7.0, 10.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::relative(1e-5)); } #[test] fn test_linear_1d_one_element_output() { let weight = TestTensor::<2>::from([[3.0], [4.0]]); let x = TestTensor::<1>::from([1.0, 2.0]); let output = linear(x, weight, None); let expected = TensorData::from([11.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::relative(1e-5)); } #[test] fn test_linear_forward_no_bias() { let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]); let output = linear(x, weight, None); let expected = TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::relative(1e-5)); } #[test] fn test_linear_forward_with_bias() { let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let bias = Some(TestTensor::<1>::from([1.0, -1.0])); let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]); let output = linear(x, weight, bias); let expected = TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::relative(1e-5)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/maxpool1d.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use burn_tensor::module::{max_pool1d, max_pool1d_with_indices}; #[test] fn test_max_pool1d_simple() { let kernel_size = 3; let padding = 0; let stride = 1; let dilation = 1; let x = TestTensor::from([[ [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], ]]); let y = TestTensor::<3>::from([[ [0.9861, 0.5474, 0.4477, 0.8221], [0.949, 0.949, 0.949, 0.789], ]]); let output = max_pool1d(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool1d_different_padding_stride_kernel() { let kernel_size = 3; let padding = 1; let stride = 2; let dilation = 1; let x = TestTensor::from([[[0.6309, 0.6112, 0.6998, 0.4708]]]); let y = TestTensor::<3>::from([[[0.6309, 0.6998]]]); let output = max_pool1d(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool1d_with_neg() { let kernel_size = 3; let padding = 1; let stride = 1; let dilation = 1; let x = TestTensor::from([[[-0.6309, -0.6112, -0.6998, -0.4708]]]); let y = TestTensor::<3>::from([[[-0.6112, -0.6112, -0.4708, -0.4708]]]); let output = max_pool1d(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool1d_with_dilation() { let kernel_size = 2; let padding = 1; let stride = 1; let dilation = 2; let x = TestTensor::from([[ [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], ]]); let y = TestTensor::<3>::from([[ [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548], [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537], ]]); let output = max_pool1d(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool1d_with_indices() { let kernel_size = 2; let padding = 0; let stride = 1; let dilation = 1; let x = TestTensor::from([[[0.2479, 0.6386, 0.3166, 0.5742]]]); let indices = TensorData::from([[[1, 1, 3]]]); let y = TestTensor::<3>::from([[[0.6386, 0.6386, 0.5742]]]); let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices.into_data().assert_eq(&indices, false); } #[test] fn test_max_pool1d_complex() { let kernel_size = 4; let padding = 2; let stride = 1; let dilation = 1; let x = TestTensor::from([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]); let indices = TensorData::from([[[0, 2, 3, 3, 3, 3]]]); let y = TestTensor::<3>::from([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]); let (output, output_indices) = max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices.into_data().assert_eq(&indices, false); } #[test] fn test_max_pool1d_ceil_mode() { // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride // Input: 1x1x6, kernel: 3, stride: 2, padding: 0 // Floor mode: output = (6-3)/2+1 = 2 elements // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements let kernel_size = 3; let padding = 0; let stride = 2; let dilation = 1; let x = TestTensor::from([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]]); // With ceil_mode=false (floor): output is 2 elements // Window 0: positions [0:3] -> max(1,2,3) = 3 // Window 1: positions [2:5] -> max(3,4,5) = 5 let y_floor = TestTensor::<3>::from([[[3.0, 5.0]]]); let output_floor = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false); y_floor .to_data() .assert_approx_eq::(&output_floor.into_data(), Tolerance::default()); // With ceil_mode=true: output is 3 elements // Window 0: positions [0:3] -> max(1,2,3) = 3 // Window 1: positions [2:5] -> max(3,4,5) = 5 // Window 2: positions [4:7] -> max(5,6) = 6 (partial window) let y_ceil = TestTensor::<3>::from([[[3.0, 5.0, 6.0]]]); let output_ceil = max_pool1d(x, kernel_size, stride, padding, dilation, true); y_ceil .to_data() .assert_approx_eq::(&output_ceil.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/maxpool2d.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use burn_tensor::module::{max_pool2d, max_pool2d_with_indices}; #[test] fn test_max_pool2d_simple() { let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 1; let padding_2 = 1; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 1; let dilation_2 = 1; let x = TestTensor::from([ [ [ [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221], [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689], [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182], [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392], [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605], [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068], ], [ [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473], [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947], [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149], [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426], [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544], [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572], ], ], [ [ [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910], [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546], [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584], [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441], [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881], [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467], ], [ [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952], [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454], [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628], [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527], [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827], [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421], ], ], ]); let y = TestTensor::<4>::from([ [ [ [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], ], [ [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947], [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149], [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149], [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149], [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025], [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775], ], ], [ [ [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546], [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546], [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037], [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037], ], [ [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378], [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378], [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445], [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128], [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128], [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128], ], ], ]); let output = max_pool2d( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool2d_different_padding_stride_kernel() { let kernel_size_1 = 3; let kernel_size_2 = 1; let padding_1 = 1; let padding_2 = 0; let stride_1 = 1; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; let x = TestTensor::from([[[ [0.6309, 0.6112, 0.6998], [0.4708, 0.9161, 0.5402], [0.4577, 0.7397, 0.9870], [0.6380, 0.4352, 0.5884], [0.6277, 0.5139, 0.4525], [0.9333, 0.9846, 0.5006], ]]]); let y = TestTensor::<4>::from([[[ [0.6309, 0.6998], [0.6309, 0.9870], [0.6380, 0.9870], [0.6380, 0.9870], [0.9333, 0.5884], [0.9333, 0.5006], ]]]); let output = max_pool2d( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool2d_with_neg() { let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 1; let padding_2 = 1; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 1; let dilation_2 = 1; let x = TestTensor::from([[[ [0.6309, 0.6112, 0.6998], [0.4708, 0.9161, 0.5402], [0.4577, 0.7397, 0.9870], [0.6380, 0.4352, 0.5884], [0.6277, 0.5139, 0.4525], [0.9333, 0.9846, 0.5006], ]]]) .neg(); let y = TestTensor::<4>::from([[[ [-0.4708, -0.4708, -0.5402], [-0.4577, -0.4577, -0.5402], [-0.4352, -0.4352, -0.4352], [-0.4352, -0.4352, -0.4352], [-0.4352, -0.4352, -0.4352], [-0.5139, -0.4525, -0.4525], ]]]); let output = max_pool2d( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool2d_with_dilation() { let kernel_size_1 = 2; let kernel_size_2 = 2; let padding_1 = 0; let padding_2 = 0; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 2; let dilation_2 = 2; let x = TestTensor::from([[[ [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221], [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111], [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855], [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855], [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855], ]]]); let y = TestTensor::<4>::from([[[ [0.9861, 0.9861, 0.9540, 0.9490], [0.9861, 0.9861, 0.9540, 0.9490], [0.9540, 0.9540, 0.9540, 0.9490], [0.9540, 0.9540, 0.9540, 0.9432], ]]]); let output = max_pool2d( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } #[test] fn test_max_pool2d_with_indices() { let kernel_size_1 = 2; let kernel_size_2 = 2; let padding_1 = 1; let padding_2 = 1; let stride_1 = 1; let stride_2 = 1; let dilation_1 = 1; let dilation_2 = 1; let x = TestTensor::from([[[ [0.2479, 0.6386, 0.3166, 0.5742], [0.7065, 0.1940, 0.6305, 0.8959], [0.5416, 0.8602, 0.8129, 0.1662], [0.3358, 0.3059, 0.8293, 0.0990], ]]]); let indices = TensorData::from([[[ [0, 1, 1, 3, 3], [4, 4, 1, 7, 7], [4, 9, 9, 7, 7], [8, 9, 9, 14, 11], [12, 12, 14, 14, 15], ]]]); let y = TestTensor::<4>::from([[[ [0.2479, 0.6386, 0.6386, 0.5742, 0.5742], [0.7065, 0.7065, 0.6386, 0.8959, 0.8959], [0.7065, 0.8602, 0.8602, 0.8959, 0.8959], [0.5416, 0.8602, 0.8602, 0.8293, 0.1662], [0.3358, 0.3358, 0.8293, 0.8293, 0.0990], ]]]); let (output, output_indices) = max_pool2d_with_indices( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices.into_data().assert_eq(&indices, false); } #[test] fn test_max_pool2d_complex() { let kernel_size_1 = 4; let kernel_size_2 = 2; let padding_1 = 2; let padding_2 = 1; let stride_1 = 1; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; let x = TestTensor::from([[[ [0.5388, 0.0676, 0.7122, 0.8316, 0.0653], [0.9154, 0.1536, 0.9089, 0.8016, 0.7518], [0.2073, 0.0501, 0.8811, 0.5604, 0.5075], [0.4384, 0.9963, 0.9698, 0.4988, 0.2609], [0.3391, 0.2230, 0.4610, 0.5365, 0.6880], ]]]); let indices = TensorData::from([[[ [5, 7, 3], [5, 7, 3], [5, 16, 3], [5, 16, 8], [15, 16, 24], [15, 16, 24], ]]]); let y = TestTensor::<4>::from([[[ [0.9154, 0.9089, 0.8316], [0.9154, 0.9089, 0.8316], [0.9154, 0.9963, 0.8316], [0.9154, 0.9963, 0.8016], [0.4384, 0.9963, 0.688], [0.4384, 0.9963, 0.688], ]]]); let (output, output_indices) = max_pool2d_with_indices( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices.into_data().assert_eq(&indices, false); } #[test] fn test_max_pool2d_ceil_mode() { // Test ceil_mode=true which produces larger output when input doesn't divide evenly by stride // Using 1x1x6x6 with kernel 3x3, stride 2x2, padding 0: // Floor mode: output = (6+0-1*(3-1)-1)/2+1 = 3/2+1 = 2 x 2 // Ceil mode: output = ceil(3/2)+1 = 2+1 = 3 x 3 let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 0; let padding_2 = 0; let stride_1 = 2; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; // Input (values 1-36 arranged row by row): // col: 0 1 2 3 4 5 // row 0: 1 2 3 4 5 6 // row 1: 7 8 9 10 11 12 // row 2: 13 14 15 16 17 18 // row 3: 19 20 21 22 23 24 // row 4: 25 26 27 28 29 30 // row 5: 31 32 33 34 35 36 let x = TestTensor::from([[[ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0, 17.0, 18.0], [19.0, 20.0, 21.0, 22.0, 23.0, 24.0], [25.0, 26.0, 27.0, 28.0, 29.0, 30.0], [31.0, 32.0, 33.0, 34.0, 35.0, 36.0], ]]]); // With ceil_mode=false (floor): output is 2x2 // (0,0): rows 0-2, cols 0-2 -> max(1,2,3,7,8,9,13,14,15) = 15 // (0,1): rows 0-2, cols 2-4 -> max(3,4,5,9,10,11,15,16,17) = 17 // (1,0): rows 2-4, cols 0-2 -> max(13,14,15,19,20,21,25,26,27) = 27 // (1,1): rows 2-4, cols 2-4 -> max(15,16,17,21,22,23,27,28,29) = 29 let y_floor = TestTensor::<4>::from([[[[15.0, 17.0], [27.0, 29.0]]]]); let output_floor = max_pool2d( x.clone(), [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], false, ); y_floor .to_data() .assert_approx_eq::(&output_floor.into_data(), Tolerance::default()); // With ceil_mode=true: output is 3x3 // Extra windows at edges use only available input values (padded with -inf for max pooling) // (0,0): rows 0-2, cols 0-2 -> max = 15 // (0,1): rows 0-2, cols 2-4 -> max = 17 // (0,2): rows 0-2, cols 4-5 -> max(5,6,11,12,17,18) = 18 // (1,0): rows 2-4, cols 0-2 -> max = 27 // (1,1): rows 2-4, cols 2-4 -> max = 29 // (1,2): rows 2-4, cols 4-5 -> max(17,18,23,24,29,30) = 30 // (2,0): rows 4-5, cols 0-2 -> max(25,26,27,31,32,33) = 33 // (2,1): rows 4-5, cols 2-4 -> max(27,28,29,33,34,35) = 35 // (2,2): rows 4-5, cols 4-5 -> max(29,30,35,36) = 36 let y_ceil = TestTensor::<4>::from([[[[15.0, 17.0, 18.0], [27.0, 29.0, 30.0], [33.0, 35.0, 36.0]]]]); let output_ceil = max_pool2d( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], true, ); y_ceil .to_data() .assert_approx_eq::(&output_ceil.into_data(), Tolerance::default()); } #[test] fn test_max_pool2d_ceil_mode_with_indices() { // Test ceil_mode=true with indices to verify correct index calculation // when pooling windows extend beyond original input bounds let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 0; let padding_2 = 0; let stride_1 = 2; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; // Input 6x6 (indices 0-35 in row-major order): // row 0: 0 1 2 3 4 5 // row 1: 6 7 8 9 10 11 // row 2: 12 13 14 15 16 17 // row 3: 18 19 20 21 22 23 // row 4: 24 25 26 27 28 29 // row 5: 30 31 32 33 34 35 let x = TestTensor::from([[[ [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], [18.0, 19.0, 20.0, 21.0, 22.0, 23.0], [24.0, 25.0, 26.0, 27.0, 28.0, 29.0], [30.0, 31.0, 32.0, 33.0, 34.0, 35.0], ]]]); // With ceil_mode=true: output is 3x3 // (0,0): rows 0-2, cols 0-2 -> max at index 14 // (0,1): rows 0-2, cols 2-4 -> max at index 16 // (0,2): rows 0-2, cols 4-5 -> max at index 17 // (1,0): rows 2-4, cols 0-2 -> max at index 26 // (1,1): rows 2-4, cols 2-4 -> max at index 28 // (1,2): rows 2-4, cols 4-5 -> max at index 29 // (2,0): rows 4-5, cols 0-2 -> max at index 32 // (2,1): rows 4-5, cols 2-4 -> max at index 34 // (2,2): rows 4-5, cols 4-5 -> max at index 35 let expected_values = TestTensor::<4>::from([[[[14.0, 16.0, 17.0], [26.0, 28.0, 29.0], [32.0, 34.0, 35.0]]]]); let expected_indices = TensorData::from([[[[14i64, 16, 17], [26, 28, 29], [32, 34, 35]]]]); let (output, output_indices) = max_pool2d_with_indices( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], true, ); expected_values .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices .into_data() .assert_eq(&expected_indices, false); } #[test] fn test_max_pool2d_ceil_mode_with_indices_and_padding() { // Test ceil_mode=true with padding and indices to verify correct index calculation // This exercises the case where both user padding and ceil_mode extra padding apply let kernel_size_1 = 3; let kernel_size_2 = 3; let padding_1 = 1; let padding_2 = 1; let stride_1 = 2; let stride_2 = 2; let dilation_1 = 1; let dilation_2 = 1; // Input 5x5 (indices 0-24 in row-major order): // row 0: 0 1 2 3 4 // row 1: 5 6 7 8 9 // row 2: 10 11 12 13 14 // row 3: 15 16 17 18 19 // row 4: 20 21 22 23 24 let x = TestTensor::from([[[ [0.0, 1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0, 9.0], [10.0, 11.0, 12.0, 13.0, 14.0], [15.0, 16.0, 17.0, 18.0, 19.0], [20.0, 21.0, 22.0, 23.0, 24.0], ]]]); // With padding=1, ceil_mode=true: // Effective input is 7x7 (5 + 2*1) // Output size: ceil((5 + 2*1 - 3) / 2) + 1 = ceil(4/2) + 1 = 3 // // Windows (with -inf padding at boundaries): // (0,0): rows -1 to 1, cols -1 to 1 -> valid: (0,0) to (1,1), max at (1,1)=6 // (0,1): rows -1 to 1, cols 1 to 3 -> max at (1,3)=8 // (0,2): rows -1 to 1, cols 3 to 5 -> max at (1,4)=9 // (1,0): rows 1 to 3, cols -1 to 1 -> max at (3,1)=16 // (1,1): rows 1 to 3, cols 1 to 3 -> max at (3,3)=18 // (1,2): rows 1 to 3, cols 3 to 5 -> max at (3,4)=19 // (2,0): rows 3 to 5, cols -1 to 1 -> max at (4,1)=21 // (2,1): rows 3 to 5, cols 1 to 3 -> max at (4,3)=23 // (2,2): rows 3 to 5, cols 3 to 5 -> max at (4,4)=24 let expected_values = TestTensor::<4>::from([[[[6.0, 8.0, 9.0], [16.0, 18.0, 19.0], [21.0, 23.0, 24.0]]]]); let expected_indices = TensorData::from([[[[6i64, 8, 9], [16, 18, 19], [21, 23, 24]]]]); let (output, output_indices) = max_pool2d_with_indices( x, [kernel_size_1, kernel_size_2], [stride_1, stride_2], [padding_1, padding_2], [dilation_1, dilation_2], true, ); expected_values .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); output_indices .into_data() .assert_eq(&expected_indices, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/mod.rs ================================================ use super::*; mod adaptive_avgpool1d; mod adaptive_avgpool2d; mod attention; mod avgpool1d; mod avgpool2d; mod bicubic_interpolate; mod bilinear_interpolate; mod conv1d; mod conv2d; mod conv3d; mod conv_transpose1d; mod conv_transpose2d; mod conv_transpose3d; mod deform_conv2d; mod forward; mod lanczos3_interpolate; mod linear; mod maxpool1d; mod maxpool2d; mod nearest_interpolate; mod unfold4d; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/nearest_interpolate.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::interpolate; use burn_tensor::ops::{InterpolateMode, InterpolateOptions}; #[test] fn test_upsample_interpolation() { let test = InterpolateTestCase { batch_size: 2, channels: 1, height: 7, width: 5, height_out: 8, width_out: 7, }; test.assert_output(TestTensor::from([ [[ [0., 0., 1., 2., 2., 3., 4.], [0., 0., 1., 2., 2., 3., 4.], [5., 5., 6., 7., 7., 8., 9.], [10., 10., 11., 12., 12., 13., 14.], [15., 15., 16., 17., 17., 18., 19.], [20., 20., 21., 22., 22., 23., 24.], [25., 25., 26., 27., 27., 28., 29.], [30., 30., 31., 32., 32., 33., 34.], ]], [[ [35., 35., 36., 37., 37., 38., 39.], [35., 35., 36., 37., 37., 38., 39.], [40., 40., 41., 42., 42., 43., 44.], [45., 45., 46., 47., 47., 48., 49.], [50., 50., 51., 52., 52., 53., 54.], [55., 55., 56., 57., 57., 58., 59.], [60., 60., 61., 62., 62., 63., 64.], [65., 65., 66., 67., 67., 68., 69.], ]], ])); } #[test] fn test_downsample_interpolation() { let test = InterpolateTestCase { batch_size: 1, channels: 1, height: 45, width: 14, height_out: 4, width_out: 6, }; test.assert_output(TestTensor::from([[[ [0., 2., 4., 7., 9., 11.], [154., 156., 158., 161., 163., 165.], [308., 310., 312., 315., 317., 319.], [462., 464., 466., 469., 471., 473.], ]]])); } #[test] fn test_1d_nearest() { // Initialize the model without weights (because the exported file does not contain them) let device = Default::default(); // Run the model let input = TestTensor::<3>::from_floats( [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], &device, ); let input = input.unsqueeze_dim(2); let output = interpolate( input, [1, 9], InterpolateOptions::new(InterpolateMode::Nearest), ); assert_eq!(output.dims(), [1, 1, 1, 9]); // assert output data does not contain NaN assert!( !output .clone() .to_data() .as_slice::() .unwrap() .iter() .any(|&x| x.is_nan()), "interpolate output contains NaN" ); TestTensor::<4>::from([[[[ 1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986, ]]]]) .to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } struct InterpolateTestCase { batch_size: usize, channels: usize, height: usize, width: usize, height_out: usize, width_out: usize, } impl InterpolateTestCase { fn assert_output(self, y: TestTensor<4>) { let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device()) .reshape::<4, _>(shape_x) .into_data() .convert::(), ); let output = interpolate( x, [self.height_out, self.width_out], InterpolateOptions::new(InterpolateMode::Nearest), ); y.to_data() .assert_approx_eq::(&output.into_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/module/unfold4d.rs ================================================ use super::*; use burn_tensor::Shape; use burn_tensor::Tolerance; use burn_tensor::module::unfold4d; use burn_tensor::ops::UnfoldOptions; #[test] fn test_unfold4d_shape() { let test = Unfold4dTestCase { batch_size: 2, channels_in: 5, kernel_size: [2, 3], padding: [0, 0], stride: [1, 1], dilation: [1, 1], height: 3, width: 4, }; test.assert_shape([2, 30, 4]); } #[test] fn test_unfold4d_simple() { let test = Unfold4dTestCase { batch_size: 1, channels_in: 2, kernel_size: [2, 2], padding: [0, 0], stride: [1, 1], dilation: [1, 1], height: 4, width: 4, }; test.assert_output(TestTensor::from([[ [0., 1., 2., 4., 5., 6., 8., 9., 10.], [1., 2., 3., 5., 6., 7., 9., 10., 11.], [4., 5., 6., 8., 9., 10., 12., 13., 14.], [5., 6., 7., 9., 10., 11., 13., 14., 15.], [16., 17., 18., 20., 21., 22., 24., 25., 26.], [17., 18., 19., 21., 22., 23., 25., 26., 27.], [20., 21., 22., 24., 25., 26., 28., 29., 30.], [21., 22., 23., 25., 26., 27., 29., 30., 31.], ]])); } #[test] fn test_unfold4d_complex() { let test = Unfold4dTestCase { batch_size: 1, channels_in: 2, kernel_size: [2, 3], padding: [0, 1], stride: [1, 2], dilation: [1, 2], height: 3, width: 4, }; test.assert_output(TestTensor::from([[ [0., 0.], [1., 5.], [3., 7.], [0., 0.], [5., 9.], [7., 11.], [0., 0.], [13., 17.], [15., 19.], [0., 0.], [17., 21.], [19., 23.], ]])); } struct Unfold4dTestCase { batch_size: usize, channels_in: usize, kernel_size: [usize; 2], padding: [usize; 2], stride: [usize; 2], dilation: [usize; 2], height: usize, width: usize, } impl Unfold4dTestCase { fn assert_shape(self, expected_shape: [usize; 3]) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default()) .reshape::<4, _>(shape_x) .into_data() .convert::(), ); let output = unfold4d( x, self.kernel_size, UnfoldOptions::new(self.stride, self.padding, self.dilation), ); assert_eq!( output.shape().as_slice(), expected_shape, "Expected shape doesn't match the actual shape" ); } fn assert_output(self, expected: TestTensor<3>) { let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); let x = TestTensor::from( TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default()) .reshape::<4, _>(shape_x) .into_data(), ); let output = unfold4d( x, self.kernel_size, UnfoldOptions::new(self.stride, self.padding, self.dilation), ); let tolerance = Tolerance::default() .set_half_precision_relative(2e-3) .set_half_precision_absolute(2e-3); output .into_data() .assert_approx_eq::(&expected.into_data(), tolerance); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/abs.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_abs_ops_float() { let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); let output = tensor.abs(); output .into_data() .assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/add.rs ================================================ use super::*; use burn_tensor::{TensorData, backend::Backend}; #[test] fn test_add_d2() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); let output = tensor_1 + tensor_2; output.into_data().assert_eq( &TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]), false, ); } #[test] fn test_add_broadcast() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0]]); let tensor_2 = TestTensor::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let output = tensor_1 + tensor_2; output.into_data().assert_eq( &TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]), false, ); } #[test] fn test_add_different_strides_rhs() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1; let output = tensor_1 + tensor_2.transpose(); output .into_data() .assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false); } #[test] fn test_add_different_strides_lhs() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1; let output = tensor_1.transpose() + tensor_2; output .into_data() .assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false); } #[test] fn test_add_different_strides_broadcast() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = TestTensor::from([[4.0, 5.0]]) * 1; let output = tensor_1.transpose() + tensor_2; output .into_data() .assert_eq(&TensorData::from([[4.0, 7.0], [5.0, 8.0]]), false); } #[test] fn should_support_add_scalar_ops() { let scalar = 2.0; let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor + scalar; output .into_data() .assert_eq(&TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]), false); } #[test] fn add_maybe_fused_not_contiguous() { let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor2 = TestTensorInt::arange(8..16, &Default::default()).float(); let tensor1 = tensor1.reshape([2, 4]); let tensor2 = tensor2.reshape([4, 2]); let tensor2 = tensor2.swap_dims(0, 1); TestBackend::sync(&tensor2.device()).unwrap(); let output = tensor1 + tensor2; output.into_data().assert_eq( &TensorData::from([[8.0, 11.0, 14.0, 17.0], [13.0, 16.0, 19.0, 22.0]]), false, ); } #[test] fn add_maybe_fused_not_contiguous_broadcasted() { let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor2 = TestTensorInt::arange(8..10, &Default::default()).float(); let tensor1 = tensor1.reshape([2, 4]); let tensor2 = tensor2.reshape([1, 2]); let tensor2 = tensor2.swap_dims(0, 1); TestBackend::sync(&tensor2.device()).unwrap(); let output = tensor2 + tensor1; output.into_data().assert_eq( &TensorData::from([[8.0, 9.0, 10.0, 11.0], [13.0, 14.0, 15.0, 16.0]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/aggregation.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use burn_tensor::backend::Backend; #[test] fn test_should_mean() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.mean(); let expected = TensorData::from([15.0 / 6.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_should_sum() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sum(); output .into_data() .assert_eq(&TensorData::from([15.0]), false); } #[test] fn test_should_sum_dim_maybe_fused() { let tensor = TestTensor::<2>::from([[5.0], [-12.0]]); let tensor1 = TestTensor::<2>::from([[2.0, 3.0], [-1.0, -5.0]]); let ones = TestTensor::<2>::ones([2, 2], &Default::default()); let _x = ones.clone() * tensor; let y = ones * tensor1; let output = y.clone().sum_dim(1); output .into_data() .assert_eq(&TensorData::from([[5.0], [-6.0]]), false); // Negative Indexing. let output = y.clone().sum_dim(-1); output .into_data() .assert_eq(&TensorData::from([[5.0], [-6.0]]), false); } #[test] fn test_should_mean_last_dim() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.clone().mean_dim(1); let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); // Negative Indexing. let output = tensor.clone().mean_dim(-1); let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_should_sum_last_dim() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sum_dim(1); output .into_data() .assert_eq(&TensorData::from([[3.0], [12.0]]), false); } #[test] fn test_should_sum_first_dim() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); let output = tensor.sum_dim(0); output .into_data() .assert_eq(&TensorData::from([[7.0, 3.0, 5.0]]), false); } #[test] fn test_should_mean_first_dim() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); let output = tensor.mean_dim(0); output.into_data().assert_eq( &TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]), false, ); } #[test] fn test_should_sum_mid_dim_3d_non_contiguous_1() { let tensor = TestTensor::<3>::from([ [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], ]); let output = tensor.swap_dims(0, 2).sum_dim(1); output.into_data().assert_eq( &TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]), false, ); } #[test] fn test_should_sum_mid_dim_3d_non_contiguous_2() { let tensor = TestTensor::<3>::from([ [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], ]); let output = tensor.swap_dims(0, 1).sum_dim(1); output.into_data().assert_eq( &TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]), false, ); } #[test] fn test_prod_float() { let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.prod(); // 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float let expected = TensorData::from([240.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_with_zero.prod(); output .into_data() .assert_eq(&TensorData::from([0.0]), false); } #[test] fn test_prod_dim_float() { let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.prod_dim(1); let expected = TensorData::from([[4.0], [60.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_with_zero.prod_dim(1); let expected = TensorData::from([[0.0], [60.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_sum_dim_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let output = tensor.clone().sum_dim(1); let expected = TensorData::from([[3.], [12.]]); output.into_data().assert_eq(&expected, false); let output = tensor.sum_dim(0); let expected = TensorData::from([[3., 5., 7.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dims_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); tensor .clone() .sum_dims(&[1]) .to_data() .assert_eq(&TensorData::from([[3.], [12.]]), false); tensor .clone() .sum_dims(&[-1]) .to_data() .assert_eq(&TensorData::from([[3.], [12.]]), false); tensor .clone() .sum_dims(&[0, 1]) .to_data() .assert_eq(&TensorData::from([[15.]]), false); } #[test] fn test_sum_and_squeeze_dims() { let tensor = TestTensor::<3>::from_floats( [ [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]], ], &Default::default(), ); tensor .sum_dims_squeeze::<1, _>(&[0, 1]) .to_data() .assert_eq(&TensorData::from([20., 16., 21.]), false); } #[test] fn test_sum_dim_1_reshape_maybe_fused() { let tensor = TestTensorInt::arange(0..9, &Default::default()).float(); TestBackend::sync(&tensor.device()).unwrap(); let output = tensor.reshape([3, 3]) + 2; let output = output.sum_dim(1); let expected = TensorData::from([[9.0], [18.0], [27.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_1_swap_dims_maybe_fused() { let tensor = TestTensorInt::arange(0..9, &Default::default()).float(); let tensor = tensor.reshape([3, 3]); TestBackend::sync(&tensor.device()).unwrap(); let output = tensor.swap_dims(0, 1) + 2; let output = output.sum_dim(1); let expected = TensorData::from([[15.0], [18.0], [21.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_2_reshape_maybe_fused_broadcast() { let tensor = TestTensorInt::arange(0..9, &Default::default()).float(); TestBackend::sync(&tensor.device()).unwrap(); let output = tensor.reshape([1, 3, 3]) + 2; let output = output.sum_dim(2); let expected = TensorData::from([[[9.0], [18.0], [27.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_2_maybe_fused_on_write() { let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor_2 = TestTensorInt::arange(10..12, &Default::default()).float(); let tensor_1 = tensor_1.reshape([1, 2, 4]); let tensor_2 = tensor_2.reshape([1, 2, 1]); TestBackend::sync(&tensor_1.device()).unwrap(); let output = (tensor_1 + tensor_2.clone()).sum_dim(2) + tensor_2; TestBackend::sync(&output.device()).unwrap(); let expected = TensorData::from([[[56.0], [77.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_3_maybe_fused_on_read_not_contiguous() { let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float(); let tensor_1 = tensor_1.reshape([4, 2, 1]); let tensor_1 = tensor_1.swap_dims(0, 2); let tensor_2 = tensor_2.reshape([1, 4, 2]); let tensor_2 = tensor_2.swap_dims(1, 2); TestBackend::sync(&tensor_1.device()).unwrap(); let output = (tensor_1 + tensor_2).sum_dim(2); let expected = TensorData::from([[[88.0], [96.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_4_maybe_fused_on_read_not_contiguous_mixed() { let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float(); let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float(); let tensor_1 = tensor_1.reshape([4, 2, 1]); let tensor_3 = tensor_3.reshape([1, 2, 4]); let tensor_1 = tensor_1.swap_dims(0, 2); let tensor_2 = tensor_2.reshape([1, 4, 2]); let tensor_2 = tensor_2.swap_dims(1, 2); TestBackend::sync(&tensor_1.device()).unwrap(); let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(2); let expected = TensorData::from([[[222.0], [246.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_5_maybe_fused_on_read_not_contiguous_mixed() { let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float(); let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float(); let tensor_1 = tensor_1.reshape([4, 2, 1]); let tensor_3 = tensor_3.reshape([1, 2, 4]); let tensor_1 = tensor_1.swap_dims(0, 2); let tensor_2 = tensor_2.reshape([1, 4, 2]); let tensor_2 = tensor_2.swap_dims(1, 2); TestBackend::sync(&tensor_1.device()).unwrap(); let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(1); let expected = TensorData::from([[[102.0, 112.0, 122.0, 132.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_6_maybe_fused_on_read_not_contiguous_broadcasted() { let tensor_1 = TestTensorInt::arange(0..32, &Default::default()).float(); let tensor_2 = TestTensorInt::arange(0..8, &Default::default()).float(); let tensor_1 = tensor_1.reshape([4, 2, 2, 2]); let tensor_1 = tensor_1.swap_dims(3, 2); let tensor_1 = tensor_1.swap_dims(1, 2); let tensor_2 = tensor_2.reshape([1, 2, 2, 2]); TestBackend::sync(&tensor_1.device()).unwrap(); let sum = tensor_2.clone().sum_dim(0); let sum = sum.sum_dim(1); let sum = sum.sum_dim(2); TestBackend::sync(&tensor_1.device()).unwrap(); let _tmp = sum.clone() + 2; let output = (tensor_1 + tensor_2 + sum).sum_dim(1); let expected = TensorData::from([ [[[29.0, 43.0], [41.0, 55.0]]], [[[45.0, 59.0], [57.0, 71.0]]], [[[61.0, 75.0], [73.0, 87.0]]], [[[77.0, 91.0], [89.0, 103.0]]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sum_dim_7_maybe_fused_on_read_reshaped() { let tensor_1 = TestTensorInt::arange(0..16, &Default::default()).float(); let tensor_1 = tensor_1.reshape([4, 4]); TestBackend::sync(&tensor_1.device()).unwrap(); let reshaped = tensor_1.reshape([1, 4, 4]); let tmp = reshaped + 5.0; let output = tmp.sum_dim(2); let expected = TensorData::from([[[26.0], [42.0], [58.0], [74.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_mean_dim_fused_on_read_on_write() { // https://github.com/tracel-ai/burn/issues/3987 let device = Default::default(); let x = TestTensor::ones([128, 32, 1], &device); let weight = TestTensor::ones([1, 32, 1], &device); let options = burn_tensor::ops::ConvOptions::new([1], [0], [1], 1); let x = burn_tensor::module::conv1d(x, weight, None, options); let global = x.clone().powi_scalar(2).sum_dim(2).add_scalar(1e-5).sqrt(); let norm = global.clone().div(global.mean_dim(1)); let x = x.clone().mul(norm).add(x); let out = x.sum(); out.into_data() .assert_eq(&TensorData::from([8192.0]), false); } #[test] fn test_mean_dim_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let output = tensor.clone().mean_dim(1); let expected = TensorData::from([[1.], [4.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let output = tensor.mean_dim(0); let expected = TensorData::from([[1.5, 2.5, 3.5]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_mean_dims_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); tensor .clone() .mean_dims(&[1]) .to_data() .assert_eq(&TensorData::from([[1.], [4.]]), false); tensor .clone() .mean_dims(&[-1]) .to_data() .assert_eq(&TensorData::from([[1.], [4.]]), false); tensor .clone() .mean_dims(&[0, 1]) .to_data() .assert_eq(&TensorData::from([[2.5]]), false); } #[test] fn test_multiple_reduce_dims_permuted() { // Regression test for https://github.com/tracel-ai/burn/issues/4461 let tensor = TestTensorInt::arange(0..2 * 2 * 256, &Default::default()) .float() .reshape([2, 2, 256]); let output = tensor .permute([1, 2, 0]) .mean_dim(0) .mean_dim(1) .squeeze_dims::<1>(&[0, 1]); output .into_data() .assert_approx_eq::(&TensorData::from([255.5, 767.5]), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/all.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_all() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_all_dim() { let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let data_actual = tensor.all_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/any.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_any() { // test float tensor let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); // test int tensor let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorInt::<2>::from([[0, 0, 0], [0, 0, 0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); // test bool tensor let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_any_dim() { let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); // test int tensor let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); // test bool tensor let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/arg.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_argmax_2d_dim0() { let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.argmax(0); output .into_data() .assert_eq(&TensorData::from([[0, 0, 1]]), false); } #[test] fn test_argmin_2d_dim0() { let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); let output = tensor.argmin(0); output .into_data() .assert_eq(&TensorData::from([[0, 1, 0]]), false); } #[test] fn test_argmax_2d_dim1() { let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.argmax(1); output .into_data() .assert_eq(&TensorData::from([[1], [2]]), false); } #[test] fn test_argmin_2d_dim1() { let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); let output = tensor.argmin(1); output .into_data() .assert_eq(&TensorData::from([[2], [1]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/cast.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{DType, TensorData}; #[test] fn cast_float_to_bool() { let tensor1 = TestTensor::<2>::from([[0.0, 43.0, 0.0], [2.0, -4.2, 31.33]]); let data_actual = tensor1.bool().into_data(); let data_expected = TensorData::from([[false, true, false], [true, true, true]]); data_actual.assert_eq(&data_expected, false); } #[test] fn cast_float_to_int() { let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]).int(); let expected = TensorData::from([[1, 2, 3], [4, 5, 6]]); tensor.into_data().assert_eq(&expected, false); } #[test] fn cast_int_to_float_tensor() { let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]).float(); let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); tensor.into_data().assert_eq(&expected, false); } #[test] fn cast_bool_to_float_tensor() { let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float(); let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]); tensor.into_data().assert_eq(&expected, false); } #[test] fn cast_float_precision() { let data = TensorData::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]); let tensor = TestTensor::<2>::from(data.clone()); let output = tensor.cast(DType::F32); assert_eq!(output.dtype(), DType::F32); // Use precision 2 for parameterized tests in f16 and bf16 output .into_data() .assert_approx_eq::(&data, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/cat.rs ================================================ use super::*; use alloc::vec::Vec; use burn_tensor::Tolerance; use burn_tensor::{DType, TensorData}; #[test] fn should_support_cat_ops_2d_dim0() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device); let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_ops_2d_dim1() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device); let output = TestTensor::cat(vec![tensor_1, tensor_2], 1); let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_ops_3d() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device); let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device); let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[should_panic] fn should_panic_when_dimensions_are_not_the_same() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device); TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data(); } #[test] #[should_panic] fn should_panic_when_list_of_vectors_is_empty() { let tensor: Vec> = vec![]; TestTensor::cat(tensor, 0).into_data(); } #[test] #[should_panic] fn should_panic_when_cat_exceeds_dimension() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device); let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device); TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data(); } #[test] fn should_support_cat_ops_cast_dtype() { let device = Default::default(); // ok for f32 backends, casts dtype for f16 tests let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device) .cast(DType::F32); let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device).cast(DType::F32); let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_with_empty_tensor() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let tensor_2: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor with size 0 on dim 1 // Concatenating with an empty tensor should just return the non-empty tensor let output = TestTensor::cat(vec![tensor_1.clone(), tensor_2], 1); let expected = TensorData::from([[1.0, 2.0, 3.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_with_empty_tensor_first() { let device = Default::default(); let tensor_1: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor let tensor_2 = TestTensor::<2>::from_data([[4.0, 5.0, 6.0]], &device); // Empty tensor first, then non-empty let output = TestTensor::cat(vec![tensor_1, tensor_2.clone()], 1); let expected = TensorData::from([[4.0, 5.0, 6.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_with_multiple_empty_tensors() { let device = Default::default(); let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device); let tensor_2 = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let tensor_3: TestTensor<2> = TestTensor::empty([2, 0], &device); let tensor_4 = TestTensor::<2>::from_data([[5.0], [6.0]], &device); // Mix of empty and non-empty tensors let output = TestTensor::cat(vec![tensor_1, tensor_2, tensor_3, tensor_4], 1); let expected = TensorData::from([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_cat_all_empty_tensors() { let device = Default::default(); let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device); let tensor_2: TestTensor<2> = TestTensor::empty([2, 0], &device); // All empty tensors should produce an empty tensor let output = TestTensor::cat(vec![tensor_1, tensor_2], 1); assert_eq!(output.shape().as_slice(), [2, 0]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/ceil.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_ceil_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.ceil(); let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/chunk.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_chunk_evenly_divisible() { let tensors = TestTensorInt::arange(0..12, &Default::default()) .float() .chunk(6, 0); assert_eq!(tensors.len(), 6); let expected = [ TensorData::from([0, 1]), TensorData::from([2, 3]), TensorData::from([4, 5]), TensorData::from([6, 7]), TensorData::from([8, 9]), TensorData::from([10, 11]), ]; for (index, tensor) in tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_chunk_not_evenly_divisible() { let tensors = TestTensorInt::arange(0..11, &Default::default()) .float() .chunk(6, 0); assert_eq!(tensors.len(), 6); let expected = [ TensorData::from([0, 1]), TensorData::from([2, 3]), TensorData::from([4, 5]), TensorData::from([6, 7]), TensorData::from([8, 9]), TensorData::from([10]), ]; for (index, tensor) in tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_chunk_not_evenly_divisible_remains_several() { let tensors = TestTensorInt::arange(0..100, &Default::default()) .float() .chunk(8, 0); assert_eq!(tensors.len(), 8); let expected = [13, 13, 13, 13, 13, 13, 13, 9]; for (index, tensor) in tensors.iter().enumerate() { assert_eq!(tensor.shape()[0], expected[index]); } } #[test] fn test_chunk_not_divisible() { let tensors = TestTensorInt::arange(0..6, &Default::default()) .float() .chunk(7, 0); assert_eq!(tensors.len(), 6); let expected = [ TensorData::from([0]), TensorData::from([1]), TensorData::from([2]), TensorData::from([3]), TensorData::from([4]), TensorData::from([5]), ]; for (index, tensor) in tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] #[should_panic] fn test_invalid_dim() { let _tensors = TestTensorInt::arange(0..12, &Default::default()).chunk(6, 1); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/clamp.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn clamp_min() { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp_min(2.0); output .into_data() .assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]), false); // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp_min(2); output .into_data() .assert_eq(&TensorData::from([[2, 2, 2], [3, 4, 5]]), false); } #[test] fn clamp_max() { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp_max(2.0); output .into_data() .assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]), false); // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp_max(4); output .into_data() .assert_eq(&TensorData::from([[0, 1, 2], [3, 4, 4]]), false); } #[test] fn clamp_min_max() { let device = Default::default(); // test float tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.clamp(1.0, 4.0); output .into_data() .assert_eq(&TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]), false); // test int tensor let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let tensor = TestTensorInt::<2>::from_data(data, &device); let output = tensor.clamp(1, 4); output .into_data() .assert_eq(&TensorData::from([[1, 1, 2], [3, 4, 4]]), false); } #[test] fn clamp_min_max_vec_should_compile() { let input = TestTensor::<2>::ones([2, 4], &Default::default()); let output = input.clamp(0., 0.5); output.into_data().assert_eq( &TensorData::from([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/close.rs ================================================ use super::*; use burn_tensor::{DEFAULT_ATOL, DEFAULT_RTOL, TensorData}; #[test] fn test_is_close() { let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9; let data_actual = tensor1 .clone() .is_close(tensor2.clone(), None, None) .into_data(); let defaults_expected = TensorData::from([[true, true, true], [true, true, false]]); defaults_expected.assert_eq(&data_actual, false); // Using the defaults. let data_actual = tensor1 .is_close(tensor2, Some(DEFAULT_RTOL), Some(DEFAULT_ATOL)) .into_data(); defaults_expected.assert_eq(&data_actual, false); } #[test] fn test_all_close() { let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9; assert!(!tensor1.clone().all_close(tensor2.clone(), None, None)); let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-9; assert!(tensor1.all_close(tensor2, None, None)); // non finite values let inf_plus = TestTensor::<2>::from([[f32::INFINITY]]); let one = TestTensor::<2>::from([[1.]]); let inf_minus = TestTensor::<2>::from([[-f32::INFINITY]]); assert!(!inf_plus.clone().all_close(inf_minus.clone(), None, None)); assert!(!one.clone().all_close(inf_minus.clone(), None, None)); assert!(!one.all_close(inf_plus.clone(), None, None)); assert!(inf_plus.clone().all_close(inf_plus, None, None)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/comparison.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_equal_inf() { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]); let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); let data_expected = TensorData::from([[false, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_not_equal_inf() { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]); let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); let data_expected = TensorData::from([[true, false, true], [true, true, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_equal() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); let data_expected = TensorData::from([[false, true, false], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_not_equal() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); let data_expected = TensorData::from([[true, false, true], [true, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_equal_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]); let data_actual_cloned = tensor_1.clone().equal_elem(2); let data_actual_inplace = tensor_1.equal_elem(2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_not_equal_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]); let data_actual_cloned = tensor_1.clone().not_equal_elem(2); let data_actual_inplace = tensor_1.not_equal_elem(2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn greater_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor_1.clone().greater_elem(4); let data_actual_inplace = tensor_1.greater_elem(4); let data_expected = TensorData::from([[false, false, false], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_equal_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0); let data_actual_inplace = tensor_1.greater_equal_elem(4.0); let data_expected = TensorData::from([[false, false, false], [false, true, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]); let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); let data_actual_inplace = tensor_1.greater(tensor_2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_equal() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]); let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.greater_equal(tensor_2); let data_expected = TensorData::from([[false, true, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor_1.clone().lower_elem(4.0); let data_actual_inplace = tensor_1.lower_elem(4.0); let data_expected = TensorData::from([[true, true, true], [true, false, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_equal_elem() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0); let data_actual_inplace = tensor_1.lower_equal_elem(4.0); let data_expected = TensorData::from([[true, true, true], [true, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]); let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); let data_actual_inplace = tensor_1.lower(tensor_2); let data_expected = TensorData::from([[true, false, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_equal() { let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]); let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.lower_equal(tensor_2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_broadcast() { // Test broadcasting with shape [1, 4] vs [4, 4] let device = Default::default(); let data_1 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]); let data_2 = TensorData::from([ [0.5, 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5], ]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.greater(tensor_2); let expected = TensorData::from([ [true, true, true, true], [false, false, false, false], [false, false, false, false], [false, false, false, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_greater_equal_broadcast() { // Test broadcasting with shape [4, 1] vs [1, 4] let device = Default::default(); let data_1 = TensorData::from([[1.0], [2.0], [3.0], [4.0]]); let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.greater_equal(tensor_2); let expected = TensorData::from([ [true, false, false, false], [true, true, false, false], [true, true, true, false], [true, true, true, true], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_lower_broadcast() { // Test broadcasting mimicking CLIP pattern: [1, 5] vs [5, 1] let device = Default::default(); let data_1 = TensorData::from([[0.0, 1.0, -1.0, 2.0, -2.0]]); let data_2 = TensorData::from([[0.5], [1.5], [-0.5], [-1.5], [2.5]]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.lower(tensor_2); let expected = TensorData::from([ [true, false, true, false, true], [true, true, true, false, true], [false, false, true, false, true], [false, false, false, false, true], [true, true, true, true, true], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_lower_equal_broadcast() { // Test broadcasting with shape [1, 1] vs [2, 4] let device = Default::default(); let data_1 = TensorData::from([[2.5]]); let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0], [2.0, 2.5, 3.0, 3.5]]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.lower_equal(tensor_2); let expected = TensorData::from([[false, false, true, true], [false, true, true, true]]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_equal_broadcast() { // Test broadcasting with different ranks let device = Default::default(); let data_1 = TensorData::from([[2.0], [3.0], [4.0]]); let data_2 = TensorData::from([[2.0, 3.0, 4.0, 2.0]]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.equal(tensor_2); let expected = TensorData::from([ [true, false, false, true], [false, true, false, false], [false, false, true, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_not_equal_broadcast() { // Test broadcasting with shape [3, 1] vs [1, 3] let device = Default::default(); let data_1 = TensorData::from([[1.0], [2.0], [3.0]]); let data_2 = TensorData::from([[1.0, 2.0, 3.0]]); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let result = tensor_1.not_equal(tensor_2); let expected = TensorData::from([ [false, true, true], [true, false, true], [true, true, false], ]); expected.assert_eq(&result.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/create_like.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Distribution, TensorData}; #[test] fn should_support_zeros_like() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let tensor = tensor.zeros_like(); let expected = TensorData::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]); tensor .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_ones_like() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let tensor = tensor.ones_like(); let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); tensor .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_randoms_like() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let tensor = tensor.random_like(Distribution::Uniform(0.99999, 1.)); let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]); tensor .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/cross.rs ================================================ use super::*; use burn_tensor::TensorData; #[cfg(feature = "std")] use burn_backend_tests::might_panic; #[test] fn test_cross_3d_last_dim() { let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]); let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]); let output = tensor_1.cross(tensor_2, -1); output.into_data().assert_eq( &TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]), false, ); } #[test] fn test_cross_3d_non_contiguous_last_dim() { let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]); let tensor_2 = TestTensor::from([[4.0, 3.0], [-2.0, 5.0], [1.0, -2.0]]); let output = tensor_1.cross(tensor_2.permute([1, 0]), -1); output.into_data().assert_eq( &TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]), false, ); } #[cfg(feature = "std")] #[might_panic(reason = "not implemented: Cross product on non-last dimension")] #[test] fn test_cross_3d_dim0() { let tensor_1 = TestTensor::<2>::from([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]]); let tensor_2 = TestTensor::from([[0.0, 1.0], [0.0, 0.0], [1.0, 0.0]]); let output = tensor_1.cross(tensor_2, 0); output.into_data().assert_eq( &TensorData::from([[0.0, 0.0], [-1.0, 0.0], [0.0, -1.0]]), false, ); } #[test] fn test_cross_3d_broadcast() { let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0]]); let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]); let output = tensor_1.cross(tensor_2, -1); output.into_data().assert_eq( &TensorData::from([[-7.0, -21.0, -14.0], [19.0, -13.0, -4.0]]), false, ); } #[test] fn test_cross_4d_last_dim() { let tensor_1 = TestTensor::<3>::from([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]); let tensor_2 = TestTensor::from([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]); let output = tensor_1.cross(tensor_2, -1); output.into_data().assert_eq( &TensorData::from([[[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]]), false, ); } // Helper to compute expected cross product for 2-D (N × 3) tensors. fn manual_cross(a: &[[f32; 3]], b: &[[f32; 3]]) -> Vec<[f32; 3]> { a.iter() .zip(b.iter()) .map(|(x, y)| { [ x[1] * y[2] - x[2] * y[1], x[2] * y[0] - x[0] * y[2], x[0] * y[1] - x[1] * y[0], ] }) .collect() } #[test] fn forward_matches_manual_cross() { let a_raw = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let b_raw = [[7.0, 8.0, 9.0], [1.0, 0.0, -1.0]]; let a = TestTensor::<2>::from(a_raw); let b = TestTensor::<2>::from(b_raw); let out = a.cross(b.clone(), 1); let expected_vec = manual_cross(&a_raw, &b_raw); let expected: [[f32; 3]; 2] = [expected_vec[0], expected_vec[1]]; out.into_data() .assert_eq(&TensorData::from(expected), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/cumulative.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_cumsum_float_dim_0() { let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let output = tensor.cumsum(0); output .into_data() .assert_eq(&TensorData::from([[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]), false); } #[test] fn test_cumsum_float_dim_1() { let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let output = tensor.cumsum(1); output.into_data().assert_eq( &TensorData::from([[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]), false, ); } #[test] fn test_cumsum_non_contiguous() { let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]).swap_dims(0, 1); let output = tensor.cumsum(1); output .into_data() .assert_eq(&TensorData::from([[1., 4.], [2., 6.]]), false); } #[test] fn test_cumsum_float_3d() { let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]); let output = tensor.cumsum(2); output.into_data().assert_eq( &TensorData::from([[[1.0, 3.0], [3.0, 7.0]], [[5.0, 11.0], [7.0, 15.0]]]), false, ); } #[test] fn test_cumprod_float_dim_0() { let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let output = tensor.cumprod(0); output.into_data().assert_eq( &TensorData::from([[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]), false, ); } #[test] fn test_cumprod_float_dim_1() { let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let output = tensor.cumprod(1); output.into_data().assert_eq( &TensorData::from([[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]), false, ); } #[test] fn test_cumprod_float_3d() { let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]); let output = tensor.cumprod(2); output.into_data().assert_eq( &TensorData::from([[[1.0, 2.0], [3.0, 12.0]], [[5.0, 30.0], [7.0, 56.0]]]), false, ); } #[test] fn test_cummin_float_dim_0() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]); let output = tensor.cummin(0); output .into_data() .assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [2.0, 1.0, 1.0]]), false); } #[test] fn test_cummin_float_dim_1() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]); let output = tensor.cummin(1); output .into_data() .assert_eq(&TensorData::from([[3.0, 1.0, 1.0], [2.0, 2.0, 1.0]]), false); } #[test] fn test_cummin_float_3d() { let tensor = TestTensor::<3>::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 6.0], [7.0, 8.0]]]); let output = tensor.cummin(2); output.into_data().assert_eq( &TensorData::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 5.0], [7.0, 7.0]]]), false, ); } #[test] fn test_cummax_float_dim_0() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]); let output = tensor.cummax(0); output .into_data() .assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [3.0, 5.0, 4.0]]), false); } #[test] fn test_cummax_float_dim_1() { let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]); let output = tensor.cummax(1); output .into_data() .assert_eq(&TensorData::from([[3.0, 3.0, 4.0], [1.0, 5.0, 5.0]]), false); } #[test] fn test_cummax_float_3d() { let tensor = TestTensor::<3>::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 2.0], [6.0, 1.0]]]); let output = tensor.cummax(2); output.into_data().assert_eq( &TensorData::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 5.0], [6.0, 6.0]]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/div.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_div_ops() { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_div_broadcast() { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; output.into_data().assert_eq( &TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]), false, ); } #[test] fn should_support_div_scalar_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let device = Default::default(); let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor / scalar; output .into_data() .assert_eq(&TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/dot.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_float() { let device = Default::default(); let tensor_1 = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device); let tensor_2 = TestTensor::<1>::from_data([0.0, -1.0, 4.0], &device); let output = tensor_1.dot(tensor_2); let expected = TensorData::from([10.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_int() { let device = Default::default(); let tensor_1 = TestTensor::<1>::from_data([1, 2, 3], &device); let tensor_2 = TestTensor::<1>::from_data([0, -1, 4], &device); let output = tensor_1.dot(tensor_2); let expected = TensorData::from([10]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn test_panics_for_different_sizes() { let device = Default::default(); let tensor_1 = TestTensor::<1>::from_data([1, 2], &device); let tensor_2 = TestTensor::<1>::from_data([1, 2, 3], &device); let _output = tensor_1.dot(tensor_2); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/erf.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_erf_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.erf(); let expected = TensorData::from([[0.0000, 0.8427, 0.99532], [0.99998, 1.0000, 1.0000]]); output.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_absolute(2e-3), ); } #[test] fn should_support_erf_ops_with_negative_number() { let data = TensorData::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.erf(); let expected = TensorData::from([ [-0.06312324, -0.048490416, -0.10016122], [0.99998, 1.0000, 1.0000], ]); output.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_absolute(3e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/exp.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_exp_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.exp(); let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/expand.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn expand_2d() { let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default()); let output = tensor.expand([3, 3]); output.into_data().assert_eq( &TensorData::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), false, ); let tensor = TestTensor::<1>::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default()); let output = tensor.expand([2, 4]); output.into_data().assert_eq( &TensorData::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]), false, ); } #[test] fn expand_3d() { let tensor = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default()); let output = tensor.expand([3, 2, 2]); let expected = TensorData::from([ [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn expand_higher_dimensions() { let tensor = TestTensor::<2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default()); let output = tensor.expand([2, 3, 4]); let expected = TensorData::from([ [ [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], ], [ [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn expand_sum_3d() { let tensor = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default()); let output = tensor.expand([3, 2, 2]).sum_dim(0); let expected = TensorData::from([[[3.0, 6.0], [9.0, 12.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn broadcast_single() { let tensor = TestTensor::<1>::from_floats([1.0], &Default::default()); let output = tensor.expand([2, 3]); output .into_data() .assert_eq(&TensorData::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), false); } #[test] #[should_panic] fn should_fail_expand_incompatible_shapes() { let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default()); let _expanded_tensor = tensor.expand([2, 2]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/finite.rs ================================================ use super::*; #[test] fn is_finite() { let all_finite = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let all_finite_expected = TestTensorBool::<2>::from([[true, true, true], [true, true, true]]); let with_inf_nan = TestTensor::<2>::from([ [0.0, f32::INFINITY, f32::NAN], [f32::NEG_INFINITY, f32::NAN, 5.0], ]); let with_inf_nan_expected = TestTensorBool::<2>::from([[true, false, false], [false, false, true]]); all_finite_expected .into_data() .assert_eq(&all_finite.is_finite().into_data(), false); with_inf_nan .is_finite() .into_data() .assert_eq(&with_inf_nan_expected.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/flatten.rs ================================================ use super::*; use burn_tensor::Shape; /// Test if the function can successfully flatten a 4D tensor to a 1D tensor. #[test] fn should_flatten_to_1d() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: TestTensor<1> = tensor.flatten(0, 3); let expected_shape = Shape::new([120]); assert_eq!(flattened_tensor.shape(), expected_shape); } /// Test if the function can successfully flatten the middle dimensions of a 4D tensor. #[test] fn should_flatten_middle() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: TestTensor<3> = tensor.flatten(1, 2); let expected_shape = Shape::new([2, 12, 5]); assert_eq!(flattened_tensor.shape(), expected_shape); } /// Test if the function can successfully flatten the first dimensions of a 4D tensor. #[test] fn should_flatten_begin() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: TestTensor<2> = tensor.flatten(0, 2); let expected_shape = Shape::new([24, 5]); assert_eq!(flattened_tensor.shape(), expected_shape); } /// Test if the function can successfully flatten the last dimensions of a 4D tensor. #[test] fn should_flatten_end() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: TestTensor<2> = tensor.flatten(1, 3); let expected_shape = Shape::new([2, 60]); assert_eq!(flattened_tensor.shape(), expected_shape); } /// Test if the function can flatten negative indices. #[test] fn should_flatten_end_negative_indices() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let flattened_tensor: TestTensor<2> = tensor.flatten(-3, -1); let expected_shape = Shape::new([2, 60]); assert_eq!(flattened_tensor.shape(), expected_shape); } /// Test if the function panics when the start dimension is greater than the end dimension. #[test] #[should_panic] fn should_flatten_panic() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let _flattened_tensor: TestTensor<2> = tensor.flatten(2, 0); } #[test] #[should_panic] fn not_enough_destination_dimension() { let tensor = TestTensor::<3>::ones(Shape::new([1, 5, 15]), &Default::default()); let flattened_tensor: TestTensor<1> = tensor.flatten(1, 2); let expected_shape = Shape::new([75]); assert_eq!(flattened_tensor.shape(), expected_shape); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/flip.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn flip_float() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); let flipped = tensor.clone().flip([0, 2]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).float() let expected = TensorData::from([ [ [15., 14., 13., 12.], [19., 18., 17., 16.], [23., 22., 21., 20.], ], [[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]], ]); flipped.into_data().assert_eq(&expected, false); // Test with no flip let flipped = tensor.clone().flip([]); tensor.into_data().assert_eq(&flipped.into_data(), false); } #[test] #[should_panic] fn flip_duplicated_axes() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a duplicated axis let _ = tensor.clone().flip([0, 0, 1]); } #[test] #[should_panic] fn flip_out_of_bound_axis() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().flip([3, 0, 1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/floor.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_floor_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.floor(); let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/fmod.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ElementConversion, TensorData}; #[allow(unused_imports)] // f16 use num_traits::Float; #[test] fn should_support_fmod_ops() { let dividend = TensorData::from([[5.3, -5.3], [7.5, -7.5]]); let divisor = TensorData::from([[2.0, 2.0], [3.0, 3.0]]); let dividend_tensor = TestTensor::<2>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<2>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([[1.3, -1.3], [1.5, -1.5]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_fmod_scalar() { let data = TensorData::from([5.3, -5.3, 7.5, -7.5]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.fmod_scalar(2.0); let expected = TensorData::from([1.3, -1.3, 1.5, -1.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_positive_dividend_positive_divisor() { let dividend = TensorData::from([10.0, 7.5, 3.8, 1.2]); let divisor = TensorData::from([3.0, 2.0, 1.5, 0.7]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([1.0, 1.5, 0.8, 0.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_negative_dividend() { let dividend = TensorData::from([-10.0, -7.5, -3.8, -1.2]); let divisor = TensorData::from([3.0, 2.0, 1.5, 0.7]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([-1.0, -1.5, -0.8, -0.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_mixed_signs() { let dividend = TensorData::from([5.3, -5.3, 5.3, -5.3]); let divisor = TensorData::from([2.0, 2.0, -2.0, -2.0]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); // fmod result has same sign as dividend let expected = TensorData::from([1.3, -1.3, 1.3, -1.3]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_infinity_dividend() { // If x is ±∞ and y is not NaN, NaN is returned let dividend = TensorData::from([ f32::INFINITY, f32::NEG_INFINITY, f32::INFINITY, f32::NEG_INFINITY, ]); let divisor = TensorData::from([2.0, 3.0, -2.0, -3.0]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let data = output.into_data(); let values = data.as_slice::().unwrap(); // All results should be NaN assert!(values[0].is_nan(), "fmod(inf, 2.0) should be NaN"); assert!(values[1].is_nan(), "fmod(-inf, 3.0) should be NaN"); assert!(values[2].is_nan(), "fmod(inf, -2.0) should be NaN"); assert!(values[3].is_nan(), "fmod(-inf, -3.0) should be NaN"); } #[test] fn should_handle_zero_divisor() { // If y is ±0 and x is not NaN, NaN should be returned let dividend = TensorData::from([5.3, -5.3, 0.0, 1.0]); let divisor = TensorData::from([0.0, -0.0, 0.0, -0.0]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let data = output.into_data(); let values = data.as_slice::().unwrap(); // All results should be NaN assert!(values[0].is_nan(), "fmod(5.3, 0.0) should be NaN"); assert!(values[1].is_nan(), "fmod(-5.3, -0.0) should be NaN"); assert!(values[2].is_nan(), "fmod(0.0, 0.0) should be NaN"); assert!(values[3].is_nan(), "fmod(1.0, -0.0) should be NaN"); } #[test] fn should_handle_infinity_divisor() { // If y is ±∞ and x is finite, x is returned let dividend = TensorData::from([5.3, -5.3, 0.0, -0.0]); let divisor = TensorData::from([ f32::INFINITY, f32::NEG_INFINITY, f32::INFINITY, f32::NEG_INFINITY, ]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([5.3, -5.3, 0.0, -0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_nan_arguments() { // If either argument is NaN, NaN is returned let dividend = TensorData::from([f32::NAN, 5.3, f32::NAN, 0.0]); let divisor = TensorData::from([2.0, f32::NAN, f32::NAN, 3.0]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let data = output.into_data(); let values = data.as_slice::().unwrap(); assert!(values[0].is_nan(), "fmod(NaN, 2.0) should be NaN"); assert!(values[1].is_nan(), "fmod(5.3, NaN) should be NaN"); assert!(values[2].is_nan(), "fmod(NaN, NaN) should be NaN"); assert!(!values[3].is_nan(), "fmod(0.0, 3.0) should be 0.0"); } #[test] fn should_handle_negative_zero() { // If x is -0 and y is greater than zero, either +0 or -0 may be returned let dividend = TensorData::from([-0.0_f32]); let divisor = TensorData::from([2.0_f32]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let data = output.into_data(); let values = data.as_slice::().unwrap(); // Result should be zero (either +0 or -0 is acceptable) assert_eq!( values[0], 0.0f32.elem::(), "fmod(-0, 2.0) should be zero" ); } #[test] fn should_support_fmod_broadcasting_2d() { // Test broadcasting: 1x2 with 3x2 let dividend = TensorData::from([[5.3, -5.3]]); // Shape: 1x2 let divisor = TensorData::from([[2.0, 2.0], [3.0, 3.0], [1.5, 1.5]]); // Shape: 3x2 let dividend_tensor = TestTensor::<2>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<2>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([ [1.3, -1.3], // 5.3 % 2.0, -5.3 % 2.0 [2.3, -2.3], // 5.3 % 3.0, -5.3 % 3.0 [0.8, -0.8], // 5.3 % 1.5, -5.3 % 1.5 ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_fmod_broadcasting_3d() { // Test broadcasting: 1x1x3 with 2x1x3 let dividend = TensorData::from([[[5.0, -7.0, 8.0]]]); // Shape: 1x1x3 let divisor = TensorData::from([[[3.0, 3.0, 3.0]], [[4.0, 4.0, 4.0]]]); // Shape: 2x1x3 let dividend_tensor = TestTensor::<3>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<3>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([ [[2.0, -1.0, 2.0]], // 5.0 % 3.0, -7.0 % 3.0, 8.0 % 3.0 [[1.0, -3.0, 0.0]], // 5.0 % 4.0, -7.0 % 4.0, 8.0 % 4.0 ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_fmod_scalar_broadcasting() { // Test scalar operation with different shapes let data = TensorData::from([[5.3, -5.3, 7.5], [-7.5, 10.0, -10.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.fmod_scalar(3.0); let expected = TensorData::from([[2.3, -2.3, 1.5], [-1.5, 1.0, -1.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_edge_case_values() { // Test various edge cases let dividend = TensorData::from([0.0, -0.0, 1e-10, -1e-10, 10.0, -10.0]); let divisor = TensorData::from([1.0, 1.0, 1.0, 1.0, 3.0, 3.0]); let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default()); let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default()); let output = dividend_tensor.fmod(divisor_tensor); let expected = TensorData::from([0.0, -0.0, 1e-10, -1e-10, 1.0, -1.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_special_scalar_cases() { // Test scalar operations with special values let data = TensorData::from([5.3, -5.3, 0.0, -0.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); // Test with infinity divisor let output_inf = tensor.clone().fmod_scalar(f32::INFINITY); let expected_inf = TensorData::from([5.3, -5.3, 0.0, -0.0]); output_inf .into_data() .assert_approx_eq::(&expected_inf, Tolerance::default()); // Test with very small divisor // Doesn't work if the test divisor is subnormal if FloatElem::MIN_POSITIVE > 1e-5f32.elem::() { return; } let output_small = tensor.clone().fmod_scalar(1e-5); let data = output_small.into_data(); let values = data.as_slice::().unwrap(); // let expected = TensorData::from([0.0, 0.0, 0.0, 0.0]); // Results should be very small remainders assert!(values[0].abs() < 1e-5f32.elem::()); assert!(values[1].abs() < 1e-5f32.elem::()); assert_eq!(values[2], 0.0f32.elem::()); assert_eq!(values[3], 0.0f32.elem::()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/full.rs ================================================ use super::*; use burn_tensor::{DType, TensorData}; #[test] fn test_data_full() { let tensor = TensorData::full([2, 3], 2.0); tensor.assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), false); } #[test] fn test_tensor_full() { let device = Default::default(); let tensor = TestTensor::<2>::full([2, 3], 2.1, &device); tensor .into_data() .assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false); } #[test] fn test_tensor_full_options() { let tensor = TestTensor::<2>::full([2, 3], 2.1, (&Default::default(), DType::F32)); assert_eq!(tensor.dtype(), DType::F32); tensor .into_data() .assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/gather_scatter.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_gather_1d_dim0() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([0.0, 1.0, 2.0], &device); let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device); let output = tensor.gather(0, indices); output .into_data() .assert_eq(&TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]), false); } #[test] fn should_gather_2d_dim0() { let device = Default::default(); let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]], &device); let output = tensor.gather(0, indices); output .into_data() .assert_eq(&TensorData::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]), false); } #[test] fn should_gather_2d_dim1() { let device = Default::default(); let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]], &device); let output = tensor.gather(1, indices); output.into_data().assert_eq( &TensorData::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]), false, ); } #[test] fn should_gather_3d_dim1() { let device = Default::default(); let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &device, ); let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device); let output = tensor.gather(1, indices); let expected = TensorData::from([ [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_gather_2d_only_1dim() { let device = Default::default(); let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::<2>::from_ints([[1, 2]], &device).reshape([2, 1]); let output = tensor.gather(1, indices); output .into_data() .assert_eq(&TensorData::from([[1.0], [5.0]]), false); } #[test] fn should_scatter_add_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([0.0, 0.0, 0.0], &device); let values = TestTensor::from_floats([5.0, 4.0, 3.0], &device); let indices = TestTensorInt::from_ints([1, 0, 2], &device); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([4.0, 5.0, 3.0]), false); } #[test] fn should_scatter_add_2d_dim0() { let device = Default::default(); let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]], &device); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]), false); } #[test] fn should_scatter_add_2d_dim1() { let device = Default::default(); let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]], &device); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]), false); } #[test] fn should_scatter_add_3d_dim1() { let device = Default::default(); let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &device, ); let values = TestTensor::from_floats( [ [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], ], &device, ); let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([ [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_scatter_add_2d_dim1_diff_shape() { let device = Default::default(); let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); let values = TestTensor::from_floats([[1.0], [4.0]], &device); let indices = TestTensorInt::from_ints([[1], [2]], &device); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]), false); } #[test] #[should_panic] fn scatter_should_panic_on_mismatch_of_shapes() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([0.0, 0.0, 0.0], &device); let values = TestTensor::from_floats([5.0, 4.0], &device); let indices = TestTensorInt::from_ints([1, 0, 2], &device); tensor.scatter(0, indices, values, IndexingUpdateOp::Add); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/grid_sample.rs ================================================ use super::*; use burn_tensor::{ TensorData, Tolerance, ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}, }; /// Tests grid_sample_2d with default options (align_corners=false, zeros padding). /// /// For a 3x3 input with grid coordinates: /// - (0.0, 0.0) maps to pixel (1.0, 1.0) -> center pixel = 4.0 /// - (-1.0, 0.25) maps to pixel (-0.5, 1.375) -> partially out of bounds /// - (1.0, 1.0) maps to pixel (2.5, 2.5) -> corner, partially out of bounds /// - (0.2, -0.8) maps to pixel (1.3, 0.3) -> interpolates around center-top #[test] fn should_grid_sample_2d_default() { let device = Default::default(); let tensor = TestTensor::<4>::from_floats( [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], &device, ); let grid = TestTensor::<4>::from_floats( [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]], &device, ); let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear); // Expected values computed with PyTorch grid_sample(align_corners=False, padding_mode='zeros') let expected = TensorData::from([[[[4.0, 2.0625], [2.0, 1.04]]]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Tests grid_sample_2d with align_corners=true and border padding. /// /// This is the original Burn semantics before the API change. #[test] fn should_grid_sample_2d_align_corners_border() { let device = Default::default(); let tensor = TestTensor::<4>::from_floats( [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], &device, ); let grid = TestTensor::<4>::from_floats( [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]], &device, ); let options = GridSampleOptions::new(InterpolateMode::Bilinear) .with_padding_mode(GridSamplePaddingMode::Border) .with_align_corners(true); let output = tensor.grid_sample_2d(grid, options); // Expected values computed with PyTorch grid_sample(align_corners=True, padding_mode='border') let expected = TensorData::from([[[[4.0, 3.75], [8.0, 1.8]]]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Tests out-of-bounds grid coordinates with zeros padding. /// Grid coordinate (0.0, -2.0) maps to pixel (1.0, -2.5) which is completely out of bounds. #[test] fn should_pad_zeros_grid_sample_2d() { let device = Default::default(); let tensor = TestTensor::<4>::from_floats( [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], &device, ); let grid = TestTensor::<4>::from_floats([[[[0.0, -2.0]]]], &device); let output = tensor.grid_sample_2d(grid, GridSampleOptions::default()); // With zeros padding, out-of-bounds samples return 0 let expected = TensorData::from([[[[0.0]]]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Tests out-of-bounds grid coordinates with border padding. #[test] fn should_pad_border_grid_sample_2d() { let device = Default::default(); let tensor = TestTensor::<4>::from_floats( [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], &device, ); let grid = TestTensor::<4>::from_floats([[[[0.0, -2.0]]]], &device); let options = GridSampleOptions::new(InterpolateMode::Bilinear) .with_padding_mode(GridSamplePaddingMode::Border); let output = tensor.grid_sample_2d(grid, options); // With border padding, out-of-bounds coordinates are clamped to border // Grid (0.0, -2.0) with align_corners=false: pixel (1.0, -2.5) -> clamped to (1.0, 0.0) = 1.0 let expected = TensorData::from([[[[1.0]]]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Tests bilinear interpolation with reflection padding. #[test] fn should_pad_reflection_grid_sample_2d() { let device = Default::default(); let tensor = TestTensor::<4>::from_floats( [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], &device, ); let grid = TestTensor::<4>::from_floats( [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]], &device, ); let options = GridSampleOptions::new(InterpolateMode::Bilinear) .with_padding_mode(GridSamplePaddingMode::Reflection); let output = tensor.grid_sample_2d(grid, options); // Expected values computed with PyTorch F.grid_sample(mode='bilinear', padding_mode='reflection', align_corners=False) let expected = TensorData::from([[[[4.0, 4.125], [8.0, 1.3]]]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/inf.rs ================================================ use super::*; #[test] fn is_inf() { let no_inf = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let no_inf_expected = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]); let with_inf = TestTensor::<2>::from([[0.0, f32::INFINITY, 2.0], [f32::NEG_INFINITY, 4.0, 5.0]]); let with_inf_expected = TestTensorBool::<2>::from([[false, true, false], [true, false, false]]); no_inf .is_inf() .into_data() .assert_eq(&no_inf_expected.into_data(), false); with_inf .is_inf() .into_data() .assert_eq(&with_inf_expected.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/init.rs ================================================ use super::*; use burn_tensor::{DType, TensorData}; #[test] fn should_support_float_empty() { let shape = [2, 2]; let tensor = TestTensor::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_float_empty_options() { let shape = [2, 2]; let tensor = TestTensor::<2>::empty(shape, (&Default::default(), DType::F32)); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_float_zeros() { let shape = [2, 2]; let tensor = TestTensor::<2>::zeros(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[0., 0.], [0., 0.]]), false); } #[test] fn should_support_float_zeros_options() { let shape = [2, 2]; let tensor = TestTensor::<2>::zeros(shape, (&Default::default(), DType::F32)); assert_eq!(tensor.shape(), shape.into()); assert_eq!(tensor.dtype(), DType::F32); tensor .into_data() .assert_eq(&TensorData::from([[0., 0.], [0., 0.]]), false); } #[test] fn should_support_float_ones() { let shape = [2, 2]; let tensor = TestTensor::<2>::ones(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false); } #[test] fn should_support_float_ones_options() { let shape = [2, 2]; let tensor = TestTensor::<2>::ones(shape, (&Default::default(), DType::F32)); assert_eq!(tensor.shape(), shape.into()); assert_eq!(tensor.dtype(), DType::F32); tensor .into_data() .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/iter_dim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_1d_iter_last_item() { let data = [1, 2, 3, 4]; let device = Default::default(); let tensor = TestTensorInt::<1>::from_ints(data, &device); tensor .iter_dim(0) .last() .unwrap() .into_data() .assert_eq(&TensorData::from([4]), false); } #[test] #[should_panic] fn test_too_high_dimension() { TestTensor::<1>::zeros([10], &Default::default()).iter_dim(1); } #[test] fn test_transposed() { let data = [ [1., 2., 3., 1., 2.], [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]; let tensor = TestTensor::<2>::from_floats(data, &Default::default()); let lhs = tensor.clone().slice([1..2, 0..5]); let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap(); assert_eq!( lhs.into_data().as_slice::().unwrap(), rhs.into_data().as_slice::().unwrap() ); } #[test] fn test_2d_iter_dim() { let tensor = TestTensor::<2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &Default::default()); let mut iter = tensor.iter_dim(0); let iter1 = iter.next().unwrap(); iter1 .into_data() .assert_eq(&TensorData::from([[3.0, 4.9, 2.0]]), false); let iter2 = iter.next().unwrap(); iter2 .into_data() .assert_eq(&TensorData::from([[2.0, 1.9, 3.0]]), false); assert!(iter.next().is_none()); } #[test] fn test_2d_iter_dim1() { let tensor = TestTensor::<2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &Default::default()); let mut iter = tensor.iter_dim(1); let iter1 = iter.next().unwrap(); iter1 .into_data() .assert_eq(&TensorData::from([[3.0], [2.0]]), false); let iter2 = iter.next().unwrap(); iter2 .into_data() .assert_eq(&TensorData::from([[4.9], [1.9]]), false); let iter3 = iter.next().unwrap(); iter3 .into_data() .assert_eq(&TensorData::from([[2.0], [3.0]]), false); assert!(iter.next().is_none()); } #[test] fn test_3d_iter_dim() { let tensor = TestTensor::<3>::from([[ [1., 2., 3., 1., 2.], [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]]); let mut iter = tensor.clone().iter_dim(0); let iter1 = iter.next().unwrap(); iter1.into_data().assert_eq(&tensor.into_data(), true); assert!(iter.next().is_none()); } #[test] fn test_3d_iter_dim1() { let tensor = TestTensor::<3>::from([[ [1., 2., 3., 1., 2.], [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]]); let mut iter = tensor.iter_dim(1); let iter1 = iter.next().unwrap(); iter1 .into_data() .assert_eq(&TensorData::from([[[1., 2., 3., 1., 2.]]]), false); let iter2 = iter.next().unwrap(); iter2 .into_data() .assert_eq(&TensorData::from([[[4., 5., 6., 1., 2.]]]), false); let iter3 = iter.next().unwrap(); iter3 .into_data() .assert_eq(&TensorData::from([[[7., 8., 9., 1., 2.]]]), false); assert!(iter.next().is_none()); } #[test] fn test_3d_iter_dim2() { let tensor = TestTensor::<3>::from([[ [1., 2., 3., 1., 2.], [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]]); let mut iter = tensor.iter_dim(2); let iter1 = iter.next().unwrap(); iter1 .into_data() .assert_eq(&TensorData::from([[[1.], [4.], [7.]]]), false); let iter2 = iter.next().unwrap(); iter2 .into_data() .assert_eq(&TensorData::from([[[2.], [5.], [8.]]]), false); let iter3 = iter.next().unwrap(); iter3 .into_data() .assert_eq(&TensorData::from([[[3.], [6.], [9.]]]), false); let iter4 = iter.next().unwrap(); iter4 .into_data() .assert_eq(&TensorData::from([[[1.], [1.], [1.]]]), false); let iter5 = iter.next().unwrap(); iter5 .into_data() .assert_eq(&TensorData::from([[[2.], [2.], [2.]]]), false); assert!(iter.next().is_none()); } #[test] fn test_iteration_over_low_dim() { let data = [[ [1., 2., 3., 1., 2.], [4., 5., 6., 1., 2.], [7., 8., 9., 1., 2.], ]]; let tensor = TestTensor::<3>::from_floats(data, &Default::default()); let lhs = tensor.iter_dim(2).nth(1).unwrap(); let rhs = TestTensor::<1>::from([2., 5., 8.]); assert_eq!( lhs.into_data().as_slice::().unwrap(), rhs.into_data().as_slice::().unwrap() ); } #[test] fn test_iter_dim_double_end() { let input = TestTensorInt::<1>::arange(0..(4 * 6 * 3), &Default::default()).reshape([4, 6, 3]); let mut iter = input.iter_dim(1); let ele0 = TensorData::from([[[0, 1, 2]], [[18, 19, 20]], [[36, 37, 38]], [[54, 55, 56]]]); let ele1 = TensorData::from([[[3, 4, 5]], [[21, 22, 23]], [[39, 40, 41]], [[57, 58, 59]]]); let ele2 = TensorData::from([[[6, 7, 8]], [[24, 25, 26]], [[42, 43, 44]], [[60, 61, 62]]]); let ele3 = TensorData::from([ [[9, 10, 11]], [[27, 28, 29]], [[45, 46, 47]], [[63, 64, 65]], ]); let ele4 = TensorData::from([ [[12, 13, 14]], [[30, 31, 32]], [[48, 49, 50]], [[66, 67, 68]], ]); let ele5 = TensorData::from([ [[15, 16, 17]], [[33, 34, 35]], [[51, 52, 53]], [[69, 70, 71]], ]); iter.next().unwrap().into_data().assert_eq(&ele0, false); iter.next_back() .unwrap() .into_data() .assert_eq(&ele5, false); iter.next_back() .unwrap() .into_data() .assert_eq(&ele4, false); iter.next().unwrap().into_data().assert_eq(&ele1, false); iter.next().unwrap().into_data().assert_eq(&ele2, false); iter.next().unwrap().into_data().assert_eq(&ele3, false); assert!(iter.next().is_none()); assert!(iter.next_back().is_none()); } #[test] fn test_iter_dim_single_element() { let input = TestTensorInt::<1>::arange(0..(4 * 3), &Default::default()).reshape([4, 1, 3]); let mut iter = input.clone().iter_dim(1); iter.next() .unwrap() .into_data() .assert_eq(&input.clone().into_data(), false); assert!(iter.next_back().is_none()); assert!(iter.next().is_none()); let mut iter = input.clone().iter_dim(1); iter.next_back() .unwrap() .into_data() .assert_eq(&input.clone().into_data(), false); assert!(iter.next().is_none()); assert!(iter.next_back().is_none()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/log.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_log_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.log(); let expected = TensorData::from([ [-f32::INFINITY, 0.0, core::f32::consts::LN_2], [1.09861, 1.38629, 1.60944], ]); output.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_relative(1e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/log1p.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_exp_log1p() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.log1p(); let expected = TensorData::from([ [0.0, core::f32::consts::LN_2, 1.09861], [1.38629, 1.60944, 1.79176], ]); output.into_data().assert_approx_eq::( &expected, Tolerance::default().set_half_precision_relative(1e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/mask.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_mask_fill_swap_dims() { let device = Default::default(); let tensor_1 = TestTensorInt::arange(0..16, &device).float(); let tensor_1 = tensor_1.reshape([2, 2, 4]); let tensor_1 = tensor_1.swap_dims(0, 2); let mask = tensor_1.clone().lower_equal_elem(5.0); let output = tensor_1.clone().mask_fill(mask, -5.0); let expected = TensorData::from([ [[-5.0, 8.0], [-5.0, 12.0]], [[-5.0, 9.0], [-5.0, 13.0]], [[-5.0, 10.0], [6.0, 14.0]], [[-5.0, 11.0], [7.0, 15.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mask_where_ops() { let device = Default::default(); let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let value = TestTensor::<2>::from_data(TensorData::from([[1.8, 2.8], [3.8, 4.8]]), &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mask_where_broadcast() { let device = Default::default(); // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]); let mask = TestTensorBool::<3>::from_bool( TensorData::from([ [[true, false], [false, true]], [[false, true], [true, false]], [[false, false], [false, false]], [[true, true], [true, true]], ]), &device, ); let value = TestTensor::<3>::ones([4, 2, 2], &device); let output = tensor.float().mask_where(mask, value); let expected = TensorData::from([ [[1., 3.], [4., 1.]], [[2., 1.], [1., 5.]], [[2., 3.], [4., 5.]], [[1., 1.], [1., 1.]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mask_where_broadcast_value_small() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(2..4, &device).float(); let mask = TestTensorBool::<1>::from_bool(TensorData::from([true, false]), &device); let value = TestTensor::<1>::ones([1], &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([1., 3.]); output.into_data().assert_eq(&expected, false); } #[test] fn should_handle_mask_where_nans() { let device = Default::default(); let tensor = TestTensor::from_data( [ [f32::NAN, f32::NAN, f32::NAN], [f32::NAN, f32::NAN, f32::NAN], [f32::NAN, f32::NAN, f32::NAN], ], &device, ); let mask = TestTensorBool::<2>::from_bool( TensorData::from([ [true, true, true], [true, true, false], [false, false, false], ]), &device, ); let value = TestTensor::<2>::from_data( TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]), &device, ); let output = tensor.mask_where(mask, value); let expected = TensorData::from([ [0.9, 0.8, 0.7], [0.6, 0.5, f32::NAN], [f32::NAN, f32::NAN, f32::NAN], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_mask_fill_ops() { let device = Default::default(); let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let output = tensor.mask_fill(mask, 2.0); let expected = TensorData::from([[2.0, 7.0], [2.0, 2.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mask_fill_broadcasted() { let device = Default::default(); let tensor = TestTensor::zeros([1, 4, 2, 2], &device); let mask = TestTensorBool::<4>::from_bool( TensorData::from([[[[true, false], [false, true]]]]), &device, ); let output = tensor.mask_fill(mask, 2.0); let expected = TensorData::from([[ [[2., 0.], [0., 2.]], [[2., 0.], [0., 2.]], [[2., 0.], [0., 2.]], [[2., 0.], [0., 2.]], ]]); output.into_data().assert_eq(&expected, false); } #[test] fn float_mask_fill_infinite() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [f32::NEG_INFINITY, f32::NEG_INFINITY], [f32::NEG_INFINITY, f32::NEG_INFINITY], ], &device, ); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let output = tensor.mask_fill(mask, 10.0f32); let expected = TensorData::from([[10f32, f32::NEG_INFINITY], [f32::NEG_INFINITY, 10f32]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/matmul.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::{ElementConversion, Tolerance, backend::Backend}; #[test] fn test_float_matmul_d2() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]], &device); let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_d3() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_floats([[[1.0, 7.0], [2.0, 3.0]]], &device); let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_broadcast_1() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_floats([[[1.0, 7.0], [2.0, 3.0]]], &device); let tensor_2 = TestTensor::from_floats( [[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]], &device, ); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_broadcast_4d() { let device = Default::default(); // [2, 1, 2, 2] let tensor_1 = TestTensor::<4>::from_floats( [[[[1.0, 7.0], [2.0, 3.0]]], [[[2.0, 5.0], [6.0, 3.0]]]], &device, ); // [1, 2, 2, 2] let tensor_2 = TestTensor::from_floats( [[[[9.0, 8.0], [1.0, 4.0]], [[2.0, 7.0], [3.0, 5.0]]]], &device, ); // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [[[16.0, 36.0], [21.0, 28.0]], [[23.0, 42.0], [13.0, 29.0]]], [[[23.0, 36.0], [57.0, 60.0]], [[19.0, 39.0], [21.0, 57.0]]], ]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_simple_1() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats([[5.0, 14.0], [14.0, 50.0]], &device); let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_4_3() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats( [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]], &device, ); let tensor_2 = TestTensor::from_floats( [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.], [12., 13., 14.]], &device, ); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[56., 62., 68.], [152., 174., 196.], [248., 286., 324.]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_batch_vec_mat() { let device = Default::default(); // [..., B, 1, K] = [3, 1, 2] let tensor_1 = TestTensor::<3>::from_floats([[[1.0, 7.0]], [[2.0, 3.0]], [[1.0, 5.0]]], &device); // [..., 1, K, N] = [1, 2, 3] let tensor_2 = TestTensor::<3>::from_floats([[[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]], &device); let tensor_3 = tensor_1.matmul(tensor_2); // [..., B, 1, N] = [3, 1, 3] let expected = TensorData::from([ [[18.0, 28.0, 40.0]], [[14.0, 23.0, 25.0]], [[14.0, 22.0, 30.0]], ]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_trivial() { let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..16, &device) .reshape([4, 4]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1); tensor_3.into_data().assert_approx_eq::( &TensorData::from([ [56., 62., 68., 74.], [152., 174., 196., 218.], [248., 286., 324., 362.], [344., 398., 452., 506.], ]), Tolerance::default(), ); } #[test] fn test_float_matmul_trivial_transposed() { let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..16, &device) .reshape([4, 4]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); tensor_3.into_data().assert_approx_eq::( &TensorData::from([ [14., 38., 62., 86.], [38., 126., 214., 302.], [62., 214., 366., 518.], [86., 302., 518., 734.], ]), Tolerance::default(), ); } /// Regression test for batch bug in fused matmul #[test] fn test_float_matmul_vecmat_transposed_fused() { let device = Default::default(); let batch1 = 1; let batch2 = 2; let batch = batch1 * batch2; let seq_length = 3; let d_model = 32; // Guard int arange limits #[allow(clippy::unnecessary_cast)] if (IntElem::MAX as i64) < seq_length * d_model * batch { return; } if FloatElem::MAX.elem::() < 269493.0 { return; } let weight: TestTensor<4> = TestTensorInt::arange(0..d_model * batch, &device) .reshape([batch1, batch2, 1, d_model]) .float(); let signal: TestTensor<4> = TestTensorInt::arange(0..seq_length * d_model * batch, &device) .reshape([batch1, batch2, seq_length, d_model]) .float(); TestBackend::sync(&device).unwrap(); let weight = weight.transpose(); let out = signal.matmul(weight) + 5; let expected = TensorData::from([[ [[10421.0], [26293.0], [42165.0]], [[172213.0], [220853.0], [269493.0]], ]]); expected.assert_approx_eq(&out.into_data(), Tolerance::::strict()); } #[test] fn test_float_matmul_4_8() { let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..32, &device) .reshape([4, 8]) .float(); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); tensor_3.into_data().assert_approx_eq::( &TensorData::from([ [140., 364., 588., 812.], [364., 1100., 1836., 2572.], [588., 1836., 3084., 4332.], [812., 2572., 4332., 6092.], ]), Tolerance::default(), ); } #[test] fn test_float_matmul_simple_2() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &device); let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[50.0]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_float_matmul_simple_3() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats( [[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]], &device, ); let tensor_2 = TestTensor::from_floats( [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]], &device, ); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [9., 18., 27., 36.], [12., 24., 36., 48.], [15., 30., 45., 60.], [18., 36., 54., 72.], ]); tensor_3.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_should_panic_when_inner_dimensions_are_not_equal() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]], &device); let tensor_2 = TestTensor::from_floats( [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]], &device, ); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [9., 18., 27., 36.], [12., 24., 36., 48.], [15., 30., 45., 60.], [18., 36., 54., 72.], ]); tensor_3.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/maxmin.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_max_dim_2d() { let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); f.clone() .max_dim(0) .into_data() .assert_eq(&TensorData::from([[3., 4., 5.]]), false); f.clone() .max_dim(1) .into_data() .assert_eq(&TensorData::from([[2.], [5.]]), false); // Negative Index f.clone() .max_dim(-1) .into_data() .assert_eq(&TensorData::from([[2.], [5.]]), false); // Regression Test: https://github.com/tracel-ai/burn/issues/3139 let z = f.clone().int(); z.clone() .max_dim(0) .into_data() .assert_eq(&TensorData::from([[3, 4, 5]]), false); z.clone() .max_dim(1) .into_data() .assert_eq(&TensorData::from([[2], [5]]), false); } #[test] fn test_max_dims_2d() { let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); f.clone() .max_dims(&[0]) .into_data() .assert_eq(&TensorData::from([[3., 4., 5.]]), false); f.clone() .max_dims(&[-2]) .into_data() .assert_eq(&TensorData::from([[3., 4., 5.]]), false); f.clone() .max_dims(&[0, 1]) .into_data() .assert_eq(&TensorData::from([[5.]]), false); } #[test] fn test_max_dim_with_indices_2d_with_dim_0th() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); // Positive, Negative Index for idx in [0, -2] { let (output, index) = tensor.clone().max_dim_with_indices(idx); let output_expected = TensorData::from([[3., 4., 5.]]); let index_expected = TensorData::from([[1, 1, 1]]); output.into_data().assert_eq(&output_expected, false); index.into_data().assert_eq(&index_expected, false); } } #[test] fn test_max_dim_with_indices_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let (output, index) = tensor.max_dim_with_indices(1); let output_expected = TensorData::from([[2.], [5.]]); let index_expected = TensorData::from([[2], [2]]); output.into_data().assert_eq(&output_expected, false); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_max_dim_2d_with_0th_dim() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let output = tensor.max_dim(0); let expected = TensorData::from([[3., 4., 5.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_max_pair() { let a = TestTensor::<1>::from_floats([1.0, 2.0, 3.0, 4.0], &Default::default()); let b = TestTensor::from_floats([2.0, 1.0, 4.0, 5.0], &Default::default()); let output = a.max_pair(b); let expected = TensorData::from([2.0, 2.0, 4.0, 5.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_min_dim_2d() { let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); f.clone() .min_dim(0) .into_data() .assert_eq(&TensorData::from([[0., 1., 2.]]), false); f.clone() .min_dim(1) .into_data() .assert_eq(&TensorData::from([[0.], [3.]]), false); // Negative Index f.clone() .min_dim(-1) .into_data() .assert_eq(&TensorData::from([[0.], [3.]]), false); // Regression Test: https://github.com/tracel-ai/burn/issues/3139 let z = f.int(); z.clone() .min_dim(0) .into_data() .assert_eq(&TensorData::from([[0, 1, 2]]), false); z.clone() .min_dim(1) .into_data() .assert_eq(&TensorData::from([[0], [3]]), false); } #[test] fn test_min_dims_2d() { let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); f.clone() .min_dims(&[0]) .into_data() .assert_eq(&TensorData::from([[0., 1., 2.]]), false); f.clone() .min_dims(&[-2]) .into_data() .assert_eq(&TensorData::from([[0., 1., 2.]]), false); f.clone() .min_dims(&[0, 1]) .into_data() .assert_eq(&TensorData::from([[0.]]), false); } #[test] fn test_min_dim_with_indices_2d() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let (output, index) = tensor.min_dim_with_indices(1); let output_expected = TensorData::from([[0.], [3.]]); let index_expected = TensorData::from([[0], [0]]); output.into_data().assert_eq(&output_expected, false); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_min_dim_2d_with_0th_dim() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); let output = tensor.min_dim(0); let expected = TensorData::from([[0., 1., 2.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_min_dim_with_indices_2d_with_0th_dim() { let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default()); // Positive, Negative Index for idx in [0, -2] { let (output, index) = tensor.clone().min_dim_with_indices(idx); let output_expected = TensorData::from([[0., 1., 2.]]); let index_expected = TensorData::from([[0, 0, 0]]); output.into_data().assert_eq(&output_expected, false); index.into_data().assert_eq(&index_expected, false); } } #[test] fn test_min_pair() { let a = TestTensor::<1>::from_floats([1.0, 2.0, 3.0, 4.0], &Default::default()); let b = TestTensor::from_floats([2.0, 1.0, 4.0, 5.0], &Default::default()); let output = a.min_pair(b); let expected = TensorData::from([1.0, 1.0, 3.0, 4.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_max_abs() { let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default()); let output = tensor.max_abs(); let expected = TensorData::from([6.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_max_abs_dim_2d_dim_0() { let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default()); let output = tensor.clone().max_abs_dim(0); let expected = TensorData::from([[5., 6., 2.]]); output.into_data().assert_eq(&expected, false); // Negative Index let output = tensor.clone().max_abs_dim(-2); let expected = TensorData::from([[5., 6., 2.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_max_abs_dims_2d() { let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default()); tensor .clone() .max_abs_dims(&[0]) .into_data() .assert_eq(&TensorData::from([[5., 6., 2.]]), false); tensor .clone() .max_abs_dims(&[-2]) .into_data() .assert_eq(&TensorData::from([[5., 6., 2.]]), false); tensor .clone() .max_abs_dims(&[0, 1]) .into_data() .assert_eq(&TensorData::from([[6.]]), false); } #[test] fn test_max_abs_dim_2d_dim_1() { let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default()); let output = tensor.max_abs_dim(1); let expected = TensorData::from([[2.], [6.]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/mod.rs ================================================ use super::*; mod abs; mod add; mod aggregation; mod all; mod any; mod arg; mod cast; mod cat; mod ceil; mod chunk; mod clamp; mod close; mod comparison; mod create_like; mod cross; mod cumulative; mod div; mod dot; mod erf; mod exp; mod expand; mod finite; mod flatten; mod flip; mod floor; mod fmod; mod full; mod gather_scatter; mod grid_sample; mod inf; mod init; mod iter_dim; mod log; mod log1p; mod mask; mod matmul; mod maxmin; mod movedim; mod mul; mod nan; mod narrow; mod neg; mod one_hot; mod padding; mod permute; mod powf; mod powf_scalar; mod prod; mod random; mod recip; mod remainder; mod repeat; mod repeat_dim; mod reshape; mod round; mod select; mod sign; mod slice; mod slice_assign; mod sort_argsort; mod split; mod sqrt; mod square; mod squeeze; mod stack; mod sub; mod take; mod topk; mod transaction; mod transpose; mod tri; mod trig; mod trunc; mod unfold; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/movedim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn movedim_float() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); let permuted = tensor.clone().movedim(0, 2); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float() let expected = TensorData::from([ [[0., 12.], [1., 13.], [2., 14.], [3., 15.]], [[4., 16.], [5., 17.], [6., 18.], [7., 19.]], [[8., 20.], [9., 21.], [10., 22.], [11., 23.]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().movedim(0, -1); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().movedim(0, 0); permuted.into_data().assert_eq(&tensor.into_data(), true); } #[test] fn vec_input_float() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]); // from pytorch // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).float() let expected = TensorData::from([ [[0., 1., 2., 3.], [12., 13., 14., 15.]], [[4., 5., 6., 7.], [16., 17., 18., 19.]], [[8., 9., 10., 11.], [20., 21., 22., 23.]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axes let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]); permuted.into_data().assert_eq(&expected, false); // Test with the same axes let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]); permuted.into_data().assert_eq(&tensor.into_data(), true); } #[test] fn different_input_types() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); let permuted = tensor.clone().movedim(0_usize, 2_i32); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float() let expected = TensorData::from([ [[0., 12.], [1., 13.], [2., 14.], [3., 15.]], [[4., 16.], [5., 17.], [6., 18.], [7., 19.]], [[8., 20.], [9., 21.], [10., 22.], [11., 23.]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().movedim(0_usize, -1); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().movedim(0_i32, 0_usize); permuted.into_data().assert_eq(&tensor.into_data(), true); } #[test] #[should_panic] fn edge_different_sizes() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().movedim(vec![0, 1], vec![0]); } #[test] #[should_panic] fn edge_out_of_bound_axis() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().movedim(0, 100); } #[test] #[should_panic] fn edge_vec_is_not_a_set() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().movedim(vec![0, 1, 1, 1, 1], vec![0, 0, 1]); } #[test] #[should_panic] fn edge_out_of_bound_axis_vec() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with an out of bound axis let _ = tensor.clone().movedim(vec![0, 100], vec![0, 1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/mul.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_mul_ops() { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_mul_broadcast() { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_mul_broadcast_2_dims() { let device = Default::default(); let tensor_1 = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device).reshape([3, 1]); let tensor_2 = TestTensor::<1>::from_data([3.0, 4.0, 5.0], &device).reshape([1, 3]); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mul_scalar_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor * scalar; let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/nan.rs ================================================ use super::*; use burn_tensor::cast::ToElement; #[test] fn is_nan() { let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let no_nan_expected = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]); let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]); let with_nan_expected = TestTensorBool::<2>::from([[false, true, false], [true, false, false]]); assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data()); assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data()); } #[test] fn contains_nan() { let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); assert!(!no_nan.contains_nan().into_scalar().to_bool()); let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]); assert!(with_nan.contains_nan().into_scalar().to_bool()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/narrow.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{Shape, TensorData}; #[test] fn test_narrow_1() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let output = tensor.clone().narrow(0, 0, 2); let expected = TensorData::from([[1., 2., 3.], [4., 5., 6.]]); assert_eq!(output.shape(), Shape::from([2, 3])); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_narrow_2() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let output = tensor.clone().narrow(1, 1, 2); let expected = TensorData::from([[2., 3.], [5., 6.], [8., 9.]]); assert_eq!(output.shape(), Shape::from([3, 2])); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_narrow_3() { let device = &Default::default(); let shape = Shape::new([8, 8]); let tensor = TestTensorInt::arange(0..shape.num_elements() as i64, device) .reshape::<2, _>(shape) .float(); let output = tensor.clone().narrow(0, 3, 4); let expected = TensorData::from([ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0], [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0], [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[should_panic] fn test_narrow_invalid_dim() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let _output = tensor.narrow(2, 0, 2); } #[test] #[should_panic] fn test_narrow_invalid_start() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let _output = tensor.narrow(0, 3, 2); } #[test] #[should_panic] fn test_narrow_invalid_zero_length() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let _output = tensor.narrow(0, 1, 0); } #[test] #[should_panic] fn test_narrow_invalid_length() { let tensor = TestTensor::<2>::from_data( TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), &Default::default(), ); let _output = tensor.narrow(0, 0, 4); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/neg.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_neg_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.neg(); let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32 assert_eq!( output .into_data() .convert::() .as_slice::() .unwrap(), expected.as_slice::().unwrap() ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/one_hot.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn float_should_support_one_hot() { let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); let one_hot_tensor: TestTensor<2> = tensor.one_hot(5); let expected = TensorData::from([ [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], ]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] fn float_should_support_one_hot_index() { let tensor = TestTensor::<1>::from([2.0]); let one_hot_tensor: TestTensor<2> = tensor.one_hot::<2>(10); let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { let tensor = TestTensor::<1>::from([5.0]); let _result: TestTensor<2> = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { let tensor = TestTensor::<1>::from([0.0]); let _result: TestTensor<2> = tensor.one_hot(0); } #[test] fn one_hot_fill_with_negative_axis_and_indices() { let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); let expected = TensorData::from([ [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]], ]); let one_hot_tensor: TestTensor<3> = tensor.one_hot_fill(3, 5.0, 0.0, -1); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] fn one_hot_fill_with_negative_indices() { let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); let expected = TensorData::from([ [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], ]); let one_hot_tensor: TestTensor<2> = tensor.one_hot_fill(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, false); } #[should_panic] #[test] fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); let _one_hot_tensor: TestTensor<3> = tensor.one_hot_fill(2, 5.0, 0.0, 3); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/padding.rs ================================================ use super::*; use burn_tensor::{TensorData, ops::PadMode}; #[test] fn padding_constant_2d_test() { let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]; let tensor = TestTensor::<2>::from(unpadded_floats); let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); let expected = TensorData::from([ [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1], [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_4d_test() { let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; let tensor = TestTensor::<4>::from(unpadded_floats); let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); let expected = TensorData::from([[[ [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 1.1, 1.1], [1.1, 1.1, 2.0, 3.0, 1.1, 1.1], [1.1, 1.1, 4.0, 5.0, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], ]]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_asymmetric_test() { let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; let tensor = TestTensor::<4>::from(unpadded_floats); let padded_tensor = tensor.pad((2, 1, 4, 3), 1.1); let expected = TensorData::from([[[ [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 1.1], [1.1, 1.1, 2.0, 3.0, 1.1], [1.1, 1.1, 4.0, 5.0, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1], ]]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_2d_test() { // Test reflect padding on a 2D tensor // Input: [[1, 2, 3], [4, 5, 6]] // With padding (1, 1, 1, 1): // - Top: reflect row 1 -> [4, 5, 6] // - Bottom: reflect row 0 -> [1, 2, 3] // - Left: reflect col 1 // - Right: reflect col 1 let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect); // Expected: reflect excludes the edge value // Before padding height: [[1,2,3], [4,5,6]] // After top pad (reflect row at index 1): [[4,5,6], [1,2,3], [4,5,6]] // After bottom pad (reflect row at index 1 from end): [[4,5,6], [1,2,3], [4,5,6], [1,2,3]] // Then pad width similarly let expected = TensorData::from([ [5.0, 4.0, 5.0, 6.0, 5.0], [2.0, 1.0, 2.0, 3.0, 2.0], [5.0, 4.0, 5.0, 6.0, 5.0], [2.0, 1.0, 2.0, 3.0, 2.0], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_width_only_test() { // Test reflect padding on width dimension only let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); let padded_tensor = tensor.pad((2, 2, 0, 0), PadMode::Reflect); // Input: [1, 2, 3, 4] // Reflect left 2: take indices [1, 2] = [2, 3], flip = [3, 2] // Reflect right 2: take indices [1, 2] from end = [2, 3], flip = [3, 2] // Result: [3, 2, 1, 2, 3, 4, 3, 2] let expected = TensorData::from([[3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_4d_test() { // Test reflect padding on 4D tensor (common for images: NCHW) let tensor = TestTensor::<4>::from([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]); let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect); let expected = TensorData::from([[[ [5.0, 4.0, 5.0, 6.0, 5.0], [2.0, 1.0, 2.0, 3.0, 2.0], [5.0, 4.0, 5.0, 6.0, 5.0], [8.0, 7.0, 8.0, 9.0, 8.0], [5.0, 4.0, 5.0, 6.0, 5.0], ]]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_edge_2d_test() { // Test edge padding on a 2D tensor let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge); // Edge padding replicates the boundary values let expected = TensorData::from([ [1.0, 1.0, 2.0, 3.0, 3.0], [1.0, 1.0, 2.0, 3.0, 3.0], [4.0, 4.0, 5.0, 6.0, 6.0], [4.0, 4.0, 5.0, 6.0, 6.0], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_edge_width_only_test() { // Test edge padding on width dimension only let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); let padded_tensor = tensor.pad((2, 3, 0, 0), PadMode::Edge); // Input: [1, 2, 3, 4] // Left 2: [1, 1] // Right 3: [4, 4, 4] // Result: [1, 1, 1, 2, 3, 4, 4, 4, 4] let expected = TensorData::from([[1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_edge_4d_test() { // Test edge padding on 4D tensor let tensor = TestTensor::<4>::from([[[[1.0, 2.0], [3.0, 4.0]]]]); let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge); let expected = TensorData::from([[[ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ]]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_default_test() { // Test default PadMode (Constant with 0.0) let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::default()); let expected = TensorData::from([ [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 3.0, 4.0, 0.0], [0.0, 0.0, 0.0, 0.0], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_max_valid_test() { // Test reflect padding at maximum valid size (dim_size - 1) // For a 4-element dimension, max valid padding is 3 let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); // Padding of 3 on left is valid for width=4 (3 < 4) let padded_tensor = tensor.pad((3, 3, 0, 0), PadMode::Reflect); // Input: [1, 2, 3, 4] // Reflect left 3: take indices [1, 2, 3] = [2, 3, 4], flip = [4, 3, 2] // Reflect right 3: take indices [0, 1, 2] = [1, 2, 3], flip = [3, 2, 1] // Result: [4, 3, 2, 1, 2, 3, 4, 3, 2, 1] let expected = TensorData::from([[4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_asymmetric_test() { // Test asymmetric reflect padding let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); // Asymmetric padding: left=2, right=1, top=1, bottom=2 let padded_tensor = tensor.pad((2, 1, 1, 2), PadMode::Reflect); let expected = TensorData::from([ [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], [3.0, 2.0, 1.0, 2.0, 3.0, 2.0], [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], [9.0, 8.0, 7.0, 8.0, 9.0, 8.0], [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], [3.0, 2.0, 1.0, 2.0, 3.0, 2.0], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic(expected = "Reflect padding")] fn padding_reflect_exceeds_dimension_test() { // Test that reflect padding panics when padding >= dim_size let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0]]); // Padding of 3 on width=3 should panic (3 >= 3, need padding < dim_size) let _ = tensor.pad((3, 0, 0, 0), PadMode::Reflect); } #[test] fn padding_edge_asymmetric_test() { // Test asymmetric edge padding let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); // Asymmetric padding: left=2, right=1, top=3, bottom=1 let padded_tensor = tensor.pad((2, 1, 3, 1), PadMode::Edge); let expected = TensorData::from([ [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], [4.0, 4.0, 4.0, 5.0, 6.0, 6.0], [4.0, 4.0, 4.0, 5.0, 6.0, 6.0], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_zero_padding_test() { // Test that zero padding returns the original tensor unchanged let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let padded_constant = tensor.clone().pad((0, 0, 0, 0), PadMode::Constant(5.0)); let padded_reflect = tensor.clone().pad((0, 0, 0, 0), PadMode::Reflect); let padded_edge = tensor.clone().pad((0, 0, 0, 0), PadMode::Edge); let expected = TensorData::from([[1.0, 2.0], [3.0, 4.0]]); padded_constant.into_data().assert_eq(&expected, false); padded_reflect.into_data().assert_eq(&expected, false); padded_edge.into_data().assert_eq(&expected, false); } #[test] fn padding_empty_tensor_constant_test() { // Test constant padding on an empty tensor (zero-sized dimension) // This should work - creates a tensor filled with the constant value let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); // Padding an empty height dimension with constant should create a tensor of just padding let padded = tensor.pad((0, 0, 2, 2), 1.0); // Result should be 4x3 (0 + 2 + 2 = 4 rows) assert_eq!(padded.dims(), [4, 3]); let expected = TensorData::from([ [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], ]); padded.into_data().assert_eq(&expected, false); } #[test] #[should_panic(expected = "edge padding")] fn padding_empty_tensor_edge_panics_test() { // Test that edge padding panics on empty tensor let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); // Edge padding on zero-sized dimension should panic let _ = tensor.pad((0, 0, 1, 1), PadMode::Edge); } #[test] #[should_panic(expected = "Reflect padding")] fn padding_empty_tensor_reflect_panics_test() { // Test that reflect padding panics on empty tensor let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); // Reflect padding on zero-sized dimension should panic let _ = tensor.pad((0, 0, 1, 1), PadMode::Reflect); } // --- Tests for N-dimensional padding using (before, after) pairs --- #[test] fn padding_constant_pairs_2d_test() { // Same as padding_constant_2d_test but using the new pairs API let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); // [(row_before, row_after), (col_before, col_after)] let padded_tensor = tensor.pad([(2, 2), (2, 2)], 1.1); let expected = TensorData::from([ [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1], [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], ]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_single_dim_test() { // Pad only the last dimension let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let padded_tensor = tensor.pad([(1, 1)], 0.0); let expected = TensorData::from([[0.0, 1.0, 2.0, 0.0], [0.0, 3.0, 4.0, 0.0]]); padded_tensor.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_all_dims_4d_test() { // Pad all 4 dimensions of a 4D tensor (batch, channel, height, width) // Input: shape [1, 1, 2, 2] let tensor = TestTensor::<4>::from([[[[1.0, 2.0], [3.0, 4.0]]]]); // Pad: batch(1,1), channel(1,1), height(0,0), width(0,0) let padded = tensor.pad([(1, 1), (1, 1), (0, 0), (0, 0)], 0.0); // Shape should be [3, 3, 2, 2] assert_eq!(padded.dims(), [3, 3, 2, 2]); let expected = TensorData::from([ [ [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], ], [ [[0.0, 0.0], [0.0, 0.0]], [[1.0, 2.0], [3.0, 4.0]], [[0.0, 0.0], [0.0, 0.0]], ], [ [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], ], ]); padded.into_data().assert_eq(&expected, false); } #[test] fn padding_constant_batch_dim_only_test() { // Pad only the batch dimension of a 3D tensor [N, H, W] let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]]]); // 3 pairs for 3 dims: batch(1,1), height(0,0), width(0,0) let padded = tensor.pad([(1, 1), (0, 0), (0, 0)], -1.0); assert_eq!(padded.dims(), [3, 2, 2]); let expected = TensorData::from([ [[-1.0, -1.0], [-1.0, -1.0]], [[1.0, 2.0], [3.0, 4.0]], [[-1.0, -1.0], [-1.0, -1.0]], ]); padded.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_pairs_test() { // Reflect padding using pairs API let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); let padded = tensor.pad([(1, 1), (1, 1)], PadMode::Reflect); let expected = TensorData::from([ [5.0, 4.0, 5.0, 6.0, 5.0], [2.0, 1.0, 2.0, 3.0, 2.0], [5.0, 4.0, 5.0, 6.0, 5.0], [8.0, 7.0, 8.0, 9.0, 8.0], [5.0, 4.0, 5.0, 6.0, 5.0], ]); padded.into_data().assert_eq(&expected, false); } #[test] fn padding_edge_pairs_test() { // Edge padding using pairs API let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); let padded = tensor.pad([(1, 1), (1, 1)], PadMode::Edge); let expected = TensorData::from([ [1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0], ]); padded.into_data().assert_eq(&expected, false); } #[test] fn padding_reflect_batch_dim_3d_test() { // Reflect pad the batch dimension of a 3D tensor [N, H, W] // Input shape: [3, 1, 2] - 3 batches, 1 row, 2 cols let tensor = TestTensor::<3>::from([[[1.0, 2.0]], [[3.0, 4.0]], [[5.0, 6.0]]]); // Pad batch dim with reflect(1, 1), no spatial padding let padded = tensor.pad([(1, 1), (0, 0), (0, 0)], PadMode::Reflect); assert_eq!(padded.dims(), [5, 1, 2]); // Reflect on batch: [3,4] [1,2] [3,4] [5,6] [3,4] let expected = TensorData::from([ [[3.0, 4.0]], [[1.0, 2.0]], [[3.0, 4.0]], [[5.0, 6.0]], [[3.0, 4.0]], ]); padded.into_data().assert_eq(&expected, false); } #[test] fn padding_edge_batch_dim_3d_test() { // Edge pad the batch dimension of a 3D tensor let tensor = TestTensor::<3>::from([[[1.0, 2.0]], [[3.0, 4.0]]]); let padded = tensor.pad([(2, 1), (0, 0), (0, 0)], PadMode::Edge); assert_eq!(padded.dims(), [5, 1, 2]); let expected = TensorData::from([ [[1.0, 2.0]], [[1.0, 2.0]], [[1.0, 2.0]], [[3.0, 4.0]], [[3.0, 4.0]], ]); padded.into_data().assert_eq(&expected, false); } #[test] #[should_panic(expected = "Padding has")] fn padding_too_many_pairs_panics_test() { let tensor = TestTensor::<2>::from([[1.0, 2.0]]); // 3 pairs for a 2D tensor should panic let _ = tensor.pad([(1, 1), (1, 1), (1, 1)], 0.0); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/permute.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn permute_float_a() { let tensor = TestTensor::<1>::from([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., ]) .reshape([2, 2, 4]); let permuted = tensor.clone().permute([2, 1, 0]); let expected = TensorData::from([ [[0., 8.], [4., 12.]], [[1., 9.], [5., 13.]], [[2., 10.], [6., 14.]], [[3., 11.], [7., 15.]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().permute([-1, 1, 0]); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().permute([0, 1, 2]); permuted.into_data().assert_eq(&tensor.into_data(), false); } #[test] fn permute_float() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device) .reshape([2, 3, 4]) .float(); let permuted = tensor.clone().permute([2, 1, 0]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0).float() let expected = TensorData::from([ [[0., 12.], [4., 16.], [8., 20.]], [[1., 13.], [5., 17.], [9., 21.]], [[2., 14.], [6., 18.], [10., 22.]], [[3., 15.], [7., 19.], [11., 23.]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().permute([-1, 1, 0]); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().permute([0, 1, 2]); permuted.into_data().assert_eq(&tensor.into_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/powf.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_powf_ops() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]); let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_power() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]]); let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_values_with_even_power() { let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[2.0, 2.0, 4.0], [4.0, 4.0, 2.0]]); let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, 1.0, 16.0], [81.0, 256.0, 25.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_values_with_odd_power() { let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let pow = TensorData::from([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]); let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default()); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_powf_broadcasted() { let device = Default::default(); let tensor_1 = TestTensor::<1>::from_floats([2.0, 3.0, 4.0], &device); let tensor_2 = TestTensor::from_floats([1.0], &device); // Broadcast rhs let output = tensor_1.clone().powf(tensor_2.clone()); output .into_data() .assert_approx_eq::(&tensor_1.to_data(), Tolerance::default()); // Broadcast lhs let output = tensor_2.powf(tensor_1); output .into_data() .assert_approx_eq::(&TensorData::from([1.0, 1.0, 1.0]), Tolerance::default()); } fn outer(a: TestTensor<1>, b: TestTensor<1>) -> TestTensor<2> { a.unsqueeze_dim::<2>(1) * b.unsqueeze_dim::<2>(0) } #[test] fn should_support_powf_scalar_tensor() { let device = Default::default(); let head_dim = 64; let seq_len = 1024; let base = 10000; let channel_range = TestTensorInt::arange_step(0..head_dim as i64, 2, &device).float(); let base = TestTensor::<1>::from_data([base as f32], &device); let inv_freq = base.powf(-channel_range / head_dim as f32); let t = TestTensorInt::arange(0..seq_len as i64, &device).float(); let freqs = outer(t, inv_freq); let _cos = freqs.clone().cos(); let _sin = freqs.sin(); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/powf_scalar.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_powf_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(0.71); let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.1815, 2.67586, 3.13522]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_power() { let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(-0.33); let expected = TensorData::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_values_with_even_power() { let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(4.0); let expected = TensorData::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_neg_values_with_odd_power() { let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.powf_scalar(3.0); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/prod.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_prod_float() { let tensor_1 = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_1.prod(); output .into_data() .assert_eq(&TensorData::from([-600.0]), false); } #[test] fn test_prod_dim_2d() { let f = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); f.clone() .prod_dim(1) .into_data() .assert_eq(&TensorData::from([[-10.0], [60.0]]), false); f.clone() .prod_dim(-1) .into_data() .assert_eq(&TensorData::from([[-10.0], [60.0]]), false); } #[test] fn test_prod_dims_2d() { let f = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); f.clone() .prod_dims(&[1]) .into_data() .assert_eq(&TensorData::from([[-10.0], [60.0]]), false); f.clone() .prod_dims(&[-1]) .into_data() .assert_eq(&TensorData::from([[-10.0], [60.0]]), false); f.clone() .prod_dims(&[0, 1]) .into_data() .assert_eq(&TensorData::from([[-600.0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/random.rs ================================================ use super::*; use burn_tensor::{Distribution, ElementConversion, TensorData, Tolerance, backend::Backend}; #[test] fn rand_default() { let tensor = TestTensor::<1>::random([20], Distribution::Default, &Default::default()); // check that the tensor is within the range of [0..1) (1 is exclusive) // the conversion can ceil the value if `FloatElem` is less precise than f32 let low = 0.elem::(); let high = 1.elem::(); if FloatElem::EPSILON.elem::() > f32::EPSILON { tensor.into_data().assert_within_range_inclusive(low..=high); } else { tensor.into_data().assert_within_range(low..high); } } #[test] fn rand_uniform() { let tensor = TestTensor::<1>::random([20], Distribution::Uniform(4., 5.), &Default::default()); let low = 4.elem::(); let high = 5.elem::(); if FloatElem::EPSILON.elem::() > f32::EPSILON { tensor.into_data().assert_within_range_inclusive(low..=high); } else { tensor.into_data().assert_within_range(low..high); } } #[test] fn rand_bernoulli() { let tensor = TestTensor::<1>::random([20], Distribution::Bernoulli(1.), &Default::default()); tensor.into_data().assert_eq( &TensorData::new::(vec![1.elem(); 20], [20]), true, ); } #[test] #[ignore] // TODO: mark serial for backends that handle the same devices (e.g. fusion)? fn test_seed_reproducibility() { let device = Default::default(); TestBackend::seed(&device, 42); let t1 = TestTensor::<1>::random([5], Distribution::Default, &device); TestBackend::seed(&device, 42); let t2 = TestTensor::<1>::random([5], Distribution::Default, &device); t1.into_data() .assert_approx_eq::(&t2.into_data(), Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/recip.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_recip_ops() { let data = TensorData::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.recip(); let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/remainder.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; /// From https://pytorch.org/docs/stable/generated/torch.remainder.html #[test] fn should_support_remainder_basic() { let device = Default::default(); let lhs = TestTensor::<1>::from_data(TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]), &device); let rhs = TestTensor::<1>::from_data(TensorData::from([2.0, 3.0, 1.0, 2.0, 1.0, 3.0]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([1.0, 1.0, -0.0, 1.0, 0.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_remainder_basic_scalar() { let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(2.0); let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_remainder_float() { let device = Default::default(); let lhs = TestTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]), &device); let rhs = TestTensor::<1>::from_data( TensorData::from([1.4233, 2.7313, 0.2641, 1.9651, 0.5897]), &device, ); let output = lhs.remainder(rhs); let expected = TensorData::from([1.0, 2.0, 0.0949, 0.0698, 0.2824]); // Metal has less precise remainder function let tolerance = Tolerance::default() .set_half_precision_relative(1e-2) .set_half_precision_absolute(2e-3); output .into_data() .assert_approx_eq::(&expected, tolerance); } /// Also from https://pytorch.org/docs/stable/generated/torch.remainder.html #[test] fn should_support_remainder_float_scalar() { let data = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.clone().remainder_scalar(-1.5); let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_be_zero() { let device = Default::default(); let lhs = TestTensor::<1>::from_data(TensorData::from([0.0, 0.0, 0.0]), &device); let rhs = TestTensor::<1>::from_data(TensorData::from([3.5, -2.1, 1e-4]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([0.0, 0.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_be_zero_scalar() { let data = TensorData::from([0.0, 0.0, 0.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.clone().remainder_scalar(3.5); let expected = TensorData::from([0.0, 0.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_have_no_remainder() { let device = Default::default(); let lhs = TestTensor::<1>::from_data( // Previous values failed on some vulkan backends (driver bug?) // TensorData::from([-1.4843, 1.1350, -2.1563, 1.0862, 0.5, 3.6587]), TensorData::from([-1.0, 1.5, -2.0, 2.5, 0.5, 4.0]), &device, ); let rhs = TestTensor::<1>::from_data( // TensorData::from([1.4843, 1.1350, 2.1563, 1.0862, 0.5, 3.6587]), TensorData::from([1.0, 1.5, 2.0, 2.5, 0.5, 4.0]), &device, ); let output = lhs.remainder(rhs); let expected = TensorData::from([-0., 0., -0., 0., 0., 0.]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_have_no_remainder_scalar() { let data = TensorData::from([-4.0, 4.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(4.0); let expected = TensorData::from([-0.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_be_negative() { let device = Default::default(); let lhs = TestTensor::<1>::from_data(TensorData::from([-7.0, -3.0, 2.0, 6.0]), &device); let rhs = TestTensor::<1>::from_data(TensorData::from([-2.5, -2.1, -1.5, -3.25]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([-2.0, -0.9, -1.0, -0.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_be_negative_scalar() { let data = TensorData::from([-7.0, -3.0, 2.0, 6.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.clone().remainder_scalar(-2.5); let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_fp_dividends() { let data = TensorData::from([-7.5, -2.5, 2.5, 7.5]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(3.0); let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); // for tensor.remainder case, tests above have already covered float point dividend cases } #[test] fn should_support_large_divisor() { let device = Default::default(); let lhs = TestTensor::<1>::from_data( TensorData::from([-1.0, 1.0, -1.5, 1.5, -1.0, 1.0, -1.5, 1.5]), &device, ); let rhs = TestTensor::<1>::from_data( TensorData::from([10.0, 10.0, 10.0, 10.0, -10.0, -10.0, -10.0, -10.0]), &device, ); let output = lhs.remainder(rhs); let expected = TensorData::from([9.0, 1.0, 8.5, 1.5, -1.0, -9.0, -1.5, -8.5]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_large_divisor_scalar() { let data = TensorData::from([-1.0, 1.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.remainder_scalar(10.0); let expected = TensorData::from([9.0, 1.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_remainder_op() { let device = Default::default(); let lhs = TestTensor::<1>::from_data(TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]), &device); let rhs = TestTensor::<1>::from_data(TensorData::from([2.0, 3.0, 1.0, 2.0, 1.0, 3.0]), &device); let output = lhs % rhs; let expected = TensorData::from([1.0, 1.0, -0.0, 1.0, 0.0, 0.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_remainder_scalar_op() { let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor % 2.0; let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/repeat.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_repeat_ops_one_dimension() { let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_float_repeat_repeating_on_many_dimensions() { let data = TensorData::from([ [[1.0f32, 2.0f32], [3.0f32, 4.0f32]], [[5.0f32, 6.0f32], [7.0f32, 8.0f32]], [[9.0f32, 10.0f32], [11.0f32, 12.0f32]], [[13.0f32, 14.0f32], [15.0f32, 16.0f32]], ]); let tensor = TestTensor::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); let expected = TensorData::from([ [ [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], ], [ [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], ], [ [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], ], [ [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], ], [ [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], ], [ [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], ], [ [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], ], [ [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_repeat_0_times_empty() { let tensor = TestTensor::<3>::ones([2, 3, 4], &Default::default()); let output = tensor.repeat(&[1, 0, 2]); assert_eq!(output.shape(), [2, 0, 8].into()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/repeat_dim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_repeat_ops() { let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([ [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], [0.0f32, 1.0f32, 2.0f32], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_float_repeat_on_dims_larger_than_1() { let data = TensorData::from([ [[1.0f32, 2.0f32], [3.0f32, 4.0f32]], [[5.0f32, 6.0f32], [7.0f32, 8.0f32]], [[9.0f32, 10.0f32], [11.0f32, 12.0f32]], [[13.0f32, 14.0f32], [15.0f32, 16.0f32]], ]); let tensor = TestTensor::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(2, 2); let expected = TensorData::from([ [ [1.0f32, 2.0f32, 1.0f32, 2.0f32], [3.0f32, 4.0f32, 3.0f32, 4.0f32], ], [ [5.0f32, 6.0f32, 5.0f32, 6.0f32], [7.0f32, 8.0f32, 7.0f32, 8.0f32], ], [ [9.0f32, 10.0f32, 9.0f32, 10.0f32], [11.0f32, 12.0f32, 11.0f32, 12.0f32], ], [ [13.0f32, 14.0f32, 13.0f32, 14.0f32], [15.0f32, 16.0f32, 15.0f32, 16.0f32], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn repeat_dim_swap_dims_1() { let tensor = TestTensorInt::arange(0..16, &Default::default()).float(); let tensor = tensor.reshape([4, 1, 4]); let tensor = tensor.swap_dims(0, 2); let output = tensor.repeat_dim(1, 4); let expected = TensorData::from([ [ [0.0, 4.0, 8.0, 12.0], [0.0, 4.0, 8.0, 12.0], [0.0, 4.0, 8.0, 12.0], [0.0, 4.0, 8.0, 12.0], ], [ [1.0, 5.0, 9.0, 13.0], [1.0, 5.0, 9.0, 13.0], [1.0, 5.0, 9.0, 13.0], [1.0, 5.0, 9.0, 13.0], ], [ [2.0, 6.0, 10.0, 14.0], [2.0, 6.0, 10.0, 14.0], [2.0, 6.0, 10.0, 14.0], [2.0, 6.0, 10.0, 14.0], ], [ [3.0, 7.0, 11.0, 15.0], [3.0, 7.0, 11.0, 15.0], [3.0, 7.0, 11.0, 15.0], [3.0, 7.0, 11.0, 15.0], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn repeat_dim_swap_dims_2() { let tensor = TestTensorInt::arange(0..16, &Default::default()).float(); let tensor = tensor.reshape([2, 2, 1, 4]); let tensor = tensor.swap_dims(0, 1); let output = tensor.repeat_dim(2, 4); let expected = TensorData::from([ [ [ [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], ], [ [8.0, 9.0, 10.0, 11.0], [8.0, 9.0, 10.0, 11.0], [8.0, 9.0, 10.0, 11.0], [8.0, 9.0, 10.0, 11.0], ], ], [ [ [4.0, 5.0, 6.0, 7.0], [4.0, 5.0, 6.0, 7.0], [4.0, 5.0, 6.0, 7.0], [4.0, 5.0, 6.0, 7.0], ], [ [12.0, 13.0, 14.0, 15.0], [12.0, 13.0, 14.0, 15.0], [12.0, 13.0, 14.0, 15.0], [12.0, 13.0, 14.0, 15.0], ], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn repeat_dim_swap_dims_3() { let tensor = TestTensorInt::arange(0..16, &Default::default()).float(); let tensor = tensor.reshape([1, 2, 2, 4]); let tensor = tensor.swap_dims(0, 2); let tensor = tensor.swap_dims(1, 3); let output = tensor.repeat_dim(2, 4); let expected = TensorData::from([ [ [[0.0, 8.0], [0.0, 8.0], [0.0, 8.0], [0.0, 8.0]], [[1.0, 9.0], [1.0, 9.0], [1.0, 9.0], [1.0, 9.0]], [[2.0, 10.0], [2.0, 10.0], [2.0, 10.0], [2.0, 10.0]], [[3.0, 11.0], [3.0, 11.0], [3.0, 11.0], [3.0, 11.0]], ], [ [[4.0, 12.0], [4.0, 12.0], [4.0, 12.0], [4.0, 12.0]], [[5.0, 13.0], [5.0, 13.0], [5.0, 13.0], [5.0, 13.0]], [[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]], [[7.0, 15.0], [7.0, 15.0], [7.0, 15.0], [7.0, 15.0]], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_repeat_dim_0_times_empty() { let tensor = TestTensor::<3>::ones([2, 3, 4], &Default::default()); let output = tensor.repeat_dim(2, 0); assert_eq!(output.shape(), [2, 3, 0].into()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/reshape.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_rank() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); assert_eq!(tensor.rank(), 1); let data = TensorData::from([[0.0, 1.0, 2.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); assert_eq!(tensor.rank(), 2); } #[test] fn should_support_reshape_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[0.0, 1.0, 2.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_reshape_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.clone().reshape([6]); let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_dim_infererence() { let data = TensorData::from([ [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0], ]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); // Infer the dimension via -1 let reshaped = tensor.clone().reshape([2, -1]); assert_eq!(reshaped.shape(), [2, 6].into()); // Infer the dimension via 0 (keep from the source) and -1 (infer) let reshaped = reshaped.reshape([0, 2, -1]); assert_eq!(reshaped.shape(), [2, 2, 3].into()); // This is effectively as if we did a flatten let reshaped = tensor.clone().reshape([-1]); assert_eq!(reshaped.shape(), [12].into()); // Keeping the first dimension the same (using 0) let reshaped = tensor.clone().reshape([0, 3]); assert_eq!(reshaped.shape(), [4, 3].into()); } #[test] fn should_not_corrupt_after_slice() { let zeros = TestTensor::<1>::zeros([2], &Default::default()); zeros.clone().slice([1..2]).reshape([1]).exp(); // May lead to zeroes being equal to [0.0, 1.0] zeros.into_data().assert_eq( &TestTensor::<1>::zeros([2], &Default::default()).to_data(), true, ); } #[test] #[should_panic] fn multiple_neg_ones() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let _data_actual = tensor.reshape([-1, -1]).into_data(); } #[test] #[should_panic] fn neg_value() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let _data_actual = tensor.reshape([-2, -1]).into_data(); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/round.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_round_ops() { let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.round(); let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_round_ties_even() { let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.round(); let expected = TensorData::from([2., 2., 4., 4., 6., 6.]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/select.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_select_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); let output = tensor.select(0, indices); let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_2d_dim0_same_num_dim() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_data([1, 0], &device); let output = tensor.select(0, indices); let expected = TensorData::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_2d_dim0_more_num_dim() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_data([1, 0, 1, 1], &device); let output = tensor.select(0, indices); let expected = TensorData::from([ [3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_2d_dim0_vec() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], &device); let indices = TestTensorInt::from_data([1, 0, 3, 2], &device); let output = tensor.select(0, indices); let expected = TensorData::from([[2.0, 3.0], [0.0, 1.0], [6.0, 7.0], [4.0, 5.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_2d_dim1() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); let output = tensor.select(1, indices); let expected = TensorData::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device); let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0], &device); let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([3.0, 12.0, 3.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_1d_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([7, 8, 9], &device); let values = TestTensorInt::from_data([5, 4, 3, 2, 1], &device); let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([10, 19, 10]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_2d_dim0() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::from_data(TensorData::from([1, 0]), &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_2d_dim1() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::from_data(TensorData::from([1, 0, 2]), &device); let output = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_3d_dim1_vec() { let device = Default::default(); let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], [[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0], [-7.0, -8.0]], ], &device, ); let indices = TestTensorInt::from_data([1, 0, 3, 2], &device); let output = tensor.select(1, indices); let expected = TensorData::from([ [[3.0, 4.0], [1.0, 2.0], [7.0, 8.0], [5.0, 6.0]], [[-3.0, -4.0], [-1.0, -2.0], [-7.0, -8.0], [-5.0, -6.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn should_select_panic_invalid_dimension() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); tensor.select(10, indices); } #[test] fn should_match_default_implementation_behavior() { // Verify optimized implementation matches original default logic let device = Default::default(); let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); let indices = TestTensorInt::from_data([0, 1, 0], &device); let values = TestTensorBool::<1>::from_data([false, true, true], &device); let optimized_result = tensor .clone() .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); // Manual default implementation logic let int_tensor = tensor.int(); let int_values = values.int(); let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); let default_result = assigned.greater_elem(0); optimized_result .into_data() .assert_eq(&default_result.into_data(), false); } #[test] fn should_select_with_negative_dim_2d() { // Test using negative dimension indexing on 2D tensor let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::from_data([1, 0, 2], &device); // Using -1 should refer to the last dimension (dim 1) let output_neg = tensor.clone().select(-1, indices.clone()); let output_pos = tensor.select(1, indices); // Both should produce the same result output_neg .into_data() .assert_eq(&output_pos.into_data(), false); } #[test] fn should_select_add_with_negative_dim_2d() { // Test select_add with negative dimension on 2D tensor let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let values = TestTensor::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let indices = TestTensorInt::from_data([0, 2], &device); // Using -1 should refer to the last dimension (dim 1) let output_neg = tensor .clone() .select_assign(-1, indices.clone(), values.clone(), IndexingUpdateOp::Add); let output_pos = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add); output_neg .into_data() .assert_eq(&output_pos.into_data(), false); } #[test] #[should_panic] fn should_panic_select_negative_dim_out_of_bounds() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let indices = TestTensorInt::from_data([0, 1], &device); // This should panic because -3 is out of bounds for a 2D tensor tensor.select(-3, indices); } #[test] #[should_panic] fn should_panic_select_add_negative_dim_out_of_bounds() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let values = TestTensor::from_data([[5.0], [6.0]], &device); let indices = TestTensorInt::from_data([0], &device); // This should panic because -3 is out of bounds for a 2D tensor tensor.select_assign(-3, indices, values, IndexingUpdateOp::Add); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/sign.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_sign_ops_float() { let tensor = TestTensor::<2>::from([[-0.2, -1.0, 2.0], [3.0, 0.0, -5.0]]); let output = tensor.sign(); let expected = TensorData::from([[-1.0, -1.0, 1.0], [1.0, 0.0, -1.0]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/slice.rs ================================================ use super::*; use burn_tensor::{ElementConversion, Slice, TensorData, s}; #[test] fn should_support_slice_dim_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); // Test with range (negative index) let output = tensor.clone().slice_dim(0, -2..); output .into_data() .assert_eq(&TensorData::from([1.0, 2.0]), false); // Test with Slice directly let slice = Slice::new(1, None, 1); // equivalent to 1.. let output = tensor.slice_dim(0, slice); output .into_data() .assert_eq(&TensorData::from([1.0, 2.0]), false); } #[test] #[should_panic(expected = "The provided dimension exceeds the tensor dimensions")] fn should_panic_when_slice_dim_1d_bad_dim() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let _output = tensor.slice_dim(1, 1..); } #[test] fn should_support_slice_dim_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let output = tensor.slice_dim(1, 1..); output .into_data() .assert_eq(&TensorData::from([[1.0, 2.0], [4.0, 5.0]]), false); } #[test] fn should_support_slice_dim_with_step() { let data = TensorData::from([[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); // Test 1: Slice dimension 1 with step=2 using s! macro let output = tensor.clone().slice_dim(1, s![0..4;2]); output .into_data() .assert_eq(&TensorData::from([[0.0, 2.0], [4.0, 6.0]]), false); // Test 2: Slice dimension 1 with step=2 using Slice directly let slice = Slice::new(0, Some(4), 2); let output = tensor.slice_dim(1, slice); output .into_data() .assert_eq(&TensorData::from([[0.0, 2.0], [4.0, 6.0]]), false); } #[test] fn should_support_slice_dim_with_negative_step() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); // Slice dimension 1 with negative step (reverse columns) let output = tensor.slice_dim(1, s![..;-1]); output .into_data() .assert_eq(&TensorData::from([[2.0, 1.0, 0.0], [5.0, 4.0, 3.0]]), false); } #[test] fn should_support_full_sliceing_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let output = tensor.slice([0..3]); output.into_data().assert_eq(&data, false); } #[test] fn should_support_full_sliceing_vec() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let slices: Vec = vec![(0..2).into()]; let output = tensor.clone().slice(&slices); output.into_data().assert_eq(&data, false); let output = tensor.slice([0..2, 0..3]); output.into_data().assert_eq(&data, false); } #[test] fn should_support_partial_sliceing_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.slice([1..3]); let expected = TensorData::from([1.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_full_sliceing_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default()); let output = tensor.clone().slice([0..2]); output.into_data().assert_eq(&data, false); let output = tensor.slice([0..2, 0..3]); output.into_data().assert_eq(&data, false); } #[test] fn should_support_partial_sliceing_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.slice([0..2, 0..2]); let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_range_first_dim() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.slice(0..1); let expected = TensorData::from([[0.0, 1.0, 2.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_partial_sliceing_3d() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let output = tensor.slice([1..2, 1..2, 0..2]); let expected = TensorData::from([[[9.0, 10.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_partial_sliceing_3d_non_contiguous() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let output = tensor.transpose().slice([1..2, 1..2, 0..2]); let expected = TensorData::from([[[7.0, 10.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.slice_fill([0..2], -1.0); let expected = TensorData::from([-1.0, -1.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_vec() { let data = TensorData::from([0.0, 1.0, 2.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let slices: Vec = vec![(0..2).into()]; let output = tensor.slice_fill(&slices, -1.0); let expected = TensorData::from([-1.0, -1.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_cast_f32() { let data = TensorData::from([0.0, 1.0, 2.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device).cast(burn_tensor::DType::F32); tensor .slice_fill(s![0..2], 1.0) .into_data() .assert_eq(&TensorData::from([1.0, 1.0, 2.0]), false); } // Skip on metal - F64 not supported #[cfg(not(feature = "metal"))] #[test] fn should_support_slice_fill_cast_f64() { let data = TensorData::from([0.0, 1.0, 2.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device).cast(burn_tensor::DType::F64); tensor .slice_fill(s![0..2], 1.0) .into_data() .assert_eq(&TensorData::from([1.0, 1.0, 2.0]), false); } #[test] fn should_support_slice_fill_1d_neg() { let data = TensorData::from([0.0, 1.0, 2.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let output = tensor.slice_fill([-1..], -1.0); let expected = TensorData::from([0.0, 1.0, -1.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let device = Default::default(); let tensor = TestTensor::<2>::from_data(data, &device); let output = tensor.slice_fill([1..2, 0..2], -1.0); let expected = TensorData::from([[0.0, 1.0, 2.0], [-1.0, -1.0, 5.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_with_positive_step() { let device = Default::default(); // Test 1D tensor with step let tensor = TestTensor::<1>::zeros([10], &device); let output = tensor.slice_fill(s![0..10;2], 5.0); let expected = TensorData::from([5.0, 0.0, 5.0, 0.0, 5.0, 0.0, 5.0, 0.0, 5.0, 0.0]); output.into_data().assert_eq(&expected, false); // Test 2D tensor with step on first dimension let tensor = TestTensor::<2>::zeros([4, 4], &device); let output = tensor.slice_fill(s![0..4;2, ..], 3.0); let expected = TensorData::from([ [3.0, 3.0, 3.0, 3.0], [0.0, 0.0, 0.0, 0.0], [3.0, 3.0, 3.0, 3.0], [0.0, 0.0, 0.0, 0.0], ]); output.into_data().assert_eq(&expected, false); // Test 2D tensor with step on second dimension let tensor = TestTensor::<2>::zeros([3, 6], &device); let output = tensor.slice_fill(s![.., 0..6;3], 2.0); let expected = TensorData::from([ [2.0, 0.0, 0.0, 2.0, 0.0, 0.0], [2.0, 0.0, 0.0, 2.0, 0.0, 0.0], [2.0, 0.0, 0.0, 2.0, 0.0, 0.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_with_negative_step() { let device = Default::default(); // Test 1D tensor with negative step (reverse fill) let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0], &device); let output = tensor.slice_fill(s![0..5;-1], 10.0); // Should reverse the indices [4,3,2,1,0] and fill them with 10.0 let expected = TensorData::from([10.0, 10.0, 10.0, 10.0, 10.0]); output.into_data().assert_eq(&expected, false); // Test 2D tensor with negative step let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device); let output = tensor.slice_fill(s![.., 0..3;-2], -1.0); // Should fill columns in reverse order with step 2: indices 2, 0 let expected = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 5.0, -1.0], [-1.0, 8.0, -1.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_fill_with_mixed_steps() { let device = Default::default(); // Test 2D tensor with mixed positive and negative steps let tensor = TestTensor::<2>::zeros([4, 6], &device); let output = tensor.slice_fill(s![0..4;2, 0..6;-3], 7.0); // Step 2 on dim 0 selects rows 0, 2 // Step -3 on dim 1 with range 0..6 reverses and takes every 3rd: indices [5, 2] let expected = TensorData::from([ [0.0, 0.0, 7.0, 0.0, 0.0, 7.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 7.0, 0.0, 0.0, 7.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]); output.into_data().assert_eq(&expected, false); // Test 3D tensor with steps let tensor = TestTensor::<3>::zeros([2, 4, 4], &device); let output = tensor.slice_fill(s![.., 0..4;2, 0..4;-2], 1.0); // Step 2 on dim 1 selects rows 0, 2 // Step -2 on dim 2 with range 0..4 reverses and takes every 2nd: indices [3, 1] let expected_slice = [ [0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0], ]; let expected = TensorData::from([expected_slice, expected_slice]); output.into_data().assert_eq(&expected, false); } #[test] fn clamp_when_slice_exceeds_dimension() { let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]); let data = tensor.to_data(); let output = tensor.slice([0..4]); output.into_data().assert_eq(&data, true); } #[test] fn negative_dimensions() { let tensor = TestTensor::<2>::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = tensor.to_data(); // Clamping to the tensor dimensions let output = tensor.clone().slice([0..4, 0..4]); output.into_data().assert_eq(&data, true); // Negative dimensions let output = tensor.clone().slice([0..1, 0..1]); let data = TensorData::from([[0.elem::()]]); output.into_data().assert_eq(&data, true); let output = tensor.slice(s![0..-1, 0..-2]); output.into_data().assert_eq(&data, true); } #[test] fn missing_dimensions() { let tensor = TestTensor::<2>::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = tensor.to_data(); // Clamping to the tensor dimensions let output = tensor.clone().slice([0..4, 0..4]); output.into_data().assert_eq(&data, true); // Negative dimensions let data = TensorData::from([[0.elem::()]]); let output = tensor.clone().slice(s![0..-1, 0..-2]); output.into_data().assert_eq(&data, true); // Missing dimensions let output = tensor.clone().slice(s![0..1, ..]); let data = TensorData::from([[0.0f32, 1.0, 2.0]]); output.into_data().assert_eq(&data, false); let output = tensor.clone().slice(s![.., 0..2]); let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]); output.into_data().assert_eq(&data, false); let output = tensor.clone().slice([.., ..]); let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); output.into_data().assert_eq(&data, false); } #[test] fn should_slice_aggregation_result() { let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]).mean(); let output = tensor.clone().slice([(0..1)]); output.into_data().assert_eq(&tensor.into_data(), true); } #[test] #[should_panic] fn should_panic_when_slice_with_too_many_dimensions() { let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]); let _output = tensor.slice([0..1, 0..1]); } #[test] fn should_support_descending_slice_as_empty() { // Like PyTorch, x[3:1] should return an empty tensor, not panic let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.slice(s![2..1]); // Should produce an empty tensor with shape [0] assert_eq!(output.dims(), [0]); } #[test] fn should_support_empty_slice() { // ONNX models can have empty slices where start == end // This should produce a tensor with size 0 in that dimension let data = TensorData::from([0.0, 1.0, 2.0]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.slice([1..1]); // Should produce an empty tensor with shape [0] assert_eq!(output.dims(), [0]); } #[test] fn should_support_empty_slice_2d() { // Test empty slice on 2D tensor let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); // Empty slice on first dimension let output = tensor.clone().slice([1..1, 0..3]); assert_eq!(output.dims(), [0, 3]); // Empty slice on second dimension let output = tensor.slice([0..2, 2..2]); assert_eq!(output.dims(), [2, 0]); } #[test] fn test_slice_with_positive_step() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); // Test step=2 along first dimension let sliced = tensor.clone().slice([s![0..3;2]]); let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0], [9.0, 10.0, 11.0, 12.0]]); sliced.into_data().assert_eq(&expected, false); // Test step=2 along second dimension let sliced = tensor.clone().slice(s![.., 0..4;2]); let expected = TensorData::from([[1.0, 3.0], [5.0, 7.0], [9.0, 11.0]]); sliced.into_data().assert_eq(&expected, false); // Test step=2 along both dimensions let sliced = tensor.clone().slice(s![0..3;2, 0..4;2]); let expected = TensorData::from([[1.0, 3.0], [9.0, 11.0]]); sliced.into_data().assert_eq(&expected, false); } #[test] fn test_slice_with_negative_step() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); // Test step=-1 along first dimension (reverse rows) let sliced = tensor.clone().slice([s![0..3;-1]]); let expected = TensorData::from([ [9.0, 10.0, 11.0, 12.0], [5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0], ]); sliced.into_data().assert_eq(&expected, false); // Test step=-1 along second dimension (reverse columns) let sliced = tensor.clone().slice(s![.., 0..4;-1]); let expected = TensorData::from([ [4.0, 3.0, 2.0, 1.0], [8.0, 7.0, 6.0, 5.0], [12.0, 11.0, 10.0, 9.0], ]); sliced.into_data().assert_eq(&expected, false); // Test step=-2 along first dimension let sliced = tensor.clone().slice([s![0..3;-2]]); let expected = TensorData::from([[9.0, 10.0, 11.0, 12.0], [1.0, 2.0, 3.0, 4.0]]); sliced.into_data().assert_eq(&expected, false); // Test step=-2 along second dimension let sliced = tensor.clone().slice(s![.., 0..4;-2]); let expected = TensorData::from([[4.0, 2.0], [8.0, 6.0], [12.0, 10.0]]); sliced.into_data().assert_eq(&expected, false); } #[test] fn test_slice_with_mixed_steps() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); // Test positive step along first dimension, negative along second let sliced = tensor.clone().slice(s![0..3;2, 0..4;-1]); let expected = TensorData::from([[4.0, 3.0, 2.0, 1.0], [12.0, 11.0, 10.0, 9.0]]); sliced.into_data().assert_eq(&expected, false); // Test negative step along first dimension, positive along second let sliced = tensor.clone().slice(s![0..3;-1, 0..4;2]); let expected = TensorData::from([[9.0, 11.0], [5.0, 7.0], [1.0, 3.0]]); sliced.into_data().assert_eq(&expected, false); } #[test] fn test_slice_with_steps_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], &device); // Test positive step let sliced = tensor.clone().slice([s![0..10;2]]); let expected = TensorData::from([1.0, 3.0, 5.0, 7.0, 9.0]); sliced.into_data().assert_eq(&expected, false); // Test negative step let sliced = tensor.clone().slice([s![0..10;-1]]); let expected = TensorData::from([10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]); sliced.into_data().assert_eq(&expected, false); // Test negative step with partial range let sliced = tensor.clone().slice([s![2..8;-2]]); let expected = TensorData::from([8.0, 6.0, 4.0]); sliced.into_data().assert_eq(&expected, false); } #[test] fn test_slice_with_steps_3d() { let device = Default::default(); let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], &device, ); // Test step=2 along first dimension let sliced = tensor.clone().slice(s![0..4;2, .., ..]); let expected = TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[9.0, 10.0], [11.0, 12.0]]]); sliced.into_data().assert_eq(&expected, false); // Test step=-1 along all dimensions let sliced = tensor.clone().slice(s![0..4;-1, 0..2;-1, 0..2;-1]); let expected = TensorData::from([ [[16.0, 15.0], [14.0, 13.0]], [[12.0, 11.0], [10.0, 9.0]], [[8.0, 7.0], [6.0, 5.0]], [[4.0, 3.0], [2.0, 1.0]], ]); sliced.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/slice_assign.rs ================================================ use super::*; use burn_tensor::{Slice, TensorData, s}; #[test] fn should_support_slice_assign_1d() { let data = TensorData::from([0.0, 1.0, 2.0]); let data_assigned = TensorData::from([10.0, 5.0]); let device = Default::default(); let tensor = TestTensor::<1>::from_data(data, &device); let tensor_assigned = TestTensor::<1>::from_data(data_assigned, &device); let output = tensor.slice_assign([0..2], tensor_assigned); let expected = TensorData::from([10.0, 5.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_2d() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_assigned = TensorData::from([[10.0, 5.0]]); let device = Default::default(); let tensor = TestTensor::<2>::from_data(data, &device); let tensor_assigned = TestTensor::<2>::from_data(data_assigned, &device); let output = tensor.slice_assign([1..2, 0..2], tensor_assigned); let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_vec() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_assigned = TensorData::from([[10.0, 5.0]]); let device = Default::default(); let tensor = TestTensor::<2>::from_data(data, &device); let tensor_assigned = TestTensor::<2>::from_data(data_assigned, &device); let slices: Vec = vec![1..2, 0..2].into_iter().map(Slice::from).collect(); let output = tensor.slice_assign(&slices, tensor_assigned); let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn slice_assign_now_supports_non_unit_step() { let device = Default::default(); // Create tensors where the shapes match for stepped slicing let tensor = TestTensor::<2>::ones([4, 4], &device); // With step=2 on first dim, we select indices 0 and 2, so we need a [2, 4] values tensor let values = TestTensor::<2>::zeros([2, 4], &device); // This now works because slice_assign supports steps != 1 // We use s! macro to create a slice with step=2 let result = tensor.slice_assign(s![0..3;2, ..], values); // Verify the result: rows 0 and 2 should be zeros, rows 1 and 3 should be ones let expected = TensorData::from([ [0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], ]); result.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_with_positive_step_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device); let values = TestTensor::<1>::from_data([10.0, 20.0, 30.0], &device); // Assign to indices 0, 2, 4 (step=2) let output = tensor.slice_assign([s![0..6;2]], values); let expected = TensorData::from([10.0, 2.0, 20.0, 4.0, 30.0, 6.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_with_positive_step_2d() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0], ], &device, ); // Assign to rows 0, 2 (step=2) let values = TestTensor::<2>::from_data( [[100.0, 101.0, 102.0, 103.0], [200.0, 201.0, 202.0, 203.0]], &device, ); let output = tensor.clone().slice_assign([s![0..4;2]], values); let expected = TensorData::from([ [100.0, 101.0, 102.0, 103.0], [5.0, 6.0, 7.0, 8.0], [200.0, 201.0, 202.0, 203.0], [13.0, 14.0, 15.0, 16.0], ]); output.into_data().assert_eq(&expected, false); // Assign to columns 0, 2 (step=2) let values = TestTensor::<2>::from_data( [ [100.0, 200.0], [101.0, 201.0], [102.0, 202.0], [103.0, 203.0], ], &device, ); let output = tensor.clone().slice_assign(s![.., 0..4;2], values); let expected = TensorData::from([ [100.0, 2.0, 200.0, 4.0], [101.0, 6.0, 201.0, 8.0], [102.0, 10.0, 202.0, 12.0], [103.0, 14.0, 203.0, 16.0], ]); output.into_data().assert_eq(&expected, false); // Assign with step=2 on both dimensions let values = TestTensor::<2>::from_data([[100.0, 200.0], [300.0, 400.0]], &device); let output = tensor.slice_assign(s![0..4;2, 0..4;2], values); let expected = TensorData::from([ [100.0, 2.0, 200.0, 4.0], [5.0, 6.0, 7.0, 8.0], [300.0, 10.0, 400.0, 12.0], [13.0, 14.0, 15.0, 16.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_with_negative_step_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device); let values = TestTensor::<1>::from_data([60.0, 50.0, 40.0, 30.0, 20.0, 10.0], &device); // Assign in reverse order (step=-1) let output = tensor.slice_assign([s![0..6;-1]], values); let expected = TensorData::from([10.0, 20.0, 30.0, 40.0, 50.0, 60.0]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_with_negative_step_2d() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); // Assign to rows in reverse order (step=-1) let values = TestTensor::<2>::from_data( [ [30.0, 31.0, 32.0, 33.0], [20.0, 21.0, 22.0, 23.0], [10.0, 11.0, 12.0, 13.0], ], &device, ); let output = tensor.clone().slice_assign([s![0..3;-1]], values); let expected = TensorData::from([ [10.0, 11.0, 12.0, 13.0], [20.0, 21.0, 22.0, 23.0], [30.0, 31.0, 32.0, 33.0], ]); output.into_data().assert_eq(&expected, false); // Assign to columns in reverse order (step=-1) let values = TestTensor::<2>::from_data( [ [40.0, 30.0, 20.0, 10.0], [80.0, 70.0, 60.0, 50.0], [120.0, 110.0, 100.0, 90.0], ], &device, ); let output = tensor.clone().slice_assign(s![.., 0..4;-1], values); let expected = TensorData::from([ [10.0, 20.0, 30.0, 40.0], [50.0, 60.0, 70.0, 80.0], [90.0, 100.0, 110.0, 120.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_with_mixed_steps() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0], ], &device, ); // Positive step along rows, negative along columns let values = TestTensor::<2>::from_data( [[100.0, 101.0, 102.0, 103.0], [200.0, 201.0, 202.0, 203.0]], &device, ); let output = tensor.clone().slice_assign(s![0..4;2, 0..4;-1], values); let expected = TensorData::from([ [103.0, 102.0, 101.0, 100.0], [5.0, 6.0, 7.0, 8.0], [203.0, 202.0, 201.0, 200.0], [13.0, 14.0, 15.0, 16.0], ]); output.into_data().assert_eq(&expected, false); // Negative step along rows, positive along columns let values = TestTensor::<2>::from_data( [ [100.0, 200.0], [101.0, 201.0], [102.0, 202.0], [103.0, 203.0], ], &device, ); let output = tensor.slice_assign(s![0..4;-1, 0..4;2], values); let expected = TensorData::from([ [103.0, 2.0, 203.0, 4.0], [102.0, 6.0, 202.0, 8.0], [101.0, 10.0, 201.0, 12.0], [100.0, 14.0, 200.0, 16.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_3d_with_steps() { let device = Default::default(); let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], &device, ); // Test step=2 along first dimension let values = TestTensor::<3>::from_data( [ [[100.0, 101.0], [102.0, 103.0]], [[200.0, 201.0], [202.0, 203.0]], ], &device, ); let output = tensor.clone().slice_assign(s![0..4;2, .., ..], values); let expected = TensorData::from([ [[100.0, 101.0], [102.0, 103.0]], [[5.0, 6.0], [7.0, 8.0]], [[200.0, 201.0], [202.0, 203.0]], [[13.0, 14.0], [15.0, 16.0]], ]); output.into_data().assert_eq(&expected, false); // Test step=-1 along all dimensions let values = TestTensor::<3>::from_data( [ [[400.0, 399.0], [398.0, 397.0]], [[396.0, 395.0], [394.0, 393.0]], [[392.0, 391.0], [390.0, 389.0]], [[388.0, 387.0], [386.0, 385.0]], ], &device, ); let output = tensor.slice_assign(s![0..4;-1, 0..2;-1, 0..2;-1], values); let expected = TensorData::from([ [[385.0, 386.0], [387.0, 388.0]], [[389.0, 390.0], [391.0, 392.0]], [[393.0, 394.0], [395.0, 396.0]], [[397.0, 398.0], [399.0, 400.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_partial_with_steps() { let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0], [11.0, 12.0, 13.0, 14.0, 15.0], [16.0, 17.0, 18.0, 19.0, 20.0], [21.0, 22.0, 23.0, 24.0, 25.0], ], &device, ); // Assign to a subset with step=2 let values = TestTensor::<2>::from_data([[100.0, 200.0], [300.0, 400.0]], &device); let output = tensor.slice_assign(s![1..4;2, 1..4;2], values); let expected = TensorData::from([ [1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 100.0, 8.0, 200.0, 10.0], [11.0, 12.0, 13.0, 14.0, 15.0], [16.0, 300.0, 18.0, 400.0, 20.0], [21.0, 22.0, 23.0, 24.0, 25.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_empty_range() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let values: TestTensor<2> = TestTensor::empty([2, 0], &device); // Empty slice assignment (start == end) should be a no-op let output = tensor.clone().slice_assign([0..2, 1..1], values); let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_empty_range_1d() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0], &device); let values: TestTensor<1> = TestTensor::empty([0], &device); // Empty slice assignment should return tensor unchanged let output = tensor.clone().slice_assign([2..2], values); let expected = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_single_dim_slice() { let device = Default::default(); let x = TestTensor::<3>::ones([2, 3, 1], &device); let values = TestTensor::<3>::zeros([1, 3, 1], &device); let output = x.slice_assign(s![1], values); output.into_data().assert_eq( &TensorData::from([[[1.0], [1.0], [1.0]], [[0.0], [0.0], [0.0]]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/sort_argsort.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_sort_1d_float() { let tensor = TestTensor::<1>::from([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let values = tensor.sort(0); let values_expected = TensorData::from([ -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 199.412, ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); } #[test] fn test_argsort_1d_float() { let tensor = TestTensor::<1>::from([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let indices = tensor.argsort(0); let indices_expected = TensorData::from([12, 6, 2, 3, 0, 5, 10, 1, 4, 7, 11, 9, 8]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_with_indices_descending_float() { // 1D let tensor = TestTensor::<1>::from([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let (values, indices) = tensor.sort_descending_with_indices(0); let values_expected = TensorData::from([ 199.412, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1, ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]); indices.into_data().assert_eq(&indices_expected, false); // 2D let tensor = TestTensor::<3>::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, 4.], [0.99, 3., -8.1]], ]); // Sort along dim=1 let (values, indices) = tensor.sort_descending_with_indices(1); let values_expected = TensorData::from([ [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]], [[0.99, 3., 4.], [-0.3, 2.3, -8.1]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_float() { let tensor = TestTensor::<3>::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, 4.], [0.99, 3., -8.1]], ]); // Sort along dim=0 let values = tensor.clone().sort(0); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); // Sort along dim=1 let values = tensor.clone().sort(1); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); // Sort along dim=2 let values = tensor.sort(2); let values_expected = TensorData::from([ [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); } #[test] fn test_sort_with_indices_float() { let tensor = TestTensor::<3>::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, 4.], [0.99, 3., -8.1]], ]); // Sort along dim=0 let (values, indices) = tensor.clone().sort_with_indices(0); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let (values, indices) = tensor.clone().sort_with_indices(1); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let (values, indices) = tensor.sort_with_indices(2); let values_expected = TensorData::from([ [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_argsort_float() { let tensor = TestTensor::<3>::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, 4.], [0.99, 3., -8.1]], ]); // Sort along dim=0 let indices = tensor.clone().argsort(0); let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let indices = tensor.clone().argsort(1); let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let indices = tensor.argsort(2); let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_float_nan() { let tensor = TestTensor::<2>::from([[-0.5, f32::NAN], [0., 0.94], [-0.3, f32::NAN]]); // Sort along dim=0 let values = tensor.sort(0); let values_expected = TensorData::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); } #[test] fn test_sort_descending_1d() { let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); // Sort along dim=0 let values = tensor.sort_descending(0); let values_expected = TensorData::from([5., 4., 3., 2., 1.]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/split.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_split_evenly_divisible() { let device = Default::default(); let tensors = TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], &device); let split_tensors = tensors.split(2, 0); assert_eq!(split_tensors.len(), 3); let expected = [ TensorData::from([[0, 1], [2, 3]]), TensorData::from([[4, 5], [6, 7]]), TensorData::from([[8, 9], [10, 11]]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_not_evenly_divisible() { let device = Default::default(); let tensors = TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], &device); let split_tensors = tensors.split(2, 0); assert_eq!(split_tensors.len(), 3); let expected = [ TensorData::from([[0, 1], [2, 3]]), TensorData::from([[4, 5], [6, 7]]), TensorData::from([[8, 9]]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_along_dim1() { let device = Default::default(); let tensors = TestTensor::<2>::from_data([[0, 1, 2], [3, 4, 5]], &device); let split_tensors = tensors.split(2, 1); assert_eq!(split_tensors.len(), 2); let expected = [ TensorData::from([[0, 1], [3, 4]]), TensorData::from([[2], [5]]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_split_size_larger_than_tensor_size() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device); let split_tensors = tensors.split(10, 0); assert_eq!(split_tensors.len(), 1); let expected = [TensorData::from([0, 1, 2, 3, 4])]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_with_zero_split_size_zero_tensor_size() { let device = Default::default(); let empty_array: [i32; 0] = []; let tensors = TestTensor::<1>::from_data(empty_array, &device); let split_tensors = tensors.split(0, 0); assert_eq!(split_tensors.len(), 0); } #[test] fn test_split_zero_sized_tensor() { let device = Default::default(); let empty_array: [i32; 0] = []; let tensors = TestTensor::<1>::from_data(empty_array, &device); let split_tensors = tensors.split(1, 0); assert_eq!(split_tensors.len(), 0); } #[test] #[should_panic( expected = "split_size must be greater than 0 unless the tensor size along the dimension is 0." )] fn test_split_with_zero_split_size_non_zero_tensor() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device); let _split_tensors = tensors.split(0, 0); } #[test] #[should_panic(expected = "Given dimension is greater than or equal to the tensor rank.")] fn test_split_invalid_dim() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2], &device); let _split_tensors = tensors.split(1, 2); } #[test] fn test_split_3d_tensor_along_dim0() { let device = Default::default(); let tensors = TestTensor::<3>::from_data( [ [[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]], [[12, 13], [14, 15]], ], &device, ); let split_tensors = tensors.split(2, 0); assert_eq!(split_tensors.len(), 2); let expected = [ TensorData::from([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), TensorData::from([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_3d_tensor_along_dim1() { let device = Default::default(); let tensors = TestTensor::<3>::from_data( [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], &device, ); let split_tensors = tensors.split(2, 1); assert_eq!(split_tensors.len(), 2); let expected = [ TensorData::from([[[0, 1], [2, 3]], [[6, 7], [8, 9]]]), TensorData::from([[[4, 5]], [[10, 11]]]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] fn test_split_with_sizes() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device); let split_tensors = tensors.split_with_sizes(vec![2, 3, 1], 0); assert_eq!(split_tensors.len(), 3); let expected = [ TensorData::from([0, 1]), TensorData::from([2, 3, 4]), TensorData::from([5]), ]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } #[test] #[should_panic( expected = "The sum of split_sizes must equal the tensor size along the specified dimension." )] fn test_split_with_sizes_invalid_sum() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device); let _split_tensors = tensors.split_with_sizes(vec![2, 2, 1], 0); } #[test] fn test_split_with_sizes_zero_length() { let device = Default::default(); let tensors = TestTensor::<1>::from_data([0, 1, 2], &device); let split_tensors = tensors.split_with_sizes(vec![0, 1, 2], 0); assert_eq!(split_tensors.len(), 2); let expected = [TensorData::from([0]), TensorData::from([1, 2])]; for (index, tensor) in split_tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/sqrt.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use core::f32::consts::SQRT_2; #[test] fn should_support_sqrt_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.sqrt(); let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); output.into_data().assert_approx_eq::( &expected, Tolerance::relative(1e-4).set_half_precision_relative(1e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/square.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_sqrt_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.square(); let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); output.into_data().assert_approx_eq::( &expected, Tolerance::relative(1e-4).set_half_precision_relative(1e-3), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/squeeze.rs ================================================ use super::*; use burn_tensor::Shape; /// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor. #[test] fn should_squeeze_dim() { let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default()); let squeezed_tensor: TestTensor<2> = tensor.squeeze_dim(1); let expected_shape = Shape::new([2, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); } #[test] fn should_squeeze() { let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default()); let squeezed_tensor: TestTensor<2> = tensor.squeeze(); let expected_shape = Shape::new([2, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor. #[test] fn should_squeeze_first() { let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 4, 5]), &Default::default()); let squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(0); let expected_shape = Shape::new([3, 4, 5]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor. #[test] fn should_squeeze_last() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 1]), &Default::default()); let squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(3); let expected_shape = Shape::new([2, 3, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test if the function panics when the squeezed dimension is not of size 1. #[test] #[should_panic] fn should_squeeze_panic() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let _squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(2); } /// Test if the function works with an empty slice #[test] fn should_squeeze_dims_with_empty_slice() { let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 3]), &Default::default()); let squeezed_tensor: TestTensor<1> = tensor.squeeze_dims(&[]); let expected_shape = Shape::new([3]); assert_eq!(squeezed_tensor.shape(), expected_shape); } #[test] fn should_squeeze_all_dims() { let tensor = TestTensor::<3>::ones(Shape::new([1, 3, 1]), &Default::default()); let squeezed_tensor: TestTensor<1> = tensor.squeeze(); let expected_shape = Shape::new([3]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test if the function works with positive indices #[test] fn should_squeeze_dims_with_positive_indices() { let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default()); let squeezed_tensor: TestTensor<2> = tensor.squeeze_dims(&[0, 2]); let expected_shape = Shape::new([3, 5]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test if the function works with negative indices #[test] fn should_squeeze_dims_with_negative_indices() { let tensor = TestTensor::<4>::ones(Shape::new([2, 1, 3, 1]), &Default::default()); let squeezed_tensor: TestTensor<2> = tensor.squeeze_dims(&[-3, -1]); let expected_shape = Shape::new([2, 3]); assert_eq!(squeezed_tensor.shape(), expected_shape); } /// Test to make sure the function panics if a non-singleton dimension is squeezed #[test] #[should_panic] fn should_squeeze_dims_work_if_non_singleton() { let tensor = TestTensor::<3>::ones(Shape::new([2, 3, 4]), &Default::default()); let squeezed_tensor: TestTensor<3> = tensor.squeeze_dims(&[1]); let expected_shape = Shape::new([2, 3, 4]); assert_eq!(squeezed_tensor.shape(), expected_shape); } #[test] #[should_panic] fn should_panic_squeeze_consumes_all_singleton() { let tensor = TestTensor::<3>::ones(Shape::new([1, 3, 1]), &Default::default()); let _squeezed_tensor: TestTensor<2> = tensor.squeeze(); // output rank should be 1 } /// Test to make sure the function panics if too many dimensions are requested to be squeezed #[test] #[should_panic] fn should_squeeze_dims_panic_on_too_many_dimensions() { let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default()); let _: TestTensor<1> = tensor.squeeze_dims(&[0, 1, 2]); } /// Test to make sure function panics if dimensions are mismatched #[test] #[should_panic] fn should_squeeze_dims_dimension_mismatch_panic() { let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default()); let _: TestTensor<3> = tensor.squeeze_dims(&[0, 2]); } /// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor. #[test] fn should_unsqueeze_dim() { let tensor = TestTensor::<3>::ones(Shape::new([2, 4, 1]), &Default::default()); let unsqueezed_tensor: TestTensor<4> = tensor.unsqueeze_dim(1); let expected_shape = Shape::new([2, 1, 4, 1]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); } /// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor. #[test] fn should_unsqueeze_dim_first() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(0); let expected_shape = Shape::new([1, 2, 3, 4, 5]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); } /// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor. #[test] fn should_unsqueeze_dim_last() { let tensor = TestTensor::<4>::ones(Shape::new([5, 4, 3, 2]), &Default::default()); let unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(4); let expected_shape = Shape::new([5, 4, 3, 2, 1]); assert_eq!(unsqueezed_tensor.shape(), expected_shape); } /// Test if the function panics when the unsqueezed dimension is out of bounds. #[test] #[should_panic] fn should_unsqueeze_dim_panic() { let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default()); let _unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(5); } #[test] fn should_unsqueeze_dims_support_dim_inference() { let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor = input_tensor.unsqueeze_dims::<5>(&[1, -2]); let expected_shape = Shape::new([3, 1, 4, 1, 5]); assert_eq!(output_tensor.shape(), expected_shape); } #[test] fn should_unsqueeze_dims_handle_first_last() { let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor = input_tensor.unsqueeze_dims::<5>(&[0, 4]); let expected_shape = Shape::new([1, 3, 4, 5, 1]); assert_eq!(output_tensor.shape(), expected_shape); } #[test] fn should_unsqueeze_dims_work_with_single_dim() { //bruh, just call unsqueeze_dim let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor: TestTensor<4> = input_tensor.unsqueeze_dims(&[1]); let expected_shape = Shape::new([3, 1, 4, 5]); assert_eq!(output_tensor.shape(), expected_shape); } #[test] fn should_unsqueeze_dims_multiple_trailing_negatives() { let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let output_tensor: TestTensor<6> = input_tensor.unsqueeze_dims(&[0, -1, -1]); let expected_shape = Shape::new([1, 3, 4, 5, 1, 1]); assert_eq!(output_tensor.shape(), expected_shape); } #[test] #[should_panic] fn should_unsqueeze_dims_panic() { let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default()); let _output_tensor: TestTensor<5> = input_tensor.unsqueeze_dims(&[0, -6]); } #[test] #[should_panic] fn squeeze_all_singleton_not_supported() { let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default()); let _ = tensor.squeeze::<0>(); } #[test] #[should_panic] fn squeeze_dim_singleton_not_supported() { let tensor = TestTensor::<1>::ones(Shape::new([1]), &Default::default()); let _ = tensor.squeeze_dim::<0>(0); } #[test] #[should_panic] fn squeeze_dims_all_singleton_not_supported() { let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default()); let _ = tensor.squeeze_dims::<0>(&[0, 1, 2]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/stack.rs ================================================ use super::*; use alloc::{vec, vec::Vec}; use burn_tensor::{Tensor, TensorData}; #[test] fn should_support_stack_ops_2d_dim0() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_stack_ops_2d_dim1() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 1); let expected = TensorData::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_stack_ops_3d() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device); let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]], [[4.1, 5.1, 6.1]]], &device); let output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([ [[[1.0000, 2.0000, 3.0000]], [[1.1000, 2.1000, 3.1000]]], [[[4.0000, 5.0000, 6.0000]], [[4.1000, 5.1000, 6.1000]]], ]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn should_panic_when_dimensions_are_not_the_same() { let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device); let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device); let _output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); } #[test] #[should_panic] fn should_panic_when_list_of_vectors_is_empty() { let tensors: Vec> = vec![]; let _output = Tensor::stack::<3>(tensors, 0); } #[test] #[should_panic] fn should_panic_when_stack_exceeds_dimension() { let device = Default::default(); let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device); let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device); let _output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 3); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/sub.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_sub_ops() { let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_2 = TensorData::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sub_broadcast() { let data_1 = TensorData::from([[0.0, 1.0, 2.0]]); let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let device = Default::default(); let tensor_1 = TestTensor::<2>::from_data(data_1, &device); let tensor_2 = TestTensor::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_sub_scalar_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor - scalar; let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/take.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_take_1d() { // Test that take works with 1D indices let device = Default::default(); let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device); let indices = TestTensorInt::<1>::from_data([1, 1, 0, 1, 2], &device); let output = tensor.take::<1, 1>(0, indices); let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_2d_dim0() { // Test take on 2D tensor along dimension 0 let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::<1>::from_data([1, 0, 1, 1], &device); let output = tensor.take::<1, 2>(0, indices); let expected = TensorData::from([ [3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_2d_dim1() { // Test take on 2D tensor along dimension 1 let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::<1>::from_data([2, 0, 1], &device); let output = tensor.take::<1, 2>(1, indices); let expected = TensorData::from([[2.0, 0.0, 1.0], [5.0, 3.0, 4.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn take_and_select_should_be_equivalent() { // Verify that take and select produce identical results let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); let indices = TestTensorInt::<1>::from_data([2, 0, 1, 1], &device); let result_take = tensor.clone().take::<1, 2>(0, indices.clone()); let result_select = tensor.select(0, indices); let take_data = result_take.into_data(); let select_data = result_select.into_data(); take_data.assert_eq(&select_data, false); } #[test] fn should_take_with_2d_indices() { // Test take with 2D indices - output will be 3D with shape [2, 2, 4] let device = Default::default(); let tensor = TestTensor::<2>::from_data( [ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ], &device, ); // 2D indices to select along dimension 0 - shape [2, 2] let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device); let output = tensor.take::<2, 3>(0, indices); // Expected: shape [2, 2, 4] - indices shape replaces dim 0 let expected = TensorData::from([ [[1.0, 2.0, 3.0, 4.0], [9.0, 10.0, 11.0, 12.0]], [[5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_with_2d_indices_dim1() { // Test take with 2D indices along dimension 1 - output will be 3D with shape [2, 2, 2] let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device); // 2D indices to select along dimension 1 - shape [2, 2] let indices = TestTensorInt::<2>::from_data([[0, 3], [2, 1]], &device); let output = tensor.take::<2, 3>(1, indices); // Expected: shape [2, 2, 2] - indices shape replaces dim 1 let expected = TensorData::from([[[1.0, 4.0], [3.0, 2.0]], [[5.0, 8.0], [7.0, 6.0]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_3d_tensor() { // Test take with 3D tensor - output will be 4D with shape [2, 2, 2, 2] let device = Default::default(); let tensor = TestTensor::<3>::from_data( [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], ], &device, ); // 2D indices to select along dimension 1 - shape [2, 2] let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device); let output = tensor.take::<2, 4>(1, indices); // Expected: shape [2, 2, 2, 2] - indices shape replaces dim 1 let expected = TensorData::from([ [[[1.0, 2.0], [5.0, 6.0]], [[3.0, 4.0], [1.0, 2.0]]], [[[7.0, 8.0], [11.0, 12.0]], [[9.0, 10.0], [7.0, 8.0]]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_with_3d_indices() { // Test take with 3D indices - output will be 4D let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); // 3D indices to select along dimension 1 - shape [2, 2, 2] let indices = TestTensorInt::<3>::from_data([[[0, 2], [1, 0]], [[2, 1], [0, 2]]], &device); let output = tensor.take::<3, 4>(1, indices); // Expected: shape [2, 2, 2, 2] - indices shape replaces dim 1 let expected = TensorData::from([ [[[1.0, 3.0], [2.0, 1.0]], [[3.0, 2.0], [1.0, 3.0]]], [[[4.0, 6.0], [5.0, 4.0]], [[6.0, 5.0], [4.0, 6.0]]], ]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn should_panic_take_invalid_dimension() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); let indices = TestTensorInt::<1>::from_data([1, 0], &device); // This should panic because dimension 10 is out of bounds tensor.take::<1, 2>(10, indices); } #[test] fn should_take_with_single_index() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::<1>::from_data([1], &device); let output = tensor.take::<1, 2>(0, indices); let expected = TensorData::from([[4.0, 5.0, 6.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_with_negative_dim_2d() { // Test using negative dimension indexing on 2D tensor let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let indices = TestTensorInt::<1>::from_data([2, 0, 1], &device); // Using -1 should refer to the last dimension (dim 1) let output_neg = tensor.clone().take::<1, 2>(-1, indices.clone()); let output_pos = tensor.take::<1, 2>(1, indices); // Both should produce the same result let neg_data = output_neg.into_data(); let pos_data = output_pos.into_data(); neg_data.assert_eq(&pos_data, false); } #[test] #[should_panic] fn should_panic_take_negative_dim_out_of_bounds() { let device = Default::default(); let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let indices = TestTensorInt::<1>::from_data([0, 1], &device); // This should panic because -3 is out of bounds for a 2D tensor tensor.take::<1, 2>(-3, indices); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/topk.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_topk_with_indices_3d() { let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); values .into_data() .assert_approx_eq::(&values_expected, Tolerance::default()); let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/transaction.rs ================================================ use super::*; use burn_tensor::Transaction; // https://github.com/tracel-ai/burn/issues/4021 #[test] fn should_support_transaction() { let rows = 261120; let cols = 408; let device = Default::default(); let j = TestTensor::<2>::zeros([rows, cols], &device); let jt = j.clone().transpose(); let g = jt.matmul(j); let g = g.transpose(); let expected = g.to_data(); assert_eq!(g.shape().dims(), [cols, cols]); // Fails let [data] = Transaction::default() .register(g) .execute() .try_into() .unwrap(); // check byte equality assert_eq!(data, expected); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/transpose.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_transpose_ops() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); // Check the .t() alias. let output = tensor.t(); let expected = TensorData::from([ [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_transpose_maybe_fused_with_one() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let ones = TestTensor::<3>::ones([1, 1, 1], &Default::default()); let output = tensor.transpose(); let expected = TensorData::from([ [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], ]); let expected_ones = TensorData::from([[[1.0]]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); ones.into_data() .assert_approx_eq::(&expected_ones, Tolerance::default()); } #[test] fn should_support_swap_dims_no_op() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let output = tensor.swap_dims(0, 0); let expected = TensorData::from([ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_swap_dims() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let output = tensor.swap_dims(0, 2); let expected = TensorData::from([ [[0.0, 6.0], [3.0, 9.0]], [[1.0, 7.0], [4.0, 10.0]], [[2.0, 8.0], [5.0, 11.0]], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_swap_dims_neg_index() { let tensor = TestTensor::<3>::from_floats( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ], &Default::default(), ); let output = tensor.swap_dims(-3, -1); let expected = TensorData::from([ [[0.0, 6.0], [3.0, 9.0]], [[1.0, 7.0], [4.0, 10.0]], [[2.0, 8.0], [5.0, 11.0]], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/tri.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_triu() { let tensor = TestTensor::<2>::from([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]); let output = tensor.triu(0); let expected = TensorData::from([[1., 1., 1.], [0., 1., 1.], [0., 0., 1.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_triu_positive_diagonal() { let tensor = TestTensor::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]); let output = tensor.triu(1); let expected = TensorData::from([[0, 1, 1], [0, 0, 1], [0, 0, 0]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/trig.rs ================================================ #![allow(clippy::approx_constant)] use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use core::f32::consts::{FRAC_PI_2, FRAC_PI_3, FRAC_PI_4, FRAC_PI_6, FRAC_PI_8, PI}; #[test] fn should_support_cos_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cos(); let expected = TensorData::from([[1.0, 0.54030, -0.41615], [-0.98999, -0.65364, 0.28366]]); // Metal has less precise trigonometric functions let tolerance = Tolerance::default().set_half_precision_relative(1e-2); output .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn should_support_cosh_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cosh(); let expected = TensorData::from([[1.0000, 1.5431, 3.7622], [10.0677, 27.3082, 74.2099]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_sin_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.sin(); let expected = TensorData::from([[0.0, 0.841471, 0.909297], [0.141120, -0.756802, -0.958924]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_sinh_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.sinh(); let expected = TensorData::from([[0.0000, 1.1752, 3.6269], [10.0179, 27.2899, 74.2032]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_tan_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.tan(); let expected = TensorData::from([[0.0, 1.557408, -2.185040], [-0.142547, 1.157821, -3.380515]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_tanh_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.tanh(); let expected = TensorData::from([[0.0, 0.761594, 0.964028], [0.995055, 0.999329, 0.999909]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_asin_ops() { let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -1.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.asin(); let expected = TensorData::from([[0.0, 0.523599, 0.785398], [-0.523599, -0.785398, -1.570796]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_acos_ops() { let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -1.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.acos(); let expected = TensorData::from([ [1.570796, 1.047198, 0.785398], [2.094395, 2.356194, 3.141593], ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_atan_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.atan(); let expected = TensorData::from([[0.0, 0.785398, 1.107149], [1.249046, 1.325818, 1.373401]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_asinh_ops() { let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.asinh(); let expected = TensorData::from([[0.0, 0.881374, 1.443635], [1.818446, 2.094713, 2.312438]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_acosh_ops() { let data = TensorData::from([[1.0, 1.5, 2.0], [3.0, 4.0, 5.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.acosh(); let expected = TensorData::from([[0.0, 0.962424, 1.316958], [1.762747, 2.063437, 2.292432]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_atanh_ops() { let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -0.9]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.atanh(); let expected = TensorData::from([[0.0, 0.549306, 0.881374], [-0.549306, -0.881374, -1.472219]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_atan2_ops() { let y = TensorData::from([[0.0, 1.0, 1.0], [-1.0, -1.0, 0.0]]); let x = TensorData::from([[1.0, 1.0, 0.0], [1.0, 0.0, -1.0]]); let y_tensor = TestTensor::<2>::from_data(y, &Default::default()); let x_tensor = TestTensor::<2>::from_data(x, &Default::default()); let output = y_tensor.atan2(x_tensor); let expected = TensorData::from([[0.0, 0.785398, 1.570796], [-0.785398, -1.570796, 3.141593]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_deg2rad_ops() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats( [ 0.0, 22.5, 30.0, 45.0, 60.0, 90.0, 135.0, 180.0, 270.0, 360.0, ], &device, ); let output = tensor.deg2rad(); let expected = TensorData::from([ 0.0f32, FRAC_PI_8, FRAC_PI_6, FRAC_PI_4, FRAC_PI_3, FRAC_PI_2, 0.75 * PI, PI, 1.5 * PI, 2.0 * PI, ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_support_rad2deg_ops() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats( [ 0.0, FRAC_PI_8, FRAC_PI_6, FRAC_PI_4, FRAC_PI_3, FRAC_PI_2, PI, 1.5 * PI, 2.0 * PI, -FRAC_PI_3, ], &device, ); let output = tensor.rad2deg(); let expected = TensorData::from([ 0.0f32, 22.5, 30.0, 45.0, 60.0, 90.0, 180.0, 270.0, 360.0, -60.0, ]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/trunc.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ElementConversion, TensorData}; #[test] fn should_support_trunc_ops() { let data = TensorData::from([[2.3, -1.7, 0.5], [-0.5, 3.9, -4.2]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.trunc(); let expected = TensorData::from([[2.0, -1.0, 0.0], [0.0, 3.0, -4.0]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_truncate_positive_values_like_floor() { let data = TensorData::from([1.7, 2.9, 3.1, 4.5]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.trunc(); let expected = TensorData::from([1.0, 2.0, 3.0, 4.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_truncate_negative_values_like_ceil() { let data = TensorData::from([-1.7, -2.9, -3.1, -4.5]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.trunc(); let expected = TensorData::from([-1.0, -2.0, -3.0, -4.0]); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn should_handle_special_cases() { // Test special IEEE 754 cases let data = TensorData::from([0.0, -0.0, f32::INFINITY, f32::NEG_INFINITY, f32::NAN]); let tensor = TestTensor::<1>::from_data(data, &Default::default()); let output = tensor.trunc(); let values = output.into_data().as_slice::().unwrap().to_vec(); // Check positive zero assert_eq!(values[0], 0.0f32.elem::()); assert!(values[0].is_sign_positive()); // Check negative zero is preserved assert_eq!(values[1], 0.0f32.elem::()); assert!(values[1].is_sign_negative()); // Check infinity is preserved assert!(values[2].is_infinite() && values[2].is_sign_positive()); assert!(values[3].is_infinite() && values[3].is_sign_negative()); // Check NaN is preserved assert!(values[4].is_nan()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/ops/unfold.rs ================================================ use super::*; use burn_tensor::Distribution; use burn_tensor::s; #[test] fn test_unfold_float() { let device = Default::default(); let input = TestTensor::<3>::random([2, 6, 6], Distribution::Default, &device); let dim = 1; let size = 3; let step = 2; let actual: TestTensor<4> = input.clone().unfold(dim, size, step); let expected = TestTensor::<4>::empty([2, 2, 6, 3], &device) .slice_assign( s![.., 0, .., ..], input .clone() .slice(s![.., 0..3, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ) .slice_assign( s![.., 1, .., ..], input .clone() .slice(s![.., 2..5, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ); actual.to_data().assert_eq(&expected.to_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/primitive.rs ================================================ use super::*; use burn_tensor::{Element, Shape}; #[test] fn should_support_float_dtype() { let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]).into_primitive(); assert_eq!( burn_tensor::TensorMetadata::shape(&tensor), Shape::new([2, 3]) ); assert_eq!( burn_tensor::TensorMetadata::dtype(&tensor), FloatElem::dtype() // default float elem type ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/calibration.rs ================================================ use super::*; use burn_tensor::{ TensorData, ops::QuantizedTensor, quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantValue, compute_range}, }; // NOTE: The scheme variant fields are not important for calibration, only the "main" variant (e.g., per-tensor) #[test] fn min_max_calibration_range_per_tensor() { let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default()); let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let range = compute_range(&scheme, &tensor, &Calibration::MinMax); range .min .into_data() .assert_eq(&TensorData::from([-1.8]), false); range .max .into_data() .assert_eq(&TensorData::from([0.5]), false); } #[test] fn min_max_calibration_range_per_block() { let tensor = TestTensor::<2>::from_floats( [ [-1.8, -1.0, 0.0, 0.5], [1.8, 1.0, 0.0, -0.5], [0.01, 0.02, 0.03, 0.04], [-0.01, -0.02, -0.03, -0.04], ], &Default::default(), ); let scheme = QuantizedTensor::::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([4])); let range = compute_range(&scheme, &tensor, &Calibration::MinMax); range .min .into_data() .assert_eq(&TensorData::from([[-1.8], [-0.5], [0.01], [-0.04]]), false); range .max .into_data() .assert_eq(&TensorData::from([[0.5], [1.8], [0.04], [-0.01]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/data.rs ================================================ use super::*; use alloc::vec; use burn_tensor::quantization::{QTensorPrimitive, QuantLevel, QuantValue}; use burn_tensor::{TensorData, ops::QuantizedTensor}; #[test] fn should_support_per_tensor_symmetric_int8() { let data = TensorData::quantized( vec![-127i8, -71, 0, 35], [4], QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S), &[0.014_173_228], ); let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); let q_data = tensor.into_data(); q_data.assert_eq(&data, true); let tensor = TestTensor::<1>::from_data(q_data.clone(), &Default::default()); tensor.into_data().assert_eq(&q_data, true); } #[test] fn should_support_per_block_symmetric_int8() { let data = TensorData::quantized( vec![ -127i8, -71, 0, 35, -127i8, -71, 0, 35, -32, -63, -95, -127, -32, -63, -95, -127, ], [16], QuantizedTensor::::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([8])), &[0.014_173_228, 0.000_314_96], ); let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); tensor.into_data().assert_eq(&data, true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/mod.rs ================================================ pub use super::*; // re-export test types mod calibration; mod data; mod ops; mod scheme; /// Quantized tensor utilities pub mod qtensor { use core::marker::PhantomData; use burn_tensor::quantization::QuantLevel; use burn_tensor::{ Tensor, TensorData, backend::Backend, quantization::{QTensorPrimitive, QuantValue}, }; pub struct QTensor { b: PhantomData, } impl QTensor { /// Creates a quantized int8 tensor from the floating point data using the default quantization scheme /// (i.e., per-tensor symmetric quantization). pub fn int8>(floats: F) -> Tensor { Self::int8_symmetric(floats) } /// Creates a quantized int8 tensor from the floating point data using blocks of size 16 pub fn int8_block>(floats: F) -> Tensor { Tensor::from_floats(floats, &Default::default()).quantize_dynamic( &::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([16])), ) } /// Creates a quantized int8 tensor from the floating point data using per-tensor symmetric quantization. pub fn int8_symmetric>(floats: F) -> Tensor { Tensor::from_floats(floats, &Default::default()).quantize_dynamic( &::default_scheme() .with_value(QuantValue::Q8S), ) } } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/abs.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_abs_ops() { let tensor = QTensor::::int8([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); let output = tensor.abs(); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), Tolerance::absolute(1e-1), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/add.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_add_d2() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); let output = tensor_1 + tensor_2; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_add_broadcast() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0]]); let tensor_2 = QTensor::::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let output = tensor_1 + tensor_2; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_add_different_strides_rhs() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = QTensor::::int8([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = QTensor::::int8([[4.0, 5.0], [6.0, 7.0]]) * 1; let output = tensor_1 + tensor_2.transpose(); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[4.0, 7.0], [7.0, 10.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_add_different_strides_lhs() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = QTensor::::int8([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = QTensor::::int8([[4.0, 5.0], [6.0, 7.0]]) * 1; let output = tensor_1.transpose() + tensor_2; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[4.0, 7.0], [7.0, 10.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_add_different_strides_broadcast() { // We need to execute an operation after `from data` to trigger inplace in some backends. // Which is the operation that might be problematic in this case. let tensor_1 = QTensor::::int8([[0.0, 1.0], [2.0, 3.0]]) * 1; let tensor_2 = QTensor::::int8([[4.0, 5.0]]) * 1; let output = tensor_1.transpose() + tensor_2; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[4.0, 7.0], [5.0, 8.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_support_add_scalar_ops() { let scalar = 2.0; let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor + scalar; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]), Tolerance::absolute(1e-1), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/aggregation.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_should_mean() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.mean(); output .dequantize() .into_data() .assert_approx_eq::(&TensorData::from([15.0 / 6.0]), Tolerance::absolute(1e-1)); } #[test] fn test_should_sum() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sum(); output .dequantize() .into_data() .assert_approx_eq::(&TensorData::from([15.0]), Tolerance::absolute(1e-1)); } #[test] fn test_should_mean_last_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.mean_dim(1); let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn test_should_sum_last_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sum_dim(1); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[3.0], [12.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_should_sum_first_dim() { let tensor = QTensor::::int8([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); let output = tensor.sum_dim(0); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[7.0, 3.0, 5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_should_mean_first_dim() { let tensor = QTensor::::int8([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]); let output = tensor.mean_dim(0); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]), Tolerance::absolute(1e-1), ); } #[test] fn test_should_sum_mid_dim_3d_non_contiguous_1() { let tensor = QTensor::::int8([ [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], ]); let output = tensor.swap_dims(0, 2).sum_dim(1); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]), Tolerance::absolute(1e-1), ); } #[test] fn test_should_sum_mid_dim_3d_non_contiguous_2() { let tensor = QTensor::::int8([ [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]], [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]], ]); let output = tensor.swap_dims(0, 1).sum_dim(1); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]), Tolerance::absolute(1e-1), ); } #[test] fn test_prod_float() { let tensor = QTensor::::int8([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.prod(); output .dequantize() .into_data() .assert_approx_eq::(&TensorData::from([240.0]), Tolerance::rel_abs(1e-1, 1e-1)); let tensor_with_zero = QTensor::::int8([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_with_zero.prod(); output .dequantize() .into_data() .assert_approx_eq::(&TensorData::from([0.0]), Tolerance::rel_abs(1e-1, 1e-1)); } #[test] fn test_prod_dim_float() { let tensor = QTensor::::int8([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.prod_dim(1); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[4.0], [60.0]]), Tolerance::absolute(1e-1), ); let tensor_with_zero = QTensor::::int8([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_with_zero.prod_dim(1); let expected = TensorData::from([[0.0], [60.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/all.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; #[test] fn test_all() { let tensor = QTensor::::int8([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([false]); assert_eq!(data_expected, data_actual); let tensor = QTensor::::int8([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([true]); assert_eq!(data_expected, data_actual); } #[test] fn test_all_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]); let data_actual = tensor.all_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); assert_eq!(data_expected, data_actual); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/any.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; #[test] fn test_any() { let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); assert_eq!(data_expected, data_actual); let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); assert_eq!(data_expected, data_actual); } #[test] fn test_any_dim() { let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); assert_eq!(data_expected, data_actual); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/arg.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; #[test] fn test_argmax_2d_dim0() { let tensor = QTensor::::int8([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.argmax(0); output .into_data() .assert_eq(&TensorData::from([[0, 0, 1]]), false); } #[test] fn test_argmin_2d_dim0() { let tensor = QTensor::::int8([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); let output = tensor.argmin(0); output .into_data() .assert_eq(&TensorData::from([[0, 1, 0]]), false); } #[test] fn test_argmax_2d_dim1() { let tensor = QTensor::::int8([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.argmax(1); output .into_data() .assert_eq(&TensorData::from([[1], [2]]), false); } #[test] fn test_argmin_2d_dim1() { let tensor = QTensor::::int8([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]); let output = tensor.argmin(1); output .into_data() .assert_eq(&TensorData::from([[2], [1]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cat.rs ================================================ use super::qtensor::*; use super::*; use alloc::vec; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_cat_ops_2d_dim0() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0, 6.0]]); let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_cat_ops_2d_dim1() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0, 6.0]]); let output = TestTensor::cat(vec![tensor_1, tensor_2], 1); let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_cat_ops_3d() { let tensor_1 = QTensor::::int8([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]); let tensor_2 = QTensor::::int8([[[4.0, 5.0, 6.0]]]); let output = TestTensor::cat(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[should_panic] fn should_panic_when_dimensions_are_not_the_same() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0]]); let _output = TestTensor::cat(vec![tensor_1, tensor_2], 0); } #[test] #[should_panic] fn should_panic_when_cat_exceeds_dimension() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0, 6.0]]); let _output = TestTensor::cat(vec![tensor_1, tensor_2], 3); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/ceil.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_ceil_ops() { let tensor = QTensor::::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let output = tensor.ceil(); let expected = TensorData::from([[25., 88., 77.], [60., 44., 96.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(1e-1, 1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/chunk.rs ================================================ use super::qtensor::*; use super::*; use alloc::vec::Vec; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_chunk_evenly_divisible() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let tensors: Vec> = tensor.chunk(3, 0); assert_eq!(tensors.len(), 3); let expected = [ TensorData::from([0., 1.]), TensorData::from([2., 3.]), TensorData::from([4., 5.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_chunk_not_evenly_divisible() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); let tensors: Vec> = tensor.chunk(4, 0); assert_eq!(tensors.len(), 4); let expected = [ TensorData::from([0., 1.]), TensorData::from([2., 3.]), TensorData::from([4., 5.]), TensorData::from([6.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_chunk_not_divisible() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let tensors: Vec> = tensor.chunk(7, 0); assert_eq!(tensors.len(), 6); let expected = [ TensorData::from([0.]), TensorData::from([1.]), TensorData::from([2.]), TensorData::from([3.]), TensorData::from([4.]), TensorData::from([5.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_chunk_multi_dimension() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]); let tensors: Vec> = tensor.chunk(2, 1); assert_eq!(tensors.len(), 2); let expected = [ TensorData::from([[0., 1., 2.]]), TensorData::from([[3., 4., 5.]]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] #[should_panic] fn test_invalid_dim() { let _tensors = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).chunk(6, 1); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/clamp.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn clamp_min() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.clamp_min(2.0); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn clamp_max() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.clamp_max(2.0); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]), Tolerance::absolute(1e-1), ); } #[test] fn clamp_min_max() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.clamp(1.0, 4.0); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]), Tolerance::absolute(1e-1), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cos.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_cos_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.cos(); let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cosh.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_cosh_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.cosh(); let expected = TensorData::from([[1.0000, 1.5431, 3.7622], [10.0677, 27.3082, 74.2100]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/div.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_div_ops() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_1 / tensor_2; let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn test_div_broadcast() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0]]); let tensor_2 = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor_1 / tensor_2; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]), Tolerance::absolute(1e-1), ); } #[test] fn should_support_div_scalar_ops() { let scalar = 2.0; let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor / scalar; output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]), Tolerance::absolute(1e-1), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/erf.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_erf_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.erf(); let expected = TensorData::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_erf_ops_with_negative_number() { let tensor = QTensor::::int8([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]); let output = tensor.erf(); let expected = TensorData::from([ [-0.06312324, -0.048490416, -0.10016122], [1.0000, 1.0000, 1.0000], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/exp.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_exp_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.exp(); let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/expand.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn expand_2d() { let tensor = QTensor::::int8([1.0, 2.0, 3.0]); let output = tensor.expand([3, 3]); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), Tolerance::absolute(1e-1), ); // Quantized [4.0, 7.0, 2.0, 3.0] let tensor = QTensor::::int8([4.0, 7.0, 2.0, 3.0]); let output = tensor.expand([2, 4]); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]), Tolerance::absolute(1e-1), ); } #[test] fn expand_3d() { let tensor = QTensor::::int8([[1.0, 2.0], [3.0, 4.0]]); let output = tensor.expand([3, 2, 2]); let expected = TensorData::from([ [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn expand_higher_dimensions() { let tensor = QTensor::::int8([[1.0, 2.0, 3.0, 4.0]]); let output = tensor.expand([2, 3, 4]); let expected = TensorData::from([ [ [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], ], [ [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], ], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn broadcast_single() { let tensor = QTensor::::int8([1.0]); let output = tensor.expand([2, 3]); output .dequantize() .into_data() .assert_eq(&TensorData::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), false); } #[test] #[should_panic] fn should_fail_expand_incompatible_shapes() { let tensor = QTensor::::int8([1.0, 2.0, 3.0]); let _expanded_tensor = tensor.expand([2, 2]); } #[test] fn should_all_negative_one() { let tensor = QTensor::::int8([1.0, 2.0, 3.0]); let output = tensor.expand([2, -1]); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[1., 2., 3.], [1., 2., 3.]]), Tolerance::absolute(1e-1), ); } #[test] #[should_panic] fn should_panic_negative_one_on_non_existing_dim() { let tensor = QTensor::::int8([1.0, 2.0, 3.0]); let _expanded_tensor = tensor.expand([-1, 3]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/flip.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn flip_float() { let tensor = QTensor::::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]); let flipped = tensor.clone().flip([0, 2]); let expected = TensorData::from([[[5., 4., 3.]], [[2., 1., 0.]]]); flipped .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); // Test with no flip let flipped = tensor.clone().flip([]); tensor.into_data().assert_eq(&flipped.into_data(), true); } #[test] #[should_panic] fn flip_duplicated_axes() { let tensor = QTensor::::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]); // Test with a duplicated axis let _ = tensor.flip([0, 0, 1]); } #[test] #[should_panic] fn flip_out_of_bound_axis() { let tensor = QTensor::::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]); // Test with an out of bound axis let _ = tensor.clone().flip([3, 0, 1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/floor.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_floor_ops() { let tensor = QTensor::::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let output = tensor.floor(); let expected = TensorData::from([[24., 87., 76.], [59., 43., 95.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(1e-1, 1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/gather_scatter.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::IndexingUpdateOp; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_gather_1d_dim0() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &Default::default()); let output = tensor.gather(0, indices); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]), Tolerance::absolute(1e-1), ); } #[test] fn should_gather_2d_dim0() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]], &Default::default()); let output = tensor.gather(0, indices); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_gather_2d_dim1() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]], &Default::default()); let output = tensor.gather(1, indices); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_gather_3d_dim1() { let tensor = QTensor::::int8([ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ]); let indices = TestTensorInt::from_ints( [[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &Default::default(), ); let output = tensor.gather(1, indices); let expected = TensorData::from([ [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_gather_2d_only_1dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::<2>::from_ints([[1, 2]], &Default::default()).reshape([2, 1]); let output = tensor.gather(1, indices); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[1.0], [5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_scatter_1d() { let tensor = QTensor::::int8([0.0, 0.0, 0.0]); let values = QTensor::::int8([5.0, 4.0, 3.0]); let indices = TestTensorInt::from_ints([1, 0, 2], &Default::default()); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([4.0, 5.0, 3.0]), Tolerance::absolute(1e-1), ); } #[test] fn should_scatter_2d_dim0() { let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); let values = QTensor::::int8([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]], &Default::default()); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_scatter_2d_dim1() { let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); let values = QTensor::::int8([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]], &Default::default()); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]), Tolerance::absolute(1e-1), ); } #[test] fn should_scatter_3d_dim1() { let tensor = QTensor::::int8([ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], ]); let values = QTensor::::int8([ [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], ]); let indices = TestTensorInt::from_ints( [[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &Default::default(), ); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([ [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]], ]); // Set higher tolerance (0.2) due to larger de/quantization errors output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(2e-1)); } #[test] fn should_scatter_2d_dim1_diff_shape() { let tensor = QTensor::::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); let values = QTensor::::int8([[1.0], [4.0]]); let indices = TestTensorInt::from_ints([[1], [2]], &Default::default()); let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); output .dequantize() .into_data() .assert_approx_eq::( &TensorData::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]), Tolerance::absolute(1e-1), ); } #[test] #[should_panic] fn scatter_should_panic_on_mismatch_of_shapes() { let tensor = QTensor::::int8([0.0, 0.0, 0.0]); let values = QTensor::::int8([1.0, 4.0]); let indices = TestTensorInt::from_ints([1, 0, 2], &Default::default()); tensor.scatter(0, indices, values, IndexingUpdateOp::Add); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_log_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.log(); let expected = TensorData::from([ [-f32::INFINITY, 0.0, core::f32::consts::LN_2], [1.0986, 1.3862, 1.6094], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log1p.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_exp_log1p() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.log1p(); let expected = TensorData::from([ [0.0, core::f32::consts::LN_2, 1.0986], [1.3862, 1.6094, 1.7917], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/map_comparison.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; #[test] fn test_equal() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); let data_expected = TensorData::from([[true, true, false], [true, false, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_not_equal() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); let data_expected = TensorData::from([[false, false, true], [false, true, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] #[ignore = "quantization equality with float element is undefined"] fn test_equal_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]); let data_actual_cloned = tensor.clone().equal_elem(2); let data_actual_inplace = tensor.equal_elem(2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] #[ignore = "quantization equality with float element is undefined"] fn test_not_equal_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]); let data_actual_cloned = tensor.clone().not_equal_elem(2); let data_actual_inplace = tensor.not_equal_elem(2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] #[ignore = "quantization equality with float element is undefined"] fn test_greater_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor.clone().greater_elem(4); let data_actual_inplace = tensor.greater_elem(4); let data_expected = TensorData::from([[false, false, false], [false, false, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_greater_equal_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor.clone().greater_equal_elem(4.0); let data_actual_inplace = tensor.greater_equal_elem(4.0); let data_expected = TensorData::from([[false, false, false], [false, true, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_greater() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); let data_actual_inplace = tensor_1.greater(tensor_2); let data_expected = TensorData::from([[false, false, true], [false, false, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_greater_equal() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.greater_equal(tensor_2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_lower_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor.clone().lower_elem(4.0); let data_actual_inplace = tensor.lower_elem(4.0); let data_expected = TensorData::from([[true, true, true], [true, false, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] #[ignore = "quantization equality with float element is undefined"] fn test_lower_equal_elem() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data_actual_cloned = tensor.clone().lower_equal_elem(4.0); let data_actual_inplace = tensor.lower_equal_elem(4.0); let data_expected = TensorData::from([[true, true, true], [true, true, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_lower() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); let data_actual_inplace = tensor_1.lower(tensor_2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } #[test] fn test_lower_equal() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]); let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.lower_equal(tensor_2); let data_expected = TensorData::from([[true, true, false], [true, true, false]]); assert_eq!(data_expected, data_actual_cloned.into_data()); assert_eq!(data_expected, data_actual_inplace.into_data()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mask.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_mask_where_ops() { let tensor = QTensor::::int8([[1.0, 7.0], [2.0, 3.0]]); let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &Default::default(), ); let value = QTensor::::int8([[1.8, 2.8], [3.8, 4.8]]); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_mask_fill_ops() { let tensor = QTensor::::int8([[1.0, 7.0], [2.0, 3.0]]); let mask = TestTensorBool::<2>::from_bool( TensorData::from([[true, false], [false, true]]), &Default::default(), ); let output = tensor.mask_fill(mask, 2.0); let expected = TensorData::from([[2.0, 7.0], [2.0, 2.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/maxmin.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_max_dim_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.max_dim(1); let expected = TensorData::from([[2.], [5.]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_max_dim_with_indices_2d_with_dim_0th() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let (output, index) = tensor.max_dim_with_indices(0); let output_expected = TensorData::from([[3., 4., 5.]]); let index_expected = TensorData::from([[1, 1, 1]]); output .dequantize() .into_data() .assert_approx_eq::(&output_expected, Tolerance::rel_abs(2e-2, 1e-2)); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_max_dim_with_indices_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let (output, index) = tensor.max_dim_with_indices(1); let output_expected = TensorData::from([[2.], [5.]]); let index_expected = TensorData::from([[2], [2]]); output .dequantize() .into_data() .assert_approx_eq::(&output_expected, Tolerance::rel_abs(2e-2, 1e-2)); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_min_dim_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.min_dim(1); let expected = TensorData::from([[0.], [3.]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_min_dim_with_indices_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let (output, index) = tensor.min_dim_with_indices(1); let output_expected = TensorData::from([[0.], [3.]]); let index_expected = TensorData::from([[0], [0]]); output .dequantize() .into_data() .assert_approx_eq::(&output_expected, Tolerance::rel_abs(2e-2, 1e-2)); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_min_dim_2d_with_0th_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.min_dim(0); let expected = TensorData::from([[0., 1., 2.]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_max_dim_2d_with_0th_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.max_dim(0); let expected = TensorData::from([[3., 4., 5.]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_min_dim_with_indices_2d_with_0th_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let (output, index) = tensor.min_dim_with_indices(0); let output_expected = TensorData::from([[0., 1., 2.]]); let index_expected = TensorData::from([[0, 0, 0]]); output .dequantize() .into_data() .assert_approx_eq::(&output_expected, Tolerance::rel_abs(2e-2, 1e-2)); index.into_data().assert_eq(&index_expected, false); } #[test] fn test_maximum_pair() { let a = QTensor::::int8([1.0, 5.0, 3.0, 4.0]); let b = QTensor::::int8([2.0, 1.0, 4.0, 5.0]); let output = a.max_pair(b); let expected = TensorData::from([2.0, 5.0, 4.0, 5.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_minimum_pair() { let a = QTensor::::int8([1.0, 5.0, 3.0, 4.0]); let b = QTensor::::int8([2.0, 1.0, 4.0, 5.0]); let output = a.min_pair(b); let expected = TensorData::from([1.0, 1.0, 3.0, 4.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs ================================================ pub use super::*; mod abs; mod add; mod aggregation; mod all; mod any; mod arg; mod cat; mod ceil; mod chunk; mod clamp; mod cos; mod cosh; mod div; mod erf; mod exp; mod expand; mod flip; mod floor; mod gather_scatter; mod log; mod log1p; mod map_comparison; mod mask; mod maxmin; mod mul; mod narrow; mod neg; mod permute; mod powf; mod powf_scalar; mod recip; mod remainder; mod repeat_dim; mod reshape; mod round; mod select; mod sin; mod sinh; mod slice; mod sort_argsort; mod split; mod sqrt; mod stack; mod sub; mod tan; mod tanh; mod topk; mod transpose; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mul.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_mul_ops() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = tensor_1.clone(); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(5e-2, 1e-2)); } #[test] fn test_mul_broadcast() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0]]); let tensor_2 = QTensor::::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn test_mul_broadcast_2_dims() { let tensor_1 = QTensor::::int8([[0.0], [1.0], [2.0]]); let tensor_2 = QTensor::::int8([[3.0, 4.0, 5.0]]); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_mul_scalar_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let output = tensor * scalar; let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/narrow.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::Tolerance; use burn_tensor::{Shape, TensorData}; #[test] fn test_narrow() { let tensor = QTensor::::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]); let output = tensor.clone().narrow(0, 0, 2); let expected = TensorData::from([[1., 2., 3.], [7., 8., 9.]]); assert_eq!(output.shape(), Shape::from([2, 3])); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); let output = tensor.narrow(1, 1, 2); let expected = TensorData::from([[2., 3.], [8., 9.], [14., 15.]]); assert_eq!(output.shape(), Shape::from([3, 2])); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[should_panic] fn test_narrow_invalid_dim() { let tensor = QTensor::::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]); let _output = tensor.narrow(2, 0, 2); } #[test] #[should_panic] fn test_narrow_invalid_start() { let tensor = QTensor::::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]); let _output = tensor.narrow(0, 3, 2); } #[test] #[should_panic] fn test_narrow_invalid_zero_length() { let tensor = QTensor::::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]); let _output = tensor.narrow(0, 1, 0); } #[test] #[should_panic] fn test_narrow_invalid_length() { let tensor = QTensor::::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]); let _output = tensor.narrow(0, 0, 4); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/neg.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_neg_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.neg(); let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::(); // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32 output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/permute.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn permute_float() { let tensor = QTensor::::int8([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., ]) .reshape([2, 2, 4]); let permuted = tensor.clone().permute([2, 1, 0]); let expected = TensorData::from([ [[0., 8.], [4., 12.]], [[1., 9.], [5., 13.]], [[2., 10.], [6., 14.]], [[3., 11.], [7., 15.]], ]); permuted .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(1e-1, 1e-1)); // Test with negative axis let permuted = tensor.clone().permute([-1, 1, 0]); permuted .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(1e-1, 1e-1)); // Test with the same axis let permuted = tensor.clone().permute([0, 1, 2]); permuted .dequantize() .into_data() .assert_approx_eq::( &tensor.dequantize().into_data(), Tolerance::rel_abs(1e-4, 1e-4), // dequant error should be the same ); } #[test] #[should_panic] fn edge_repeated_axes() { let tensor = QTensor::::int8([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., ]) .reshape([2, 2, 4]); // Test with a repeated axis let _ = tensor.permute([0, 0, 1]); } #[test] #[should_panic] fn edge_out_of_bound_axis() { let tensor = QTensor::::int8([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., ]) .reshape([2, 2, 4]); // Test with an invalid axis let _ = tensor.permute([3, 0, 1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_powf_ops() { let tensor = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_pow = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_power() { let tensor = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_pow = QTensor::::int8([[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]]); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_values_with_even_power() { let tensor = QTensor::::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let tensor_pow = QTensor::::int8([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_values_with_odd_power() { let tensor = QTensor::::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]]); let tensor_pow = QTensor::::int8([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]); let output = tensor.powf(tensor_pow); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf_scalar.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_powf_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.powf_scalar(0.71); let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_power() { let tensor = QTensor::::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.powf_scalar(-0.33); let expected = TensorData::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_values_with_even_power() { let tensor = QTensor::::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); let output = tensor.powf_scalar(2.0); let expected = TensorData::from([[0., 1., 4.], [9., 16., 25.]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } #[test] fn should_support_neg_values_with_odd_power() { let tensor = QTensor::::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]]); let output = tensor.powf_scalar(3.0); let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(4e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/recip.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_recip_ops() { let tensor = QTensor::::int8([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]); let output = tensor.recip(); let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/remainder.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_remainder_basic() { let lhs = QTensor::::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 2.0]); let rhs = QTensor::::int8([2.0, 3.0, 1.0, 2.0, 1.0, 2.0]); let output = lhs.remainder(rhs); let expected = TensorData::from([1., 1., 0., 1., 0., 0.]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[ignore = "quantization remainder with float element is undefined"] fn should_support_remainder_basic_scalar() { let tensor = QTensor::::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let output = tensor.remainder_scalar(2.0); let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_remainder_float() { let lhs = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let rhs = QTensor::::int8([1.4233, 2.7313, 0.2641, 1.9651, 0.5897]); let output = lhs.remainder(rhs); let expected = TensorData::from([1., 2., 0.0949, 0.0698, 0.2824]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_remainder_float_scalar() { let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let output = tensor.remainder_scalar(-1.5); let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_be_zero() { let lhs = QTensor::::int8([0.0, 0.0, 0.0]); let rhs = QTensor::::int8([3.5, -2.1, 1.5]); let output = lhs.remainder(rhs); let expected = TensorData::from([0.0, 0.0, 0.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_be_zero_scalar() { let tensor = QTensor::::int8([0.0, 0.0, 0.0]); let output = tensor.remainder_scalar(3.5); let expected = TensorData::from([0.0, 0.0, 0.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_have_no_remainder() { let lhs = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let rhs = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let output = lhs.remainder(rhs); let expected = TensorData::from([0.0, 0.0, 0.0, 0.0, 0.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_have_no_remainder_scalar() { let tensor = QTensor::::int8([4.0, 4.0]); let output = tensor.remainder_scalar(4.0); let expected = TensorData::from([0.0, 0.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_be_negative() { let lhs = QTensor::::int8([-7.0, -3.0, 2.0, 6.0]); let rhs = QTensor::::int8([-2.5, -2.1, -1.5, -3.25]); let output = lhs.remainder(rhs); let expected = TensorData::from([-2., -0.9, -1., -0.5]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_be_negative_scalar() { let tensor = QTensor::::int8([-7.0, -3.0, 2.0, 6.0]); let output = tensor.remainder_scalar(-2.5); let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_fp_dividends() { let tensor = QTensor::::int8([-7.5, -2.5, 2.5, 7.5]); let output = tensor.remainder_scalar(3.0); let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_large_divisor() { let lhs = QTensor::::int8([-1.0, 1.0, -1.5, 1.5, -1.0, 1.0, -1.5, 1.5]); let rhs = QTensor::::int8([10.0, 10.0, 10.0, 10.0, -10.0, -10.0, -10.0, -10.0]); let output = lhs.remainder(rhs); let expected = TensorData::from([9., 1., 8.5, 1.5, -1., -9., -1.5, -8.5]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_large_divisor_scalar() { let tensor = QTensor::::int8([-1.0, 1.0]); let output = tensor.remainder_scalar(10.0); let expected = TensorData::from([9.0, 1.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_remainder_op() { let lhs = QTensor::::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 2.0]); let rhs = QTensor::::int8([2.0, 3.0, 1.0, 2.0, 1.0, 2.0]); let output = lhs % rhs; let expected = TensorData::from([1., 1., 0., 1., 0., 0.]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[ignore = "quantization remainder with float element is undefined"] fn should_support_remainder_scalar_op() { let tensor = QTensor::::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); let output = tensor % 2.0; let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/repeat_dim.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_support_repeat_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0, 3.0]]); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([ [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0, 3.0], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::permissive()); } #[test] fn should_support_repeat_on_dims_larger_than_1() { let tensor = QTensor::::int8([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., ]) .reshape([4, 2, 2]); let output = tensor.repeat_dim(2, 2); let expected = TensorData::from([ [[0., 1., 0., 1.], [2., 3., 2., 3.]], [[4., 5., 4., 5.], [6., 7., 6., 7.]], [[8., 9., 8., 9.], [10., 11., 10., 11.]], [[12., 13., 12., 13.], [14., 15., 14., 15.]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(1e-1, 1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/reshape.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn should_support_reshape_1d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0, 3.0]]); let output = tensor.clone().reshape([1, 4]); let expected = TensorData::from([[0.0, 1.0, 2.0, 3.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_reshape_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.clone().reshape([6]); let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_dim_infererence() { let tensor = QTensor::::int8([ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, ]) .reshape([4, 3]); // Infer the dimension via -1 let reshaped = tensor.clone().reshape([2, -1]); assert_eq!(reshaped.shape(), [2, 6].into()); // Infer the dimension via 0 (keep from the source) and -1 (infer) let reshaped = reshaped.reshape([0, 2, -1]); assert_eq!(reshaped.shape(), [2, 2, 3].into()); // This is effectively as if we did a flatten let reshaped = tensor.clone().reshape([-1]); assert_eq!(reshaped.shape(), [12].into()); // Keeping the first dimension the same (using 0) let reshaped = tensor.clone().reshape([0, 3]); assert_eq!(reshaped.shape(), [4, 3].into()); } #[test] fn should_not_corrupt_after_slice() { let zeros = QTensor::::int8([0.0, 0.0]); zeros.clone().slice([1..2]).reshape([1]).exp(); // May lead to zeroes being equal to [0.0, 1.0] zeros.dequantize().into_data().assert_eq( &TestTensor::<1>::zeros([2], &Default::default()).to_data(), true, ); } #[test] #[should_panic] fn multiple_neg_ones() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let _ = tensor.reshape([-1, -1]); } #[test] #[should_panic] fn neg_value() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let _ = tensor.reshape([-2, -1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/round.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_round_ops() { let tensor = QTensor::::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); let output = tensor.round(); let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]); output .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_round_ties_even() { // NOTE: round ties to even only affects values that are exact halfway from ceil/floor, so quantization // errors can impact this. This basically only guarantees the values for the max value in the range since // it is always represented correctly. let tensor = QTensor::::int8([5.5]); let output = tensor.round(); let expected = TensorData::from([6.]); output .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/select.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::IndexingUpdateOp; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_select_1d() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0]); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default()); let output = tensor.select(0, indices); let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_select_2d_dim0_same_num_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_data([1, 0], &Default::default()); let output = tensor.select(0, indices); let expected = TensorData::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_select_2d_dim0_more_num_dim() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_data([1, 0, 1, 1], &Default::default()); let output = tensor.select(0, indices); let expected = TensorData::from([ [3.0, 4.0, 5.0], [0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_select_2d_dim1() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default()); let output = tensor.select(1, indices); let expected = TensorData::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_select_assign_1d() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let values = QTensor::::int8([5.0, 4.0, 3.0, 2.0, 1.0]); let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &Default::default()); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([3.0, 12.0, 3.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_select_assign_2d_dim0() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let values = tensor.clone(); let indices = TestTensorInt::from_data(TensorData::from([1, 0]), &Default::default()); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([[3.0, 5.0, 7.0], [3.0, 5.0, 7.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_select_assign_2d_dim1() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let values = tensor.clone(); let indices = TestTensorInt::from_data(TensorData::from([1, 0, 2]), &Default::default()); let output = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([[1.0, 1.0, 4.0], [7.0, 7.0, 10.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[should_panic] fn should_select_panic_invalid_dimension() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default()); tensor.select(10, indices); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sin.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_sin_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sin(); let expected = TensorData::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sinh.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_sinh_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sinh(); let expected = TensorData::from([[0.0000, 1.1752, 3.6269], [10.0179, 27.2899, 74.2032]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(3e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/slice.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::{Tolerance, s}; #[test] fn should_support_full_sliceing_1d() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0]); let data = tensor.to_data(); let output = tensor.slice([0..4]); output.into_data().assert_eq(&data, false); } #[test] fn should_support_partial_sliceing_1d() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0]); let output = tensor.slice([1..3]); let expected = TensorData::from([1.0, 2.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_full_sliceing_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = tensor.to_data(); let output = tensor.clone().slice([0..2]); output.into_data().assert_eq(&data, true); let output = tensor.slice([0..2, 0..3]); output.into_data().assert_eq(&data, true); } #[test] fn should_support_partial_sliceing_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.slice([0..2, 0..2]); let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_partial_sliceing_3d() { let tensor = QTensor::::int8([ [[0., 1., 2., 3.], [4., 5., 6., 7.]], [[8., 9., 10., 11.], [12., 13., 14., 15.]], ]); let output = tensor.slice([1..2, 1..2, 0..2]); let expected = TensorData::from([[[12.0, 13.0]]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_partial_sliceing_3d_non_contiguous() { let tensor = QTensor::::int8([ [[0., 1., 2., 3.], [4., 5., 6., 7.]], [[8., 9., 10., 11.], [12., 13., 14., 15.]], ]); let output = tensor.transpose().slice([1..2, 1..2, 0..2]); let expected = TensorData::from([[[9.0, 13.0]]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn should_support_slice_assign_1d() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let tensor_assigned = QTensor::::int8([10.0, 5.0]); let output = tensor.slice_assign([0..2], tensor_assigned); let expected = TensorData::from([10.0, 5.0, 2.0]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_slice_assign_2d() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_assigned = QTensor::::int8([[10.0, 5.0]]); let output = tensor.slice_assign([1..2, 0..2], tensor_assigned); let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn slice_should_not_corrupt_potentially_inplace_operations() { let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]); let expected = TensorData::from([4., 6., 8.]); tensor .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn slice_assign_should_not_corrupt_potentially_inplace_operations() { let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let values = QTensor::::int8([10., 20., 30.]); let tensor_1 = tensor.clone().slice_assign([0..3], values); let tensor_2 = tensor + 2; let expected = TensorData::from([10., 20., 30., 4., 5.]); tensor_1 .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); let expected = TensorData::from([3., 4., 5., 6., 7.]); tensor_2 .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn clamp_when_slice_exceeds_dimension() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let data = tensor.to_data(); let output = tensor.slice([0..4]); output.into_data().assert_eq(&data, true); } #[test] fn negative_dimensions() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = tensor.to_data(); // Clamping to the tensor dimensions let output = tensor.clone().slice([0..4, 0..4]); output.into_data().assert_eq(&data, true); // Negative dimensions let output = tensor.clone().slice([0..1, 0..1]); let data = TensorData::from([[0.0f32]]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); let output = tensor.slice(s![0..-1, 0..-2]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] fn missing_dimensions() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = tensor.to_data(); // Clamping to the tensor dimensions let output = tensor.clone().slice([0..4, 0..4]); output.into_data().assert_eq(&data, true); // Negative dimensions let data = TensorData::from([[0.0f32]]); let output = tensor.clone().slice(s![0..-1, 0..-2]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); // Missing dimensions let output = tensor.clone().slice(s![0..1, ..]); let data = TensorData::from([[0.0f32, 1.0, 2.0]]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); let output = tensor.clone().slice(s![.., 0..2]); let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); let output = tensor.clone().slice([.., ..]); let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]); output .dequantize() .into_data() .assert_approx_eq::(&data, Tolerance::rel_abs(2e-2, 1e-2)); } #[test] #[should_panic] fn should_panic_when_slice_with_too_many_dimensions() { let tensor = QTensor::::int8([0.0, 1.0, 2.0]); let _output = tensor.slice([0..1, 0..1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sort_argsort.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_sort_1d_float() { // Quantized [0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1] let tensor = QTensor::::int8([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let values = tensor.sort(0); let values_expected = TensorData::from([ -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 5.2, ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); } #[test] fn test_argsort_1d_float() { let tensor = QTensor::::int8([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let indices = tensor.argsort(0); let indices_expected = TensorData::from([12, 6, 2, 3, 0, 5, 10, 1, 4, 7, 11, 9, 8]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_with_indices_descending_float() { // 1D let tensor = QTensor::::int8([ 0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1, ]); // Sort along dim=0 let (values, indices) = tensor.sort_descending_with_indices(0); let values_expected = TensorData::from([ 5.2, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1, ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]); indices.into_data().assert_eq(&indices_expected, false); // 3D // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1] let tensor = QTensor::::int8([ -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1, ]) .reshape([2, 2, 3]); // Sort along dim=1 let (values, indices) = tensor.sort_descending_with_indices(1); let values_expected = TensorData::from([ [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]], [[0.99, 3., 4.], [-0.3, 2.3, -8.1]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_float() { let tensor = QTensor::::int8([ -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1, ]) .reshape([2, 2, 3]); // Sort along dim=0 let values = tensor.clone().sort(0); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); // Sort along dim=1 let values = tensor.clone().sort(1); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); // Sort along dim=2 let values = tensor.sort(2); let values_expected = TensorData::from([ [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); } #[test] fn test_sort_with_indices_float() { let tensor = QTensor::::int8([ -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1, ]) .reshape([2, 2, 3]); // Sort along dim=0 let (values, indices) = tensor.clone().sort_with_indices(0); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]], [[-0.3, 2.3, 4.], [0.99, 3., 0.94]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let (values, indices) = tensor.clone().sort_with_indices(1); let values_expected = TensorData::from([ [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]], [[-0.3, 2.3, -8.1], [0.99, 3., 4.]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let (values, indices) = tensor.sort_with_indices(2); let values_expected = TensorData::from([ [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]], [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]], ]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_argsort_float() { let tensor = QTensor::::int8([ -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1, ]) .reshape([2, 2, 3]); // Sort along dim=0 let indices = tensor.clone().argsort(0); let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let indices = tensor.clone().argsort(1); let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let indices = tensor.argsort(2); let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_descending_1d() { let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); // Sort along dim=0 let values = tensor.sort_descending(0); let values_expected = TensorData::from([5., 4., 3., 2., 1.]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/split.rs ================================================ use super::qtensor::*; use super::*; use alloc::vec; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_split_evenly_divisible() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let tensors = tensor.split(2, 0); assert_eq!(tensors.len(), 3); let expected = [ TensorData::from([0., 1.]), TensorData::from([2., 3.]), TensorData::from([4., 5.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_split_not_evenly_divisible() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); let tensors = tensor.split(2, 0); assert_eq!(tensors.len(), 4); let expected = [ TensorData::from([0., 1.]), TensorData::from([2., 3.]), TensorData::from([4., 5.]), TensorData::from([6.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_split_along_dim1() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensors = tensor.split(2, 1); assert_eq!(tensors.len(), 2); let expected = [ TensorData::from([[0., 1.], [3., 4.]]), TensorData::from([[2.], [5.]]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] fn test_split_split_size_larger_than_tensor_size() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let tensors = tensor.split(10, 0); assert_eq!(tensors.len(), 1); let expected = [TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] #[should_panic( expected = "split_size must be greater than 0 unless the tensor size along the dimension is 0." )] fn test_split_with_zero_split_size_non_zero_tensor() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let _ = tensor.split(0, 0); } #[test] #[should_panic(expected = "Given dimension is greater than or equal to the tensor rank.")] fn test_split_invalid_dim() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let _ = tensor.split(1, 2); } #[test] fn test_split_with_sizes() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let tensors = tensor.split_with_sizes(vec![2, 3, 1], 0); assert_eq!(tensors.len(), 3); let expected = [ TensorData::from([0., 1.]), TensorData::from([2., 3., 4.]), TensorData::from([5.]), ]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } #[test] #[should_panic( expected = "The sum of split_sizes must equal the tensor size along the specified dimension." )] fn test_split_with_sizes_invalid_sum() { let tensor = QTensor::::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]); let _ = tensor.split_with_sizes(vec![2, 2, 1], 0); } #[test] fn test_split_with_sizes_zero_length() { let tensor = QTensor::::int8([0.0, 2.0, 5.0]); let tensors = tensor.split_with_sizes(vec![0, 1, 2], 0); assert_eq!(tensors.len(), 2); let expected = [TensorData::from([0.]), TensorData::from([2., 5.])]; for (index, tensor) in tensors.into_iter().enumerate() { tensor .dequantize() .to_data() .assert_approx_eq::(&expected[index], Tolerance::absolute(1e-1)); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sqrt.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; use core::f32::consts::SQRT_2; #[test] fn should_support_sqrt_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.sqrt(); let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/stack.rs ================================================ use super::qtensor::*; use super::*; use alloc::vec; use burn_tensor::Tensor; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_stack_ops_2d_dim0() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0, 6.0]]); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_stack_ops_2d_dim1() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0, 6.0]]); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 1); let expected = TensorData::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_stack_ops_3d() { let tensor_1 = QTensor::::int8([[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]]); let tensor_2 = QTensor::::int8([[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]]); let output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([ [[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]], [[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] #[should_panic] fn should_panic_when_dimensions_are_not_the_same() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 5.0]]); let _output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); } #[test] #[should_panic] fn should_panic_when_stack_exceeds_dimension() { let tensor_1 = QTensor::::int8([[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]]); let tensor_2 = QTensor::::int8([[[4.0, 5.0, 6.0]]]); let _output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 3); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sub.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_sub_ops() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor_2 = QTensor::::int8([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn test_sub_broadcast() { let tensor_1 = QTensor::::int8([[0.0, 1.0, 2.0]]); let tensor_2 = QTensor::::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_sub_scalar_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let scalar = 2.0; let output = tensor - scalar; let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(2e-2, 1e-2)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tan.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_tan_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.tan(); let expected = TensorData::from([[0.0, 1.5574, -2.1850], [-0.1425, 1.1578, -3.3805]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tanh.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_tanh_ops() { let tensor = QTensor::::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let output = tensor.tanh(); let expected = TensorData::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/topk.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_topk_1d() { let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let values = tensor.topk(3, /*dim*/ 0); let expected = TensorData::from([5., 4., 3.]); values .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn test_topk() { let tensor = QTensor::::int8([ [[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]], ]); let values = tensor.topk(2, /*dim*/ 2); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); values .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn test_topk_with_indices() { // 1D let tensor = QTensor::::int8([1.0, 2.0, 3.0, 4.0, 5.0]); let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); let values_expected = TensorData::from([5., 4., 3.]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::permissive()); let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_topk_with_indices_3d() { // 3D let tensor = QTensor::::int8([ [[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]], ]); let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); values .dequantize() .into_data() .assert_approx_eq::(&values_expected, Tolerance::absolute(1e-1)); let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/transpose.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn should_support_transpose_ops() { let tensor = QTensor::::int8([ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, ]) .reshape([2, 2, 3]); let output = tensor.transpose(); let expected = TensorData::from([ [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } #[test] fn should_support_swap_dims() { let tensor = QTensor::::int8([ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, ]) .reshape([2, 2, 3]); let output = tensor.swap_dims(0, 2); let expected = TensorData::from([ [[0.0, 6.0], [3.0, 9.0]], [[1.0, 7.0], [4.0, 10.0]], [[2.0, 8.0], [5.0, 11.0]], ]); output .dequantize() .into_data() .assert_approx_eq::(&expected, Tolerance::absolute(1e-1)); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/matmul.rs ================================================ use super::qtensor::*; use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] #[ignore] fn test_matmul_vectors() { let tensor_1 = QTensor::::int8([[1.0, 2.0, 3.0, 6.35]]); let tensor_2 = QTensor::::int8([[12.7], [4.0], [5.0], [1.0]]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[42.05]]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] #[ignore] fn test_matmul_2d() { let tensor_1 = QTensor::::int8([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]]); let tensor_2 = QTensor::::int8([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[16.7, 27.05, 50.8], [14., 25., 43.4], [10., 17., 30.7]]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] fn test_matmul_2d_aligned() { let tensor_1 = QTensor::::int8([ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ]); let tensor_2 = QTensor::::int8([ [2.0, 0.0, 1.0, 0.0], [1.0, 2.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [1.0, 0.0, 0.0, 1.0], ]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [8.0, 7.0, 7.0, 4.0], [24.0, 19.0, 19.0, 8.0], [40.0, 31.0, 31.0, 12.0], ]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] fn test_matmul_2d_aligned_fused() { let tensor_1 = QTensor::::int8([ [1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], ]); let tensor_2 = QTensor::::int8([ [2.0, 0.0, 1.0, 0.0], [1.0, 2.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [1.0, 0.0, 0.0, 1.0], ]); let tensor_3 = tensor_1.matmul(tensor_2); let tensor_4 = tensor_3 / 2.0; let expected = TensorData::from([ [4.0, 3.5, 3.5, 2.0], [12.0, 9.5, 9.5, 4.0], [20.0, 15.5, 15.5, 6.0], ]); tensor_4 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] #[ignore] fn test_matmul_3d() { let tensor_1 = QTensor::::int8([[[1.0, 6.35], [2.0, 3.0]]]); let tensor_2 = QTensor::::int8([[[12.7, 4.0], [2.0, 3.0]]]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[25.4, 23.05], [31.4, 17.0]]]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] #[ignore] fn test_matmul_broadcast_4d() { let tensor_1 = QTensor::::int8([[[[1.0, 7.0], [2.0, 3.0]]], [[[2.0, 5.0], [6.0, 3.0]]]]); let tensor_2 = QTensor::::int8([[[[9.0, 8.0], [1.0, 4.0]], [[2.0, 7.0], [3.0, 5.0]]]]); // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [[[16.0, 36.0], [21.0, 28.0]], [[23.0, 42.0], [13.0, 29.0]]], [[[23.0, 36.0], [57.0, 60.0]], [[19.0, 39.0], [21.0, 57.0]]], ]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] #[ignore] fn test_matmul_broadcast() { let tensor_1 = QTensor::::int8([[[1.0, 7.0], [2.0, 3.0]]]); let tensor_2 = QTensor::::int8([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]); tensor_3 .into_data() .assert_approx_eq::(&expected, Tolerance::relative(2e-2)); } #[test] #[should_panic] fn should_panic_when_inner_dimensions_are_not_equal() { let tensor_1 = QTensor::::int8([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]); let tensor_2 = QTensor::::int8([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]); let _ = tensor_1.matmul(tensor_2); } #[test] fn test_matmul_lhs_float_rhs_quantized() { // Simulates a typical workflow with linear layers (e.g., transformers), where the rhs // represents the weights. The lhs might be a float if a previous operation did not propagate // the quantization. We still want to perform an efficient matmul with quantized weights. let tensor_1 = TestTensor::<2>::from([ [1.0, 6.35, 2.0, 3.0], [2.0, 3.0, 4.0, 5.0], [1.0, 3.0, 5.0, 7.0], ]); let tensor_2 = QTensor::::int8([ [4.0, 8.0, 12.7, 1.6], [2.0, 3.0, 6.0, 4.0], [1.0, 5.0, 9.0, 2.5], [3.0, 7.0, 11.0, 0.5], ]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [27.7, 58.05, 101.8, 33.5], [33., 80., 134.4, 27.7], [36., 91., 152.7, 29.6], ]); let output = tensor_3.into_data(); output.assert_approx_eq::(&expected, Tolerance::default()); // Default quantization scheme does not propagate quantization with matmul assert!(output.dtype.is_float()); } #[test] fn test_matmul_mixed_block_scale() { let tensor_1 = TestTensor::<2>::from([ [1.0, 6.35, 2.0, 3.0], [2.0, 3.0, 4.0, 5.0], [1.0, 3.0, 5.0, 7.0], ]); let tensor_2 = QTensor::::int8_block([ [ 6.110, 4.0, 9.360, 7.850, 0.630, 1.770, 0.430, 7.550, 9.690, 3.560, 2.920, 9.130, 3.390, 0.510, 1.620, 1.460, ], [ 6.140, 8.260, 5.660, 5.610, 7.070, 3.050, 9.890, 5.520, 1.350, 3.810, 5.630, 0.250, 0.350, 8.860, 3.610, 6.240, ], [ 8.810, 4.620, 7.420, 8.110, 2.560, 4.710, 5.730, 8.980, 1.170, 6.090, 4.140, 3.610, 4.960, 9.720, 5.710, 1.470, ], [ 2.260, 9.640, 6.320, 6.980, 9.860, 1.030, 8.340, 1.570, 4.140, 4.760, 4.590, 6.400, 5.350, 1.430, 4.960, 1.180, ], ]); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [ 69.499, 94.611, 79.101, 80.633, 80.225, 33.647, 99.711, 65.272, 33.022, 54.213, 60.721, 37.138, 31.582, 80.501, 50.843, 47.564, ], [ 77.180, 99.460, 96.980, 99.870, 82.010, 36.680, 95.150, 75.430, 48.810, 66.710, 62.240, 65.450, 54.420, 73.630, 61.710, 33.420, ], [ 84.400, 119.360, 107.680, 114.090, 103.660, 41.680, 117.130, 80.0, 48.570, 78.760, 72.640, 72.730, 66.690, 85.700, 75.720, 35.790, ], ]); let output = tensor_3.into_data(); output.assert_approx_eq::(&expected, Tolerance::permissive()); // Default quantization scheme does not propagate quantization with matmul assert!(output.dtype.is_float()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/mod.rs ================================================ pub use super::*; mod matmul; mod quantize; // TODO: re-enable for cubecl backends when inputs are valid for packed U32 storage #[cfg(feature = "ndarray")] mod extended; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/ops/quantize.rs ================================================ use super::*; use alloc::{vec, vec::Vec}; use burn_tensor::quantization::{ QParams, QTensorPrimitive, QuantLevel, QuantScheme, QuantStore, QuantValue, QuantizationParameters, QuantizedBytes, }; use burn_tensor::{DType, Element, TensorData}; use burn_tensor::{Tolerance, ops::QuantizedTensor}; fn get_q_params(data: TensorData) -> QParams> { let num_elements = data.num_elements(); let scheme = if let DType::QFloat(scheme) = data.dtype { scheme } else { unreachable!() }; let q_bytes = QuantizedBytes { bytes: data.into_bytes(), scheme, num_elements, }; q_bytes.into_vec_i8().1 } #[test] fn should_support_quantize_symmetric_int8() { // Strict equality was based on full precision if !matches!(FloatElem::dtype(), DType::F32) { return; } let device = Default::default(); let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let qparams = QuantizationParameters { scales: TestTensor::from_floats([0.014_173_228], &device), }; let x_q = tensor.clone().quantize(&scheme, qparams); let x_q_data = x_q.to_data(); let expected = TensorData::quantized( vec![-127i8, -71, 0, 35], [4], scheme.with_store(QuantStore::Native), &[0.014_173_228], // scale ); // Values equality x_q_data.assert_eq(&expected, false); // Quantization parameters check let qparams = get_q_params(x_q_data); let expected = get_q_params(expected); assert_eq!(qparams.scales.len(), 1); // TODO: check scales assert_eq!(qparams.scales, expected.scales); // Dequantize let x = x_q.dequantize(); x.into_data() .assert_approx_eq::(&tensor.into_data(), Tolerance::rel_abs(1e-1, 1e-2)); } #[test] fn should_support_quantize_dynamic_int8() { let device = Default::default(); // NOTE: we use fully representable values since different backend implementations could differ slightly // due to rounding discrepancies let tensor = TestTensor::<1>::from_floats([5., 0., 4., -12.7], &device); let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let x_q = tensor.quantize_dynamic(&scheme); let expected = TensorData::quantized( vec![50i8, 0, 40, -127], [4], scheme.with_store(QuantStore::Native), &[0.1], // scale ); x_q.into_data().assert_eq(&expected, false); } #[test] fn should_quantize_dequantize_symmetric_single_with_transform() { let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let input = TestTensorInt::<1>::arange(0..32, &Default::default()).float(); let quant = input.quantize_dynamic(&scheme); let result = quant * 10; let data = result.into_data(); let expected = [ 0.0, 9.76378, 19.52756, 29.29134, 39.05512, 48.818897, 61.02362, 70.7874, 80.551186, 90.31496, 100.07874, 109.84252, 119.60631, 129.37009, 139.13387, 148.89764, 161.10237, 170.86615, 180.62991, 190.39369, 200.15749, 209.92126, 219.68504, 229.44882, 239.21262, 248.97638, 261.1811, 270.9449, 280.70865, 290.47244, 300.23624, 310.0, ]; data.assert_approx_eq::(&TensorData::from(expected), Tolerance::permissive()); } #[test] fn should_quantize_dequantize_symmetric_arange_16x16() { let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let input: TestTensor<2> = TestTensorInt::arange(0..256, &Default::default()) .float() .div_scalar(256.) .reshape([16, 16]); let output = input.clone().quantize_dynamic(&scheme); let output = output.dequantize(); output.into_data().assert_approx_eq::( &input.into_data(), Tolerance::absolute(1e-1).set_relative(1e-2), ); } #[test] fn should_quantize_dequantize_symmetric_per_block_arange_16x16() { let scheme = QuantizedTensor::::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([2, 16])); let input: TestTensor<2> = TestTensorInt::arange(0..256, &Default::default()) .float() .div_scalar(256.) .reshape([16, 16]); let output = input.clone().quantize_dynamic(&scheme); let output = output.dequantize(); output.into_data().assert_approx_eq::( &input.into_data(), Tolerance::absolute(1e-1).set_relative(1e-2), ); } fn should_quantize_transposed(tensor: Tensor, scheme: QuantScheme) { let tensor_t = tensor.clone().transpose(); let output = tensor_t.quantize_dynamic(&scheme).dequantize().transpose(); tensor.into_data().assert_approx_eq::( &output.into_data(), Tolerance::absolute(1e-1).set_relative(1e-2), ); } #[test] fn should_quantize_symmetric_int8_transposed_8x32() { let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let tensor = TestTensorInt::arange(0..256, &Default::default()) .float() .div_scalar(256.) .reshape([8, 32]); should_quantize_transposed(tensor, scheme); } #[test] fn should_quantize_symmetric_int8_transposed_48x64() { let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let tensor = TestTensorInt::arange(0..3072, &Default::default()) .float() .div_scalar(3072.) .reshape([48, 64]); should_quantize_transposed(tensor, scheme); } #[test] fn should_quantize_symmetric_per_block_int8_transposed_32x64() { let scheme = QuantizedTensor::::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([32])); let tensor = TestTensorInt::arange(0..2048, &Default::default()) .float() .div_scalar(2048.) .reshape([32, 64]); should_quantize_transposed(tensor, scheme); } #[test] fn should_quantize_symmetric_int8_permuted_batch_dims() { let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let tensor = TestTensorInt::arange(0..2048, &Default::default()) .float() .div_scalar(2048.) .reshape([2, 4, 8, 32]); // Permute [0,1,2,3] -> [1,2,0,3] // This rearranges batch dims but keeps packed dim in place let tensor_permuted = tensor.clone().permute([1, 2, 0, 3]); let output = tensor_permuted .quantize_dynamic(&scheme) .dequantize() .permute([2, 0, 1, 3]); // reverse permutation tensor.into_data().assert_approx_eq::( &output.into_data(), Tolerance::absolute(1e-1).set_relative(1e-2), ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/quantization/scheme.rs ================================================ use super::*; use burn_tensor::Tolerance; use burn_tensor::{ Element, TensorData, ops::QuantizedTensor, quantization::{CalibrationRange, QTensorPrimitive, QuantLevel, QuantValue, compute_q_params}, }; #[test] fn per_tensor_symmetric_int8() { let device = Default::default(); let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let range = CalibrationRange { min: TestTensor::<1>::from_floats([0.5], &device), max: TestTensor::<1>::from_floats([1.8], &device), }; let qparams = compute_q_params(&scheme, range); qparams .scales .into_data() .assert_approx_eq::(&TensorData::from([0.014_173_23]), Tolerance::default()); } #[test] fn per_block_symmetric_int8() { let device = Default::default(); let scheme = QuantizedTensor::::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([4])); let range = CalibrationRange { min: TestTensor::<1>::from_floats([-1.8, -0.5, 0.01, -0.04], &device), max: TestTensor::<1>::from_floats([0.5, 1.8, 0.04, -0.01], &device), }; let qparams = compute_q_params(&scheme, range); qparams.scales.into_data().assert_approx_eq::( &TensorData::from([0.014_173_23, 0.014_173_23, 0.000_314_96, 0.000_314_96]), Tolerance::default(), ); } #[test] fn quant_scheme_should_inhibit_by_default() { let device = Default::default(); let scheme = QuantizedTensor::::default_scheme().with_value(QuantValue::Q8S); let tensor_1 = TestTensor::<2>::from_floats( [[1.0, 6.35, 0., 0.], [2.0, 3.0, 0., 0.], [1.0, 3.0, 0., 0.]], &device, ) .quantize_dynamic(&scheme); let _tensor_2 = TestTensor::<2>::from_floats( [ [4.0, 8.0, 12.7, 0.], [2.0, 3.0, 6.0, 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], ], &device, ) .quantize_dynamic(&scheme); // let tensor_3 = tensor_1.clone().matmul(tensor_2); // assert_eq!(tensor_3.to_data().dtype, FloatElem::dtype()); let tensor_4 = tensor_1.add_scalar(1.); assert_eq!(tensor_4.to_data().dtype, FloatElem::dtype()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/cov.rs ================================================ use super::*; use burn_tensor::{TensorData, Tolerance}; #[test] fn test_cov_1() { let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 1); let expected = TensorData::from([[2.48917, -1.73333], [-1.73333, 15.33333]]).convert::(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); output .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_cov_4() { let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 0); let expected = TensorData::from([[1.86687, -1.30000], [-1.30000, 11.5]]).convert::(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); output .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_cov_2() { let data = TensorData::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]); let tensor = TestTensor::<2>::from_data(data, &Default::default()); let output = tensor.cov(1, 1); let expected = TensorData::from([ [0.845, -1.43, -4.55, -3.25], [-1.43, 2.42, 7.7, 5.5], [-4.55, 7.7, 24.5, 17.5], [-3.25, 5.5, 17.5, 12.5], ]) .convert::(); let tolerance = Tolerance::default().set_half_precision_relative(1e-3); output .into_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_cov_3() { let data = TensorData::from([ [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], ]); let device = Default::default(); let tensor = TestTensor::<3>::from_data(data, &device); let data_actual = tensor.cov(0, 1).into_data(); let data_expected = TestTensor::<3>::zeros([4, 4, 4], &device).to_data(); data_expected.assert_approx_eq::(&data_actual, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/display.rs ================================================ use super::*; use burn_tensor::backend::Backend; use burn_tensor::{Element, Shape, TensorData}; type FloatElem = ::FloatElem; type IntElem = ::IntElem; // Floating point values might not match for other precisions fn skip_precision_not_f32() -> bool { core::any::TypeId::of::() != core::any::TypeId::of::() } #[test] fn test_display_2d_int_tensor() { let int_data = TensorData::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let tensor_int = TestTensorInt::<2>::from_data(int_data, &Default::default()); let output = format!("{}", tensor_int); let expected = format!( r#"Tensor {{ data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]], shape: [3, 3], device: {:?}, backend: {:?}, kind: "Int", dtype: "{dtype}", }}"#, tensor_int.device(), TestBackend::name(&tensor_int.device()), dtype = core::any::type_name::(), ); assert_eq!(output, expected); } #[test] fn test_display_2d_float_tensor() { if skip_precision_not_f32() { return; } let float_data = TensorData::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]); let tensor_float = TestTensor::<2>::from_data(float_data, &Default::default()); let output = format!("{}", tensor_float); let expected = format!( r#"Tensor {{ data: [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]], shape: [3, 3], device: {:?}, backend: {:?}, kind: "Float", dtype: "f32", }}"#, tensor_float.device(), TestBackend::name(&tensor_float.device()), ); assert_eq!(output, expected); } #[test] fn test_display_2d_bool_tensor() { let bool_data = TensorData::from([ [true, false, true], [false, true, false], [false, true, true], ]); let tensor_bool = TestTensorBool::<2>::from_data(bool_data, &Default::default()); let output = format!("{}", tensor_bool); // TODO: remove once backends no longer rely on generics for default elem types let expected_name = match ::BoolElem::dtype() { burn_tensor::DType::U8 => burn_tensor::DType::Bool(burn_tensor::BoolStore::U8).name(), burn_tensor::DType::U32 => burn_tensor::DType::Bool(burn_tensor::BoolStore::U32).name(), dtype => dtype.name(), }; let expected = format!( r#"Tensor {{ data: [[true, false, true], [false, true, false], [false, true, true]], shape: [3, 3], device: {:?}, backend: {:?}, kind: "Bool", dtype: {:?}, }}"#, tensor_bool.device(), TestBackend::name(&tensor_bool.device()), expected_name, ); assert_eq!(output, expected); } #[test] fn test_display_3d_tensor() { let data = TensorData::from([ [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], ]); let tensor = TestTensorInt::<3>::from_data(data, &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], shape: [2, 3, 4], device: {:?}, backend: {:?}, kind: "Int", dtype: "{dtype}", }}"#, tensor.device(), TestBackend::name(&tensor.device()), dtype = core::any::type_name::(), ); assert_eq!(output, expected); } #[test] fn test_display_4d_tensor() { let data = TensorData::from([ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], ]); let tensor = TestTensorInt::<4>::from_data(data, &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]], shape: [2, 2, 2, 3], device: {:?}, backend: {:?}, kind: "Int", dtype: "{dtype}", }}"#, tensor.device(), TestBackend::name(&tensor.device()), dtype = core::any::type_name::(), ); assert_eq!(output, expected); } #[test] fn test_display_tensor_summarize_1() { let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 2, 1000]), &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]], [[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]]], shape: [2, 2, 2, 1000], device: {:?}, backend: {:?}, kind: "Float", dtype: "{dtype}", }}"#, tensor.device(), TestBackend::name(&tensor.device()), dtype = FloatElem::dtype().name(), ); assert_eq!(output, expected); } #[test] fn test_display_tensor_summarize_2() { let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 20, 100]), &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]], [[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]]], shape: [2, 2, 20, 100], device: {:?}, backend: {:?}, kind: "Float", dtype: "{dtype}", }}"#, tensor.device(), TestBackend::name(&tensor.device()), dtype = FloatElem::dtype().name(), ); assert_eq!(output, expected); } #[test] fn test_display_tensor_summarize_3() { let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 200, 6]), &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]], [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]], shape: [2, 2, 200, 6], device: {:?}, backend: {:?}, kind: "Float", dtype: "{dtype}", }}"#, tensor.device(), TestBackend::name(&tensor.device()), dtype = FloatElem::dtype().name(), ); assert_eq!(output, expected); } #[test] fn test_display_precision() { if skip_precision_not_f32() { return; } let tensor = TestTensor::<2>::full([1, 1], 0.123456789, &Default::default()); let output = format!("{}", tensor); let expected = format!( r#"Tensor {{ data: [[0.12345679]], shape: [1, 1], device: {:?}, backend: {:?}, kind: "Float", dtype: "f32", }}"#, tensor.device(), TestBackend::name(&tensor.device()), ); assert_eq!(output, expected); // CAN'T DO THIS BECAUSE OF GLOBAL STATE // let print_options = PrintOptions { // precision: Some(3), // ..Default::default() // }; // set_print_options(print_options); let tensor = TestTensor::<2>::full([3, 2], 0.123456789, &Default::default()); // Set precision to 3 let output = format!("{:.3}", tensor); let expected = format!( r#"Tensor {{ data: [[0.123, 0.123], [0.123, 0.123], [0.123, 0.123]], shape: [3, 2], device: {:?}, backend: {:?}, kind: "Float", dtype: "f32", }}"#, tensor.device(), TestBackend::name(&tensor.device()), ); assert_eq!(output, expected); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/eye.rs ================================================ use super::*; #[test] fn test_eye_float() { let device = Default::default(); let tensor = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); let rhs = TestTensor::<2>::eye(3, &device); assert_eq!(tensor.to_data(), rhs.to_data()); } #[test] fn test_eye_int() { let device = Default::default(); let tensor = TestTensorInt::<2>::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]); let rhs = TestTensorInt::<2>::eye(3, &device); assert_eq!(tensor.to_data(), rhs.to_data()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/median.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_median_even() { let tensor = TestTensor::<2>::from_data( [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], &Default::default(), ); let median_actual_1 = tensor.clone().median(1); let median_expected_1 = TensorData::from([[0.2], [0.0]]).convert::(); median_actual_1 .into_data() .assert_eq(&median_expected_1, false); let median_actual_0 = tensor.median(0); let median_expected_0 = TensorData::from([[0.5, -4.0, 0.2, -2.0]]).convert::(); median_actual_0 .into_data() .assert_eq(&median_expected_0, false); } #[test] fn test_median_odd() { let tensor = TestTensor::<2>::from_data( [ [0.5, 1.8, 0.2, -2.0, 1.0], [3.0, -4.0, 5.0, 0.0, -1.0], [5.0, -5.0, 1.0, 3.0, -2.0], ], &Default::default(), ); let median_actual_1 = tensor.clone().median(1); let median_expected_1 = TensorData::from([[0.5], [0.0], [1.0]]).convert::(); median_actual_1 .into_data() .assert_eq(&median_expected_1, false); let median_actual_0 = tensor.median(0); let median_expected_0 = TensorData::from([[3.0, -4.0, 1.0, 0.0, -1.0]]).convert::(); median_actual_0 .into_data() .assert_eq(&median_expected_0, false); } #[test] fn test_median_with_indices() { let device = Default::default(); let tensor = TestTensor::<1>::from_data([3.0, 1.0, 2.0], &device); // median = 2, original index = 2 let (values, indices) = tensor.median_with_indices(0); values .into_data() .assert_eq(&TensorData::from([2.0]), false); indices .into_data() .assert_eq(&TensorData::from([2i64]), false); let tensor = TestTensor::<2>::from_data([[5.0, 1.0, 3.0], [2.0, 8.0, 4.0]], &device); // Along dim 1: // Row 0: median = 3, original index = 2 // Row 1: median = 4, original index = 2 let (values, indices) = tensor.median_with_indices(1); values .into_data() .assert_eq(&TensorData::from([[3.0], [4.0]]), false); indices .into_data() .assert_eq(&TensorData::from([[2i64], [2i64]]), false); } #[test] fn test_median_all_elements() { let tensor = TestTensor::<2>::from_data( [ [0.5, 1.8, 0.2, -2.0, 1.0], [3.0, -4.0, 5.0, 0.0, -1.0], [5.0, -5.0, 1.0, 3.0, -2.0], ], &Default::default(), ); // Sorted: [-5, -4, -2, -2, -1, 0, 0.2, 0.5, 1, 1, 1.8, 3, 3, 5, 5] let dims = tensor.dims().len(); let flattened_tensor: Tensor<_, 1> = tensor.flatten(0, dims - 1); let result = flattened_tensor.median(0); result .into_data() .assert_eq(&TensorData::from([0.5]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/mod.rs ================================================ pub use super::*; // re-export test types mod cov; mod display; mod eye; mod median; mod var; ================================================ FILE: crates/burn-backend-tests/tests/tensor/float/stats/var.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_var() { let tensor = TestTensor::<2>::from_data( [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], &Default::default(), ); let output = tensor.var(1); let expected = TensorData::from([[2.4892], [15.3333]]).convert::(); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_var_mean() { let tensor = TestTensor::<2>::from_data( [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], &Default::default(), ); let (var, mean) = tensor.var_mean(1); let var_expected = TensorData::from([[2.4892], [15.3333]]).convert::(); let mean_expected = TensorData::from([[0.125], [1.]]).convert::(); var.into_data() .assert_approx_eq::(&var_expected, Tolerance::default()); mean.into_data() .assert_approx_eq::(&mean_expected, Tolerance::default()); } #[test] fn test_var_bias() { let tensor = TestTensor::<2>::from_data( [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], &Default::default(), ); let output = tensor.var_bias(1); let expected = TensorData::from([[1.86688], [11.5]]).convert::(); output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_var_mean_bias() { let tensor = TestTensor::<2>::from_data( [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]], &Default::default(), ); let (var, mean) = tensor.var_mean_bias(1); let var_expected = TensorData::from([[1.86688], [11.5]]).convert::(); let mean_expected = TensorData::from([[0.125], [1.]]).convert::(); var.into_data() .assert_approx_eq::(&var_expected, Tolerance::default()); mean.into_data() .assert_approx_eq::(&mean_expected, Tolerance::default()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/mod.rs ================================================ pub use super::*; // re-export test types mod ops; mod primitive; ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/abs.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_abs_ops_int() { let tensor = TestTensorInt::<2>::from([[0, -1, 2], [3, 4, -5]]); let output = tensor.abs(); output .into_data() .assert_eq(&TensorData::from([[0, 1, 2], [3, 4, 5]]), false); } #[test] fn should_support_abs_ops_int_signed_min() { let tensor = TestTensorInt::<2>::from([[IntElem::MIN]]); let output = tensor.abs(); output .into_data() .assert_eq(&TensorData::from([[IntElem::MIN]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/add.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_add_d2_int() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 11]]); let output = tensor_1 + tensor_2; output .into_data() .assert_eq(&TensorData::from([[6, 8, 10], [12, 14, 16]]), false); } #[test] fn test_add_broadcast_int() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2]]); let tensor_2 = TestTensorInt::from([[3, 4, 5], [6, 7, 8]]); let output = tensor_1 + tensor_2; output .into_data() .assert_eq(&TensorData::from([[3, 5, 7], [6, 8, 10]]), false); } #[test] fn should_support_add_scalar_ops_int() { let scalar = 2; let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let output = tensor + scalar; output .into_data() .assert_eq(&TensorData::from([[2, 3, 4], [5, 6, 7]]), false); } #[test] fn scalar_add_not_contiguous() { let tensor = TestTensorInt::<1>::arange(0..32, &Default::default()).float(); let tensor = tensor.reshape([1, 4, 4, 2]).permute([0, 3, 1, 2]); let tensor = tensor.slice([0..1, 0..2, 0..4, 0..4]); let before = tensor.clone(); let after = tensor.add_scalar(0.0); before .into_data() .assert_approx_eq::(&after.into_data(), Default::default()); } #[test] fn scalar_add_not_contiguous_int() { let tensor = TestTensorInt::<1>::arange(0..32, &Default::default()); let tensor = tensor.reshape([1, 4, 4, 2]).permute([0, 3, 1, 2]); let tensor = tensor.slice([0..1, 0..2, 0..4, 0..4]); let before = tensor.clone(); let after = tensor.add_scalar(0); before.into_data().assert_eq(&after.into_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/aggregation.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_should_mean_int() { let tensor = TestTensorInt::<2>::from([[2, 2, 2], [3, 4, 5]]); let output = tensor.mean(); output.into_data().assert_eq(&TensorData::from([3]), false); } #[test] fn test_should_mean_last_dim_int() { let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let output = tensor.mean_dim(1); output .into_data() .assert_eq(&TensorData::from([[1], [4]]), false); } #[test] fn test_should_sum_last_dim_int() { let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let output = tensor.sum_dim(1); output .into_data() .assert_eq(&TensorData::from([[3], [12]]), false); } #[test] fn test_should_sum_int() { let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let output = tensor.sum(); output.into_data().assert_eq(&TensorData::from([15]), false); } #[test] #[ignore = "Not implemented for all backends yet"] fn test_prod_int() { let tensor = TestTensorInt::<2>::from([[2, 1, 2], [3, 4, 5]]); let output = tensor.prod(); output .into_data() .assert_eq(&TensorData::from([240]), false); let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]); let output = tensor_with_zero.prod(); output.into_data().assert_eq(&TensorData::from([0]), false); } #[test] #[ignore = "Not implemented for all backends yet"] fn test_prod_dim_int() { let tensor = TestTensorInt::<2>::from([[2, 1, 2], [3, 4, 5]]); let output = tensor.prod_dim(1); output .into_data() .assert_eq(&TensorData::from([[4], [60]]), false); let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]); let output = tensor_with_zero.prod_dim(1); output .into_data() .assert_eq(&TensorData::from([[0], [60]]), false); // Negative Indexing. let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]); let output = tensor_with_zero.prod_dim(-1); output .into_data() .assert_eq(&TensorData::from([[0], [60]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/all.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_all() { let tensor = TestTensorInt::<2>::from([[0, 1, 0], [1, -1, 1]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1]]); let data_actual = tensor.all().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_all_dim() { let tensor = TestTensorInt::<2>::from([[0, 1, 0], [1, -1, 1]]); let data_actual = tensor.all_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/any.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_any() { let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([true]); data_expected.assert_eq(&data_actual, false); let tensor = TestTensorInt::<2>::from([[0, 0, 0], [0, 0, 0]]); let data_actual = tensor.any().into_data(); let data_expected = TensorData::from([false]); data_expected.assert_eq(&data_actual, false); } #[test] fn test_any_dim() { let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]); let data_actual = tensor.any_dim(1).into_data(); let data_expected = TensorData::from([[false], [true]]); data_expected.assert_eq(&data_actual, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/arange.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::backend::Backend; #[test] fn test_arange() { let device = ::Device::default(); let tensor = TestTensorInt::<1>::arange(2..5, &device); tensor .into_data() .assert_eq(&TensorData::from([2, 3, 4]), false); // Test arange with negative numbers let tensor = TestTensorInt::<1>::arange(-10..-5, &device); tensor .into_data() .assert_eq(&TensorData::from([-10, -9, -8, -7, -6]), false); let tensor = TestTensorInt::<1>::arange(-3..0, &device); tensor .into_data() .assert_eq(&TensorData::from([-3, -2, -1]), false); // Test arange with a mix of positive and negative numbers let tensor = TestTensorInt::<1>::arange(-2..3, &device); tensor .clone() .into_data() .assert_eq(&TensorData::from([-2, -1, 0, 1, 2]), false); assert_eq!(tensor.device(), device); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/arange_step.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::backend::Backend; #[test] fn test_arange_step() { let device = ::Device::default(); // Test correct sequence of numbers when the range is 0..9 and the step is 1 let tensor = TestTensorInt::<1>::arange_step(0..9, 1, &device); tensor .into_data() .assert_eq(&TensorData::from([0, 1, 2, 3, 4, 5, 6, 7, 8]), false); // Test correct sequence of numbers when the range is 0..3 and the step is 2 let tensor = TestTensorInt::<1>::arange_step(0..3, 2, &device); tensor .into_data() .assert_eq(&TensorData::from([0, 2]), false); // Test correct sequence of numbers when the range is 0..2 and the step is 5 let tensor = TestTensorInt::<1>::arange_step(0..2, 5, &device); tensor.into_data().assert_eq(&TensorData::from([0]), false); // Test correct sequence of numbers when the range includes negative numbers let tensor = TestTensorInt::<1>::arange_step(-3..3, 2, &device); tensor .into_data() .assert_eq(&TensorData::from([-3, -1, 1]), false); let tensor = TestTensorInt::<1>::arange_step(-5..1, 5, &device); tensor .clone() .into_data() .assert_eq(&TensorData::from([-5, 0]), false); assert_eq!(tensor.device(), device); } #[test] #[should_panic] fn should_panic_when_step_is_zero() { let device = ::Device::default(); // Test that arange_step panics when the step is 0 let _tensor = TestTensorInt::<1>::arange_step(0..3, 0, &device); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/arg.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_argmax_2d_dim0_int() { let tensor = TestTensorInt::<2>::from([[10, 11, 2], [3, 4, 5]]); let output = tensor.argmax(0); output .into_data() .assert_eq(&TensorData::from([[0, 0, 1]]), false); } #[test] fn test_argmin_2d_dim0_int() { let tensor = TestTensorInt::<2>::from([[10, 11, 2], [30, 4, 5]]); let output = tensor.argmin(0); output .into_data() .assert_eq(&TensorData::from([[0, 1, 0]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/bitwise.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_apply_bitwise_and_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); let output = tensor_1.bitwise_and(tensor_2); output .into_data() .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false); } #[test] fn should_apply_bitwise_and_1d() { let tensor_1 = TestTensorInt::<1>::from([13, 7]); let tensor_2 = TestTensorInt::from([11, 3]); let output = tensor_1.bitwise_and(tensor_2); output .into_data() .assert_eq(&TensorData::from([9, 3]), false); } #[test] fn should_apply_bitwise_and_scalar_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let scalar = 5; let output = tensor_1.bitwise_and_scalar(scalar); output .into_data() .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false); } #[test] fn should_apply_bitwise_not_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let output = tensor_1.bitwise_not(); output .into_data() .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false); } #[test] fn should_apply_bitwise_or_scalar_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let scalar = 5; let output = tensor_1.bitwise_or_scalar(scalar); output .into_data() .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false); } #[test] fn should_apply_bitwise_or_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); let output = tensor_1.bitwise_or(tensor_2); output .into_data() .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false); } #[test] fn should_apply_bitwise_or_1d() { let tensor_1 = TestTensorInt::<1>::from([13, 7]); let tensor_2 = TestTensorInt::from([11, 3]); let output = tensor_1.bitwise_or(tensor_2); output .into_data() .assert_eq(&TensorData::from([15, 7]), false); } #[test] fn should_apply_bitwise_xor_scalar_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let scalar = 5; let output = tensor_1.bitwise_xor_scalar(scalar); output .into_data() .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false); } #[test] fn should_apply_bitwise_xor_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); let output = tensor_1.bitwise_xor(tensor_2); output .into_data() .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false); } #[test] fn should_apply_bitwise_xor_1d() { let tensor_1 = TestTensorInt::<1>::from([13, 7]); let tensor_2 = TestTensorInt::from([11, 3]); let output = tensor_1.bitwise_xor(tensor_2); output .into_data() .assert_eq(&TensorData::from([6, 4]), false); } #[test] fn should_apply_bitwise_left_shift_2d() { if (IntElem::MAX as u32) < 512 { return; } let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); let output = tensor_1.bitwise_left_shift(tensor_2); output .into_data() .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false); } #[test] fn should_apply_bitwise_left_shift_scalar_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let scalar = 2; let output = tensor_1.bitwise_left_shift_scalar(scalar); output .into_data() .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false); } #[test] fn should_apply_bitwise_right_shift_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); let output = tensor_1.bitwise_right_shift(tensor_2); output .into_data() .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false); } #[test] fn should_apply_bitwise_right_shift_scalar_2d() { let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let scalar = 2; let output = tensor_1.bitwise_right_shift_scalar(scalar); output .into_data() .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/cartesian_grid.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::backend::Backend; #[test] fn test_cartesian_grid() { let device = ::Device::default(); // Test a single element tensor let tensor: TestTensorInt<2> = TestTensorInt::<1>::cartesian_grid([1], &device); tensor .into_data() .assert_eq(&TensorData::from([[0]]), false); // Test for a 2x2 tensor let tensor: TestTensorInt<3> = TestTensorInt::<2>::cartesian_grid([2, 2], &device); tensor.into_data().assert_eq( &TensorData::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]), false, ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/cast.rs ================================================ use super::*; use burn_tensor::{DType, TensorData}; #[test] fn cast_int_to_bool() { let tensor1 = TestTensorInt::<2>::from([[0, 43, 0], [2, -4, 31]]); let data_actual = tensor1.bool().into_data(); let data_expected = TensorData::from([[false, true, false], [true, true, true]]); data_actual.assert_eq(&data_expected, false); } #[test] fn cast_bool_to_int_tensor() { let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).int(); let expected = TensorData::from([[1, 0, 1], [0, 0, 1]]); tensor.into_data().assert_eq(&expected, false); } #[test] fn cast_int_precision() { let data = TensorData::from([[1, 2, 3], [4, 5, 6]]); let tensor = TestTensorInt::<2>::from(data.clone()); let output = tensor.cast(DType::I32); assert_eq!(output.dtype(), DType::I32); output.into_data().assert_eq(&data, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/cat.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_cat_ops_int() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 0); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3], [4, 5, 6]]), false); } #[test] fn should_support_cat_with_empty_tensor_int() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); let tensor_2: TestTensorInt<2> = TestTensorInt::empty([1, 0], &device); let output = Tensor::cat(vec![tensor_1, tensor_2], 1); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/chunk.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_chunk_multi_dimension() { let tensors = TestTensorInt::<2>::from_data(TensorData::from([[0, 1, 2, 3]]), &Default::default()) .chunk(2, 1); assert_eq!(tensors.len(), 2); let expected = [TensorData::from([[0, 1]]), TensorData::from([[2, 3]])]; for (index, tensor) in tensors.iter().enumerate() { tensor.to_data().assert_eq(&expected[index], false); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/comparison.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_equal() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 5]]); let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone()); let data_actual_inplace = tensor_1.equal(tensor_2); let data_expected = TensorData::from([[false, true, false], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_not_equal() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 5]]); let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.not_equal(tensor_2); let data_expected = TensorData::from([[true, false, true], [true, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_equal_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 2, 5]]); let data_actual_cloned = tensor_1.clone().equal_elem(2); let data_actual_inplace = tensor_1.equal_elem(2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_not_equal_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 2, 5]]); let data_actual_cloned = tensor_1.clone().not_equal_elem(2); let data_actual_inplace = tensor_1.not_equal_elem(2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn greater_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let data_actual_cloned = tensor_1.clone().greater_elem(4); let data_actual_inplace = tensor_1.greater_elem(4); let data_expected = TensorData::from([[false, false, false], [false, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_equal_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let data_actual_cloned = tensor_1.clone().greater_equal_elem(4); let data_actual_inplace = tensor_1.greater_equal_elem(4); let data_expected = TensorData::from([[false, false, false], [false, true, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]); let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone()); let data_actual_inplace = tensor_1.greater(tensor_2); let data_expected = TensorData::from([[false, false, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_equal() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]); let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.greater_equal(tensor_2); let data_expected = TensorData::from([[false, true, true], [false, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let data_actual_cloned = tensor_1.clone().lower_elem(4); let data_actual_inplace = tensor_1.lower_elem(4); let data_expected = TensorData::from([[true, true, true], [true, false, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_equal_elem() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let data_actual_cloned = tensor_1.clone().lower_equal_elem(4); let data_actual_inplace = tensor_1.lower_equal_elem(4); let data_expected = TensorData::from([[true, true, true], [true, true, false]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]); let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone()); let data_actual_inplace = tensor_1.lower(tensor_2); let data_expected = TensorData::from([[true, false, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_lower_equal() { let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]); let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone()); let data_actual_inplace = tensor_1.lower_equal(tensor_2); let data_expected = TensorData::from([[true, true, false], [true, false, true]]); data_expected.assert_eq(&data_actual_cloned.into_data(), false); data_expected.assert_eq(&data_actual_inplace.into_data(), false); } #[test] fn test_greater_broadcast() { // Test broadcasting with shape [1, 4] vs [4, 4] let device = Default::default(); let data_1 = TensorData::from([[1, 2, 3, 4]]); let data_2 = TensorData::from([ [0.5, 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5], ]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.greater(tensor_2); let expected = TensorData::from([ [true, true, true, true], [false, false, false, false], [false, false, false, false], [false, false, false, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_greater_equal_broadcast() { // Test broadcasting with shape [4, 1] vs [1, 4] let device = Default::default(); let data_1 = TensorData::from([[1], [2], [3], [4]]); let data_2 = TensorData::from([[1, 2, 3, 4]]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.greater_equal(tensor_2); let expected = TensorData::from([ [true, false, false, false], [true, true, false, false], [true, true, true, false], [true, true, true, true], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_equal_broadcast() { // Test broadcasting with different ranks let device = Default::default(); let data_1 = TensorData::from([[2], [3], [4]]); let data_2 = TensorData::from([[2, 3, 4, 2]]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.equal(tensor_2); let expected = TensorData::from([ [true, false, false, true], [false, true, false, false], [false, false, true, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_not_equal_broadcast() { // Test broadcasting with shape [3, 1] vs [1, 3] let device = Default::default(); let data_1 = TensorData::from([[1], [2], [3]]); let data_2 = TensorData::from([[1, 2, 3]]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.not_equal(tensor_2); let expected = TensorData::from([ [false, true, true], [true, false, true], [true, true, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_int_greater_broadcast() { let device = Default::default(); let data_1 = TensorData::from([[1i32, 2, 3]]); let data_2 = TensorData::from([[0i32], [2], [4]]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.greater(tensor_2); let expected = TensorData::from([ [true, true, true], [false, false, true], [false, false, false], ]); expected.assert_eq(&result.into_data(), false); } #[test] fn test_int_lower_equal_broadcast() { let device = Default::default(); let data_1 = TensorData::from([[2i32], [4]]); let data_2 = TensorData::from([[1i32, 2, 3]]); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let result = tensor_1.lower_equal(tensor_2); let expected = TensorData::from([[false, true, true], [false, false, false]]); expected.assert_eq(&result.into_data(), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/create_like.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_zeros_like() { let tensor = TestTensorInt::<3>::from([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); let tensor = tensor.zeros_like(); let expected = TensorData::from([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]); tensor.into_data().assert_eq(&expected, false); } #[test] fn should_support_ones_like() { let tensor = TestTensorInt::<3>::from([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]); let tensor = tensor.ones_like(); let expected = TensorData::from([[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1]]]); tensor.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/cumulative.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_cumsum_int_dim_0() { let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]); let output = tensor.cumsum(0); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3], [5, 7, 9]]), false); } #[test] fn test_cumsum_int_dim_1() { let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]); let output = tensor.cumsum(1); output .into_data() .assert_eq(&TensorData::from([[1, 3, 6], [4, 9, 15]]), false); } #[test] fn test_cumprod_int_dim_0() { let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]); let output = tensor.cumprod(0); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3], [4, 10, 18]]), false); } #[test] fn test_cumprod_int_dim_1() { let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]); let output = tensor.cumprod(1); output .into_data() .assert_eq(&TensorData::from([[1, 2, 6], [4, 20, 120]]), false); } #[test] fn test_cummin_int_dim_0() { let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]); let output = tensor.cummin(0); output .into_data() .assert_eq(&TensorData::from([[3, 1, 4], [2, 1, 1]]), false); } #[test] fn test_cummin_int_dim_1() { let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]); let output = tensor.cummin(1); output .into_data() .assert_eq(&TensorData::from([[3, 1, 1], [2, 2, 1]]), false); } #[test] fn test_cummax_int_dim_0() { let tensor = TestTensorInt::<2>::from([[3, 1, 4], [1, 5, 2]]); let output = tensor.cummax(0); output .into_data() .assert_eq(&TensorData::from([[3, 1, 4], [3, 5, 4]]), false); } #[test] fn test_cummax_int_dim_1() { let tensor = TestTensorInt::<2>::from([[3, 1, 4], [1, 5, 2]]); let output = tensor.cummax(1); output .into_data() .assert_eq(&TensorData::from([[3, 3, 4], [1, 5, 5]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/div.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_div_ops_int() { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[1, 1, 2], [1, 1, 2]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; output .into_data() .assert_eq(&TensorData::from([[0, 1, 1], [3, 4, 2]]), false); } #[test] fn test_div_broadcast_int() { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[1, 1, 2], [3, 4, 5]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 / tensor_2; output .into_data() .assert_eq(&TensorData::from([[0, 1, 1], [0, 0, 0]]), false); } #[test] fn should_support_div_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor / scalar; output .into_data() .assert_eq(&TensorData::from([[0, 0, 1], [1, 2, 2]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/expand.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn expand_2d_int() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let output = tensor.expand([3, 3]); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3], [1, 2, 3], [1, 2, 3]]), false); } #[test] fn should_all_negative_one() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let output = tensor.expand([2, -1]); output .into_data() .assert_eq(&TensorData::from([[1, 2, 3], [1, 2, 3]]), false); } #[test] #[should_panic] fn should_panic_negative_one_on_non_existing_dim() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let _expanded_tensor = tensor.expand([-1, 3]); } /// Regression test for https://github.com/tracel-ai/burn/issues/2091 #[test] fn inplace_op_after_expand() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let mut output = tensor.expand([2, 3]); output = output + 1; output .into_data() .assert_eq(&TensorData::from([[2, 3, 4], [2, 3, 4]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/flip.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn flip_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let flipped = tensor.clone().flip([0, 2]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)) let expected = TensorData::from([ [[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]], [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]], ]); flipped.into_data().assert_eq(&expected, false); // Test with no flip let flipped = tensor.clone().flip([]); assert_eq!(tensor.into_data(), flipped.into_data()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/full.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_tensor_full() { let device = Default::default(); let int_tensor = TestTensorInt::<2>::full([2, 2], 2, &device); int_tensor .into_data() .assert_eq(&TensorData::from([[2, 2], [2, 2]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/gather_scatter.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_gather_1d_dim0_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_ints([5, 6, 7], &device); let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device); let output = tensor.gather(0, indices); output .into_data() .assert_eq(&TensorData::from([6, 6, 5, 6, 7]), false); } #[test] fn should_gather_indices_broadcasted() { let device = Default::default(); let batch_size = 3; let fft_size = 4; let shape = [batch_size, fft_size, 2]; let x = TestTensorInt::arange( 0..shape.iter().product::() as i64, &Default::default(), ) .reshape(shape); let idx = TestTensorInt::<1>::from_ints([0, 2, 1, 3], &device); let expected = TestTensorInt::<3>::from([ [[0, 1], [4, 5], [2, 3], [6, 7]], [[8, 9], [12, 13], [10, 11], [14, 15]], [[16, 17], [20, 21], [18, 19], [22, 23]], ]) .into_data(); // Case 1: gather dim 2 let perm = idx .clone() .reshape([1, 1, fft_size]) .repeat_dim(0, batch_size) .repeat_dim(1, 2); let input = x.clone().permute([0, 2, 1]); let out = input.gather(2, perm).permute([0, 2, 1]); out.into_data().assert_eq(&expected, true); // Case 2: gather directly on dim 1 let perm = idx.reshape([1, fft_size, 1]).repeat_dim(0, batch_size); let out2 = x.gather(1, perm.repeat_dim(2, 2)); out2.into_data().assert_eq(&expected, true); } #[test] fn should_scatter_add_1d_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_ints([0, 0, 0], &device); let values = TestTensorInt::from_ints([5, 4, 3], &device); let indices = TestTensorInt::from_ints([1, 0, 2], &device); let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); output .into_data() .assert_eq(&TensorData::from([4, 5, 3]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/init.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_int_empty() { let shape = [2, 2]; let tensor = TestTensorInt::<2>::empty(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()) } #[test] fn should_support_int_zeros() { let shape = [2, 2]; let tensor = TestTensorInt::<2>::zeros(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[0, 0], [0, 0]]), false); } #[test] fn should_support_int_ones() { let shape = [2, 2]; let tensor = TestTensorInt::<2>::ones(shape, &Default::default()); assert_eq!(tensor.shape(), shape.into()); tensor .into_data() .assert_eq(&TensorData::from([[1, 1], [1, 1]]), false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/mask.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_mask_where_broadcast_int() { let device = Default::default(); // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]); let mask = TestTensorBool::<3>::from_bool( TensorData::from([ [[true, false], [false, true]], [[false, true], [true, false]], [[false, false], [false, false]], [[true, true], [true, true]], ]), &device, ); let value = TestTensorInt::<3>::ones([4, 2, 2], &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([ [[1, 3], [4, 1]], [[2, 1], [1, 5]], [[2, 3], [4, 5]], [[1, 1], [1, 1]], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_int_mask_where_ops() { let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let value = TestTensorInt::<2>::from_data(TensorData::from([[8, 9], [10, 11]]), &device); let output = tensor.mask_where(mask, value); let expected = TensorData::from([[8, 7], [2, 11]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_int_mask_fill_ops() { let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device); let mask = TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device); let output = tensor.mask_fill(mask, 9); let expected = TensorData::from([[9, 7], [2, 9]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/matmul.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_int_matmul_d2() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[1, 7], [2, 3], [1, 5]], &device); let tensor_2 = TestTensorInt::<2>::from_ints([[4, 7, 5], [2, 3, 5]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[18, 28, 40], [14, 23, 25], [14, 22, 30]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_d3() { let device = Default::default(); let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device); let tensor_2 = TestTensorInt::<3>::from_ints([[[4, 7], [2, 3]]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[18, 28], [14, 23]]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_broadcast_1() { let device = Default::default(); let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device); let tensor_2 = TestTensorInt::from_ints([[[4, 7], [2, 3]], [[2, 5], [6, 3]]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[[18, 28], [14, 23]], [[44, 26], [22, 19]]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_broadcast_4d() { let device = Default::default(); // [2, 1, 2, 2] let tensor_1 = TestTensorInt::<4>::from_ints([[[[1, 7], [2, 3]]], [[[2, 5], [6, 3]]]], &device); // [1, 2, 2, 2] let tensor_2 = TestTensorInt::from_ints([[[[9, 8], [1, 4]], [[2, 7], [3, 5]]]], &device); // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [[[16, 36], [21, 28]], [[23, 42], [13, 29]]], [[[23, 36], [57, 60]], [[19, 39], [21, 57]]], ]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_simple_1() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[5, 14], [14, 25]], &device); let tensor_2 = TestTensorInt::from_ints([[3, 4, 5], [0, 1, 2]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[15, 34, 53], [42, 81, 120]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_4_3() { if (IntElem::MAX as u32) < 324 { return; } let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], &device); let tensor_2 = TestTensorInt::from_ints([[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[56, 62, 68], [152, 174, 196], [248, 286, 324]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_trivial() { if (IntElem::MAX as u32) < 506 { return; } let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]); let tensor_3 = tensor_1.clone().matmul(tensor_1); tensor_3.into_data().assert_eq( &TensorData::from([ [56, 62, 68, 74], [152, 174, 196, 218], [248, 286, 324, 362], [344, 398, 452, 506], ]), false, ); } #[test] fn test_int_matmul_trivial_transposed() { if (IntElem::MAX as u32) < 734 { return; } let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); tensor_3.into_data().assert_eq( &TensorData::from([ [14, 38, 62, 86], [38, 126, 214, 302], [62, 214, 366, 518], [86, 302, 518, 734], ]), false, ); } #[test] fn test_int_matmul_4_8() { if (IntElem::MAX as u32) < 6092 { return; } let device = Default::default(); let tensor_1 = TestTensorInt::<1>::arange(0..32, &device).reshape([4, 8]); let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose()); tensor_3.into_data().assert_eq( &TensorData::from([ [140, 364, 588, 812], [364, 1100, 1836, 2572], [588, 1836, 3084, 4332], [812, 2572, 4332, 6092], ]), false, ); } #[test] fn test_int_matmul_simple_2() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[1, 2, 3, 4]], &device); let tensor_2 = TestTensorInt::from_ints([[3], [4], [5], [6]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([[50]]); tensor_3.into_data().assert_eq(&expected, false); } #[test] fn test_int_matmul_simple_3() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]], &device); let tensor_2 = TestTensorInt::from_ints([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [9, 18, 27, 36], [12, 24, 36, 48], [15, 30, 45, 60], [18, 36, 54, 72], ]); tensor_3.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_should_panic_when_inner_dimensions_are_not_equal() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_ints([[3, 3], [4, 4], [5, 5], [6, 6]], &device); let tensor_2 = TestTensorInt::from_ints([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], &device); let tensor_3 = tensor_1.matmul(tensor_2); let expected = TensorData::from([ [9, 18, 27, 36], [12, 24, 36, 48], [15, 30, 45, 60], [18, 36, 54, 72], ]); tensor_3.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/mod.rs ================================================ pub use super::*; // re-export test types mod abs; mod add; mod aggregation; mod all; mod any; mod arange; mod arange_step; mod bitwise; mod cartesian_grid; mod cast; mod cat; mod chunk; mod comparison; mod create_like; mod cumulative; mod div; mod expand; mod flip; mod full; mod gather_scatter; mod init; mod mask; mod matmul; mod movedim; mod mul; mod one_hot; mod permute; mod random; mod remainder; mod repeat; mod repeat_dim; mod reshape; mod roll; mod select; mod sign; mod slice; mod slice_assign; mod sort_argsort; mod stack; mod sub; mod take; mod topk; mod transpose; mod tri; mod unfold; ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/movedim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn movedim_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().movedim(0, 2); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2) let expected = TensorData::from([ [[0, 12], [1, 13], [2, 14], [3, 15]], [[4, 16], [5, 17], [6, 18], [7, 19]], [[8, 20], [9, 21], [10, 22], [11, 23]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().movedim(0, -1); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().movedim(0, 0); permuted.into_data().assert_eq(&tensor.into_data(), true); } #[test] fn vec_input_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]); // from pytorch // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]) let expected = TensorData::from([ [[0, 1, 2, 3], [12, 13, 14, 15]], [[4, 5, 6, 7], [16, 17, 18, 19]], [[8, 9, 10, 11], [20, 21, 22, 23]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axes let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]); permuted.into_data().assert_eq(&expected, false); // Test with the same axes let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]); permuted.into_data().assert_eq(&tensor.into_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/mul.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_mul_ops_int() { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0, 1, 4], [9, 16, 25]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_mul_broadcast_int() { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 * tensor_2; let expected = TensorData::from([[0, 4, 10], [0, 7, 16]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_mul_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor * scalar; let expected = TensorData::from([[0, 2, 4], [6, 8, 10]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/one_hot.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn int_should_support_one_hot() { let tensor = TestTensorInt::<1>::from([0, 1, 4]); let one_hot_tensor: TestTensorInt<2> = tensor.one_hot(5); let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { let tensor = TestTensorInt::<1>::from([5]); let _result: TestTensorInt<2> = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { let tensor = TestTensorInt::<1>::from([2]); let _result: TestTensorInt<2> = tensor.one_hot(0); } #[test] fn one_hot_fill_with_positive_axis_and_indices() { let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); let expected = TensorData::from([ [ [1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3], ], [ [1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], ], ]); let one_hot_tensor: TestTensorInt<3> = tensor.one_hot_fill(10, 3.0, 1.0, 1); one_hot_tensor.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/permute.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn permute_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); let permuted = tensor.clone().permute([2, 1, 0]); // from pytorch: // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0) let expected = TensorData::from([ [[0, 12], [4, 16], [8, 20]], [[1, 13], [5, 17], [9, 21]], [[2, 14], [6, 18], [10, 22]], [[3, 15], [7, 19], [11, 23]], ]); permuted.into_data().assert_eq(&expected, false); // Test with negative axis let permuted = tensor.clone().permute([-1, 1, 0]); permuted.into_data().assert_eq(&expected, false); // Test with the same axis let permuted = tensor.clone().permute([0, 1, 2]); permuted.into_data().assert_eq(&tensor.into_data(), true); } #[test] #[should_panic] fn edge_repeated_axes() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().permute([0, 0, 1]); } #[test] #[should_panic] fn edge_out_of_bound_axis() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]); // Test with a repeated axis let _ = tensor.clone().permute([3, 0, 1]); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/random.rs ================================================ use super::*; use burn_tensor::{Distribution, ElementConversion}; #[test] fn rand_uniform_int() { let low = 0.; let high = 5.; let tensor = TestTensorInt::<1>::random( [100_000], Distribution::Uniform(low, high), &Default::default(), ); tensor .into_data() .assert_within_range::(low.elem()..high.elem()); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/remainder.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_int_remainder_basic() { let data = TensorData::from([-3, -2, -1, 1, 2, 3]); let device = Default::default(); let lhs = TestTensorInt::<1>::from_data(data, &device); let rhs = TestTensorInt::from_data(TensorData::from([2, 3, 1, 2, 1, 3]), &device); let output = lhs.remainder(rhs); let expected = TensorData::from([1, 1, -0, 1, 0, 0]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_int_remainder_basic_scalar() { let data = TensorData::from([-3, -2, -1, 1, 2, 3]); let device = Default::default(); let tensor = TestTensorInt::<1>::from_data(data, &device); let output = tensor.remainder_scalar(2); let expected = TensorData::from([1, 0, 1, 1, 0, 1]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/repeat.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_int_repeat_ops_one_dimension() { let data = TensorData::from([[0i32, 1i32, 2i32]]); let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor.repeat(&[4, 1, 1]); let expected = TensorData::from([ [0i32, 1i32, 2i32], [0i32, 1i32, 2i32], [0i32, 1i32, 2i32], [0i32, 1i32, 2i32], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_int_repeat_on_many_dims() { let data = TensorData::from([ [[1i32, 2i32], [3i32, 4i32]], [[5i32, 6i32], [7i32, 8i32]], [[9i32, 10i32], [11i32, 12i32]], [[13i32, 14i32], [15i32, 16i32]], ]); let tensor = TestTensorInt::<3>::from_data(data, &Default::default()); let output = tensor.repeat(&[2, 3, 2]); let expected = TensorData::from([ [ [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], ], [ [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], ], [ [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], ], [ [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], ], [ [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], [1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32], ], [ [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], [5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32], ], [ [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], [9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32], ], [ [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], [13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32], ], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/repeat_dim.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_int_repeat_ops() { let data = TensorData::from([[0, 1, 2]]); let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor.repeat_dim(0, 4); let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_int_repeat_on_dims_larger_than_1() { let data = TensorData::from([ [[1i32, 2i32], [3i32, 4i32]], [[5i32, 6i32], [7i32, 8i32]], [[9i32, 10i32], [11i32, 12i32]], [[13i32, 14i32], [15i32, 16i32]], ]); let tensor = TestTensorInt::<3>::from_data(data, &Default::default()); let output = tensor.repeat_dim(2, 3); let expected = TensorData::from([ [ [1i32, 2i32, 1i32, 2i32, 1i32, 2i32], [3i32, 4i32, 3i32, 4i32, 3i32, 4i32], ], [ [5i32, 6i32, 5i32, 6i32, 5i32, 6i32], [7i32, 8i32, 7i32, 8i32, 7i32, 8i32], ], [ [9i32, 10i32, 9i32, 10i32, 9i32, 10i32], [11i32, 12i32, 11i32, 12i32, 11i32, 12i32], ], [ [13i32, 14i32, 13i32, 14i32, 13i32, 14i32], [15i32, 16i32, 15i32, 16i32, 15i32, 16i32], ], ]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/reshape.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_reshape_maybe_fused_1() { let tensor = TestTensorInt::arange(0..32, &Default::default()); let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); let tensor1 = tensor.clone().reshape([1, 4, 8]); let output = tensor0 + tensor1; let expected = TensorData::from([ [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], [ [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], ], ]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_reshape_maybe_fused_2() { let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); let tensor1 = tensor.reshape([2, 2, 1]); let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); let output = tensor2 + tensor1; let expected_tensor1 = TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); output.into_data().assert_eq(&expected_tensor1, false); } #[test] fn should_support_reshape_maybe_fused_3() { let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); let tensor1 = tensor.reshape([2, 2, 1]); let _tensor2 = TestTensorInt::<3>::full([2, 2, 3], 5, &Default::default()); let expected_tensor1 = TensorData::from([[[0], [2]], [[1], [2]]]); tensor1.into_data().assert_eq(&expected_tensor1, false); } #[test] fn should_support_reshape_maybe_fused_4() { let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); let tensor2 = tensor2.swap_dims(0, 1); let tensor1 = tensor.reshape([2, 2, 1]); let output = tensor2 + tensor1; let expected_tensor1 = TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); output.into_data().assert_eq(&expected_tensor1, false); } #[test] fn should_support_reshape_maybe_fused_5() { let tensor = TestTensorInt::<3>::from_data([[[0], [1], [2], [3]]], &Default::default()); let tensor1 = tensor.clone().reshape([2, 1, 2]); let tensor2 = TestTensorInt::<3>::full([2, 4, 2], 0, &Default::default()); let output = tensor2.clone() + tensor1 + tensor.clone(); let expected_tensor1 = TensorData::from([ [[0, 1], [1, 2], [2, 3], [3, 4]], [[2, 3], [3, 4], [4, 5], [5, 6]], ]); output.into_data().assert_eq(&expected_tensor1, false); } #[test] fn should_support_reshape_maybe_fused_6() { let device = Default::default(); let tensor1 = TestTensorInt::arange(0..32, &device); let tensor1 = tensor1.reshape([2, 4, 4]); let tensor2 = TestTensorInt::arange(0..16, &device); let tensor2 = tensor2.reshape([1, 4, 4]); let tensor3 = TestTensorInt::arange(0..8, &device); let tensor3 = tensor3.reshape([4, 1, 2]); let tensor3 = tensor3.swap_dims(0, 2); let out = tensor1 + tensor2 + tensor3; let expected = TensorData::from([ [ [0, 4, 8, 12], [8, 12, 16, 20], [16, 20, 24, 28], [24, 28, 32, 36], ], [ [17, 21, 25, 29], [25, 29, 33, 37], [33, 37, 41, 45], [41, 45, 49, 53], ], ]); out.to_data().assert_eq(&expected, false); } // Skip on metal - cubecl autotune error // Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4327 #[cfg(not(feature = "metal"))] #[test] fn should_support_multiple_reshapes_cloned_tensor() { let device = Default::default(); let lhs = TestTensorInt::<1>::arange(0..4, &device).reshape([2, 2]); // fusion should preserve correct strides when operating on the same tensor let rhs = lhs.clone(); let lhs = lhs.reshape([2, 2, 1]); let rhs = rhs.reshape([1, 2, 2]); let p = lhs.mul(rhs); let s = p.sum_dim(1); let out = s.reshape([2, 2]); out.into_data() .assert_eq(&TensorData::from([[2, 3], [6, 11]]), false); } #[test] fn should_support_reshape_int() { let data = TensorData::from([0, 1, 2]); let tensor = TestTensorInt::<1>::from_data(data, &Default::default()); let output = tensor.clone().reshape([1, 3]); let expected = TensorData::from([[0, 1, 2]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/roll.rs ================================================ use super::*; use burn_tensor::TensorData; #[ignore = "0 size resources are not yet supported"] #[test] fn test_roll_empty() { let device = Default::default(); let input = TestTensorInt::<2>::zeros([12, 0], &device); let result = input.clone().roll(&[1, 2], &[0, 1]); assert_eq!(&*result.shape(), &[12, 0]); // TODO: Rolling an empty tensor should return the same empty tensor; // but we have no way to compare tensor references yet. } #[test] fn test_roll() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // No-op shift: input .clone() .roll(&[0, 0], &[0, 1]) .to_data() .assert_eq(&input.clone().to_data(), false); input .clone() .roll(&[1, -1], &[0, 1]) .to_data() .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false); input .clone() .roll(&[-1, 1], &[1, 0]) .to_data() .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false); input .clone() .roll(&[2 * 32 + 1, 3 * (-400) - 1], &[0, 1]) .to_data() .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false); } #[should_panic] #[test] fn test_roll_dim_too_big() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // Attempting to roll on a dimension that doesn't exist should panic let _d = input.roll(&[1], &[2]); } #[should_panic] #[test] fn test_roll_dim_too_small() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // Attempting to roll on a dimension that doesn't exist should panic let _d = input.roll(&[1], &[-3]); } #[should_panic] #[test] fn test_roll_shift_size_mismatch() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // Attempting to roll with a shift size that doesn't match the number of dimensions should panic let _d = input.roll(&[1, 2], &[0]); } #[test] fn test_roll_dim() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); input .clone() .roll_dim(1, 0) .to_data() .assert_eq(&TensorData::from([[3, 4, 5], [0, 1, 2]]), false); input .clone() .roll_dim(-1, 1) .to_data() .assert_eq(&TensorData::from([[2, 0, 1], [5, 3, 4]]), false); } #[should_panic] #[test] fn test_roll_dim_dim_too_big() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // Attempting to roll on a dimension that doesn't exist should panic let _d = input.roll_dim(1, 2); } #[should_panic] #[test] fn test_roll_dim_dim_too_small() { let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]); // Attempting to roll on a dimension that doesn't exist should panic let _d = input.roll_dim(1, -3); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/select.rs ================================================ use super::*; use burn_tensor::{IndexingUpdateOp, TensorData}; #[test] fn should_select_1d_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([5, 6, 7], &device); let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device); let output = tensor.select(0, indices); let expected = TensorData::from([6, 6, 5, 6, 7]); output.into_data().assert_eq(&expected, false); } #[test] fn should_select_add_1d_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([7, 8, 9], &device); let values = TestTensorInt::from_data([5, 4, 3, 2, 1], &device); let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device); let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); let expected = TensorData::from([10, 19, 10]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn should_panic_select_add_invalid_num_indices() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([0; 12], &device); let values = TestTensorInt::from_data([1; 12], &device); let indices = TestTensorInt::from_data(TensorData::from([1]), &device); tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/sign.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_sign_ops_int() { let tensor = TestTensorInt::<2>::from([[-2, -1, 2], [3, 0, -5]]); let output = tensor.sign(); let expected = TensorData::from([[-1, -1, 1], [1, 0, -1]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/slice.rs ================================================ use super::*; use burn_tensor::{TensorData, s}; #[test] fn slice_should_not_corrupt_potentially_inplace_operations() { let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]); let expected = TensorData::from([4, 6, 8]); tensor.into_data().assert_eq(&expected, false); } #[test] fn test_slice_int_tensor_with_steps() { let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], &device); // Test step=2 along first dimension let sliced = tensor.clone().slice([s![0..3;2]]); let expected = TensorData::from([[1i32, 2, 3, 4], [9, 10, 11, 12]]); sliced.into_data().assert_eq(&expected, false); // Test step=-1 along second dimension let sliced = tensor.clone().slice(s![.., 0..4;-1]); let expected = TensorData::from([[4i32, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]]); sliced.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/slice_assign.rs ================================================ use super::*; use burn_tensor::{TensorData, s}; #[test] fn slice_assign_should_not_corrupt_potentially_inplace_operations() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device); let values = TestTensorInt::<1>::from_data([10, 20, 30], &device); let tensor_1 = tensor.clone().slice_assign([0..3], values); let tensor_2 = tensor + 2; let expected = TensorData::from([10, 20, 30, 4, 5]); tensor_1.into_data().assert_eq(&expected, false); let expected = TensorData::from([3, 4, 5, 6, 7]); tensor_2.into_data().assert_eq(&expected, false); } #[test] fn test_slice_assign_int_tensor_with_steps() { let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], &device); // Test step=2 along first dimension let values = TestTensorInt::<2>::from_data([[100, 101, 102, 103], [200, 201, 202, 203]], &device); let output = tensor.clone().slice_assign([s![0..3;2]], values); let expected = TensorData::from([[100i32, 101, 102, 103], [5, 6, 7, 8], [200, 201, 202, 203]]); output.into_data().assert_eq(&expected, false); // Test step=-1 along second dimension let values = TestTensorInt::<2>::from_data( [[40, 30, 20, 10], [80, 70, 60, 50], [120, 110, 100, 90]], &device, ); let output = tensor.slice_assign(s![.., 0..4;-1], values); let expected = TensorData::from([[10i32, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_slice_assign_empty_range_int() { let device = Default::default(); let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device); let values: TestTensorInt<1> = TestTensorInt::empty([0], &device); // Empty slice assignment for int tensor let output = tensor.clone().slice_assign([3..3], values); let expected = TensorData::from([1i32, 2, 3, 4, 5]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/sort_argsort.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_sort_1d_int() { // Skip with u8 if (IntElem::MAX as u32) < 1000u32 { return; } let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, 2, 8, -10, 42, 1000]); // Sort along dim=0 let values = tensor.sort(0); let values_expected = TensorData::from([-10, 0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 42, 1000]); values.into_data().assert_eq(&values_expected, false); } #[test] fn test_argsort_1d_int() { // Skip with u8 if (IntElem::MAX as u32) < 1000u32 { return; } let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]); // Sort along dim=0 let indices = tensor.argsort(0); let indices_expected = TensorData::from([10, 7, 0, 3, 6, 1, 4, 5, 2, 9, 8, 11, 12]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_with_indices_descending_int() { // Skip with u8 if (IntElem::MAX as u32) >= 1000u32 { // 1D let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]); // Sort along dim=0 let (values, indices) = tensor.sort_descending_with_indices(0); let values_expected = TensorData::from([1000, 42, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -10]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([12, 11, 8, 9, 2, 5, 4, 1, 6, 3, 0, 7, 10]); indices.into_data().assert_eq(&indices_expected, false); } // 2D let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); // Sort along dim=1 let (values, indices) = tensor.sort_descending_with_indices(1); let values_expected = TensorData::from([[[2, 5, 7], [1, 4, 6]], [[8, 2, 9], [3, 0, 8]]]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([[[1, 1, 0], [0, 0, 1]], [[1, 1, 0], [0, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_int() { let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); // Sort along dim=0 let values = tensor.clone().sort(0); let values_expected = TensorData::from([[[1, 0, 7], [2, 2, 6]], [[3, 4, 9], [8, 5, 8]]]); values.into_data().assert_eq(&values_expected, false); // Sort along dim=1 let values = tensor.clone().sort(1); let values_expected = TensorData::from([[[1, 4, 6], [2, 5, 7]], [[3, 0, 8], [8, 2, 9]]]); values.into_data().assert_eq(&values_expected, false); // Sort along dim=2 let values = tensor.sort(2); let values_expected = TensorData::from([[[1, 4, 7], [2, 5, 6]], [[0, 3, 9], [2, 8, 8]]]); values.into_data().assert_eq(&values_expected, false); } #[test] fn test_sort_with_indices_int() { let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [7, 2, 8]]]); // Sort along dim=0 let (values, indices) = tensor.clone().sort_with_indices(0); let values_expected = TensorData::from([[[1, 0, 7], [2, 2, 6]], [[3, 4, 9], [7, 5, 8]]]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([[[0, 1, 0], [0, 1, 0]], [[1, 0, 1], [1, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let (values, indices) = tensor.clone().sort_with_indices(1); let values_expected = TensorData::from([[[1, 4, 6], [2, 5, 7]], [[3, 0, 8], [7, 2, 9]]]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([[[0, 0, 1], [1, 1, 0]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let (values, indices) = tensor.sort_with_indices(2); let values_expected = TensorData::from([[[1, 4, 7], [2, 5, 6]], [[0, 3, 9], [2, 7, 8]]]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([[[0, 1, 2], [0, 1, 2]], [[1, 0, 2], [1, 0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_argsort_int() { let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [7, 2, 8]]]); // Sort along dim=0 let indices = tensor.clone().argsort(0); let indices_expected = TensorData::from([[[0, 1, 0], [0, 1, 0]], [[1, 0, 1], [1, 0, 1]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=1 let indices = tensor.clone().argsort(1); let indices_expected = TensorData::from([[[0, 0, 1], [1, 1, 0]], [[0, 0, 1], [1, 1, 0]]]); indices.into_data().assert_eq(&indices_expected, false); // Sort along dim=2 let indices = tensor.argsort(2); let indices_expected = TensorData::from([[[0, 1, 2], [0, 1, 2]], [[1, 0, 2], [1, 0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); } #[test] fn test_sort_descending_1d() { let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); // Sort along dim=0 let values = tensor.sort_descending(0); let values_expected = TensorData::from([5, 4, 3, 2, 1]); values.into_data().assert_eq(&values_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/stack.rs ================================================ use super::*; use alloc::vec; use burn_tensor::{Tensor, TensorData}; #[test] fn should_support_stack_ops_int() { let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device); let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0); let expected = TensorData::from([[[1, 2, 3]], [[4, 5, 6]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_generate_row_major_layout() { let device = Default::default(); let tensor = TestTensorInt::<1>::arange(1..25, &device).reshape([4, 6]); let zeros = TestTensorInt::zeros([4, 6], &device); let intersperse = Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]); let expected = TensorData::from([ [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0], [7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0], [13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0], [19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0], ]); intersperse.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/sub.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_sub_ops_int() { let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]); let data_2 = TensorData::from([[6, 7, 8], [9, 10, 11]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-6, -6, -6], [-6, -6, -6]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_sub_broadcast_int() { let data_1 = TensorData::from([[0, 1, 2]]); let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]); let device = Default::default(); let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device); let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device); let output = tensor_1 - tensor_2; let expected = TensorData::from([[-3, -3, -3], [-6, -6, -6]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_sub_scalar_ops_int() { let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); let scalar = 2; let tensor = TestTensorInt::<2>::from_data(data, &Default::default()); let output = tensor - scalar; let expected = TensorData::from([[-2, -1, 0], [1, 2, 3]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/take.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_take_int_tensor() { // Test take with integer tensors let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[10, 20, 30], [40, 50, 60]], &device); let indices = TestTensorInt::<1>::from_data([1, 0], &device); let output = tensor.take::<1, 2>(0, indices); let expected = TensorData::from([[40, 50, 60], [10, 20, 30]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_take_int_tensor_with_2d_indices() { // Test take with integer tensors - output will be 3D let device = Default::default(); let tensor = TestTensorInt::<2>::from_data([[10, 20, 30], [40, 50, 60], [70, 80, 90]], &device); // 2D indices - shape [2, 2] let indices = TestTensorInt::<2>::from_data([[0, 2], [2, 1]], &device); let output = tensor.take::<2, 3>(0, indices); // Expected: shape [2, 2, 3] let expected = TensorData::from([[[10, 20, 30], [70, 80, 90]], [[70, 80, 90], [40, 50, 60]]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/topk.rs ================================================ use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; #[test] fn test_topk_1d() { // Int let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); let values = tensor.topk(3, /*dim*/ 0); let expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&expected, false); // Float let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); let values = tensor.topk(3, /*dim*/ 0); let expected = TensorData::from([5., 4., 3.]); values .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_topk() { // 3D Int let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); let values = tensor.topk(2, /*dim*/ 2); let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]); values.into_data().assert_eq(&expected, false); // 3D Float let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); let values = tensor.topk(2, /*dim*/ 2); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]); values .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_topk_with_indices_1d() { let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); let values_expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&values_expected, false); let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/transpose.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn should_support_transpose_ops_int() { let tensor = TestTensorInt::<3>::from_data( [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], &Default::default(), ); let output = tensor.transpose(); let expected = TensorData::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]); output.into_data().assert_eq(&expected, false); } #[test] fn should_support_swap_dims_int() { let tensor = TestTensorInt::<3>::from_data( [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], &Default::default(), ); let output = tensor.swap_dims(0, 2); let expected = TensorData::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]); output.into_data().assert_eq(&expected, false); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/tri.rs ================================================ use super::*; use burn_tensor::TensorData; #[test] fn test_triu_negative_diagonal() { let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]); let output = tensor.triu(-1); let expected = TensorData::from([[1, 1, 1], [1, 1, 1], [0, 1, 1]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_triu_batch_tensors() { let tensor = TestTensorInt::<4>::from([ [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], ]); let output = tensor.triu(1); let expected = TensorData::from([ [[[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]], [[[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]], ]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn test_triu_too_few_dims() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let _output = tensor.triu(0); } #[test] fn test_tril() { let tensor = TestTensor::<2>::from([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]); let output = tensor.tril(0); let expected = TensorData::from([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_tril_positive_diagonal() { let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]); let output = tensor.tril(1); let expected = TensorData::from([[1, 1, 0], [1, 1, 1], [1, 1, 1]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_tril_negative_diagonal() { let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]); let output = tensor.tril(-1); let expected = TensorData::from([[0, 0, 0], [1, 0, 0], [1, 1, 0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_tril_batch_tensors() { let tensor = TestTensorInt::<4>::from([ [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], ]); let output = tensor.tril(1); let expected = TensorData::from([ [[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 1]]], [[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 1]]], ]); output.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn test_tril_too_few_dims() { let tensor = TestTensorInt::<1>::from([1, 2, 3]); let _output = tensor.tril(0); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/ops/unfold.rs ================================================ use super::*; use burn_tensor::Distribution; use burn_tensor::s; #[test] fn test_unfold_int() { // Distribution::Default samples from [0, 255) if (IntElem::MAX as u32) < 255 - 1 { return; } let device = Default::default(); let input = TestTensorInt::<3>::random([2, 6, 6], Distribution::Default, &device); let dim = 1; let size = 3; let step = 2; let actual: TestTensorInt<4> = input.clone().unfold(dim, size, step); let expected = TestTensorInt::<4>::empty([2, 2, 6, 3], &device) .slice_assign( s![.., 0, .., ..], input .clone() .slice(s![.., 0..3, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ) .slice_assign( s![.., 1, .., ..], input .clone() .slice(s![.., 2..5, ..]) .swap_dims(1, 2) .unsqueeze_dim::<4>(1), ); actual.to_data().assert_eq(&expected.to_data(), true); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/int/primitive.rs ================================================ use super::*; use burn_tensor::{Element, Shape}; #[test] fn should_support_int_dtype() { let tensor = TestTensorInt::<2>::from([[0, -1, 2], [3, 4, -5]]).into_primitive(); assert_eq!( burn_tensor::TensorMetadata::shape(&tensor), Shape::new([2, 3]) ); assert_eq!( burn_tensor::TensorMetadata::dtype(&tensor), IntElem::dtype() // default int elem type ); } ================================================ FILE: crates/burn-backend-tests/tests/tensor/mod.rs ================================================ pub use super::*; // re-export test types mod clone_invariance; #[cfg(feature = "std")] mod multi_threads; // Data types mod bool; mod float; mod int; ================================================ FILE: crates/burn-backend-tests/tests/tensor/multi_threads.rs ================================================ use super::*; use core::time::Duration; use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; struct MultiThreadTestSettings { num_threads: usize, // The number of operations that are applied while the tensor is still alive and has a // reference count > 1 on the new thread. num_ops_alive: usize, // The number of operations that are applied after the tensor is consumed for the last time. num_ops_consumed: usize, // Number of operations that needs to execute before continuing execution on the main thread. sleep_before: Duration, sleep_alive: Duration, sleep_consumed: Duration, // If the output is dropped, otherwise it will be consumed by an operation. dropped: bool, } #[test] fn should_handle_multi_threads_dropped() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 5, num_ops_consumed: 5, sleep_before: Duration::from_millis(100), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: true, }) } #[test] fn should_handle_multi_threads_consumed() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 5, num_ops_consumed: 5, sleep_before: Duration::from_millis(100), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: false, }) } #[test] fn should_handle_multi_threads_drop_no_wait() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 5, num_ops_consumed: 5, sleep_before: Duration::from_millis(100), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: true, }) } #[test] fn should_handle_multi_threads_consumed_no_wait() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 5, num_ops_consumed: 5, sleep_before: Duration::from_millis(100), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: false, }) } #[test] fn should_handle_multi_threads_no_async_op() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 0, num_ops_consumed: 0, sleep_before: Duration::from_millis(100), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: false, }) } // Skip on metal - flaky (works when ran alone) // Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4328 #[cfg(not(feature = "metal"))] #[test] fn should_handle_multi_threads_no_async_op_no_wait() { run_multi_thread_test(MultiThreadTestSettings { num_threads: 3, num_ops_alive: 0, num_ops_consumed: 0, sleep_before: Duration::from_millis(0), sleep_alive: Duration::from_millis(100), sleep_consumed: Duration::from_millis(100), dropped: false, }) } fn run_multi_thread_test(settings: MultiThreadTestSettings) { let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]); let mut joined = Vec::with_capacity(settings.num_threads); let counter_alive = Arc::new(AtomicU32::new(0)); let counter_consumed = Arc::new(AtomicU32::new(0)); for i in 0..settings.num_threads { let tensor_moved = tensor.clone(); let ca_moved = counter_alive.clone(); let cc_moved = counter_consumed.clone(); let handle = std::thread::spawn(move || { let mut base = tensor_moved.clone(); std::thread::sleep(settings.sleep_before); if settings.num_ops_alive == 0 && settings.num_ops_consumed == 0 { core::mem::drop(tensor_moved); core::mem::drop(base); } else { if settings.num_ops_alive > 1 { for j in 0..(settings.num_ops_alive - 1) { base = tensor_moved.clone() + j as u32; ca_moved.fetch_add(1, Ordering::Relaxed); std::thread::sleep(settings.sleep_alive); } } base = base * tensor_moved + i as u32; ca_moved.fetch_add(1, Ordering::Relaxed); for n in 0..settings.num_ops_consumed { base = base + n as i32; cc_moved.fetch_add(1, Ordering::Relaxed); std::thread::sleep(settings.sleep_consumed); } let _data = base.into_data(); } }); joined.push(handle); } fn wait(counter: Arc, limit: usize) { loop { let counter_curr = counter.load(Ordering::Relaxed); if counter_curr as usize >= limit { break; } else { std::thread::sleep(Duration::from_millis(10)); } } } wait(counter_alive, settings.num_ops_alive); wait(counter_consumed, settings.num_ops_consumed); if settings.dropped { core::mem::drop(tensor); } else { let t = tensor * 2.0; let _t = t.into_data(); } for j in joined { j.join().unwrap(); } } ================================================ FILE: crates/burn-backend-tests/tests/tensor.rs ================================================ //! Burn backend tensor tests. #![allow(clippy::single_range_in_vec_init, reason = "false positive")] extern crate alloc; pub type FloatElemType = f32; #[allow(unused)] pub type IntElemType = i32; #[path = "common/backend.rs"] mod backend; pub use backend::*; #[path = "common/tensor.rs"] mod tensor; ================================================ FILE: crates/burn-candle/Cargo.toml ================================================ [package] authors = ["louisfd "] categories = ["science"] description = "[Deprecated] Candle backend for the Burn framework - use burn-cubecl, burn-ndarray, or burn-tch instead" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-candle" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-candle" documentation = "https://docs.rs/burn-candle" version.workspace = true [lints] workspace = true [features] default = ["std"] std = [] doc = ["default"] tracing = [ "burn-backend/tracing", "burn-std/tracing", ] cuda = ["candle-core/cuda"] metal = ["candle-core/metal"] accelerate = ["candle-core/accelerate"] [dependencies] burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } # For rand utils and stub mutex burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } candle-core = { workspace = true } derive-new = { workspace = true } [dev-dependencies] burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", default-features = false, features = [ ] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-candle/README.md ================================================ # Burn Candle Backend > **Deprecated:** This crate is deprecated as of `0.21.0-pre.2` and will be removed in a future release. > Please migrate to one of the actively maintained backends: > - **CubeCL backends** (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration > - **NdArray** for portable CPU execution > - **LibTorch** (`burn-tch`) for a mature CPU/GPU backend This crate provides a backend for [Burn](https://github.com/tracel-ai/burn) based on the [Candle](https://github.com/huggingface/candle) framework. ## Feature Flags - `cuda` - Cuda GPU device (NVIDIA only) - `accelerate` - Accelerate framework (macOS only) ================================================ FILE: crates/burn-candle/src/backend.rs ================================================ use std::marker::PhantomData; use burn_backend::{ BackTrace, Backend, DType, DTypeUsage, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive, tensor::Device, }; use burn_std::{ rand::{SeedableRng, StdRng}, stub::Mutex, }; use candle_core::{DeviceLocation, backend::BackendDevice}; use crate::{ CandleTensor, IntoDType, element::{CandleElement, FloatCandleElement, IntCandleElement}, }; /// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations. /// /// It is compatible with a wide range of hardware configurations, including CPUs and GPUs /// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU. #[derive(Clone, Default, Debug)] pub struct Candle where F: FloatCandleElement, I: IntCandleElement, { _float: PhantomData, _int: PhantomData, } // Seed for CPU device pub(crate) static SEED: Mutex> = Mutex::new(None); pub(crate) fn get_seeded_rng() -> StdRng { let mut seed = SEED.lock().unwrap(); seed.take().unwrap_or_else(burn_std::rand::get_seeded_rng) } pub(crate) fn set_seeded_rng(rng_seeded: StdRng) { let mut seed = SEED.lock().unwrap(); *seed = Some(rng_seeded); } /// The device type for the candle backend. #[derive(Clone, Debug, PartialEq, Eq)] /// The device struct when using the `candle` backend. /// /// To create a Cuda or Metal device from the index, use the associated methods to create the variant: /// ```no_run /// use burn_candle::CandleDevice; /// /// // Create a Cuda device from its index /// let device = CandleDevice::cuda(0); /// // Create a Metal device from its index /// let device = CandleDevice::metal(0); /// ``` #[derive(Default)] pub enum CandleDevice { /// CPU device. #[default] Cpu, /// Cuda device with the given index. The index is the index of the Cuda device in the list of /// all Cuda devices found on the system. Cuda(CudaDevice), /// Metal device with the given index. The index is the index of the Metal device in the list of /// all Metal devices found on the system. Metal(MetalDevice), } impl CandleDevice { /// Create a Cuda device with the given index. /// The index is the index of the Cuda device in the list of all Cuda devices found on the system. pub fn cuda(index: usize) -> Self { CandleDevice::Cuda(CudaDevice { device: candle_core::CudaDevice::new(index).unwrap(), index, }) } /// Create a Metal device with the given index. /// The index is the index of the Metal device in the list of all Metal devices found on the system. pub fn metal(index: usize) -> Self { CandleDevice::Metal(MetalDevice { device: candle_core::MetalDevice::new(index).unwrap(), index, }) } pub(crate) fn set_seed(&self, seed: u64) { match self { CandleDevice::Cpu => { // candle_core::cpu_backend::CpuDevice.set_seed(seed).unwrap(); // Candle does not support seeding the CPU rng so we use a global seed let rng = StdRng::seed_from_u64(seed); set_seeded_rng(rng); } CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(), CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(), } } } #[derive(Clone, Debug)] /// A Cuda device for the `candle` backend. pub struct CudaDevice { pub(crate) device: candle_core::CudaDevice, /// The index of the Cuda device in the list of all devices on the system. pub index: usize, } impl PartialEq for CudaDevice { fn eq(&self, other: &Self) -> bool { self.device.same_device(&other.device) && self.index == other.index } } impl Eq for CudaDevice {} #[derive(Clone, Debug)] /// A Metal device for the `candle` backend. pub struct MetalDevice { pub(crate) device: candle_core::MetalDevice, /// The index of the Metal device in the list of all devices on the system. pub index: usize, } impl PartialEq for MetalDevice { fn eq(&self, other: &Self) -> bool { self.device.same_device(&other.device) && self.index == other.index } } impl Eq for MetalDevice {} impl From for candle_core::Device { fn from(device: CandleDevice) -> Self { match device { CandleDevice::Cpu => candle_core::Device::Cpu, CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device), CandleDevice::Metal(device) => candle_core::Device::Metal(device.device), } } } impl From for CandleDevice { fn from(device: candle_core::Device) -> Self { match device.location() { DeviceLocation::Cpu => CandleDevice::Cpu, DeviceLocation::Cuda { gpu_id } => { if let candle_core::Device::Cuda(device) = device { CandleDevice::Cuda(CudaDevice { device, index: gpu_id, }) } else { panic!("Expected CUDA device."); } } DeviceLocation::Metal { gpu_id } => { if let candle_core::Device::Metal(device) = device { CandleDevice::Metal(MetalDevice { device, index: gpu_id, }) } else { panic!("Expected Metal device."); } } } } } impl burn_backend::Device for CandleDevice { fn to_id(&self) -> burn_backend::DeviceId { match self { CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32), CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32), CandleDevice::Cpu => DeviceId::new(2, 0), } } fn from_id(device_id: DeviceId) -> Self { match device_id.type_id { 0 => CandleDevice::cuda(device_id.index_id as usize), 1 => CandleDevice::metal(device_id.index_id as usize), _ => CandleDevice::Cpu, } } fn device_count(type_id: u16) -> usize { // TODO: Fix that 1 } } impl DeviceOps for CandleDevice {} impl Backend for Candle { type Device = CandleDevice; type FloatTensorPrimitive = CandleTensor; type FloatElem = F; type IntTensorPrimitive = CandleTensor; type IntElem = I; type BoolTensorPrimitive = CandleTensor; type BoolElem = u8; type QuantizedTensorPrimitive = CandleTensor; fn ad_enabled(_device: &Self::Device) -> bool { false } fn name(device: &Self::Device) -> String { match device { CandleDevice::Cpu => "candle", CandleDevice::Cuda(..) => "candle", CandleDevice::Metal(..) => "candle", } .to_string() } fn seed(device: &CandleDevice, seed: u64) { device.set_seed(seed); } fn sync(device: &Device) -> Result<(), ExecutionError> { let device: candle_core::Device = (device.clone()).into(); match device { candle_core::Device::Cpu => (), candle_core::Device::Cuda(device) => { #[cfg(feature = "cuda")] device .synchronize() .map_err(|err| ExecutionError::Generic { reason: format!("Can't sync the cuda device: {err}"), backtrace: BackTrace::capture(), })?; } candle_core::Device::Metal(device) => { // For some reason, device.wait_until_completed() does not seem to work, // and neither does writing and reading a value with into_data return Err(ExecutionError::Generic { reason: "Device synchronization unavailable with Metal device on Candle backend" .into(), backtrace: BackTrace::capture(), }); } } Ok(()) } fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { if dtype.try_into_dtype().is_ok() { burn_backend::DTypeUsage::general() } else { burn_backend::DTypeUsageSet::empty() } } } #[cfg(test)] mod tests { use burn_std::{BoolStore, QuantScheme}; use super::*; #[test] fn should_support_dtypes() { type B = Candle; let device = Default::default(); assert!(B::supports_dtype(&device, DType::F64)); assert!(B::supports_dtype(&device, DType::F32)); assert!(B::supports_dtype(&device, DType::Flex32)); assert!(B::supports_dtype(&device, DType::F16)); assert!(B::supports_dtype(&device, DType::BF16)); assert!(B::supports_dtype(&device, DType::I64)); assert!(B::supports_dtype(&device, DType::U32)); assert!(B::supports_dtype(&device, DType::U8)); assert!(B::supports_dtype(&device, DType::I32)); assert!(B::supports_dtype(&device, DType::I16)); assert!(B::supports_dtype(&device, DType::Bool(BoolStore::U8))); assert!(!B::supports_dtype(&device, DType::U64)); assert!(!B::supports_dtype(&device, DType::U16)); assert!(!B::supports_dtype(&device, DType::I8)); assert!(!B::supports_dtype(&device, DType::Bool(BoolStore::Native))); assert!(!B::supports_dtype( &device, DType::QFloat(QuantScheme::default()) )); } } ================================================ FILE: crates/burn-candle/src/element.rs ================================================ use std::borrow::Borrow; use burn_backend::{Element, bf16, f16}; use candle_core::{FloatDType, Tensor, WithDType}; /// Candle element pub trait CandleElement: Element + WithDType {} /// Candle float element pub trait FloatCandleElement: CandleElement + FloatDType {} /// Candle int element pub trait IntCandleElement: CandleElement {} impl CandleElement for f64 {} impl FloatCandleElement for f64 {} impl CandleElement for f32 {} impl FloatCandleElement for f32 {} impl CandleElement for f16 {} impl FloatCandleElement for f16 {} impl CandleElement for bf16 {} impl FloatCandleElement for bf16 {} impl CandleElement for u8 {} impl IntCandleElement for u8 {} impl CandleElement for u32 {} impl IntCandleElement for u32 {} impl CandleElement for i64 {} impl IntCandleElement for i64 {} ================================================ FILE: crates/burn-candle/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![allow(unused)] // TODO remove when backend filled #![deprecated( since = "0.21.0", note = "burn-candle is deprecated and will be removed in a future release. Use burn-cubecl (CUDA/ROCm/Vulkan/Metal/WebGPU), burn-ndarray, or burn-tch instead." )] //! Burn Candle Backend //! //! **Deprecated:** This backend is deprecated and will be removed in a future release. //! Please migrate to one of the actively maintained backends: //! - CubeCL backends (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration //! - NdArray for portable CPU execution //! - LibTorch (`burn-tch`) for a mature CPU/GPU backend #[macro_use] extern crate derive_new; mod backend; mod element; mod ops; mod tensor; pub use backend::*; pub use element::*; pub use tensor::*; ================================================ FILE: crates/burn-candle/src/ops/activation.rs ================================================ use burn_backend::{ops::ActivationOps, tensor::FloatTensor}; use crate::{ Candle, CandleTensor, element::{CandleElement, FloatCandleElement, IntCandleElement}, tensor, }; impl ActivationOps for Candle { fn gelu(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.gelu().unwrap()) } fn relu(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.relu().unwrap()) } } ================================================ FILE: crates/burn-candle/src/ops/base.rs ================================================ use std::cmp::max; use std::marker::PhantomData; use crate::{ Candle, CandleDevice, CandleTensor, element::{CandleElement, FloatCandleElement, IntCandleElement}, }; use burn_backend::{ BackTrace, Backend, Distribution, ExecutionError, Slice, bf16, f16, ops::unfold::{calculate_unfold_shape, calculate_unfold_windows}, }; use burn_backend::{Element, Shape, TensorData, TensorMetadata}; use candle_core::{Layout, WithDType}; use super::tensor; pub fn cpu_random(shape: Shape, distribution: Distribution) -> TensorData { let mut rng = crate::get_seeded_rng(); let data = TensorData::random::(shape, distribution, &mut rng); crate::set_seeded_rng(rng); data } pub fn cat(tensors: Vec, dim: usize) -> CandleTensor { let tensors: Vec = tensors.into_iter().map(|t| t.tensor).collect(); CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap()) } pub fn from_data(data: TensorData, device: &CandleDevice) -> CandleTensor { CandleTensor::from_data::(data, device.clone()) } pub fn into_data(tensor: CandleTensor) -> Result { fn tensor_data_from_dtype( tensor: &CandleTensor, ) -> Result { let data = tensor .tensor .flatten_all() .map_err(|err| ExecutionError::Generic { reason: format!("{err}"), backtrace: BackTrace::capture(), })? .to_vec1::() .map_err(|err| ExecutionError::Generic { reason: format!("{err}"), backtrace: BackTrace::capture(), })?; Ok(TensorData::new(data, tensor.shape())) } match tensor.tensor.dtype() { candle_core::DType::BF16 => tensor_data_from_dtype::(&tensor), candle_core::DType::F16 => tensor_data_from_dtype::(&tensor), candle_core::DType::F32 => tensor_data_from_dtype::(&tensor), candle_core::DType::F64 => tensor_data_from_dtype::(&tensor), candle_core::DType::U8 => tensor_data_from_dtype::(&tensor), candle_core::DType::U32 => tensor_data_from_dtype::(&tensor), candle_core::DType::I16 => tensor_data_from_dtype::(&tensor), candle_core::DType::I32 => tensor_data_from_dtype::(&tensor), candle_core::DType::I64 => tensor_data_from_dtype::(&tensor), other => todo!("{other:?} not yet supported"), } } pub fn to_device(tensor: CandleTensor, device: &CandleDevice) -> CandleTensor { CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap()) } pub fn empty(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor { zeros(shape, device, dtype) } pub fn zeros(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor { CandleTensor::new( candle_core::Tensor::zeros(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(), ) } pub fn ones(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor { CandleTensor::new( candle_core::Tensor::ones(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(), ) } pub fn swap_dims(mut tensor: CandleTensor, dim1: usize, dim2: usize) -> CandleTensor { CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap()) } pub fn permute(tensor: CandleTensor, axes: &[usize]) -> CandleTensor { CandleTensor::new(tensor.tensor.permute(axes).unwrap()) } pub fn flip(tensor: CandleTensor, axes: &[usize]) -> CandleTensor { // FIXME: Replace with an appropriate method when Candle provides one. let mut tensor = tensor.tensor; for &axis in axes { // Ensure tensor is contiguous before index_select (required by Candle) tensor = tensor.contiguous().unwrap(); let indexes = candle_core::Tensor::arange_step( tensor.dim(axis).unwrap() as i64 - 1, -1, -1, tensor.device(), ) .unwrap(); tensor = tensor.index_select(&indexes, axis).unwrap(); } CandleTensor::new(tensor) } pub fn reshape(tensor: CandleTensor, shape: Shape) -> CandleTensor { CandleTensor::new(tensor.tensor.reshape(shape.to_vec()).unwrap()) } pub fn device(tensor: &CandleTensor) -> CandleDevice { tensor.tensor.device().clone().into() } pub fn shape(tensor: &CandleTensor) -> Shape { tensor.shape() } pub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range]) -> CandleTensor { let mut narrow_tensor = tensor.tensor; for (i, range) in ranges.iter().enumerate().take(ranges.len()) { narrow_tensor = narrow_tensor .narrow(i, range.start, range.end - range.start) .unwrap() } CandleTensor::new(narrow_tensor) } pub fn slice_with_steps(tensor: CandleTensor, slices: &[Slice]) -> CandleTensor { let mut result_tensor = tensor.tensor; for (dim, slice) in slices.iter().enumerate() { if slice.step == 1 { // Use narrow for step=1 (more efficient) // Convert slice to range using tensor shape let dim_size = result_tensor.dim(dim).unwrap(); let range = slice.to_range(dim_size); let start = range.start; let length = range.end - range.start; result_tensor = result_tensor.narrow(dim, start, length).unwrap(); } else { // Use index_select for step != 1 let dim_size = result_tensor.dim(dim).unwrap(); let range = slice.to_range(dim_size); let start = range.start; let end = range.end; let step = slice.step; // Generate indices based on step direction let indices_vec = if step > 0 { // Forward stepping let step_usize = step as usize; (start..end).step_by(step_usize).collect::>() } else { // Backward stepping (negative step) let step_usize = step.unsigned_abs(); // Start from end-1 and go backwards let mut indices = Vec::new(); let mut idx = end - 1; while idx >= start && idx < end { // Check for underflow indices.push(idx); if idx >= step_usize { idx -= step_usize; } else { break; } } indices }; // Convert indices to tensor and use index_select let indices_len = indices_vec.len(); let device = result_tensor.device(); let indices = candle_core::Tensor::from_vec( indices_vec.iter().map(|&x| x as u32).collect::>(), indices_len, device, ) .unwrap(); result_tensor = result_tensor.index_select(&indices, dim).unwrap(); } } CandleTensor::new(result_tensor) } pub fn slice_assign(tensor: CandleTensor, slices: &[Slice], value: CandleTensor) -> CandleTensor { // Check if all slices have step=1 (candle's native slice_assign requirement) let all_unit_steps = slices.iter().all(|s| s.step == 1); if all_unit_steps { // Convert Slice to Range for candle's native slice_assign let ranges: Vec> = slices .iter() .enumerate() .map(|(dim, slice)| { let dim_size = tensor.tensor.dim(dim).unwrap_or(usize::MAX); slice.to_range(dim_size) }) .collect(); CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap()) } else { // Implement slice_assign with steps using scatter operations slice_assign_with_steps_workaround(tensor, slices, value) } } /// Implements slice_assign for non-unit steps using index operations fn slice_assign_with_steps_workaround( tensor: CandleTensor, slices: &[Slice], value: CandleTensor, ) -> CandleTensor { let shape = tensor.shape(); let ndims = shape.num_dims(); let device = tensor.tensor.device(); // Generate indices for each dimension based on slice specifications let indices_per_dim = generate_slice_indices(slices, &shape); // Early return if no elements to assign let total_elements: usize = indices_per_dim.iter().map(|v| v.len()).product(); if total_elements == 0 { return tensor; } // Flatten tensors and get metadata let value_flat = value.tensor.flatten_all().unwrap(); let strides = tensor.tensor.stride(); let tensor_shape = tensor.tensor.dims(); // Use a macro to handle different dtypes without code duplication macro_rules! apply_slice_assign { ($dtype:ty, $to_vec_fn:ident) => {{ let mut tensor_vec: Vec<$dtype> = tensor.tensor.flatten_all().unwrap().$to_vec_fn().unwrap(); let value_vec: Vec<$dtype> = value_flat.$to_vec_fn().unwrap(); // Apply assignments using cartesian product of indices for (value_idx, &value) in value_vec.iter().enumerate() { let flat_idx = compute_flat_index(value_idx, &indices_per_dim, &strides); if flat_idx < tensor_vec.len() { tensor_vec[flat_idx] = value; } } candle_core::Tensor::from_vec(tensor_vec, tensor_shape, device).unwrap() }}; } use candle_core::DType; let result = match tensor.tensor.dtype() { DType::F32 => apply_slice_assign!(f32, to_vec1), DType::F64 => apply_slice_assign!(f64, to_vec1), DType::I64 => apply_slice_assign!(i64, to_vec1), DType::U32 => apply_slice_assign!(u32, to_vec1), DType::U8 => apply_slice_assign!(u8, to_vec1), _ => panic!( "Unsupported dtype {:?} for slice_assign with steps", tensor.tensor.dtype() ), }; CandleTensor::new(result) } /// Generate indices for each dimension based on slice specifications fn generate_slice_indices(slices: &[Slice], tensor_dims: &[usize]) -> Vec> { let ndims = tensor_dims.len(); let mut indices_per_dim = Vec::with_capacity(ndims); // Process provided slices for (dim_idx, slice) in slices.iter().enumerate() { let dim_size = tensor_dims[dim_idx]; let range = slice.to_range(dim_size); let indices = generate_stepped_indices(range.start, range.end, slice.step); indices_per_dim.push(indices); } // Fill remaining dimensions with full ranges for &dim_size in tensor_dims.iter().skip(slices.len()) { indices_per_dim.push((0..dim_size).collect()); } indices_per_dim } /// Generate indices for a single dimension with stepping fn generate_stepped_indices(start: usize, end: usize, step: isize) -> Vec { if step > 0 { // Forward stepping (start..end).step_by(step as usize).collect() } else if step < 0 { // Backward stepping: start from end-1 and go backwards let step_size = step.unsigned_abs(); let mut indices = Vec::new(); let mut idx = end.saturating_sub(1); while idx >= start && idx < end { indices.push(idx); if idx >= step_size { idx -= step_size; } else { break; } } indices } else { // This branch should never be reached since step is validated to be non-zero panic!("Step cannot be zero") } } /// Compute flat index from multi-dimensional indices using cartesian product logic fn compute_flat_index( value_idx: usize, indices_per_dim: &[Vec], strides: &[usize], ) -> usize { let mut flat_idx = 0; let mut remainder = value_idx; // Convert value_idx to multi-dimensional indices and compute flat tensor index for dim in (0..indices_per_dim.len()).rev() { let dim_size = indices_per_dim[dim].len(); let idx_in_dim = remainder % dim_size; remainder /= dim_size; let actual_idx = indices_per_dim[dim][idx_in_dim]; flat_idx += actual_idx * strides[dim]; } flat_idx } pub fn narrow(tensor: CandleTensor, dim: usize, start: usize, length: usize) -> CandleTensor { let tensor = tensor.tensor.narrow(dim, start, length); match tensor { Ok(tensor) => CandleTensor::new(tensor), Err(e) => panic!("error narrow from Candle"), } } pub fn chunk(tensor: CandleTensor, chunks: usize, dim: usize) -> Vec { let tensors = tensor.tensor.chunk(chunks, dim); match tensors { Ok(tensors) => tensors.into_iter().map(CandleTensor::new).collect(), Err(e) => panic!("error chunk from Candle"), } } pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor { CandleTensor::new(tensor.tensor.broadcast_as(shape.to_vec()).unwrap()) } pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor { let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step); let windows = result_shape[dim]; let mut select_ranges = tensor.shape().into_ranges(); let new_axis = select_ranges.len(); let mut stack = Vec::with_capacity(windows); for widx in 0..windows { let start = widx * step; let end = start + size; select_ranges[dim] = start..end; let mut window_slice = slice(tensor.clone(), &select_ranges); window_slice = swap_dims(window_slice, dim, new_axis); let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap()); stack.push(window_slice); } cat(stack, dim) } pub fn sign(tensor: CandleTensor) -> CandleTensor { CandleTensor::new(tensor.tensor.sign().unwrap()) } pub fn mask_where_broadcasted( tensor: CandleTensor, mask: CandleTensor, value: CandleTensor, ) -> CandleTensor { let shape = tensor .tensor .shape() .broadcast_shape_binary_op(mask.tensor.shape(), "where_cond") .unwrap(); let mut tensor = tensor.tensor; let mut mask = mask.tensor; let mut value = value.tensor; if shape != *tensor.shape() { tensor = tensor.broadcast_as(shape.clone()).unwrap(); } if shape != *mask.shape() { mask = mask.broadcast_as(shape.clone()).unwrap(); } if shape != *value.shape() { value = value.broadcast_as(shape).unwrap(); } CandleTensor::new(mask.where_cond(&value, &tensor).unwrap()) } pub fn cross(lhs: CandleTensor, rhs: CandleTensor, dim: usize) -> CandleTensor { let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); let ndims = shape_lhs.num_dims(); // Broadcast the shapes except along dim let mut broadcast_shape = vec![0; ndims]; for (i, item) in broadcast_shape.iter_mut().enumerate().take(ndims) { if i == dim { *item = shape_lhs[i]; } else { let l = shape_lhs[i]; let r = shape_rhs[i]; if l == r { *item = l; } else if l == 1 { *item = r; } else if r == 1 { *item = l; } else { panic!("Tensors are not broadcastable along dimension {}", i); } } } // Broadcast lhs and rhs let lhs_broadcast = if shape_lhs == Shape::from(broadcast_shape.clone()) { lhs } else { expand(lhs, Shape::from(broadcast_shape.clone())) }; let rhs_broadcast = if shape_rhs == Shape::from(broadcast_shape.clone()) { rhs } else { expand(rhs, Shape::from(broadcast_shape.clone())) }; // Now, move dim to the last dimension let mut perm = (0..ndims).collect::>(); perm.remove(dim); perm.push(dim); let lhs_permuted = permute(lhs_broadcast, &perm); let rhs_permuted = permute(rhs_broadcast, &perm); // Reshape to (*, 3) let total_elements = lhs_permuted.shape().num_elements(); let batch_size = total_elements / 3; let lhs_reshaped = reshape(lhs_permuted, Shape::new([batch_size, 3])); let rhs_reshaped = reshape(rhs_permuted, Shape::new([batch_size, 3])); // Extract components using narrow and squeeze let lhs_0 = CandleTensor::new( lhs_reshaped .tensor .narrow(1, 0, 1) .unwrap() .squeeze(1) .unwrap(), ); let lhs_1 = CandleTensor::new( lhs_reshaped .tensor .narrow(1, 1, 1) .unwrap() .squeeze(1) .unwrap(), ); let lhs_2 = CandleTensor::new( lhs_reshaped .tensor .narrow(1, 2, 1) .unwrap() .squeeze(1) .unwrap(), ); let rhs_0 = CandleTensor::new( rhs_reshaped .tensor .narrow(1, 0, 1) .unwrap() .squeeze(1) .unwrap(), ); let rhs_1 = CandleTensor::new( rhs_reshaped .tensor .narrow(1, 1, 1) .unwrap() .squeeze(1) .unwrap(), ); let rhs_2 = CandleTensor::new( rhs_reshaped .tensor .narrow(1, 2, 1) .unwrap() .squeeze(1) .unwrap(), ); // Compute cross product components let result_0 = CandleTensor::new( lhs_1 .tensor .mul(&rhs_2.tensor) .unwrap() .sub(&lhs_2.tensor.mul(&rhs_1.tensor).unwrap()) .unwrap(), ); let result_1 = CandleTensor::new( lhs_2 .tensor .mul(&rhs_0.tensor) .unwrap() .sub(&lhs_0.tensor.mul(&rhs_2.tensor).unwrap()) .unwrap(), ); let result_2 = CandleTensor::new( lhs_0 .tensor .mul(&rhs_1.tensor) .unwrap() .sub(&lhs_1.tensor.mul(&rhs_0.tensor).unwrap()) .unwrap(), ); // Stack the components let result_0_unsqueezed = CandleTensor::new(result_0.tensor.unsqueeze(1).unwrap()); let result_1_unsqueezed = CandleTensor::new(result_1.tensor.unsqueeze(1).unwrap()); let result_2_unsqueezed = CandleTensor::new(result_2.tensor.unsqueeze(1).unwrap()); let result = cat( vec![ result_0_unsqueezed, result_1_unsqueezed, result_2_unsqueezed, ], 1, ); // Reshape back to the broadcast shape with dim at the end let mut result_shape = broadcast_shape; result_shape.remove(dim); result_shape.push(3); let result_reshaped = reshape(result, Shape::from(result_shape)); // Permute back let mut inv_perm = vec![0; ndims]; for (i, &p) in perm.iter().enumerate() { inv_perm[p] = i; } permute(result_reshaped, &inv_perm) } ================================================ FILE: crates/burn-candle/src/ops/bool_tensor.rs ================================================ use burn_backend::{ BackTrace, DType, ExecutionError, Scalar, Shape, Slice, TensorData, TensorMetadata, ops::BoolTensorOps, tensor::{BoolTensor, Device, FloatTensor, IntTensor}, }; use crate::{ Candle, CandleTensor, element::{CandleElement, FloatCandleElement, IntCandleElement}, }; use super::base::{expand, permute, unfold}; impl BoolTensorOps for Candle { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { super::base::empty(shape, device, candle_core::DType::U8) } fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor { super::base::zeros(shape, device, candle_core::DType::U8) } fn bool_ones(shape: Shape, device: &Device) -> BoolTensor { super::base::ones(shape, device, candle_core::DType::U8) } async fn bool_into_data(tensor: BoolTensor) -> Result { let x: Vec = tensor .tensor .flatten_all() .map_err(|err| ExecutionError::Generic { reason: format!("{err}"), backtrace: BackTrace::capture(), })? .to_vec1() .map_err(|err| ExecutionError::Generic { reason: format!("{err}"), backtrace: BackTrace::capture(), })?; let y = x.iter().map(|b| !matches!(b, 0)).collect(); Ok(TensorData::new(y, tensor.shape())) } fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { match data.dtype { DType::U8 => super::base::from_data::(data, device), _ => unimplemented!("Unsupported dtype for `bool_from_data`"), } } fn bool_into_int(tensor: BoolTensor) -> IntTensor { CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) } fn bool_device(tensor: &BoolTensor) -> Device { super::base::device(tensor) } fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { super::base::to_device(tensor, device) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { super::base::reshape(tensor, shape) } fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor { super::base::slice_with_steps(tensor, slices) } fn bool_slice_assign( tensor: BoolTensor, slices: &[Slice], value: BoolTensor, ) -> BoolTensor { super::base::slice_assign(tensor, slices, value) } fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { super::base::cat(tensors, dim) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap()) } fn bool_not(tensor: BoolTensor) -> BoolTensor { let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap()); CandleTensor::new(tensor.tensor.eq(&x).unwrap()) } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap(); CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap()) } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { CandleTensor::new( lhs.tensor .add(&rhs.tensor) .unwrap() .clamp(0u32, 1u32) .unwrap(), ) } fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { super::base::swap_dims(tensor, dim1, dim2) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { super::base::permute(tensor, axes) } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { super::base::flip(tensor, axes) } fn bool_select( tensor: BoolTensor, dim: usize, indices: IntTensor, ) -> BoolTensor { CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) } fn bool_select_or( tensor: BoolTensor, dim: usize, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { CandleTensor::new( tensor .tensor .index_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { expand(tensor, shape) } fn bool_unfold( tensor: BoolTensor, dim: usize, size: usize, step: usize, ) -> BoolTensor { unfold(tensor, dim, size, step) } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { super::base::mask_where_broadcasted(tensor, mask, value) } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { CandleTensor::new( mask.tensor .where_cond( &super::candle_utils::fill_like::(value.elem(), &tensor.tensor), &tensor.tensor, ) .unwrap(), ) } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { let tensor = tensor.tensor.contiguous().unwrap(); let indices = indices.tensor.contiguous().unwrap(); CandleTensor::new(tensor.gather(&indices, dim).unwrap()) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { CandleTensor::new( tensor .tensor .scatter_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new(lhs.tensor.eq(rhs.elem::()).unwrap()) } } ================================================ FILE: crates/burn-candle/src/ops/candle_utils.rs ================================================ use candle_core::{DType, Device, Shape, Tensor}; use crate::element::CandleElement; pub(crate) fn fill>( value: E, shape: S, dtype: DType, device: &Device, ) -> Tensor { let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::()).unwrap(); values.expand(shape).unwrap() } pub(crate) fn fill_like(value: E, reference_tensor: &Tensor) -> Tensor { fill( value, reference_tensor.shape(), reference_tensor.dtype(), reference_tensor.device(), ) } /// Broadcasts two tensors to a common shape for comparison operations pub(crate) fn broadcast_for_comparison( lhs: &Tensor, rhs: &Tensor, ) -> Result<(Tensor, Tensor), candle_core::Error> { let broadcast_shape = lhs .shape() .broadcast_shape_binary_op(rhs.shape(), "comparison")?; let lhs = if broadcast_shape != *lhs.shape() { lhs.broadcast_as(&broadcast_shape)? } else { lhs.clone() }; let rhs = if broadcast_shape != *rhs.shape() { rhs.broadcast_as(&broadcast_shape)? } else { rhs.clone() }; Ok((lhs, rhs)) } ================================================ FILE: crates/burn-candle/src/ops/int_tensor.rs ================================================ use burn_backend::{ DType, Distribution, ElementConversion, ExecutionError, IntDType, Scalar, Shape, Slice, TensorData, ops::{FloatTensorOps, IntTensorOps}, tensor::{Bool, BoolTensor, Device, FloatTensor, IntElem, IntTensor}, }; use crate::{ Candle, CandleDevice, CandleTensor, IntoDType, element::{CandleElement, FloatCandleElement, IntCandleElement}, }; use super::base::{cpu_random, expand, permute, sign, unfold}; impl IntTensorOps for Candle { fn int_empty(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { super::base::empty(shape, device, dtype.into_dtype()) } async fn int_into_data(tensor: IntTensor) -> Result { super::base::into_data(tensor) } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { match data.dtype { DType::I64 => super::base::from_data::(data, device), DType::U32 => super::base::from_data::(data, device), DType::U8 => super::base::from_data::(data, device), _ => unimplemented!("Unsupported dtype for `int_from_data`"), } } fn int_device(tensor: &IntTensor) -> Device { super::base::device(tensor) } fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { super::base::to_device(tensor, device) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { super::base::reshape(tensor, shape) } fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor { super::base::slice_with_steps(tensor, slices) } fn int_slice_assign( tensor: IntTensor, slices: &[Slice], value: IntTensor, ) -> IntTensor { super::base::slice_assign(tensor, slices, value) } fn int_into_float(tensor: IntTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, source: IntTensor, ) -> IntTensor { super::base::mask_where_broadcasted(tensor, mask, source) } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> IntTensor { CandleTensor::new( mask.tensor .where_cond( &super::candle_utils::fill_like::(value.elem(), &tensor.tensor), &tensor.tensor, ) .unwrap(), ) } fn int_gather( dim: usize, tensor: IntTensor, indices: IntTensor, ) -> IntTensor { let tensor = tensor.tensor.contiguous().unwrap(); let indices = indices.tensor.contiguous().unwrap(); CandleTensor::new(tensor.gather(&indices, dim).unwrap()) } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { CandleTensor::new( tensor .tensor .scatter_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn int_select( tensor: IntTensor, dim: usize, indices: IntTensor, ) -> IntTensor { CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { CandleTensor::new( tensor .tensor .index_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { super::base::cat(tensors, dim) } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap()) } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new(lhs.tensor.eq(rhs.elem::()).unwrap()) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap()) } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .gt(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap()) } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .ge(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap()) } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .lt(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap()) } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .le(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. panic!("Not supported by Candle") } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { CandleTensor::new( (lhs.tensor.clone() - lhs .tensor .broadcast_div(&rhs.tensor) .unwrap() .broadcast_mul(&rhs.tensor) .unwrap()) .unwrap(), ) } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { // Same problem as int_div_scalar. panic!("Not supported by Candle") } fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { CandleTensor::new( candle_core::Tensor::zeros( shape.to_vec(), dtype.into_dtype(), &(device.clone()).into(), ) .unwrap(), ) } fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { CandleTensor::new( candle_core::Tensor::ones(shape.to_vec(), dtype.into_dtype(), &(device.clone()).into()) .unwrap(), ) } fn int_sum(tensor: IntTensor) -> IntTensor { let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); CandleTensor::from_data::( TensorData::new([sum].into(), [1]), Self::int_device(&tensor), ) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) } fn int_prod(tensor: IntTensor) -> IntTensor { todo!( "prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)" ) } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { todo!( "prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)" ) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0. panic!("Not supported by Candle") } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { // Candle's cumsum doesn't support integer types, so we convert to float, // compute cumsum, and convert back to int let dtype = tensor.tensor.dtype(); let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap(); let result_float = tensor_float.cumsum(dim).unwrap(); CandleTensor::new(result_float.to_dtype(dtype).unwrap()) } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { // Convert to float for computation, then convert back let dtype = tensor.tensor.dtype(); let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap(); let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| { prev.broadcast_mul(curr) }); CandleTensor::new(result_float.to_dtype(dtype).unwrap()) } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { // Convert to float for computation, then convert back let dtype = tensor.tensor.dtype(); let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap(); let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| { prev.broadcast_minimum(curr) }); CandleTensor::new(result_float.to_dtype(dtype).unwrap()) } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| { prev.broadcast_maximum(curr) }); CandleTensor::new(result) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { CandleTensor::new( tensor .tensor .argmax_keepdim(dim) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { CandleTensor::new( tensor .tensor .argmin_keepdim(dim) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ) } fn int_abs(tensor: IntTensor) -> IntTensor { // Ugly type conversion here as Candle does not support unary ops on ints match tensor.tensor.dtype() { candle_core::DType::U8 | candle_core::DType::U32 => tensor, candle_core::DType::I64 => CandleTensor::new( tensor .tensor .to_dtype(F::DTYPE) .unwrap() .abs() .unwrap() .to_dtype(candle_core::DType::I64) .unwrap(), ), _ => unreachable!(), } } fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { super::base::swap_dims(tensor, dim1, dim2) } fn int_random( shape: Shape, distribution: Distribution, device: &Device, ) -> IntTensor { if let CandleDevice::Cpu = device { let distribution = if distribution == Distribution::Default { Distribution::Uniform(0.0, 255.0) } else { distribution }; // Use our own seed since candle doesn't support it on CPU return Self::int_from_data(cpu_random::(shape, distribution), device); } let shape = shape.to_vec(); let device = &(device.clone()).into(); match distribution { Distribution::Default => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 255.elem::(), shape, device) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ), Distribution::Bernoulli(prob) => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 1.elem::(), shape.clone(), device) .unwrap() .to_dtype(I::DTYPE) .unwrap() .lt(&super::candle_utils::fill(prob, shape, I::DTYPE, device)) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ), Distribution::Uniform(from, to) => CandleTensor::new( candle_core::Tensor::rand(from.elem::(), to.elem::(), shape, device).unwrap(), ), Distribution::Normal(mean, std) => CandleTensor::new( candle_core::Tensor::randn(mean.elem::(), std.elem::(), shape, device) .unwrap(), ), } } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { super::base::permute(tensor, axes) } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { super::base::flip(tensor, axes) } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { expand(tensor, shape) } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { unfold(tensor, dim, size, step) } fn int_sign(tensor: IntTensor) -> IntTensor { sign(tensor) } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { unimplemented!("bitwise_and is not implemented for Candle IntTensor"); } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor"); } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { unimplemented!("bitwise_or is not implemented for Candle IntTensor"); } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor"); } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { unimplemented!("bitwise_xor is not implemented for Candle IntTensor"); } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor"); } fn bitwise_not(tensor: IntTensor) -> IntTensor { unimplemented!("bitwise_not is not implemented for Candle IntTensor"); } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor"); } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor"); } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor"); } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor"); } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let lhs = Self::int_into_float(lhs); let rhs = Self::int_into_float(rhs); let out = Self::float_matmul(lhs, rhs); Self::float_into_int(out) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { let dtype = dtype.into_dtype(); if tensor.tensor.dtype() == dtype { tensor } else { CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap()) } } } ================================================ FILE: crates/burn-candle/src/ops/mod.rs ================================================ mod activation; mod base; mod bool_tensor; mod candle_utils; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; mod utils; ================================================ FILE: crates/burn-candle/src/ops/module.rs ================================================ use burn_backend::{ Shape, ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, UnfoldOptions, attention::attention_fallback, }, tensor::{FloatTensor, IntTensor}, }; use candle_core::ToUsize2; use crate::{ Candle, CandleTensor, element::{CandleElement, FloatCandleElement, IntCandleElement}, ops::base::reshape, }; impl ModuleOps for Candle { fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { let conv = x .tensor .conv1d( &weight.tensor, options.padding[0], options.stride[0], options.dilation[0], options.groups, ) .unwrap(); CandleTensor::new(match bias { Some(bias) => conv .broadcast_add(&bias.tensor.unsqueeze(1).unwrap()) .unwrap(), None => conv, }) } fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { assert!( options.dilation[0] == options.dilation[1] && options.padding[0] == options.padding[1] && options.stride[0] == options.stride[1], "Candle does not support per dimension options in convolutions" ); let conv = x .tensor .conv2d( &weight.tensor, options.padding[0], options.stride[0], options.dilation[0], options.groups, ) .unwrap(); CandleTensor::new(match bias { Some(bias) => conv .broadcast_add( &bias .tensor .unsqueeze(0) .unwrap() .unsqueeze(2) .unwrap() .unsqueeze(3) .unwrap(), ) .unwrap(), None => conv, }) } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { unimplemented!("Candle does not support deformable convolutions") } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { unimplemented!("Candle does not support deformable convolutions") } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { panic!("Candle does not support 3D convolutions"); } fn conv_transpose1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { let conv_transpose = x .tensor .conv_transpose1d( &weight.tensor, options.padding[0], options.padding_out[0], options.stride[0], options.dilation[0], options.groups, ) .unwrap(); CandleTensor::new(match bias { Some(bias) => conv_transpose .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap()) .unwrap(), None => conv_transpose, }) } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { assert!( options.dilation[0] == options.dilation[1] && options.padding[0] == options.padding[1] && options.padding_out[0] == options.padding_out[1] && options.stride[0] == options.stride[1], "Candle does not support per dimension options in transposed convolutions" ); assert!( options.groups == 1, "Candle does not support groups in transposed convolutions" ); let conv_transpose = x .tensor .conv_transpose2d( &weight.tensor, options.padding[0], options.padding_out[0], options.stride[0], options.dilation[0], ) .unwrap(); CandleTensor::new(match bias { Some(bias) => conv_transpose .broadcast_add( &bias .tensor .unsqueeze(0) .unwrap() .unsqueeze(2) .unwrap() .unsqueeze(3) .unwrap(), ) .unwrap(), None => conv_transpose, }) } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { panic!("Candle does not support 3D transposed convolutions"); } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { assert!( padding[0] == 0 && padding[1] == 0, "Candle does not support padding in pooling" ); assert!( count_include_pad, "Candle does not support excluding pad count in pooling" ); assert!(!ceil_mode, "Candle does not support ceil_mode in pooling"); CandleTensor::new( x.tensor .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) .unwrap(), ) } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, _ceil_mode: bool, ) -> FloatTensor { panic!("avg_pool2d_backward is not supported by Candle") } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { assert!( padding[0] == 0 && padding[1] == 0, "Candle does not support padding in pooling" ); assert!( dilation[0] == 1 && dilation[1] == 1, "Candle does not support dilation in pooling" ); assert!(!ceil_mode, "Candle does not support ceil_mode in pooling"); CandleTensor::new( x.tensor .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1])) .unwrap(), ) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], _ceil_mode: bool, ) -> MaxPool2dWithIndices> { panic!("max_pool2d_with_indices is not supported by Candle") } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], _ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward> { panic!("max_pool2d_with_indices_backward is not supported by Candle") } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { panic!("adaptive_avg_pool2 is not supported by Candle") } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { panic!("adaptive_avg_pool2d_backward is not supported by Candle") } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { let tensor = match options.mode { InterpolateMode::Nearest => x .tensor .upsample_nearest2d(output_size[0], output_size[1]) .unwrap(), InterpolateMode::Bilinear => { panic!("bilinear interpolation is not supported by Candle") } InterpolateMode::Bicubic => { panic!("bicubic interpolation is not supported by Candle") } InterpolateMode::Lanczos3 => { panic!("lanczos3 interpolation is not supported by Candle") } }; CandleTensor::new(tensor) } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { panic!("interpolate_backward is not supported by Candle") } fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: burn_backend::ops::AttentionModuleOptions, ) -> FloatTensor { attention_fallback::(query, key, value, mask, attn_bias, options) } } ================================================ FILE: crates/burn-candle/src/ops/qtensor.rs ================================================ use burn_backend::{ Backend, DType, ExecutionError, Shape, Slice, TensorData, ops::QTensorOps, quantization::{QuantScheme, QuantizationParametersPrimitive}, tensor::{Device, FloatTensor, IntTensor, QuantizedTensor}, }; use crate::{ Candle, element::{FloatCandleElement, IntCandleElement}, }; impl QTensorOps for Candle { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { unimplemented!() } fn quantize( _tensor: FloatTensor, _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { unimplemented!() } fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { unimplemented!() } fn q_device(_tensor: &QuantizedTensor) -> Device { unimplemented!() } fn q_to_device( _tensor: QuantizedTensor, _device: &Device, ) -> QuantizedTensor { unimplemented!() } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } async fn q_into_data(tensor: QuantizedTensor) -> Result { unimplemented!() } fn q_swap_dims( _tensor: QuantizedTensor, _dim1: usize, _dim2: usize, ) -> QuantizedTensor { unimplemented!() } fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_gather( _dim: usize, _tensor: QuantizedTensor, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_select( _tensor: QuantizedTensor, _dim: usize, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_slice(_tensor: QuantizedTensor, _slices: &[Slice]) -> QuantizedTensor { unimplemented!() } fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } } ================================================ FILE: crates/burn-candle/src/ops/tensor.rs ================================================ use std::borrow::Borrow; use burn_backend::{ DType, Distribution, ElementConversion, ExecutionError, FloatDType, Scalar, Shape, Slice, TensorData, bf16, f16, ops::FloatTensorOps, tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}, }; use candle_core::{Tensor, backend::BackendStorage, shape}; use crate::{ Candle, CandleDevice, CandleTensor, IntoDType, element::{CandleElement, FloatCandleElement, IntCandleElement}, }; use super::base::{cpu_random, expand, permute, sign, unfold}; impl FloatTensorOps for Candle { fn float_from_data(data: TensorData, device: &Device) -> CandleTensor { match data.dtype { DType::F64 => super::base::from_data::(data, device), DType::F32 => super::base::from_data::(data, device), DType::F16 => super::base::from_data::(data, device), DType::BF16 => super::base::from_data::(data, device), _ => unimplemented!("Unsupported dtype for `float_from_data`"), } } fn float_random( shape: Shape, distribution: Distribution, device: &Device, ) -> FloatTensor { if let CandleDevice::Cpu = device { // Use our own seed since candle doesn't support it on CPU return Self::float_from_data(cpu_random::(shape, distribution), device); } let shape = shape.to_vec(); let device = &(device.clone()).into(); match distribution { Distribution::Default => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 1.elem::(), shape, device) .unwrap() .to_dtype(F::DTYPE) .unwrap(), ), Distribution::Bernoulli(prob) => CandleTensor::new( candle_core::Tensor::rand(0.elem::(), 1.elem::(), shape.clone(), device) .unwrap() .to_dtype(F::DTYPE) .unwrap() .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device)) .unwrap() .to_dtype(F::DTYPE) .unwrap(), ), Distribution::Uniform(from, to) => CandleTensor::new( candle_core::Tensor::rand(from.elem::(), to.elem::(), shape, device).unwrap(), ), Distribution::Normal(mean, std) => CandleTensor::new( candle_core::Tensor::randn(mean.elem::(), std.elem::(), shape, device) .unwrap(), ), } } async fn float_into_data(tensor: CandleTensor) -> Result { super::base::into_data(tensor) } fn float_device(tensor: &CandleTensor) -> Device { super::base::device(tensor) } fn float_to_device(tensor: CandleTensor, device: &Device) -> CandleTensor { super::base::to_device(tensor, device) } fn float_into_int(tensor: CandleTensor) -> IntTensor { CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap()) } fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { super::base::empty(shape, device, dtype.into_dtype()) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap()) } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { CandleTensor::new((lhs.tensor + rhs.elem::()).unwrap()) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap()) } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { CandleTensor::new((lhs.tensor - rhs.elem::()).unwrap()) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap()) } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { CandleTensor::new((lhs.tensor * rhs.elem::()).unwrap()) } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap()) } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { CandleTensor::new((lhs.tensor / rhs.elem::()).unwrap()) } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { CandleTensor::new( (lhs.tensor.clone() - lhs .tensor .broadcast_div(&rhs.tensor) .unwrap() .floor() .unwrap() .broadcast_mul(&rhs.tensor) .unwrap()) .unwrap(), ) } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { // In PyTorch, remainder can also be defined as torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b let rhs_val = rhs.elem::(); let division_result = (lhs.tensor.clone() / rhs_val).unwrap().floor().unwrap(); let product = division_result * rhs_val; CandleTensor::new((lhs.tensor - product).unwrap()) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let lhs_contiguous = if !lhs.tensor.is_contiguous() { lhs.tensor.contiguous().unwrap() } else { lhs.tensor }; let rhs_contiguous = if !rhs.tensor.is_contiguous() { rhs.tensor.contiguous().unwrap() } else { rhs.tensor }; CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap()) } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { super::base::cross(lhs, rhs, dim) } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { super::base::swap_dims(tensor, dim1, dim2) } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { super::base::reshape(tensor, shape) } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { let tensor = tensor.tensor.contiguous().unwrap(); let indices = indices.tensor.contiguous().unwrap(); CandleTensor::new(tensor.gather(&indices, dim).unwrap()) } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { CandleTensor::new( tensor .tensor .scatter_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap()) } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { CandleTensor::new( tensor .tensor .index_add(&indices.tensor, &value.tensor, dim) .unwrap(), ) } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { super::base::slice_with_steps(tensor, slices) } fn float_slice_assign( tensor: FloatTensor, slices: &[Slice], value: FloatTensor, ) -> FloatTensor { super::base::slice_assign(tensor, slices, value) } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { super::base::mask_where_broadcasted(tensor, mask, value) } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { let value = super::candle_utils::fill_like::(value.elem(), &tensor.tensor); super::base::mask_where_broadcasted(tensor, mask, CandleTensor::new(value)) } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap()) } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new(lhs.tensor.eq(rhs.elem::()).unwrap()) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap()) } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .gt(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap()) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .ge(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap()) } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .lt(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let (lhs_broadcast, rhs_broadcast) = super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap(); CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap()) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { CandleTensor::new( lhs.tensor .le(&super::candle_utils::fill_like::( rhs.elem(), &lhs.tensor, )) .unwrap(), ) } fn float_sum(tensor: FloatTensor) -> FloatTensor { let sum = tensor.tensor.sum_all().unwrap().to_scalar::().unwrap(); CandleTensor::from_data::( TensorData::new([sum].into(), [1]), Self::float_device(&tensor), ) } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| { prev.broadcast_mul(curr) }); CandleTensor::new(result) } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| { prev.broadcast_minimum(curr) }); CandleTensor::new(result) } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| { prev.broadcast_maximum(curr) }); CandleTensor::new(result) } fn float_exp(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.exp().unwrap()) } fn float_log(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.log().unwrap()) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap()) } fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor { CandleTensor::new(tensor.tensor.powf(value.elem::()).unwrap()) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.sqrt().unwrap()) } fn float_abs(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.abs().unwrap()) } fn float_cos(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.cos().unwrap()) } fn float_cosh(tensor: FloatTensor) -> FloatTensor { // cosh(x) = (e^x + e^(-x)) / 2 let exp_x = tensor.tensor.exp().unwrap(); CandleTensor::new(((exp_x.clone() + exp_x.recip().unwrap()).unwrap() / 2.0).unwrap()) } fn float_sin(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.sin().unwrap()) } fn float_sinh(tensor: FloatTensor) -> FloatTensor { // sinh(x) = (e^x - e^(-x)) / 2 let exp_x = tensor.tensor.exp().unwrap(); CandleTensor::new(((exp_x.clone() - exp_x.recip().unwrap()).unwrap() / 2.0).unwrap()) } fn float_tan(tensor: FloatTensor) -> FloatTensor { CandleTensor::new((tensor.tensor.sin().unwrap() / tensor.tensor.cos().unwrap()).unwrap()) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.tanh().unwrap()) } fn float_acos(tensor: FloatTensor) -> FloatTensor { // acos(x) = PI/2 - asin(x) let neg_asin_x = Self::float_neg(Self::float_asin(tensor)); Self::float_add_scalar(neg_asin_x, core::f64::consts::FRAC_PI_2.into()) } fn float_acosh(tensor: FloatTensor) -> FloatTensor { // acosh(x) = ln(x + sqrt(x^2 - 1)) let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into()); let x_sq_minus_one = Self::float_sub_scalar(x_squared, 1f64.into()); let sqrt_term = Self::float_sqrt(x_sq_minus_one); Self::float_log(Self::float_add(tensor, sqrt_term)) } fn float_asin(tensor: FloatTensor) -> FloatTensor { // asin(x) = atan(x / sqrt(1 - x^2)) let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into()); let one_minus_x_sq = Self::float_add_scalar(Self::float_neg(x_squared), 1f64.into()); let sqrt_term = Self::float_sqrt(one_minus_x_sq); Self::float_atan(Self::float_div(tensor, sqrt_term)) } fn float_asinh(tensor: FloatTensor) -> FloatTensor { // asinh(x) = ln(x + sqrt(x^2 + 1)) let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into()); let x_sq_plus_one = Self::float_add_scalar(x_squared, 1f64.into()); let sqrt_term = Self::float_sqrt(x_sq_plus_one); Self::float_log(Self::float_add(tensor, sqrt_term)) } fn float_atan(tensor: FloatTensor) -> FloatTensor { // atan(x) = asin(x / sqrt(1 + x^2)) let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into()); let one_plus_x_sq = Self::float_add_scalar(x_squared, 1f64.into()); let sqrt_term = Self::float_sqrt(one_plus_x_sq); Self::float_asin(Self::float_div(tensor, sqrt_term)) } fn float_atanh(tensor: FloatTensor) -> FloatTensor { // atanh(x) = ln((1 + x) / (1 - x)) / 2 let num = (1.0 + tensor.tensor.clone()).unwrap(); let denom = (1.0 - tensor.tensor).unwrap(); CandleTensor::new(((num / denom).unwrap().log().unwrap() / 2.0).unwrap()) } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { // atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x)) let x_squared = Self::float_powi_scalar(rhs.clone(), 2.into()); let y_squared = Self::float_powi_scalar(lhs.clone(), 2.into()); let r = Self::float_sqrt(Self::float_add(x_squared, y_squared)); let ratio = Self::float_div(lhs, Self::float_add(r, rhs)); Self::float_mul_scalar(Self::float_atan(ratio), 2f64.into()) } fn float_round(tensor: FloatTensor) -> FloatTensor { let inner = |tensor: FloatTensor| -> candle_core::Result> { // implements round_to_even for consistent behavior vs libtorch // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67 let floor_a = tensor.tensor.floor()?; let frac_part = tensor.tensor.sub(&floor_a)?; let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?; let mask_half = frac_part.eq(&half)?; let half_tensor = tensor.tensor.mul(&half)?; let rounded_half = half_tensor.round()?; let doubled = rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?; let standard_round = tensor.tensor.round()?; Ok(CandleTensor::new( mask_half.where_cond(&doubled, &standard_round)?, )) }; inner(tensor).unwrap() } fn float_floor(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.floor().unwrap()) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.ceil().unwrap()) } fn float_trunc(tensor: FloatTensor) -> FloatTensor { // truncate(x) = ⌊x⌋ if x ≥ 0, and ⌈x⌉ if x < 0 // This preserves the sign of zero and handles all special cases correctly let is_negative = tensor.tensor.lt(0.0).unwrap(); let floored = tensor.tensor.floor().unwrap(); let ceiled = tensor.tensor.ceil().unwrap(); CandleTensor::new(is_negative.where_cond(&ceiled, &floored).unwrap()) } fn float_erf(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.erf().unwrap()) } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { super::base::cat(tensors, dim) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { CandleTensor::new( tensor .tensor .argmax_keepdim(dim) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ) } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { CandleTensor::new( tensor .tensor .argmin_keepdim(dim) .unwrap() .to_dtype(I::DTYPE) .unwrap(), ) } fn float_clamp_max(tensor: FloatTensor, max: Scalar) -> FloatTensor { CandleTensor::new(tensor.tensor.minimum(max.elem::()).unwrap()) } fn float_clamp_min(tensor: FloatTensor, min: Scalar) -> FloatTensor { CandleTensor::new(tensor.tensor.maximum(min.elem::()).unwrap()) } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { CandleTensor::new( tensor .tensor .clamp(min.elem::(), max.elem::()) .unwrap(), ) } fn float_recip(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.recip().unwrap()) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { //broadcast_pow is in main but not yet published //note: probably replace once pow once 0.3.3 is out //see: https://github.com/huggingface/candle/pull/1583/files#diff-6319fa1e16dadc4c7b4e25698139703d93b70f30a1f8e2ac0999978e39efaa81R2594 CandleTensor::new( rhs.tensor .broadcast_mul(&lhs.tensor.log().unwrap()) .unwrap() .exp() .unwrap(), ) } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { super::base::permute(tensor, axes) } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { super::base::flip(tensor, axes) } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { expand(tensor, shape) } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { unfold(tensor, dim, size, step) } fn float_sign(tensor: FloatTensor) -> FloatTensor { sign(tensor) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { let dtype = dtype.into_dtype(); if tensor.tensor.dtype() == dtype { tensor } else { CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap()) } } } ================================================ FILE: crates/burn-candle/src/ops/transaction.rs ================================================ use burn_backend::{ Backend, ops::{TransactionOps, TransactionPrimitive}, }; use crate::{ Candle, element::{FloatCandleElement, IntCandleElement}, }; impl TransactionOps for Candle {} ================================================ FILE: crates/burn-candle/src/ops/utils.rs ================================================ /// Helper function for cumulative operations in Candle backend /// /// This function reduces code duplication for cumulative operations (cumprod, cummin, cummax) /// which all follow the same pattern of slicing, applying an operation, and concatenating. /// /// # Arguments /// /// * `tensor` - The input tensor /// * `dim` - The dimension along which to apply the cumulative operation /// * `op` - A closure that takes two tensor references and produces a result tensor pub fn cumulative_with_op(tensor: &candle_core::Tensor, dim: usize, op: F) -> candle_core::Tensor where F: Fn(&candle_core::Tensor, &candle_core::Tensor) -> candle_core::Result, { let dim_size = tensor.dims()[dim]; let mut slices = Vec::with_capacity(dim_size); // First slice is the initial value slices.push(tensor.narrow(dim, 0, 1).unwrap()); // Apply cumulative operation for i in 1..dim_size { let curr = tensor.narrow(dim, i, 1).unwrap(); let result = op(&slices[i - 1], &curr).unwrap(); slices.push(result); } candle_core::Tensor::cat(&slices, dim).unwrap() } ================================================ FILE: crates/burn-candle/src/tensor.rs ================================================ use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme}; use burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata}; use burn_std::BoolStore; use crate::{CandleDevice, element::CandleElement}; /// A tensor that uses the candle backend. #[derive(Debug, Clone)] pub struct CandleTensor { pub(crate) tensor: candle_core::Tensor, } impl TensorMetadata for CandleTensor { fn dtype(&self) -> DType { match self.tensor.dtype() { candle_core::DType::U8 => DType::U8, candle_core::DType::U32 => DType::U32, candle_core::DType::I64 => DType::I64, candle_core::DType::BF16 => DType::BF16, candle_core::DType::F16 => DType::F16, candle_core::DType::F32 => DType::F32, candle_core::DType::F64 => DType::F64, candle_core::DType::I16 => DType::I16, candle_core::DType::I32 => DType::I32, other => todo!("{other:?} not yet supported"), } } fn shape(&self) -> Shape { Shape::from(self.tensor.dims().to_vec()) } fn rank(&self) -> usize { self.tensor.dims().len() } } impl QTensorPrimitive for CandleTensor { fn scheme(&self) -> &QuantScheme { unimplemented!("Quantization is not supported") } } impl CandleTensor { /// Create a new tensor. pub fn new(tensor: candle_core::Tensor) -> Self { Self { tensor } } /// Creates a new tensor from data and a device. /// /// # Arguments /// /// * `data` - The tensor's data. /// * `device` - The device on which the tensor will be allocated. /// /// # Returns /// /// A new tensor. pub fn from_data(data: TensorData, device: CandleDevice) -> Self { let candle_shape: candle_core::Shape = data.shape.to_vec().into(); let tensor = candle_core::Tensor::from_slice( data.as_slice::().unwrap(), candle_shape, &device.into(), ); Self::new(tensor.unwrap()) } } pub(crate) trait IntoDType { fn try_into_dtype(self) -> Result; fn into_dtype(self) -> candle_core::DType where Self: Sized, { self.try_into_dtype().unwrap() } } impl IntoDType for IntDType { fn try_into_dtype(self) -> Result { let dtype: DType = self.into(); dtype.try_into_dtype() } } impl IntoDType for FloatDType { fn try_into_dtype(self) -> Result { let dtype: DType = self.into(); dtype.try_into_dtype() } } impl IntoDType for DType { fn try_into_dtype(self) -> Result { match self { DType::F64 => Ok(candle_core::DType::F64), DType::F32 => Ok(candle_core::DType::F32), DType::Flex32 => Ok(candle_core::DType::F32), DType::F16 => Ok(candle_core::DType::F16), DType::BF16 => Ok(candle_core::DType::BF16), DType::I64 => Ok(candle_core::DType::I64), DType::U32 => Ok(candle_core::DType::U32), DType::U8 => Ok(candle_core::DType::U8), DType::I16 => Ok(candle_core::DType::I16), DType::I32 => Ok(candle_core::DType::I32), DType::Bool(BoolStore::U8) => Ok(candle_core::DType::U8), _ => Err(candle_core::Error::Msg(format!( "Unsupported dtype {self:?}" ))), } } } ================================================ FILE: crates/burn-collective/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Backend extension for collective calculations." edition.workspace = true keywords = ["deep-learning", "machine-learning", "collective"] license.workspace = true name = "burn-collective" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-collective" documentation = "https://docs.rs/burn-collective" version.workspace = true [lints] workspace = true [features] default = [] doc = [] tracing = [ "dep:tracing", "burn-std/tracing", "burn-tensor/tracing", "burn-communication/tracing", "burn-ndarray?/tracing", "burn-wgpu?/tracing", "burn-cuda?/tracing", ] orchestrator = ["burn-communication/websocket"] # Backends for testing test-ndarray = ["burn-ndarray"] test-wgpu = ["burn-wgpu", "burn-wgpu/webgpu"] test-metal = ["burn-wgpu", "burn-wgpu/metal"] test-vulkan = ["burn-wgpu", "burn-wgpu/vulkan"] test-cuda = ["burn-cuda"] [dependencies] burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = true } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true } log = { workspace = true } burn-communication = { path = "../burn-communication", version = "=0.21.0-pre.2", features = [ "data-service", "websocket", ] } tokio = { workspace = true, features = [ "rt-multi-thread", "sync", "signal", "time", "tracing", ] } serde = { workspace = true, features = ["derive"] } rmp-serde = { workspace = true } bytes = { workspace = true } futures = { workspace = true } tokio-util = { workspace = true } tracing = { workspace = true, optional = true } # Tests burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true } burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true } [dev-dependencies] serial_test = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-collective/README.md ================================================ # burn-collective Collective operations on tensors The following collective operation are implemented: - `all-reduce` Aggregates a tensor between all peers, and distributes the result to all peers. Different strategies can be used on the local and global levels. The result can only be returned when all peers have called the all-reduce. - `reduce` Aggregates a tensor from all peers onto one peer, called the "root" - `broadcast` Copies a tensor from one peer to all other peers in the collective. Peers must call `register` before calling any other operation. The total number of devices on the node, or nodes in the collective, must be known ahead of time. In many libraries like NCCL and PyTorch, participating units are called "ranks". This name is confusing in the context of tensors, so in burn-collective the participating units are called "peers". *`reduce` and `broadcast` are not yet implemented for multi-node contexts* ## Local and Global Internally, there are two levels to the collective operations: local and global. Operations are done on the local level, then optionally on the global level. | Local | Global | |-----------------------------------------------|-----------------------------------------------| | Intra-node (typically within one machine) | Inter-node (typically across machies) | | Participants are threads (one per peer/GPU) | Participants are processes (one per node) | | Communication depends on backend | Network peer-to-peer communication | | Local server is launched automatically | Global coordinator must be launched manually | | Local server does the aggregation | Nodes do the operations themselves | For global operations (ie. with multiple nodes), there must be a global orchestrator available. Start one easily with `burn_collective::start_global_orchestrator()`. On the global level, nodes use the `burn_communication::data_service::TensorDataService` to expose and download tensors in a peer-to-peer manner, in order to be independent. ## Components The following are the important pieces of the collective operations system. | Term | One per... | Meaning |--------------------------------|---------------|---------------------------------------------------------- | Local Collective Client | Peer/thread | Requests operations to the Local Collective Server | Local Collective Server | Node/process | Does local-level ops for threads in this process. In the case of global operations, passes operations on to the Global Collective Client. | Global Collective Client | Node/process | Does global-level ops for this node. Registers and requests strategies from the Global Collective Orchestrator. | Global Collective Orchestrator | Collective | Responds to the Global Collective Client from each node. Responsible for aggregation strategies. ## Strategies Different strategies can be used on the local and global level. ### Centralized An arbitrary peer is designated as the "root", and all others are transferred to the root's device. The operation is done on that device. The resulting tensor then sent to each peer. ### Tree Tensors in groups of N are aggregated together. This is done recursively until only one tensor remains. The strategy tries to put devices of the same type closer in the tree. When N=2, this is like a binary tree reduce. The resulting tensor then sent to each peer ### Ring See this good explanation: The tensors are sliced into N parts, where N is the number of tensors to aggregate. Then, the slices are sent around in a series of cycles and aggregated until every tensor's slices is a sum of the other corresponding slices. In the case where the tensors are too small to split into N slices, a fallback algorithm is used. For now, the fallback is a binary tree. (p=3, n=3) o->o o o o->o o o o-> o 1->o o o 1-> 1->o o o 1 2-> 2->o 1 1 2->o 3 1 2 2 3 1 1 2 3 (This is essentially a reduce-scatter) 3->x x x 3->x x x 3-> 3 3->x x 3 3-> 3->x 3 3 3 3-> 3->3 3 3 3->3 3 3 3 3 3 3 3 3 3 (This is essentially an all-gather) This is done so that every peer is both sending and receiving data at any moment. This is an important part of this strategy's advantages. The ring strategy takes full advantage of the bandwidth available. The latency scales with the number of peers. So when the tensors are very small, or when the number of peers is very large, the latency is more important in the ring strategy, and a tree algorithm is better. Otherwise, the ring algorithm is the better. In multi-node contexts, use of the Ring strategy in the local level may be less advantageous. With multiple nodes, the global all-reduce step is enabled, and its result is redistributed to all devices. The Ring strategy inherently distributes the result, which in this context would not be necessary. It is recommended to use the Ring strategy at the global level ### Double binary tree ================================================ FILE: crates/burn-collective/multinode-tests/Cargo.toml ================================================ [package] name = "burn-collective-multinode-tests" version.workspace = true edition.workspace = true license.workspace = true [features] default = ["ndarray"] ndarray = ["burn/ndarray"] [dependencies] burn = { path = "../../burn", default-features = false, features = ["std"] } burn-std = { path = "../../burn-std", default-features = false } burn-collective = { path = "..", features = ["orchestrator"] } burn-communication = { path = "../../burn-communication" } tokio = { workspace = true, features = ["rt-multi-thread", "process"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } interprocess = "2.3.1" rmp-serde = { workspace = true } tokio-util = { workspace = true, features = ["codec"] } tokio-serde = { version = "0.9.0", features = ["messagepack"] } futures = { workspace = true } [[bin]] name = "global" path = "src/bin/global.rs" [[bin]] name = "node" path = "src/bin/node.rs" ================================================ FILE: crates/burn-collective/multinode-tests/README.md ================================================ # Integration test for burn collective operations with multiple nodes and devices. Run `cargo run --bin test_launcher` There are 3 binaries: ## node.rs Launches `n` threads each simulating a different device. Currently the backend is NdArray, so everything is CPU. The program takes a file with configurations and input data. ## global.rs Runs the global orchestrator, who is responsible for responding to global collective operation requests. In the case of an all-reduce, the orchestrator responds with a strategy for reducing, and the node can do the reduction independently. ## test_launcher.rs Generates input data, calculates the expected results, and launches the nodes each with their own inputs in a separate file. The topology is [4, 4, 4, 4]. This means 4 nodes are launched, each with 4 threads (for each device). The global orchestrator (`global.rs`) is also launched. ## Output The outputs and inputs for each node and the orchestrator are written to the `target/test_files` folder If the nodes or orchestrator stall, there is a timeout. ================================================ FILE: crates/burn-collective/multinode-tests/src/bin/global.rs ================================================ //! Global orchestrator //! //! Launches the orchestrator that responds to global collective operations for nodes for the //! integration test //! //! This is necessary for any node who needs global collective operations use std::env; #[tokio::main] /// Start the global orchestrator on the port given as first arg pub async fn main() { let args: Vec = env::args().collect(); let port = args[1].parse::().expect("invalid port"); // Launch the global orchestrator, which will listen and respond to global collective op // requests from nodes burn_collective::start_global_orchestrator(port).await; } ================================================ FILE: crates/burn-collective/multinode-tests/src/bin/node.rs ================================================ use burn::{ backend::NdArray, prelude::Backend, tensor::{Tensor, TensorPrimitive, Tolerance}, }; use burn_collective::{ CollectiveConfig, PeerId, ReduceOperation, all_reduce, finish_collective, register, reset_collective, }; use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK}; use std::{ env, sync::mpsc::SyncSender, time::{Duration, Instant}, }; use tokio::net::TcpStream; use futures::{SinkExt, StreamExt}; use std::thread::JoinHandle; use tokio_serde::formats::MessagePack; use tokio_util::codec::LengthDelimitedCodec; type TestBackend = NdArray; /// Framed TCP connection channel type TestChannel = tokio_serde::Framed< tokio_util::codec::Framed, NodeTest, NodeTestResult, MessagePack, >; /// Start a node that will test all-reduce /// Args are the following: /// - launcher endpoint #[tokio::main] pub async fn main() { let args: Vec = env::args().collect(); let launcher_addr = args[1].clone(); let socket = TcpStream::connect(launcher_addr).await.unwrap(); let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new()); let mut socket: TestChannel = tokio_serde::Framed::new( length_delimited, MessagePack::::default(), ); // Loop: receive, do test, send result while let Some(Ok(test)) = socket.next().await { println!("Received test: {test:?}"); let result = run_test::(&test); // send the result back socket.send(result).await.expect("failed to send Result"); } println!("Server closed connection; exiting."); } /// Runs a test for one node fn run_test(test_input: &NodeTest) -> NodeTestResult { reset_collective::(); // Channel for results let (result_send, result_recv) = std::sync::mpsc::sync_channel(32); // Launch a thread for each "device" let handles = launch_threads::(test_input.clone(), result_send); // Receive results let mut durations = vec![]; let tol: Tolerance = Tolerance::balanced(); for _ in 0..test_input.device_count { // Assert all results are equal to each other as well as expected result let (tensor, duration) = result_recv.recv().unwrap(); test_input.expected.assert_approx_eq(&tensor.to_data(), tol); durations.push(duration); } // Threads finish for handle in handles { let _ = handle.join(); } NodeTestResult { success: true, durations, } } /// Launch a thread for each device, and run the all-reduce fn launch_threads( test_input: NodeTest, result_send: SyncSender<(Tensor, Duration)>, ) -> Vec> { let mut handles = vec![]; for id in 0..test_input.device_count { // Launch a thread to test // Put all the parameters in the config let config = CollectiveConfig::default() .with_num_devices(test_input.device_count) .with_global_address(test_input.global_address.clone()) .with_node_address(test_input.node_address.clone()) .with_data_service_port(test_input.data_service_port) .with_num_nodes(test_input.node_count) .with_global_all_reduce_strategy(test_input.global_strategy) .with_local_all_reduce_strategy(test_input.local_strategy); // Inputs and outputs for the test let tensor_data = test_input.inputs[id].clone(); let tensor = Tensor::::from_data(tensor_data, &B::Device::default()); let result_send = result_send.clone(); let handle = std::thread::spawn(move || { run_peer::( id.into(), config, tensor, result_send, test_input.all_reduce_op, ) }); handles.push(handle); } handles } /// Runs a thread in the all-reduce test. pub fn run_peer( id: PeerId, config: CollectiveConfig, input: Tensor, output: SyncSender<(Tensor, Duration)>, all_reduce_op: ReduceOperation, ) { // Register the device register::(id, input.device(), config).unwrap(); let start = Instant::now(); // All-reduce let input = input.into_primitive().tensor(); let tensor = all_reduce::(id, input, all_reduce_op).unwrap(); let tensor = Tensor::::from_primitive(TensorPrimitive::Float(tensor)); let duration = start.elapsed(); // Send result output.send((tensor, duration)).unwrap(); finish_collective::(id).unwrap(); } ================================================ FILE: crates/burn-collective/multinode-tests/src/bin/test_launcher.rs ================================================ use burn::tensor::TensorData; use burn_communication::Address; use futures::{SinkExt, StreamExt}; use std::{ fmt::Display, fs::{self, File}, str::FromStr, time::{Duration, Instant}, vec, }; use tokio::net::TcpListener; use tokio_serde::formats::MessagePack; use tokio_util::codec::LengthDelimitedCodec; use burn::{backend::NdArray, prelude::Backend, tensor::Tensor}; use burn_collective::{AllReduceStrategy, ReduceOperation}; use burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK}; use burn_std::rand::{SeedableRng, StdRng}; use tokio::process::{Child, Command}; #[derive(Clone)] struct AllReduceTest { shape: [usize; TENSOR_RANK], op: ReduceOperation, local_strategy: AllReduceStrategy, global_strategy: AllReduceStrategy, } impl Display for AllReduceTest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let op_str = match self.op { ReduceOperation::Sum => "sum", ReduceOperation::Mean => "mean", }; let local_strategy_str = match self.local_strategy { AllReduceStrategy::Centralized => "local_centralized", AllReduceStrategy::Tree(n) => &format!("local_tree_{n}"), AllReduceStrategy::Ring => "local_ring", }; let global_strategy_str = match self.global_strategy { AllReduceStrategy::Centralized => "global_centralized", AllReduceStrategy::Tree(n) => &format!("global_tree_{n}"), AllReduceStrategy::Ring => "global_ring", }; write!(f, "{op_str}_{local_strategy_str}_{global_strategy_str}") } } /// Framed TCP connection for sending tests and receiving results type TestChannel = tokio_serde::Framed< tokio_util::codec::Framed, NodeTestResult, NodeTest, MessagePack, >; /// Handle for each node process struct NodeProcessHandle { process: Child, channel: TestChannel, } /// Main function to run the multi-node all-reduce test. /// Launches a orchestrator and multiple nodes based on the provided topology. #[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() { let all_reduce_tests = vec![ AllReduceTest { shape: [4, 64, 512], op: ReduceOperation::Mean, local_strategy: AllReduceStrategy::Tree(2), global_strategy: AllReduceStrategy::Tree(2), }, AllReduceTest { shape: [4, 64, 512], op: ReduceOperation::Mean, local_strategy: AllReduceStrategy::Tree(2), global_strategy: AllReduceStrategy::Ring, }, AllReduceTest { shape: [4, 64, 512], op: ReduceOperation::Mean, local_strategy: AllReduceStrategy::Centralized, global_strategy: AllReduceStrategy::Centralized, }, ]; let test_files_dir = "target/test_files"; fs::create_dir_all(test_files_dir).expect("Couldn't create test_files directory"); let topology: Vec = vec![4; 4]; let mut orchestrator = launch_orchestrator(test_files_dir); let launcher_endpoint = "127.0.0.1:4000"; // Build and run node processes let mut all_tests_durations = vec![]; if let Ok(mut nodes) = launch_nodes(&topology, launcher_endpoint).await { // Run one test for test in all_reduce_tests.clone() { let test_name = test.to_string(); let time = test_all_reduce_centralized_no_collective::(&topology, test.clone()); println!( "{test_name}: Benchmark (no collective, centralized, single-threaded): {} secs", time.as_secs_f32() ); match test_all_reduce(&topology, test, &mut nodes).await { Err(node_idx) => { println!("{test_name}: Node with index {node_idx} failed!"); // Kill other node processes for mut node in nodes.drain(..) { node.process.kill().await.unwrap(); node.process.wait().await.unwrap(); } break; } Ok(durations) => { all_tests_durations.append(&mut durations.clone()); let avg = durations.iter().map(|dur| dur.as_secs_f32()).sum::() / durations.len() as f32; println!("{test_name}: Success in {avg} secs"); } } } } if !all_tests_durations.is_empty() { let avg = all_tests_durations .iter() .map(|dur| dur.as_secs_f32()) .sum::() / all_tests_durations.len() as f32; println!("Average for all tests: {avg} secs"); } // Shutdown orchestrator orchestrator.kill().await.unwrap(); orchestrator.wait().await.unwrap(); } /// Launch a global orchestrator with an output file in the given directory. /// Necessary for global collective operations /// /// Server listens on localhost port 3000 fn launch_orchestrator(test_files_dir: &str) -> Child { let out_path = format!("{test_files_dir}/orchestrator_out.txt"); let out = File::create(out_path).expect("Could't create orchestrator output file"); Command::new("cargo") .args(["run", "--bin", "global", "--", "3000"]) .stdout(out.try_clone().unwrap()) .stderr(out) .spawn() .expect("failed to launch orchestrator") } /// Launch nodes for a all_reduce test /// Each node will connect to the global orchestrator and run an all-reduce operation. /// The topology is a vector where each element represents the number of devices in that node. async fn launch_nodes( topology: &[usize], launcher_endpoint: &str, ) -> Result, ()> { println!( "Launching {} nodes with topology: {:?}", topology.len(), topology ); // Listen for node connections let listener = TcpListener::bind(launcher_endpoint).await.unwrap(); println!("Server listening on {launcher_endpoint}"); let mut nodes = vec![]; for node_idx in 0..topology.len() { // Create log file let output_filename = format!("target/test_files/node_{}_log.txt", node_idx + 1); let out = File::create(output_filename).expect("Could't open node log file"); // Start a process for each node. Pass on our feature flags let node_process: Child = Command::new("cargo") .args([ "run", "--release", "--features", #[cfg(feature = "ndarray")] "ndarray", "--bin", "node", "--", launcher_endpoint, &node_idx.to_string(), ]) .stdout(out.try_clone().unwrap()) .stderr(out) .spawn() .expect("node failed"); // Wait for child to connect for io let (socket, _peer_addr) = listener.accept().await.unwrap(); let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new()); let channel: TestChannel = tokio_serde::Framed::new( length_delimited, MessagePack::::default(), ); nodes.push(NodeProcessHandle { process: node_process, channel, }); } Ok(nodes) } async fn test_all_reduce( topology: &[usize], test: AllReduceTest, nodes: &mut [NodeProcessHandle], ) -> Result, usize> { dispatch_all_reduce_test(topology, test, nodes).await; let mut all_durations = vec![]; for (idx, handle) in nodes.iter_mut().enumerate() { match handle.channel.next().await { Some(Ok(mut result)) => { if !result.success { return Err(idx); } all_durations.append(&mut result.durations); } _ => { return Err(idx); } } } Ok(all_durations) } async fn dispatch_all_reduce_test( topology: &[usize], test: AllReduceTest, nodes: &mut [NodeProcessHandle], ) { let total_device_count: usize = topology.iter().sum(); let (mut all_inputs, expected) = generate_random_input(test.shape, test.op, total_device_count, 42); // URL for the global orchestrator on port 3000 let global_url = "ws://localhost:3000"; let global_address = Address::from_str(global_url).unwrap(); for (node_idx, &device_count) in topology.iter().enumerate() { // Construct URL for node // Ports 3001... are for each node let data_service_port = node_idx as u16 + 3001; let node_url = format!("ws://localhost:{data_service_port}"); let node_address = Address::from_str(&node_url).unwrap(); // take input tensors for each device let inputs = all_inputs[0..device_count].to_vec(); all_inputs = all_inputs[device_count..].to_vec(); let test = NodeTest { device_count, node_id: node_idx.into(), node_count: topology.len() as u32, global_address: global_address.clone(), node_address, data_service_port, all_reduce_op: test.op, global_strategy: test.global_strategy, local_strategy: test.local_strategy, inputs, expected: expected.clone(), }; let handle = &mut nodes[node_idx]; handle.channel.send(test).await.unwrap(); } assert!( all_inputs.is_empty(), "Not all inputs have been sent to tests" ); } /// Run the test sequentially with no collective operations to get the optimal single-threaded speed fn test_all_reduce_centralized_no_collective( topology: &[usize], test: AllReduceTest, ) -> Duration { let total_device_count: usize = topology.iter().sum(); let (all_inputs, _expected) = generate_random_input(test.shape, test.op, total_device_count, 42); let mut all_inputs = all_inputs .into_iter() .map(|data| Tensor::::from_data(data, &B::Device::default())); let start = Instant::now(); // Sequential test let mut result = all_inputs.next().unwrap(); for other in all_inputs { result = result.add(other); } if test.op == ReduceOperation::Mean { result.div_scalar(total_device_count as u32); } start.elapsed() } /// Generates random input tensors and expected output based on the provided shape and reduce kind. fn generate_random_input( shape: [usize; 3], reduce_kind: ReduceOperation, input_count: usize, seed: u64, ) -> (Vec, TensorData) { let mut rng = StdRng::seed_from_u64(seed); // A random tensor for each device let input: Vec = (0..input_count) .map(|_| { TensorData::random::(shape, burn::tensor::Distribution::Default, &mut rng) }) .collect(); // Sum up the inputs let device = ::Device::default(); let mut expected_tensor = Tensor::::zeros(shape, &device); for item in input.iter().take(input_count) { let input_tensor = Tensor::::from_data(item.clone(), &device); expected_tensor = expected_tensor.add(input_tensor); } if reduce_kind == ReduceOperation::Mean { expected_tensor = expected_tensor.div_scalar(input_count as u32); } // All-Reduce results should have this value let expected = expected_tensor.to_data(); (input, expected) } ================================================ FILE: crates/burn-collective/multinode-tests/src/lib.rs ================================================ pub mod shared; ================================================ FILE: crates/burn-collective/multinode-tests/src/shared.rs ================================================ use std::time::Duration; use burn::tensor::TensorData; use burn_collective::{AllReduceStrategy, NodeId, ReduceOperation}; use burn_communication::Address; use serde::{Deserialize, Serialize}; /// Ranks of inputs and outputs for all testing pub const TENSOR_RANK: usize = 3; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeTest { /// How many threads to start on this node pub device_count: usize, /// ID for this node pub node_id: NodeId, /// How many nodes in the cluster pub node_count: u32, /// Global server address pub global_address: Address, /// Node address pub node_address: Address, /// Node's data service port, for initializing the p2p tensor data service pub data_service_port: u16, /// What kind of all-reduce pub all_reduce_op: ReduceOperation, /// Node's data service port, for initializing the p2p tensor data service pub global_strategy: AllReduceStrategy, /// What kind of aggregation pub local_strategy: AllReduceStrategy, /// Input data for test: all tensors are D=3 pub inputs: Vec, /// Expected output for test pub expected: TensorData, } /// Result sent back from each node for each test #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeTestResult { pub success: bool, pub durations: Vec, } ================================================ FILE: crates/burn-collective/src/api.rs ================================================ use burn_tensor::backend::Backend; use crate::{ CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError, local::server::get_collective_client, }; /// Errors from collective operations #[allow(unused)] #[derive(Debug, Clone)] pub enum CollectiveError { /// The [config](CollectiveConfig) was invalid. /// Usually happens if only some global parameters have been defined InvalidConfig, /// Cannot un-register a node twice MultipleUnregister, /// Cannot register a node twice MultipleRegister, /// Trying to register a different way than is currently being done RegisterParamsMismatch, /// Trying to all-reduce tensors of different shapes: shape must match AllReduceShapeMismatch, /// Trying to all-reduce a different way than is currently being done: op must match AllReduceOperationMismatch, /// Trying to reduce tensors of different shapes: shape must match ReduceShapeMismatch, /// Trying to reduce a different way than is currently being done: op must match ReduceOperationMismatch, /// Trying to reduce with different roots ReduceRootMismatch, /// Trying to broadcast with different roots BroadcastRootMismatch, /// Trying to broadcast but no peer sent a tensor BroadcastNoTensor, /// Trying to broadcast but multiple peers sent a tensor BroadcastMultipleTensors, /// Local collective server couldn't respond LocalServerMissing, /// Another operation was called before Register RegisterNotFirstOperation, /// The global orchestrator had an error Global(GlobalCollectiveError), #[allow(unused)] Other(String), } /// Registers a device. `num_devices` must be the same for every register, /// and `device_id` must be unique. /// /// * `id` - The peer id of the caller /// /// With auto-diff backends, make sure to use the inner backend. pub fn register( id: PeerId, device: B::Device, config: CollectiveConfig, ) -> Result<(), CollectiveError> { log::info!("Registering peer {id} with config: {config}"); let mut client = get_collective_client::(); client.register(id, device, config) } /// Calls for an all-reduce operation with the given parameters, and returns the result. /// The `params` must be the same as the parameters passed by the other nodes. /// /// * `id` - The peer id of the caller /// * `tensor` - The input tensor to reduce with the peers' tensors /// * `config` - Config of the collective operation, must be coherent with the other calls pub fn all_reduce( id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, ) -> Result { let client = get_collective_client::(); client.all_reduce(id, tensor, op) } /// Broadcasts, or receives a broadcasted tensor. /// /// * `id` - The peer id of the caller /// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive /// the broadcasted tensor. /// /// Returns the broadcasted tensor. pub fn broadcast( id: PeerId, tensor: Option, ) -> Result { let client = get_collective_client::(); client.broadcast(id, tensor) } /// Reduces a tensor onto one device. /// /// * `id` - The peer id of the caller /// * `tensor` - The tensor to send as input /// * `root` - The ID of the peer that will receive the result. /// /// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor. pub fn reduce( id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, root: PeerId, ) -> Result, CollectiveError> { let client = get_collective_client::(); client.reduce(id, tensor, op, root) } /// Closes the collective session, unregistering the device pub fn finish_collective(id: PeerId) -> Result<(), CollectiveError> { let client = get_collective_client::(); client.finish(id) } /// Resets the local collective server. All registered callers and ongoing operations are forgotten pub fn reset_collective() { let client = get_collective_client::(); client.reset(); } ================================================ FILE: crates/burn-collective/src/config.rs ================================================ use std::fmt::Display; use burn_communication::Address; use serde::{Deserialize, Serialize}; /// Parameter struct for setting up and getting parameters for collective operations. /// Used in most collective api calls. /// This config is per-node. It is passed to [reduce](crate::register). #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CollectiveConfig { pub(crate) num_devices: usize, pub(crate) local_all_reduce_strategy: AllReduceStrategy, pub(crate) local_reduce_strategy: ReduceStrategy, pub(crate) local_broadcast_strategy: BroadcastStrategy, // Global parameters (all are optional, but if one is defined they should all be) pub(crate) num_nodes: Option, pub(crate) global_address: Option
, pub(crate) node_address: Option
, pub(crate) data_service_port: Option, // These strategies may be defined when no other global params are defined pub(crate) global_all_reduce_strategy: Option, pub(crate) global_reduce_strategy: Option, pub(crate) global_broadcast_strategy: Option, } impl Default for CollectiveConfig { fn default() -> Self { Self::new() } } impl Display for CollectiveConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let num_devices = self.num_devices; let local_all_reduce_strategy = self.local_all_reduce_strategy; let local_reduce_strategy = self.local_reduce_strategy; let local_broadcast_strategy = self.local_broadcast_strategy; let num_nodes = self.num_nodes; let global_address = &self.global_address; let node_address = &self.node_address; let data_service_port = self.data_service_port; let global_all_reduce_strategy = self.global_all_reduce_strategy; let global_reduce_strategy = self.global_reduce_strategy; let global_broadcast_strategy = self.global_broadcast_strategy; write!( f, r#" CollectiveConfig {{ num_devices: {num_devices:?}, local_all_reduce_strategy: {local_all_reduce_strategy:?}, local_reduce_strategy: {local_reduce_strategy:?}, local_broadcast_strategy: {local_broadcast_strategy:?}, num_nodes: {num_nodes:?}, global_address: {global_address:?}, node_address: {node_address:?}, data_service_port: {data_service_port:?}, global_all_reduce_strategy: {global_all_reduce_strategy:?}, global_reduce_strategy: {global_reduce_strategy:?}, global_broadcast_strategy: {global_broadcast_strategy:?}, }} "# ) } } impl CollectiveConfig { fn new() -> Self { Self { num_devices: 1, local_all_reduce_strategy: AllReduceStrategy::Tree(2), local_reduce_strategy: ReduceStrategy::Tree(2), local_broadcast_strategy: BroadcastStrategy::Tree(2), num_nodes: None, global_address: None, node_address: None, data_service_port: None, global_all_reduce_strategy: Some(AllReduceStrategy::Ring), global_reduce_strategy: Some(ReduceStrategy::Tree(2)), global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)), } } /// Selects the number of devices (local peers) on the current node pub fn with_num_devices(mut self, num: usize) -> Self { self.num_devices = num; self } /// Selects an all-reduce strategy to use on the local level. /// /// In multi-node contexts, use of the Ring strategy in the local level may be less /// advantageous. With multiple nodes, the global all-reduce step is enabled, and its result /// is redistributed to all devices. /// The Ring strategy inherently distributes the result, which in this context would not be /// necessary. /// /// It is recommended to use a tree strategy locally, and a ring strategy globally. pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self { self.local_all_reduce_strategy = strategy; self } /// Selects a reduce strategy to use on the local level. pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self { self.local_reduce_strategy = strategy; self } /// Selects a broadcast strategy to use on the local level. pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self { self.local_broadcast_strategy = strategy; self } /// Set the number of nodes in the collective /// /// This parameter is a global parameter and should only be set in multi-node contexts pub fn with_num_nodes(mut self, n: u32) -> Self { self.num_nodes = Some(n); self } /// Set the network address of the Global Collective Orchestrator /// /// This parameter is a global parameter and should only be set in multi-node contexts pub fn with_global_address(mut self, addr: Address) -> Self { self.global_address = Some(addr); self } /// Define the address for this node /// /// This parameter is a global parameter and should only be set in multi-node contexts pub fn with_node_address(mut self, addr: Address) -> Self { self.node_address = Some(addr); self } /// Selects the network port on which to expose the tensor data service /// used for peer-to-peer tensor downloading. /// /// This parameter is a global parameter and should only be set in multi-node contexts pub fn with_data_service_port(mut self, port: u16) -> Self { self.data_service_port = Some(port); self } /// Selects an all-reduce strategy to use on the global level. /// /// This parameter is a global parameter and should only be set in multi-node contexts. /// See [the local strategy](Self::with_local_all_reduce_strategy) pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self { self.global_all_reduce_strategy = Some(strategy); self } /// Selects an reduce strategy to use on the global level. /// /// This parameter is a global parameter and should only be set in multi-node contexts. /// See [the local strategy](Self::with_local_reduce_strategy) pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self { self.global_reduce_strategy = Some(strategy); self } /// Selects an broadcst strategy to use on the global level. /// /// This parameter is a global parameter and should only be set in multi-node contexts. /// See [the local strategy](Self::with_local_broadcast_strategy) pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self { self.global_broadcast_strategy = Some(strategy); self } /// Returns whether the config is valid. If only some required global-level parameters are /// defined and others are not, the config is invalid. pub fn is_valid(&self) -> bool { match ( self.num_nodes, &self.global_address, &self.node_address, self.data_service_port, ) { (None, None, None, None) => true, (Some(_), Some(_), Some(_), Some(_)) => true, // Global parameters have only been partially defined! _ => false, } } /// Return the global parameters for registering in a multi-node context. /// /// If only some global parameters are defined, returns None. Use [is_valid](Self::is_valid) to check for /// validity in this case. pub(crate) fn global_register_params(&self) -> Option { match ( self.num_nodes, &self.global_address, &self.node_address, self.data_service_port, ) { // Only local collective (None, None, None, None) => None, // Local + global collective (Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => { Some(GlobalRegisterParams { num_nodes, global_address: global_addr.clone(), node_address: node_addr.clone(), data_service_port, }) } // Config is invalid! _ => None, } } } /// Helper struct for parameters in a multi-node register operation. Either they are all defined, /// or all not defined. Passed to the global client for registering on the global level and /// opening the p2p tensor service. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GlobalRegisterParams { /// The address for the connection to the global orchestrator. pub global_address: Address, /// The address for the connection to this node. pub node_address: Address, /// The port on which to open the tensor data service for peer-to-peer tensor transfers with /// other nodes. Should match the port given in the node url. pub data_service_port: u16, /// The number of nodes globally. Should be the same between different nodes pub num_nodes: u32, } /// Parameters for an all-reduce that should be the same between all devices #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct SharedAllReduceParams { pub op: ReduceOperation, pub local_strategy: AllReduceStrategy, pub global_strategy: Option, } /// Parameters for a reduce that should be the same between all devices #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct SharedReduceParams {} /// Parameters for a broadcast that should be the same between all devices #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub struct SharedBroadcastParams { pub op: ReduceOperation, pub local_strategy: BroadcastStrategy, pub global_strategy: Option, } /// Reduce can be done different ways #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub enum ReduceOperation { Sum, Mean, } /// All reduce can be implemented with different algorithms, which all have the same result. #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum AllReduceStrategy { /// One device is the "central". The other devices, "peripherals", send their tensors to the /// central. The central does the reduction, and sends the result back to each peripheral. Centralized, /// Devices are organized in a tree structure (with a given arity). Each node reduces its /// children's tensors with its own, and sends the result to its parent. Leaf nodes will /// simply send their tensors to their parents. /// When the root node calculates the result, it is propagated down the tree. Tree(u32), /// Devices are organized in a ring. The tensors are split into N slices, where N is the /// number of devices participating. The slices are progressively sent around the ring until /// every device has one fully reduced slice of the tensor. Then, the resulting slices are sent /// around until every device has the full result. /// See `ring.rs` for details. Ring, } /// Reduce can be implemented with different algorithms, which all have the same result. #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum ReduceStrategy { /// See [all-reduce](AllReduceStrategy::Centralized) Centralized, /// See [all-reduce](AllReduceStrategy::Tree) Tree(u32), } /// Broadcast can be implemented with different algorithms, which all have the same result. #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum BroadcastStrategy { /// See [all-reduce](AllReduceStrategy::Centralized) Centralized, /// See [all-reduce](AllReduceStrategy::Tree) Tree(u32), } /// A unique identifier for a peer in the context of collective operations. /// They must be unique, even in multi-node contexts. /// /// This is like the rank in NCCL #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct PeerId(u32); impl Display for PeerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "PeerId({})", self.0) } } impl From for PeerId { fn from(value: u32) -> Self { Self(value) } } impl From for PeerId { fn from(value: i32) -> Self { Self(value as u32) } } impl From for PeerId { fn from(value: usize) -> Self { Self(value as u32) } } ================================================ FILE: crates/burn-collective/src/global/base.rs ================================================ use serde::{Deserialize, Serialize}; /// Unique identifier for any node in the global collective. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct NodeId(u32); impl From for NodeId { fn from(value: u32) -> Self { Self(value) } } impl From for NodeId { fn from(value: usize) -> Self { Self(value as u32) } } impl From for NodeId { fn from(value: i32) -> Self { Self(value as u32) } } ================================================ FILE: crates/burn-collective/src/global/mod.rs ================================================ pub(crate) mod node; pub(crate) mod shared; #[cfg(feature = "orchestrator")] pub mod orchestrator; #[cfg(feature = "orchestrator")] pub use orchestrator::*; mod base; pub use base::*; ================================================ FILE: crates/burn-collective/src/global/node/base.rs ================================================ use burn_communication::Protocol; use burn_communication::data_service::TensorDataServer; use burn_communication::{Address, ProtocolServer, data_service::TensorDataService}; use burn_tensor::backend::Backend; use std::collections::HashMap; use std::{marker::PhantomData, sync::Arc}; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use crate::node::sync::SyncService; use crate::{ AllReduceStrategy, BroadcastStrategy, GlobalRegisterParams, NodeId, PeerId, ReduceStrategy, }; use crate::{ ReduceOperation, global::{ node::{ centralized::centralized_all_reduce_sum, ring::ring_all_reduce_sum, tree::tree_all_reduce_sum, worker::GlobalClientWorker, }, shared::{GlobalCollectiveError, RemoteRequest, RemoteResponse}, }, local::server::get_collective_server_runtime, }; /// Must be synchronized between all nodes for collective operations to work pub(crate) struct NodeState { pub node_id: NodeId, pub nodes: HashMap, pub num_global_devices: u32, } /// A node talks to the global orchestrator as well as other nodes with a peer-to-peer service pub(crate) struct Node where B: Backend, P: Protocol, { // State is written during `register` and read during other operations, // sometimes by multiple threads (ex. syncing during an all-reduce) state: Arc>>, data_service: Arc>, sync_service: Arc>, worker: GlobalClientWorker, _n: PhantomData

, } impl Node where B: Backend, P: Protocol, { pub fn new(global_address: &Address, comms_server: P::Server) -> Self { let state = Arc::new(tokio::sync::RwLock::new(None)); let cancel_token = CancellationToken::new(); let data_service = Arc::new(TensorDataService::new(cancel_token.clone())); let sync_service = Arc::new(SyncService::new(state.clone())); let runtime = get_collective_server_runtime(); let server = comms_server .route_tensor_data_service(data_service.clone()) .route("/sync", { let sync_service = sync_service.clone(); async move |channel: ::Channel| { sync_service.handle_sync_connection(channel).await; } }) .serve({ let cancel_token = cancel_token.clone(); async move { cancel_token.cancelled().await } }); runtime.spawn(server); let worker = GlobalClientWorker::new(&runtime, cancel_token.clone(), global_address); Self { state, data_service, sync_service, worker, _n: PhantomData, } } pub async fn register( &mut self, peers: Vec, global_params: GlobalRegisterParams, ) -> Result<(), GlobalCollectiveError> { let req = RemoteRequest::Register { node_addr: global_params.node_address, num_nodes: global_params.num_nodes, peers, }; match self.worker.request(req).await { RemoteResponse::Register { node_id, nodes, num_global_devices, } => { let mut state = self.state.write().await; *state = Some(NodeState { node_id, nodes, num_global_devices, }); } RemoteResponse::Error(err) => { return Err(err); } resp => { log::error!("Response to a register request should be an ack, not {resp:?}"); return Err(GlobalCollectiveError::WrongOrchestratorResponse); } } Ok(()) } /// Performs an all-reduce /// /// Reads the NodeState pub async fn all_reduce( &self, tensor: B::FloatTensorPrimitive, strategy: AllReduceStrategy, op: ReduceOperation, ) -> Result { let state = self.state.read().await; let Some(ref state) = *state else { return Err(GlobalCollectiveError::AllReduceBeforeRegister); }; let node = state.node_id; let nodes = &state.nodes; let mut result = match strategy { AllReduceStrategy::Centralized => { centralized_all_reduce_sum( node, nodes, &self.data_service, self.sync_service.clone(), tensor, ) .await? } AllReduceStrategy::Tree(arity) => { tree_all_reduce_sum( node, nodes, self.data_service.clone(), self.sync_service.clone(), tensor, arity, ) .await? } AllReduceStrategy::Ring => { ring_all_reduce_sum( node, nodes, self.data_service.clone(), self.sync_service.clone(), tensor, ) .await? } }; if op == ReduceOperation::Mean { result = B::float_div_scalar(result, (state.num_global_devices as f32).into()); } Ok(result) } pub async fn reduce( &self, _tensor: B::FloatTensorPrimitive, _strategy: ReduceStrategy, _root: PeerId, _op: ReduceOperation, ) -> Result, GlobalCollectiveError> { unimplemented!("Global reduce unimplemented"); } pub async fn broadcast( &self, _tensor: Option, _strategy: BroadcastStrategy, ) -> Result { unimplemented!("Global broadcast unimplemented"); } pub async fn finish(&mut self) { let res = self.worker.close_connection().await; if let Err(err) = res { log::error!("Global collective client error: {err:?}"); } self.data_service.close().await; } } ================================================ FILE: crates/burn-collective/src/global/node/centralized.rs ================================================ use std::{collections::HashMap, sync::Arc}; use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService}; use burn_communication::data_service::TensorDataService; use burn_communication::{Address, Protocol}; use burn_tensor::TensorMetadata; use burn_tensor::backend::Backend; use futures::StreamExt; use futures::stream::FuturesUnordered; /// Global all-reduce, using a centralized strategy. /// /// Returns the resulting tensor on the same device as the input tensor pub(crate) async fn centralized_all_reduce_sum( node: NodeId, nodes: &HashMap, data_service: &Arc>, sync_service: Arc>, tensor: B::FloatTensorPrimitive, ) -> Result where B: Backend, P: Protocol, { let ids = nodes.keys().cloned().collect::>(); let central = get_central_node(ids.clone()); let shape = tensor.shape(); let device = &B::float_device(&tensor); let res = if central == node { // Transfer 1: download tensors from other nodes let mut futures = ids .iter() .filter(|id| **id != central) // Only non-central nodes .map(|id| { let address = nodes.get(id).unwrap(); let device = device.clone(); let data_service = data_service.clone(); async move { let data = data_service .download_tensor((*address).clone(), 0.into()) .await .expect("Couldn't find the tensor for transfer id 0"); B::float_from_data(data, &device) } }) .collect::>(); // Sum all downloads async let mut sum = tensor; while let Some(res) = futures.next().await { if shape != res.shape() { return Err(GlobalCollectiveError::PeerSentIncoherentTensor); } sum = B::float_add(sum, res); } // Transfer 2: Expose result let other_nodes_count = ids.len() as u32 - 1; data_service .expose(sum.clone(), other_nodes_count, 1.into()) .await; sum } else { // Transfer 1: Expose input data_service.expose(tensor, 1, 0.into()).await; // Transfer 2: Download result let central_addr = nodes.get(¢ral).unwrap().clone(); let data = data_service .download_tensor(central_addr, 1.into()) .await .expect("Couldn't find the tensor for transfer id 1"); let res = B::float_from_data(data, device); if shape != res.shape() { return Err(GlobalCollectiveError::PeerSentIncoherentTensor); } res }; // Wait for all nodes to finish sync_service.sync().await; Ok(res) } /// Get the central node for a centralized all-reduce pub(crate) fn get_central_node(mut nodes: Vec) -> NodeId { nodes.sort(); *nodes.first().unwrap() } ================================================ FILE: crates/burn-collective/src/global/node/mod.rs ================================================ pub mod base; pub mod centralized; pub mod ring; pub mod sync; pub mod tree; pub mod worker; ================================================ FILE: crates/burn-collective/src/global/node/ring.rs ================================================ //! Implements the collective ring all-reduce algorithm on the global level use core::ops::Range; use std::{collections::HashMap, sync::Arc}; use crate::{ NodeId, global::shared::GlobalCollectiveError, local::{get_ring_reduce_slice_ranges, get_slice_dim}, node::sync::SyncService, }; use burn_communication::{Address, Protocol, data_service::TensorDataService}; use burn_tensor::{Slice, TensorMetadata, backend::Backend}; // https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model // Example: tensors=3, slices=3 // phase 1 // o->o o // o o->o //>o o o-> // o 1->o //>o o 1-> // 1->o o // o 1 2 // 2 o 1 // 1 2 o // phase 2 //>o 1 2-> // 2->o 1 // 1 2->o // 2->1 2 // 2 2->1 //>1 2 2-> // 2 2 2 // 2 2 2 // 2 2 2 /// Ring all-reduce algorithm with summation /// /// * `node` - The id of the current node /// * `nodes` - Map of all nodes in the operation /// * `data_service` - The data service handles peer-to-peer tensor transfers /// * `sync_service` - The sync service handles syncing with peers /// * `tensor` - The tensor to reduce. At least one dimension size must be greater than the number /// of nodes pub(crate) async fn ring_all_reduce_sum( node: NodeId, nodes: &HashMap, data_service: Arc>, sync_service: Arc>, tensor: B::FloatTensorPrimitive, ) -> Result where B: Backend, P: Protocol, { let shape = tensor.shape(); let device = &B::float_device(&tensor); // Slice tensors in N parts, N is node count let slice_dim = get_slice_dim(&shape); if shape[slice_dim] < nodes.len() { return Err(GlobalCollectiveError::RingReduceImpossible); } let ring = get_ring_topology(nodes.keys().cloned().collect::>()); let slice_ranges = get_ring_reduce_slice_ranges(shape[slice_dim], ring.len()); let mut slices = slice_tensor::(tensor, slice_dim, slice_ranges); let mut send_slice_idx = ring .iter() .position(|id| *id == node) .expect("Node is in ring"); let prev_node_idx = (send_slice_idx + ring.len() - 1) % ring.len(); // +ring.len for overflow let prev_node = nodes.get(&ring[prev_node_idx]).unwrap(); let mut transfer_counter: u64 = 0; // Phase 1: add do_cycles::( &mut slices, &mut transfer_counter, &mut send_slice_idx, true, prev_node.clone(), &data_service, device, ) .await?; // Phase 2: replace do_cycles::( &mut slices, &mut transfer_counter, &mut send_slice_idx, false, prev_node.clone(), &data_service, device, ) .await?; // Wait for all nodes to finish sync_service.sync().await; // merge slices Ok(B::float_cat(slices, slice_dim)) } /// Do N-1 cycles of ring-reduce /// /// * `slices` - Slices of the original tensor, len equal to node count /// * `transfer_counter` - counter for each step (one send one receive) /// * `send_slice_idx` - counter for the index of each slice to send /// * `is_phase_one` - In phase 1, the tensors are aggregated. Otherwise, they are overridden /// * `data_service` - TensorDataService for peer-to-peer tensor transfers /// * `device` - The device on which all local tensors are stored. Should match `slices` async fn do_cycles( slices: &mut [B::FloatTensorPrimitive], transfer_counter: &mut u64, send_slice_idx: &mut usize, is_phase_one: bool, prev_node: Address, data_service: &Arc>, device: &B::Device, ) -> Result<(), GlobalCollectiveError> where B: Backend, P: Protocol, { let slice_count = slices.len(); for _ in 0..(slice_count - 1) { let transfer_id = (*transfer_counter).into(); // +slice_count to avoid overflow let recv_slice_idx = (*send_slice_idx + slice_count - 1) % slice_count; let slice_send = slices[*send_slice_idx].clone(); let upload = { let data_service = data_service.clone(); tokio::spawn(async move { data_service .expose(slice_send.clone(), 1, transfer_id) .await }) }; let download = { let data_client = data_service.clone(); let next_node = prev_node.clone(); tokio::spawn(async move { data_client.download_tensor(next_node, transfer_id).await }) }; upload.await.unwrap(); let download = download.await.unwrap(); if is_phase_one { let download = download.expect("Peer closed download connection"); let tensor = B::float_from_data(download, device); slices[recv_slice_idx] = B::float_add(slices[recv_slice_idx].clone(), tensor); } else { let tensor = B::float_from_data(download.unwrap(), device); let old_shape = slices[recv_slice_idx].shape(); if old_shape != tensor.shape() { return Err(GlobalCollectiveError::PeerSentIncoherentTensor); } slices[recv_slice_idx] = tensor; } // Move slice index *send_slice_idx = recv_slice_idx; *transfer_counter += 1; } Ok(()) } /// But a tensor into even slices across a dimension /// /// * `tensor` - the tensor to slice /// * `slice_dim` - the dimension to slice across /// * `slice_ranges` - The ranges of indices on `slice_dim` to use when slicing the tensor fn slice_tensor( tensor: B::FloatTensorPrimitive, slice_dim: usize, slice_ranges: Vec>, ) -> Vec { let shape = tensor.shape(); // full range across all dims as Slice let full_range = shape .iter() .map(|dim| Slice::from(0..*dim)) .collect::>(); // Slice tensors let mut slices = vec![]; for range in &slice_ranges { let mut all_ranges = full_range.clone(); all_ranges[slice_dim] = Slice::from(range.clone()); let slice = B::float_slice(tensor.clone(), &all_ranges); slices.push(slice); } slices } /// Get the ring topology fn get_ring_topology(mut nodes: Vec) -> Vec { // This ordering could be more sophisticated, using node proximities etc nodes.sort(); nodes } ================================================ FILE: crates/burn-collective/src/global/node/sync.rs ================================================ use std::{ marker::PhantomData, sync::{Arc, Mutex}, vec, }; use burn_communication::{CommunicationChannel, Message, Protocol, ProtocolClient}; use serde::{Deserialize, Serialize}; use tokio::sync::{Notify, RwLock}; use crate::{NodeId, node::base::NodeState}; /// Handles the status of sync requests from other nodes pub(crate) struct SyncService { /// Current node's state, shared with the thread that does aggregations node_state: Arc>>, /// The number of peers that have requested to sync with us since the last successful sync. syncing_peers: Mutex>, /// Notification on each incoming sync request sync_notif: Notify, _p: PhantomData

, } #[derive(Debug, Serialize, Deserialize)] struct SyncRequest(NodeId); impl SyncService

{ pub fn new(node_state: Arc>>) -> Self { Self { node_state, syncing_peers: Mutex::new(vec![]), sync_notif: Notify::new(), _p: PhantomData, } } fn add_syncing_peer(&self, peer: NodeId) { let mut syncing_peers = self.syncing_peers.lock().unwrap(); syncing_peers.push(peer); } /// Sync with all peers. pub async fn sync(&self) { // we can't sync while we register let node_state = self.node_state.read().await; let node_state = node_state .as_ref() .expect("Trying to sync a node before having registered to the orchestrator"); // this peer is syncing self.add_syncing_peer(node_state.node_id); for (id, addr) in &node_state.nodes { if *id == node_state.node_id { continue; } let mut connection = P::Client::connect(addr.clone(), "sync") .await .expect("Couldn't connect to peer for sync"); let msg = SyncRequest(node_state.node_id); let sync_bytes = rmp_serde::to_vec(&msg).unwrap(); connection .send(Message::new(sync_bytes.into())) .await .expect("Peer closed connection unexpectedly"); } loop { { // compare currently synced peers with list of all nodes let mut syncing_peers = self.syncing_peers.lock().unwrap().to_vec(); syncing_peers.sort(); let mut all_node_ids = node_state.nodes.keys().cloned().collect::>(); all_node_ids.sort(); if syncing_peers == all_node_ids { // all nodes have synced syncing_peers.clear(); return; } } // Wait for the next sync to come in self.sync_notif.notified().await } } pub async fn handle_sync_connection(&self, mut channel: C) { let msg = channel.recv().await.unwrap(); let Some(msg) = msg else { return; }; let msg = rmp_serde::from_slice::(&msg.data).unwrap(); self.add_syncing_peer(msg.0); self.sync_notif.notify_waiters(); } } ================================================ FILE: crates/burn-collective/src/global/node/tree.rs ================================================ use std::{collections::HashMap, sync::Arc}; use crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService}; use burn_communication::{Address, Protocol, data_service::TensorDataService}; use burn_tensor::{TensorMetadata, backend::Backend}; use futures::{StreamExt, stream::FuturesUnordered}; struct TreeTopology { parents: HashMap, children: HashMap>, } /// Global all-reduce, using a b-tree strategy. /// /// Returns the resulting tensor on the same device as the input tensor pub(crate) async fn tree_all_reduce_sum( node: NodeId, nodes: &HashMap, data_service: Arc>, sync_service: Arc>, tensor: B::FloatTensorPrimitive, arity: u32, ) -> Result where B: Backend, P: Protocol, { let shape = tensor.shape(); let device = &B::float_device(&tensor); // Topology could be cached based on (nodes.keys().cloned(), arity) let strategy = get_tree_topology(nodes.keys().cloned().collect::>(), arity); // Transfer 1: Download and sum tensors from children let mut result = tensor; if let Some(children) = strategy.children.get(&node) { let mut downloads = children .iter() .map(|child| { let child_addr = nodes.get(child).unwrap().clone(); let data_service = data_service.clone(); async move { let data = data_service .download_tensor(child_addr.clone(), 0.into()) .await .ok_or(GlobalCollectiveError::PeerLost(*child))?; Ok::(B::float_from_data( data, device, )) } }) .collect::>(); for _ in children { let res = downloads.next().await.unwrap().unwrap(); if res.shape() != shape { return Err(GlobalCollectiveError::PeerSentIncoherentTensor); } result = B::float_add(result, res); } } // Transfer 2: Expose result to parent and download final result if not root if let Some(parent) = strategy.parents.get(&node) { data_service.expose(result.clone(), 1, 0.into()).await; let parent_addr = nodes.get(parent).unwrap().clone(); let data = data_service .download_tensor(parent_addr.clone(), 1.into()) .await .ok_or(GlobalCollectiveError::PeerLost(*parent))?; let parent_tensor = B::float_from_data(data, device); if parent_tensor.shape() != shape { return Err(GlobalCollectiveError::PeerSentIncoherentTensor); } result = parent_tensor; } // Transfer 3: Expose final result to children (if any) if let Some(children) = strategy.children.get(&node) && !children.is_empty() { data_service .expose(result.clone(), children.len() as u32, 1.into()) .await; } // Final barrier sync_service.sync().await; Ok(result) } /// Get the tree topology. /// /// * `nodes` - List of node ids. Order doesn't matter. Nodes must be unique. fn get_tree_topology(mut nodes: Vec, arity: u32) -> TreeTopology { assert!(arity >= 1, "Arity must be ≥ 1"); nodes.sort(); // Sort let n = nodes.len(); let k = arity as usize; let mut parents: HashMap<_, _> = HashMap::with_capacity(n); let mut children: HashMap<_, _> = HashMap::with_capacity(n); for (i, &parent_id) in nodes.iter().enumerate() { // compute the window [first_child, last_child) let first = i * k + 1; if first < n { let last = usize::min(first + k, n); let mut ch = Vec::with_capacity(last - first); for &child_id in &nodes[first..last] { parents.insert(child_id, parent_id); ch.push(child_id); } children.insert(parent_id, ch); } else { // leaf‐node: no children children.insert(parent_id, Vec::new()); } } TreeTopology { parents, children } } #[cfg(test)] mod tests { use super::*; /// Test the tree topology algorithm with arity 2 and 7 nodes #[test] fn test_get_tree_topology_arity2_size7() { let mut nodes = vec![]; for i in 0..7 { nodes.push(i.into()); } let topology = get_tree_topology(nodes, 2); // Root is 0, so it should have no parent assert!(!topology.parents.contains_key(&0.into())); // Parents: // Node 1 and 2 → parent 0 // Node 3 and 4 → parent 1 // Node 5 and 6 → parent 2 let expected_parents = [ (1.into(), 0.into()), (2.into(), 0.into()), (3.into(), 1.into()), (4.into(), 1.into()), (5.into(), 2.into()), (6.into(), 2.into()), ]; for (child, parent) in &expected_parents { assert_eq!( topology.parents.get(child), Some(parent), "wrong parent for {child:?}" ); } // There should be exactly 6 entries in parents assert_eq!(topology.parents.len(), expected_parents.len()); // Children: // 0 → [1, 2] // 1 → [3, 4] // 2 → [5, 6] // 3,4,5,6 → [] assert_eq!( topology.children.get(&0.into()), Some(&vec![1.into(), 2.into()]) ); assert_eq!( topology.children.get(&1.into()), Some(&vec![3.into(), 4.into()]) ); assert_eq!( topology.children.get(&2.into()), Some(&vec![5.into(), 6.into()]) ); // Leaves for leaf in 3..7 { assert_eq!( topology.children.get(&leaf.into()), Some(&Vec::new()), "leaf {leaf:?} should have no children" ); } // Ensure we have exactly 7 entries in children assert_eq!(topology.children.len(), 7); } } ================================================ FILE: crates/burn-collective/src/global/node/worker.rs ================================================ use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration}; use burn_communication::{Address, CommunicationChannel, Message, ProtocolClient}; use tokio::{ runtime::Runtime, sync::{ Mutex, mpsc::{Receiver, Sender}, }, task::JoinHandle, }; use tokio_util::sync::CancellationToken; use crate::global::shared::{ CollectiveMessage, CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest, RemoteResponse, RequestId, SessionId, }; /// Worker that handles communication with the orchestrator for global collective operations. pub(crate) struct GlobalClientWorker { handle: Option>>, cancel_token: CancellationToken, request_sender: Sender, _phantom_data: PhantomData

, } // Rename struct GlobalClientWorkerState { requests: HashMap>, } impl GlobalClientWorkerState { fn new() -> Self { Self { requests: HashMap::new(), } } } #[derive(Debug)] pub(crate) struct ClientRequest { pub request: RemoteRequest, pub callback: Sender, } impl ClientRequest { pub(crate) fn new(request: RemoteRequest, callback: Sender) -> Self { Self { request, callback } } } impl GlobalClientWorker { /// Create a new global client worker and start the tasks. pub(crate) fn new( runtime: &Runtime, cancel_token: CancellationToken, global_address: &Address, ) -> Self { let (request_sender, request_recv) = tokio::sync::mpsc::channel::(10); let state = Arc::new(Mutex::new(GlobalClientWorkerState::new())); let handle = runtime.spawn(Self::start( state, cancel_token.clone(), global_address.clone(), request_recv, )); Self { handle: Some(handle), cancel_token, request_sender, _phantom_data: PhantomData, } } /// Start the global client tasks async fn start( state: Arc>, cancel_token: CancellationToken, global_address: Address, request_recv: Receiver, ) -> Result<(), GlobalCollectiveError> { // Init the connection. let (request, response) = Self::init_connection(&global_address).await?; // Websocket async worker loading responses from the server. let response_handle = tokio::spawn(Self::response_loader( state.clone(), response, cancel_token.clone(), )); // Channel async worker sending operations to the server. let request_handle = tokio::spawn(Self::request_sender( request_recv, state, request, cancel_token.clone(), )); if let Err(e) = response_handle.await { log::error!("Response handler failed: {e:?}"); } if let Err(e) = request_handle.await { log::error!("Request handler failed: {e:?}"); } Ok(()) } async fn init_connection( address: &Address, ) -> Result<(C::Channel, C::Channel), GlobalCollectiveError> { let session_id = SessionId::new(); let stream_request = tokio::spawn(Self::connect_with_retry( address.clone(), "request", std::time::Duration::from_secs(1), None, session_id, )); let stream_response = tokio::spawn(Self::connect_with_retry( address.clone(), "response", std::time::Duration::from_secs(1), None, session_id, )); let Ok(Some(request)) = stream_request.await else { return Err(GlobalCollectiveError::OrchestratorUnreachable); }; let Ok(Some(response)) = stream_response.await else { return Err(GlobalCollectiveError::OrchestratorUnreachable); }; Ok((request, response)) } /// Connect with websocket with retries. async fn connect_with_retry( address: Address, route: &str, retry_pause: Duration, retry_max: Option, session_id: SessionId, ) -> Option { let mut retries = 0; loop { if let Some(max) = retry_max && retries >= max { log::warn!("Failed to connect to {address} after {max} retries."); return None; } // Try to connect to the request address. println!("Connecting to {address} ..."); let result = C::connect(address.clone(), route).await; if let Some(mut stream) = result { let init_msg = CollectiveMessage::Init(session_id); let bytes: bytes::Bytes = rmp_serde::to_vec(&init_msg).unwrap().into(); stream .send(Message::new(bytes)) .await .expect("Can send the init message on the websocket."); return Some(stream); } println!("Failed to connect to {address}, retrying... Attempt #{retries}"); tokio::time::sleep(retry_pause).await; retries += 1; } } /// Unregister the worker and close the connection. pub(crate) async fn close_connection(&mut self) -> Result<(), GlobalCollectiveError> { if let Some(handle) = self.handle.take() { // Un-register from server let req = RemoteRequest::Finish; let resp = self.request(req).await; if resp != RemoteResponse::FinishAck { log::error!("Requested to finish, did not get FinishAck; got {resp:?}"); return Err(GlobalCollectiveError::WrongOrchestratorResponse); } self.cancel_token.cancel(); if let Err(e) = handle.await.unwrap() { log::error!("Connection error {e:?}"); } } Ok(()) } async fn response_loader( state: Arc>, mut stream_response: C::Channel, cancel_token: CancellationToken, ) { loop { tokio::select! { // Check if the cancel token is cancelled _ = cancel_token.cancelled() => { break; } // .. Or get a message from the websocket response = stream_response.recv() => { match response { Err(err) => { log::error!("Error receiving message from websocket: {err:?}"); break; } Ok(response) => { let Some(response) = response else { log::warn!("Closed connection"); break; }; let response: CollectiveMessageResponse = rmp_serde::from_slice(&response.data) .expect("Can deserialize messages from the websocket."); let state_resp = state.lock().await; let response_callback = state_resp .requests .get(&response.request_id) .expect("Got a response to an unknown request"); response_callback.send(response.content).await.unwrap(); } } } } } log::info!("Worker closing connection"); stream_response .close() .await .expect("Can close the websocket stream."); } async fn request_sender( mut request_recv: Receiver, worker: Arc>, mut stream_request: C::Channel, cancel_token: CancellationToken, ) { loop { tokio::select! { _ = cancel_token.cancelled() => { break; }, request = request_recv.recv() => { let Some(request) = request else { continue; }; let id = RequestId::new(); // Register the callback if there is one { let mut state = worker.lock().await; state.requests.insert(id, request.callback); } let request = CollectiveMessage::Request(id, request.request); let bytes = rmp_serde::to_vec::(&request) .expect("Can serialize tasks to bytes.") .into(); stream_request .send(Message::new(bytes)) .await .expect("Can send the message on the websocket."); } } } log::info!("Worker closing connection"); stream_request .close() .await .expect("Can send the close message on the websocket."); } pub(crate) async fn request(&self, req: RemoteRequest) -> RemoteResponse { let (callback, mut response_recv) = tokio::sync::mpsc::channel::(10); let client_req = ClientRequest::new(req, callback); self.request_sender.send(client_req).await.unwrap(); response_recv.recv().await.unwrap() } } ================================================ FILE: crates/burn-collective/src/global/orchestrator/base.rs ================================================ use std::fmt::Debug; use std::sync::Arc; use tokio::sync::Mutex; use crate::global::{ orchestrator::state::GlobalCollectiveState, shared::{CollectiveMessage, GlobalCollectiveError}, }; use burn_communication::{ CommunicationChannel, Message, ProtocolServer, util::os_shutdown_signal, websocket::WsServer, }; /// The global collective state manages collective operations on the global level #[derive(Clone)] pub(crate) struct GlobalOrchestrator { state: Arc>, } impl GlobalOrchestrator { /// Starts the comms server with two routes: "/request" and "/response" pub(crate) async fn start( shutdown_signal: F, comms_server: S, ) -> Result<(), GlobalCollectiveError> where F: Future + Send + 'static, { let state = GlobalCollectiveState::new(); let server = Self { state: Arc::new(tokio::sync::Mutex::new(state)), }; comms_server .route("/response", { let server = server.clone(); async move |socket| { if let Err(err) = server.handle_socket_response::(socket).await { log::error!("[Response Handler] Error: {err:?}") } } }) .route("/request", { let server = server.clone(); async move |socket| { if let Err(err) = server.handle_socket_request::(socket).await { log::error!("[Request Handler] Error: {err:?}") } } }) .serve(shutdown_signal) .await .map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?; Ok(()) } async fn handle_socket_response( self, mut stream: S::Channel, ) -> Result<(), GlobalCollectiveError> { log::info!("[Response Handler] On new connection."); let msg = stream .recv() .await .map_err(|err| GlobalCollectiveError::Server(format!("{err:?}")))?; let Some(msg) = msg else { log::warn!("Response socket closed early!"); return Ok(()); }; let msg = rmp_serde::from_slice::(&msg.data) .map_err(|_| GlobalCollectiveError::InvalidMessage)?; let CollectiveMessage::Init(id) = msg else { return Err(GlobalCollectiveError::FirstMsgNotInit); }; let mut receiver = { let mut state = self.state.lock().await; state.get_session_responder(id) }; while let Some(response) = receiver.recv().await { let bytes = rmp_serde::to_vec(&response).unwrap(); stream.send(Message::new(bytes.into())).await?; } log::info!("[Response Handler] Closing connection."); Ok(()) } async fn handle_socket_request( self, mut stream: S::Channel, ) -> Result<(), GlobalCollectiveError> { log::info!("[Request Handler] On new connection."); let mut session_id = None; loop { let packet = stream.recv().await?; let Some(msg) = packet else { log::info!("Peer closed the connection"); break; }; let mut state = self.state.lock().await; let msg = rmp_serde::from_slice::(&msg.data) .map_err(|_| GlobalCollectiveError::InvalidMessage)?; match msg { CollectiveMessage::Init(id) => { state.init_session(id); session_id = Some(id); } CollectiveMessage::Request(request_id, remote_request) => { let session_id = session_id.ok_or(GlobalCollectiveError::FirstMsgNotInit)?; state .process_request(session_id, request_id, remote_request) .await; } } } Ok(()) } } /// Start a global orchestrator with WebSocket on the given port pub async fn start_global_orchestrator(port: u16) { let server = WsServer::new(port); let res = GlobalOrchestrator::start(os_shutdown_signal(), server).await; if let Err(err) = res { log::error!("Global Collective Orchestrator error: {err:?}"); } } ================================================ FILE: crates/burn-collective/src/global/orchestrator/mod.rs ================================================ pub(crate) mod base; pub(crate) mod state; pub use base::start_global_orchestrator; ================================================ FILE: crates/burn-collective/src/global/orchestrator/state.rs ================================================ use crate::{ PeerId, global::{ NodeId, shared::{ CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest, RemoteResponse, RequestId, SessionId, }, }, }; use burn_communication::Address; use std::collections::HashMap; use tokio::sync::mpsc::{Receiver, Sender}; pub(crate) struct Session { response_sender: Sender, response_receiver: Option>, } impl Session { fn new() -> Self { let (response_sender, recv) = tokio::sync::mpsc::channel::(1); Self { response_sender, response_receiver: Some(recv), } } async fn respond(&mut self, response: CollectiveMessageResponse) { self.response_sender.send(response).await.unwrap(); } } pub(crate) struct GlobalCollectiveState { /// The ids passed to each register so far, and their addresses registered_nodes: HashMap, /// Address for each node node_addresses: HashMap, /// Peer on each node node_peers: HashMap>, /// How many total nodes for the current register operation, as defined by the first caller cur_num_nodes: Option, /// How many peers have registered total num_global_peers: u32, register_requests: Vec<(SessionId, RequestId, NodeId)>, sessions: HashMap, } impl GlobalCollectiveState { pub fn new() -> Self { Self { registered_nodes: HashMap::new(), node_addresses: HashMap::new(), node_peers: HashMap::new(), cur_num_nodes: None, num_global_peers: 0, register_requests: Vec::new(), sessions: HashMap::new(), } } pub(crate) fn init_session(&mut self, id: SessionId) { if self.sessions.contains_key(&id) { return; } self.sessions.insert(id, Session::new()); } /// Create the session with given id if necessary, and get the response receiver pub(crate) fn get_session_responder( &mut self, id: SessionId, ) -> Receiver { self.init_session(id); let session = self.sessions.get_mut(&id).unwrap(); let response_recv = session.response_receiver.take(); response_recv.unwrap() } pub(crate) async fn respond( &mut self, session_id: SessionId, response: CollectiveMessageResponse, ) { let session = self.sessions.get_mut(&session_id).unwrap(); session.respond(response).await; } /// Process an incoming node's request pub(crate) async fn process_request( &mut self, session_id: SessionId, request_id: RequestId, request: RemoteRequest, ) { if let Err(err) = match request { RemoteRequest::Register { node_addr, num_nodes, peers, } => { self.register(session_id, request_id, node_addr, num_nodes, peers) .await } RemoteRequest::Finish => self.finish(session_id, request_id).await, } { // Error occurred, send it as response let content = RemoteResponse::Error(err); self.respond( session_id, CollectiveMessageResponse { request_id, content, }, ) .await; } } /// Un-register a node. Any pending requests will be cancelled, returning error responses. async fn finish( &mut self, session_id: SessionId, request_id: RequestId, ) -> Result<(), GlobalCollectiveError> { let node_id = self .registered_nodes .remove(&session_id) .ok_or(GlobalCollectiveError::NotRegisteredOnFinish)?; self.node_addresses.remove(&node_id); self.node_peers.remove(&node_id); self.num_global_peers = 0; let mut register_requests = vec![]; core::mem::swap(&mut register_requests, &mut self.register_requests); for (session, req, node_id) in register_requests { if session == session_id { // Send a response if we are finishing a session with a pending register request let content = RemoteResponse::Error(GlobalCollectiveError::PendingRegisterOnFinish); let response = CollectiveMessageResponse { request_id: req, content, }; self.respond(session_id, response).await; } else { // keep the register request self.register_requests.push((session, req, node_id)); } } self.respond( session_id, CollectiveMessageResponse { request_id, content: RemoteResponse::FinishAck, }, ) .await; Ok(()) } async fn register( &mut self, session_id: SessionId, request_id: RequestId, node_addr: Address, num_nodes: u32, peers: Vec, ) -> Result<(), GlobalCollectiveError> { match &self.cur_num_nodes { Some(cur_num_nodes) => { if *cur_num_nodes != num_nodes { return Err(GlobalCollectiveError::RegisterParamsMismatch); } } None => { self.cur_num_nodes = Some(num_nodes); } } self.num_global_peers += peers.len() as u32; let node_id: NodeId = self.registered_nodes.len().into(); self.registered_nodes.insert(session_id, node_id); if self.node_addresses.values().any(|addr| node_addr == *addr) { return Err(GlobalCollectiveError::DoubleRegister); } self.node_addresses.insert(node_id, node_addr); self.node_peers.insert(node_id, peers); self.register_requests .push((session_id, request_id, node_id)); if self.registered_nodes.len() == num_nodes as usize { let mut callbacks = vec![]; core::mem::swap(&mut callbacks, &mut self.register_requests); for (session, request, node_id) in callbacks { let content = RemoteResponse::Register { node_id, nodes: self.node_addresses.clone(), num_global_devices: self.num_global_peers, }; let resp = CollectiveMessageResponse { request_id: request, content, }; self.respond(session, resp).await; } } Ok(()) } } ================================================ FILE: crates/burn-collective/src/global/shared.rs ================================================ use std::{collections::HashMap, sync::atomic::AtomicU32}; use crate::{NodeId, PeerId}; use burn_communication::{Address, CommunicationError}; use burn_std::id::IdGenerator; use serde::{Deserialize, Serialize}; /// A unique identifier for each request made to a global orchestrator #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] pub(crate) struct RequestId(u32); static REQ_ID_COUNTER: AtomicU32 = AtomicU32::new(0); impl RequestId { pub(crate) fn new() -> Self { let id = REQ_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); Self(id) } } impl Default for RequestId { fn default() -> Self { Self::new() } } /// Unique identifier that can represent a session between a node and the orchestrator. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) struct SessionId { id: u64, } impl SessionId { /// Create a new [session id](SessionId). pub(crate) fn new() -> Self { Self { id: IdGenerator::generate(), } } } /// Requests sent from the client #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) enum CollectiveMessage { Init(SessionId), Request(RequestId, RemoteRequest), } /// Responses sent to the client #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct CollectiveMessageResponse { pub request_id: RequestId, pub content: RemoteResponse, } /// Requests made from a client to a server. #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) enum RemoteRequest { // Register a node Register { /// Endpoint for this node node_addr: Address, /// Number of total nodes num_nodes: u32, /// List of peers on this node peers: Vec, }, /// Unregister node Finish, } /// Responses for each server request #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub(crate) enum RemoteResponse { /// Response to a register request Register { /// The orchestrator gives the node its id node_id: NodeId, /// All the nodes in the collective: including self nodes: HashMap, /// How many devices exist globally? For averaging values num_global_devices: u32, }, // Finish FinishAck, // There was a server-side error Error(GlobalCollectiveError), } /// Errors that occur during collective operations on the global level #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum GlobalCollectiveError { /// Operations that can't be done before registering AllReduceBeforeRegister, /// Ring all-reduce can't be done if all tensor dimensions are smaller than the number of nodes. RingReduceImpossible, /// Either a node has unregistered twice, or a Finish has been called before a Register NotRegisteredOnFinish, /// Finish has been called before a Register operation was finished PendingRegisterOnFinish, /// Trying to register a different way than is currently being done RegisterParamsMismatch, /// Trying to register while already registered DoubleRegister, /// Trying to aggregate a different way than is currently being done AllReduceParamsMismatch, /// First message on socket should be Message::Init FirstMsgNotInit, /// Messages should be rmp_serde serialized `Message` types InvalidMessage, /// A peer behaved unexpectedly PeerSentIncoherentTensor, /// Tried to download from a peer, but the peer closed or lost the connection PeerLost(NodeId), /// Error from the coordinator Server(String), /// The node received an invalid response WrongOrchestratorResponse, /// Node couldn't connect to coordinator OrchestratorUnreachable, } impl From for GlobalCollectiveError { fn from(err: E) -> Self { Self::Server(format!("{err:?}")) } } ================================================ FILE: crates/burn-collective/src/lib.rs ================================================ mod global; pub use global::*; mod config; pub use config::*; mod api; pub use api::*; mod local; #[cfg(all( test, any( feature = "test-ndarray", feature = "test-wgpu", feature = "test-cuda", feature = "test-metal" ) ))] mod tests; ================================================ FILE: crates/burn-collective/src/local/all_reduce/base.rs ================================================ use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices}; use crate::{ AllReduceStrategy, CollectiveConfig, CollectiveError, ReduceOperation, local::{ all_reduce_sum_centralized, all_reduce_sum_ring, all_reduce_sum_tree, broadcast_centralized, broadcast_tree, reduce_sum_centralized, reduce_sum_tree, }, node::base::Node, }; use burn_communication::Protocol; use burn_tensor::backend::Backend; #[cfg(feature = "tracing")] use tracing::Instrument; /// Perform an all-reduce with no multi-node operations (global ops) #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors, config)) )] pub(crate) async fn all_reduce_local_only( tensors: CollectiveTensorMap, op: ReduceOperation, config: &CollectiveConfig, ) -> Result, CollectiveError> { let local_strategy = &config.local_all_reduce_strategy; let mut reduced_tensors = match local_strategy { AllReduceStrategy::Centralized => all_reduce_sum_centralized::(tensors), AllReduceStrategy::Tree(arity) => all_reduce_sum_tree::(tensors, *arity), AllReduceStrategy::Ring => all_reduce_sum_ring::(tensors), }; if op == ReduceOperation::Mean { #[cfg(feature = "tracing")] let _span = tracing::info_span!("mean_reduction").entered(); // Apply mean division let div = (reduced_tensors.len() as f32).into(); reduced_tensors = reduced_tensors .into_iter() .map(|(id, t)| (id, B::float_div_scalar(t, div))) .collect(); } Ok(reduced_tensors) } /// Do an all-reduce in a multi-node context /// /// With Tree and Centralized strategies, the all-reduce is split between a /// reduce (all tensors are reduced to one device), and a broadcast (the result is sent to all /// other devices). The all-reduce on the global level is done between both steps. /// Due to the nature of the Ring strategy, this separation can't be done. /// /// For the Ring strategy, this isn't possible, because it is more like a /// reduce-scatter plus an all-gather, so using a Ring strategy locally in a multi-node /// setup may be unadvantageous. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors, config, global_client)) )] pub(crate) async fn all_reduce_with_global( tensors: CollectiveTensorMap, op: ReduceOperation, config: &CollectiveConfig, global_client: &mut Node, ) -> Result, CollectiveError> { let peer_devices = get_peer_devices::(&tensors); // For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later let main_device = *tensors.keys().next().unwrap(); let mut main_tensor = match config.local_all_reduce_strategy { AllReduceStrategy::Centralized => reduce_sum_centralized::(tensors, &main_device), AllReduceStrategy::Tree(arity) => reduce_sum_tree::(tensors, &main_device, arity), AllReduceStrategy::Ring => all_reduce_sum_ring::(tensors) .remove(&main_device) .unwrap(), }; // Do aggregation on global level with the main tensor main_tensor = { let fut = async { let global_strategy = config .global_all_reduce_strategy .expect("global_all_reduce_strategy must be set"); global_client .all_reduce(main_tensor, global_strategy, op) .await }; #[cfg(feature = "tracing")] { fut.instrument(tracing::info_span!("global_all_reduce")) } #[cfg(not(feature = "tracing"))] { fut } } .await .map_err(CollectiveError::Global)?; // Broadcast result to all devices let tensors = match config.local_all_reduce_strategy { AllReduceStrategy::Tree(arity) => { broadcast_tree::(peer_devices, main_device, main_tensor, arity) } // If we chose the ring strategy and we must still broadcast the global result, // we use the centralized strategy for broadcasting, but the tree may be better. AllReduceStrategy::Centralized | AllReduceStrategy::Ring => { broadcast_centralized::(peer_devices, main_device, main_tensor) } }; Ok(tensors) } ================================================ FILE: crates/burn-collective/src/local/all_reduce/centralized.rs ================================================ use burn_tensor::backend::Backend; use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices}; use crate::local::{broadcast_centralized, reduce_sum_centralized}; /// Perform an all-reduce operation by reducing all tensors on one device, and broadcasting the /// result to all other devices /// /// Internally, this is just a call to `reduce` followed by a `broadcast` #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] pub(crate) fn all_reduce_sum_centralized( tensors: CollectiveTensorMap, ) -> CollectiveTensorMap { // Get corresponding devices for each peer let peer_devices = get_peer_devices::(&tensors); let central_device = *tensors.keys().next().unwrap(); // Reduce to central device let central_tensor = reduce_sum_centralized::(tensors, ¢ral_device); // Broadcast result to all broadcast_centralized::(peer_devices, central_device, central_tensor) } ================================================ FILE: crates/burn-collective/src/local/all_reduce/mod.rs ================================================ mod base; mod centralized; mod op; mod ring; mod tree; pub(crate) use base::*; pub(crate) use centralized::*; pub(crate) use op::*; pub(crate) use ring::*; pub(crate) use tree::*; ================================================ FILE: crates/burn-collective/src/local/all_reduce/op.rs ================================================ use crate::global::node::base::Node; use crate::local::tensor_map::CollectiveTensorMap; use crate::{CollectiveConfig, CollectiveError, PeerId, ReduceOperation, local}; use burn_communication::Protocol; use burn_std::Shape; use burn_tensor::TensorMetadata; use burn_tensor::backend::Backend; use std::sync::mpsc::SyncSender; /// An on-going all-reduce operation #[derive(Debug)] pub struct AllReduceOp { /// all-reduce calls, one for each calling device calls: Vec>, /// The reduce operation of the current all-reduce, as defined by the first caller op: ReduceOperation, /// The shape of the current all-reduce, as defined by the first caller shape: Shape, } /// Struct for each device that calls an all-reduce operation #[derive(Debug)] pub struct AllReduceOpCall { /// Id of the caller for this operation caller: PeerId, /// The tensor primitive passed as input input: B::FloatTensorPrimitive, /// Callback for the result of the all-reduce result_sender: SyncSender>, } /// Type sent to the collective client upon completion of a all-reduce aggregation pub(crate) type AllReduceResult = Result; impl AllReduceOp { pub fn new(shape: Shape, reduce_op: ReduceOperation) -> Self { Self { calls: vec![], op: reduce_op, shape, } } /// Get a list of the peers. fn peers(&self) -> Vec { self.calls.iter().map(|c| c.caller).collect() } /// Register a call to all-reduce in this operation. /// /// # Returns /// /// `true` if enough peers have registered, and the all-reduce is ready pub fn register_call( &mut self, caller: PeerId, input: B::FloatTensorPrimitive, result_sender: SyncSender>, op: ReduceOperation, peer_count: usize, ) -> Result { if self.shape != input.shape() { return Err(CollectiveError::AllReduceShapeMismatch); } if self.op != op { return Err(CollectiveError::AllReduceOperationMismatch); } self.calls.push(AllReduceOpCall { caller, input, result_sender, }); Ok(self.calls.len() == peer_count) } /// Runs the all-reduce if the operation is ready. Otherwise, do nothing #[cfg_attr(feature = "tracing", tracing::instrument( level = "trace", skip(self, config, global_client), fields( ?self.op, ?self.shape, self.peers = ?self.peers(), ) ))] pub async fn execute( mut self, config: &CollectiveConfig, global_client: &mut Option>, ) { // all registered callers have sent a tensor to aggregate match self.all_reduce(config, global_client).await { Ok(mut tensors) => { // Return resulting tensors self.calls.iter().for_each(|call| { let result = tensors .remove(&call.caller) .expect("tensor/peer internal mismatch."); call.result_sender.send(Ok(result)).unwrap(); }); assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch."); } Err(err) => { // Send error to all subscribers self.fail(err); } } } /// Perform an all-reduce operation. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(self, config, global_client)) )] async fn all_reduce( &mut self, config: &CollectiveConfig, global_client: &mut Option>, ) -> Result, CollectiveError> { let tensors = self .calls .iter() .map(|call| (call.caller, call.input.clone())) .collect(); if let Some(global_client) = global_client.as_mut() { local::all_reduce_with_global(tensors, self.op, config, global_client).await } else { local::all_reduce_local_only::(tensors, self.op, config).await } } /// Send a collective error as result to operation caller pub fn fail(self, err: CollectiveError) { self.calls.iter().for_each(|op| { op.result_sender.send(Err(err.clone())).unwrap(); }); } } ================================================ FILE: crates/burn-collective/src/local/all_reduce/ring.rs ================================================ use super::tree::all_reduce_sum_tree; use crate::PeerId; use crate::local::tensor_map; use crate::local::tensor_map::CollectiveTensorMap; use burn_tensor::{Shape, Slice, TensorMetadata, backend::Backend}; use std::{collections::HashMap, ops::Range}; /// Ring implementation of All-Reduce (Ring-Reduce) #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] pub(crate) fn all_reduce_sum_ring( tensors: CollectiveTensorMap, ) -> CollectiveTensorMap { // https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model // Example: tensors=3, slices=3 // phase 1 // o->o o // o o->oå // o o o-> // o 1->o // o o 1-> // 1->o o // o 1 2 // 2 o 1 // 1 2 o // phase 2 // o 1 2-> // 2->o 1 // 1 2->o // 2->1 2 // 2 2->1 // 1 2 2-> // 2 2 2 // 2 2 2 // 2 2 2 // Verify all shapes are the same let shape = tensor_map::get_common_shape::(&tensors) .expect("Cannot aggregate tensors with different sizes"); // Chose an axis let slice_dim = get_slice_dim(&shape); let slice_dim_size = shape[slice_dim]; let tensor_count = tensors.len(); if slice_dim_size < tensor_count { // Tensor cannot be split into N slices! Use a fallback algorithm: binary tree return all_reduce_sum_tree::(tensors, 2); } // Split tensors into slices let mut sliced_tensors = slice_tensors::(tensors, shape, slice_dim); // phase 1: aggregate in ring N-1 times (Reduce-Scatter) ring_cycles::(&mut sliced_tensors, true); // phase 2: share (overwrite) in a ring N-1 times (All-Gather) ring_cycles::(&mut sliced_tensors, false); // merge slices and put back in result sliced_tensors .into_iter() .map(|(id, slices)| (id, B::float_cat(slices, slice_dim))) .collect() } /// Get the dimension to slice across: the largest dimension of the shape pub(crate) fn get_slice_dim(shape: &Shape) -> usize { // get dimension with the greatest size. shape .iter() .enumerate() .max_by(|(_, a), (_, b)| a.cmp(b)) .map(|(index, _)| index) .unwrap() } /// With a ring of N tensors, send the tensors N-1 times, either for the first of second phase. /// During the first phase, the tensor slices are summed. /// During the second, the slices are replaced. fn ring_cycles( sliced_tensors: &mut [(PeerId, Vec)], is_phase_one: bool, ) { let tensor_count = sliced_tensors.len(); for cycle in 0..(tensor_count - 1) { for i in 0..tensor_count { let src_tensor_idx = i; let dest_tensor_idx = (i + 1) % tensor_count; let slice_idx = if is_phase_one { (i + (tensor_count - 1) * cycle) % tensor_count } else { // in phase 2, the starting slice is different (see diagrams) (i + 1 + (tensor_count - 1) * cycle) % tensor_count }; let src_slice = sliced_tensors[src_tensor_idx].1.remove(slice_idx); let mut dest_slice = sliced_tensors[dest_tensor_idx].1.remove(slice_idx); let dest_device = B::float_device(&dest_slice); let src_slice_on_dest = B::float_to_device(src_slice.clone(), &dest_device); if is_phase_one { dest_slice = B::float_add(dest_slice, src_slice_on_dest); } else { let slices: Vec = dest_slice .shape() .iter() .map(|&d| Slice::new(0, Some(d as isize), 1)) .collect(); // in phase 2, we don't sum the two slices, we replace with the new one. dest_slice = B::float_slice_assign(dest_slice, slices.as_slice(), src_slice_on_dest); } sliced_tensors[src_tensor_idx] .1 .insert(slice_idx, src_slice); sliced_tensors[dest_tensor_idx] .1 .insert(slice_idx, dest_slice); } } } /// Slice a list of tensors the same way, evenly across a given dimension. /// The given `shape` should be the same for every tensor. fn slice_tensors( mut tensors: HashMap, shape: Shape, slice_dim: usize, ) -> Vec<(PeerId, Vec<::FloatTensorPrimitive>)> { // Get slice index ranges let ranges = get_ring_reduce_slice_ranges(shape[slice_dim], tensors.len()); // Slice tensors let mut sliced_tensors = vec![]; for (id, tensor) in tensors.drain() { let mut slices = vec![]; for range in &ranges { let full_range = shape .iter() .enumerate() .map(|(dim_idx, dim)| { if dim_idx == slice_dim { Slice::from(range.clone()) } else { Slice::from(0..*dim) } }) .collect::>(); let slice = B::float_slice(tensor.clone(), &full_range); slices.push(slice); } sliced_tensors.push((id, slices)); } sliced_tensors } /// Get the index ranges for the slices to split a tensor evently across a given axis. /// /// * `slice_dim_size` - The size of the dim to slice on /// * `slice_count` - The number of slices /// /// Returns a vector of index ranges for each slice. pub(crate) fn get_ring_reduce_slice_ranges( slice_dim_size: usize, slice_count: usize, ) -> Vec> { let mut ranges: Vec> = vec![]; let slice_size = slice_dim_size.div_ceil(slice_count); for i in 0..slice_count { let start = i * slice_size; let end = start + slice_size; ranges.push(Range { start, end }); } ranges.last_mut().unwrap().end = slice_dim_size; ranges } ================================================ FILE: crates/burn-collective/src/local/all_reduce/tree.rs ================================================ use crate::PeerId; use crate::local::tensor_map::CollectiveTensorMap; use burn_tensor::backend::{Backend, DeviceOps}; use std::collections::HashMap; /// Performs an all-reduce on the provided tensors in a b-tree structure with `arity`. /// Similar to [reduce_sum_tree](reduce_sum_tree), but this function broadcasts the result with /// the same tree algorithm. /// The returned tensors are on the same devices as the corresponding inputs #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] pub(crate) fn all_reduce_sum_tree( tensors: CollectiveTensorMap, arity: u32, ) -> CollectiveTensorMap { let mut input = tensors.into_iter().collect::>(); // Sort to put devices of the same type together input.sort_by(|a, b| { let dev_a = B::float_device(&a.1); let dev_b = B::float_device(&b.1); dev_a.id().cmp(&dev_b.id()) }); // Recursive all-reduce let out = all_reduce_sum_tree_inner::(input, arity); let mut tensors = HashMap::new(); for (id, tensor) in out { tensors.insert(id, tensor); } tensors } /// Recursive function that sums `tensors` and redistributes the result to the host devices #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] fn all_reduce_sum_tree_inner( mut tensors: Vec<(PeerId, B::FloatTensorPrimitive)>, arity: u32, ) -> Vec<(PeerId, B::FloatTensorPrimitive)> { let mut parent_tensors = vec![]; let mut children_groups = vec![]; // Phase 1: Sum tensors in groups of `arity` + 1 while !tensors.is_empty() { // Maps ids to devices for each child of this parent let mut children = vec![]; let (parent, mut parent_tensor) = tensors.remove(0); let parent_device = B::float_device(&parent_tensor); for _ in 0..arity { if tensors.is_empty() { break; } let (child, mut child_tensor) = tensors.remove(0); let child_device = B::float_device(&child_tensor); children.push((child, child_device)); child_tensor = B::float_to_device(child_tensor, &parent_device); parent_tensor = B::float_add(parent_tensor, child_tensor); } parent_tensors.push((parent, parent_tensor)); children_groups.push(children); } if parent_tensors.len() > 1 { // Parents are not yet at the root, do the upper part of the tree parent_tensors = all_reduce_sum_tree_inner::(parent_tensors, arity); } // Phase 2: Redistribute result from each parent to the respective devices for (parent, parent_tensor) in parent_tensors { let children = children_groups.remove(0); for (child, child_device) in children { // replace child tensors with result tensors.push(( child, B::float_to_device(parent_tensor.clone(), &child_device), )); } tensors.push((parent, parent_tensor)); } tensors } ================================================ FILE: crates/burn-collective/src/local/broadcast/centralized.rs ================================================ use std::collections::HashMap; use crate::PeerId; use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap}; use burn_tensor::backend::Backend; /// Broadcasts the tensor from one device in a map to all the others #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(devices, tensor)) )] pub(crate) fn broadcast_centralized( mut devices: PeerDeviceMap, central: PeerId, tensor: B::FloatTensorPrimitive, ) -> CollectiveTensorMap { let mut output = HashMap::new(); devices .remove(¢ral) .expect("Central device id is in `devices`"); for (dest, dest_device) in devices { let tensor = B::float_to_device(tensor.clone(), &dest_device); output.insert(dest, tensor); } output.insert(central, tensor); output } ================================================ FILE: crates/burn-collective/src/local/broadcast/mod.rs ================================================ mod centralized; mod op; mod tree; pub(crate) use centralized::*; pub(crate) use op::*; pub(crate) use tree::*; ================================================ FILE: crates/burn-collective/src/local/broadcast/op.rs ================================================ use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap}; use crate::{ BroadcastStrategy, CollectiveConfig, CollectiveError, PeerId, local::{broadcast_centralized, broadcast_tree}, node::base::Node, }; use burn_communication::Protocol; #[allow(unused_imports)] // TensorMetadata is used by tracing::instrument. use burn_tensor::TensorMetadata; use burn_tensor::backend::Backend; use std::sync::mpsc::SyncSender; /// An on-going broadcast operation pub struct BroadcastOp { /// broadcast calls, one for each calling device calls: Vec>, /// The tensor to broadcast, as defined by the root. Should be defined before all /// peers call the operation. tensor: Option, /// ID of the root (or use the first call's peer). root: Option, } /// Struct for each device that calls an broadcast operation pub struct BroadcastOpCall { /// Id of the caller of the operation caller: PeerId, /// Device of the calling peer device: B::Device, /// Callback for the result of the broadcast result_sender: SyncSender>, } /// Type sent to the collective client upon completion of a broadcast op pub(crate) type BroadcastResult = Result; impl BroadcastOp { pub fn new() -> Self { Self { calls: vec![], tensor: None, root: None, } } /// Get the effective root of the broadcast operation. /// If the root is set, return it. Otherwise, return the first caller's peer. pub fn effective_root(&self) -> PeerId { self.root.unwrap_or(self.calls.first().unwrap().caller) } pub fn peers(&self) -> Vec { self.calls.iter().map(|c| c.caller).collect() } fn peer_devices(&self) -> PeerDeviceMap { self.calls .iter() .map(|call| (call.caller, call.device.clone())) .collect() } /// Register a call to reduce in this operation. /// When the last caller registers a reduce, the operation is executed. pub fn register_call( &mut self, caller: PeerId, input: Option, result_sender: SyncSender>, device: B::Device, peer_count: usize, ) -> Result { if input.is_some() { if self.tensor.is_some() { return Err(CollectiveError::BroadcastMultipleTensors); } self.tensor = input; } self.calls.push(BroadcastOpCall { caller, device, result_sender, }); Ok(self.calls.len() == peer_count) } /// Runs the broadcast if the operation is ready. Otherwise, do nothing #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(self, config, global_client), fields( self.peers = ?self.peers(), self.shape = ?self.tensor.as_ref().map(|t| t.shape()), self.dtype = ?self.tensor.as_ref().map(|t| t.dtype()), ) ))] pub async fn execute( mut self, config: &CollectiveConfig, global_client: &mut Option>, ) { // all registered callers have sent a tensor to aggregate match self.broadcast(config, global_client).await { Ok(mut tensors) => { // Return resulting tensors self.calls.iter().for_each(|call| { let result = tensors .remove(&call.caller) .expect("tensor/peer internal mismatch."); call.result_sender.send(Ok(result)).unwrap(); }); assert_eq!(tensors.len(), 0, "tensor/peer internal mismatch."); } Err(err) => { // Send error to all subscribers self.fail(err); } } } #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(self, config, global_client)) )] async fn broadcast( &mut self, config: &CollectiveConfig, global_client: &mut Option>, ) -> Result, CollectiveError> { // Do broadcast on global level with the main tensor if let Some(global_client) = &global_client { let strategy = config .global_broadcast_strategy .expect("global_broadcast_strategy not defined"); self.tensor = Some( global_client .broadcast(self.tensor.clone(), strategy) .await .map_err(CollectiveError::Global)?, ) } // At this point tensor must be defined let Some(tensor) = self.tensor.take() else { return Err(CollectiveError::BroadcastNoTensor); }; let root = self.effective_root(); let peer_devices = self.peer_devices(); // Broadcast locally Ok(match config.local_broadcast_strategy { BroadcastStrategy::Tree(arity) => { broadcast_tree::(peer_devices, root, tensor, arity) } BroadcastStrategy::Centralized => { broadcast_centralized::(peer_devices, root, tensor) } }) } /// Send a collective error as result to operation caller pub fn fail(self, err: CollectiveError) { self.calls.iter().for_each(|call| { call.result_sender.send(Err(err.clone())).unwrap(); }); } } ================================================ FILE: crates/burn-collective/src/local/broadcast/tree.rs ================================================ use burn_tensor::backend::{Backend, DeviceOps}; use std::collections::HashMap; use crate::PeerId; use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap}; /// Performs a broadcast on the provided tensors in a b-tree structure with `arity`. /// /// Tensor must be on the device in the `devices` map corresponding to the `root` key. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(devices, tensor)) )] pub(crate) fn broadcast_tree( mut devices: PeerDeviceMap, root: PeerId, tensor: B::FloatTensorPrimitive, arity: u32, ) -> CollectiveTensorMap { // Convert hash map to vector of key-value pairs because order matters let mut devices_vec = vec![]; let root_device = devices.remove(&root).unwrap(); for (id, tensor) in devices.drain() { devices_vec.push((id, tensor)); } // Sort to put devices of the same type together devices_vec.sort_by(|a, b| { let dev_a = &a.1; let dev_b = &b.1; dev_a.id().cmp(&dev_b.id()) }); // put the root first devices_vec.insert(0, (root, root_device)); // Recursive broadcast let out = broadcast_tree_inner::(tensor, devices_vec, arity); // put results in a hash map let mut tensors = HashMap::new(); for (id, tensor) in out { tensors.insert(id, tensor); } tensors } /// Recursive function that broadcasts tensor across the other devices. Tensor should be on the /// first device of the list /// /// Broadcasts the tensor across the devices in the tree in a pre-order traversal. fn broadcast_tree_inner( tensor: B::FloatTensorPrimitive, mut all_devices: Vec<(PeerId, B::Device)>, arity: u32, ) -> Vec<(PeerId, B::FloatTensorPrimitive)> { let mut parents = vec![]; let mut children_groups = vec![]; // Put devices in groups of `arity` + the parent while !all_devices.is_empty() { let mut children = vec![]; let parent = all_devices.remove(0); for _ in 0..arity { if all_devices.is_empty() { break; } children.push(all_devices.remove(0)); } parents.push(parent); children_groups.push(children); } let mut parents = if parents.len() > 1 { broadcast_tree_inner::(tensor, parents, arity) } else { let root = parents.first().unwrap(); // `tensor` should already be on the root's device, no need to call B::float_to_device vec![(root.0, tensor)] }; // Redistribute result from each parent to the respective devices let mut tensors = vec![]; for children in children_groups { let parent = parents.remove(0); for (child_id, child_device) in children { // replace child's tensor with parent's let child_tensor = B::float_to_device(parent.1.clone(), &child_device); tensors.push((child_id, child_tensor)); } tensors.push(parent); } tensors } ================================================ FILE: crates/burn-collective/src/local/client.rs ================================================ use crate::local::all_reduce::AllReduceResult; use crate::{ CollectiveConfig, CollectiveError, PeerId, ReduceOperation, local::{ BroadcastResult, ReduceResult, server::{FinishResult, Message, RegisterResult}, }, }; use burn_tensor::backend::Backend; use std::sync::mpsc::{Receiver, SyncSender}; /// Local client to communicate with the local server. Each thread has a client. #[derive(Clone)] pub(crate) struct LocalCollectiveClient { pub channel: SyncSender>, } /// A pending operation that can be waited on. pub(crate) struct PendingCollectiveOperation { rx: Receiver>, } impl From> for Receiver> { fn from(value: PendingCollectiveOperation) -> Self { value.rx } } impl PendingCollectiveOperation { /// Wait on the operation. /// /// Given a `Receiver>`, this function will wait: /// - Unwraps `Ok(Result)` into `Result`; /// - maps `Err(RecvError)` to `Err(CollectiveError::LocalServerMissing)`. pub(crate) fn wait(self) -> Result { let tensor = self .rx .recv() .unwrap_or(Err(CollectiveError::LocalServerMissing))?; Ok(tensor) } } impl LocalCollectiveClient { /// Common logic for starting a collective operation. /// /// - Allocates `(callback, recv)` channels, /// - Passes the `callback` to the `Message` builder, /// - Sends the message through the collective channel, /// - Returns the `recv`. pub(crate) fn start_operation(&self, builder: F) -> PendingCollectiveOperation where F: FnOnce(SyncSender>) -> Message, { let (tx, rx) = std::sync::mpsc::sync_channel(1); self.channel.send((builder)(tx)).unwrap(); PendingCollectiveOperation { rx } } /// Common logic for starting a collective operation, with validation. /// /// When `valid` is `Err`, this function returns a `Receiver>` that /// immediately returns `Err(valid)`; /// otherwise, it behaves like [`LocalCollectiveClient::start_operation`]. pub(crate) fn start_valid_operation( &self, valid: Result<(), CollectiveError>, builder: F, ) -> PendingCollectiveOperation where F: FnOnce(SyncSender>) -> Message, { match valid { Err(e) => { let (tx, rx) = std::sync::mpsc::sync_channel(1); tx.send(Err(e)).unwrap(); PendingCollectiveOperation { rx } } _ => self.start_operation(builder), } } pub(crate) fn reset(&self) { self.channel.send(Message::Reset).unwrap(); } pub(crate) fn register( &mut self, id: PeerId, device: B::Device, config: CollectiveConfig, ) -> RegisterResult { self.register_start(id, device, config).wait() } pub(crate) fn register_start( &mut self, id: PeerId, device: B::Device, config: CollectiveConfig, ) -> PendingCollectiveOperation<()> { self.start_valid_operation( match config.is_valid() { true => Ok(()), false => Err(CollectiveError::InvalidConfig), }, |callback| Message::Register { device_id: id, device, config, callback, }, ) } /// Calls for an all-reduce operation with the given parameters and returns the result. /// The `params` must be the same as the parameters passed by the other nodes. /// /// # Arguments /// * `id` - The peer id of the caller /// * `tensor` - The input tensor to reduce with the peers' tensors /// * `config` - Config of the collective operation. Must be coherent with the other calls. /// /// # Result /// - `Ok(tensor)` if the operation was successful /// - `Err(CollectiveError)` on error. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(self, tensor)) )] pub fn all_reduce( &self, id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, ) -> AllReduceResult { self.all_reduce_start(id, tensor, op).wait() } /// Starts an all-reduce operation with the given parameters. /// /// The `params` must be the same as the parameters passed by the other nodes. /// /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`]. /// /// # Arguments /// * `id` - The peer id of the caller /// * `tensor` - The input tensor to reduce with the peers' tensors /// * `config` - Config of the collective operation. Must be coherent with the other calls. /// /// # Result /// /// A `Receiver<>` that will yield: /// - `Ok(AllReduceResult)` if the operation was successful /// - `Err(SendError)` if the channel was dropped. pub(crate) fn all_reduce_start( &self, id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, ) -> PendingCollectiveOperation { self.start_operation(|callback| Message::AllReduce { device_id: id, tensor, op, callback, }) } /// Reduces a tensor onto one device. /// /// # Arguments /// - `id` - The peer id of the caller. /// - `tensor` - The tensor to send as input. /// - `op` - The reduce operation to apply. /// - `root` - The ID of the peer that will receive the result. /// /// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor. pub fn reduce( &self, id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, root: PeerId, ) -> ReduceResult { self.reduce_start(id, tensor, op, root).wait() } /// Starts a reduce operation on a tensor onto one device. /// /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`]. /// /// # Arguments /// - `id` - The peer id of the caller. /// - `tensor` - The tensor to send as input. /// - `op` - The reduce operation to apply. /// - `root` - The ID of the peer that will receive the result. /// /// # Result /// /// A `Receiver<>` that will yield: /// - `Ok(ReduceResult)` if the operation was successful /// - `Err(SendError)` if the channel was dropped. pub(crate) fn reduce_start( &self, id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, root: PeerId, ) -> PendingCollectiveOperation> { self.start_operation(|callback| Message::Reduce { device_id: id, tensor, op, root, callback, }) } /// Broadcasts, or receives a broadcasted tensor. /// /// # Arguments /// - `id` - The peer id of the caller /// - `tensor` - If defined, this tensor will be broadcasted. /// Otherwise, this call will receive the broadcasted tensor. /// /// # Result /// Synchronously waits on the broadcasted tensor. pub fn broadcast( &self, id: PeerId, tensor: Option, ) -> BroadcastResult { self.broadcast_start(id, tensor).wait() } /// Starts a Broadcast, or receives a broadcasted tensor. /// /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`]. /// /// # Arguments /// - `id` - The peer id of the caller /// - `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive /// the broadcasted tensor. /// /// # Result /// /// A `Receiver<>` that will yield: /// - `Ok(BroadcastResult)` if the operation was successful /// - `Err(SendError)` if the channel was dropped. pub(crate) fn broadcast_start( &self, id: PeerId, tensor: Option, ) -> PendingCollectiveOperation { self.start_operation(|callback| Message::Broadcast { device_id: id, tensor, callback, }) } pub(crate) fn finish(&self, id: PeerId) -> FinishResult { self.finish_start(id).wait() } pub(crate) fn finish_start(&self, id: PeerId) -> PendingCollectiveOperation<()> { self.start_operation(|callback| Message::Finish { id, callback }) } } ================================================ FILE: crates/burn-collective/src/local/mod.rs ================================================ mod all_reduce; mod broadcast; mod reduce; pub(crate) mod tensor_map; pub(crate) use all_reduce::*; pub(crate) use broadcast::*; pub(crate) use reduce::*; pub(crate) mod client; pub(crate) mod server; ================================================ FILE: crates/burn-collective/src/local/reduce/centralized.rs ================================================ use burn_tensor::backend::Backend; use crate::PeerId; use crate::local::tensor_map::CollectiveTensorMap; #[cfg(feature = "tracing")] use crate::local::tensor_map::get_common_shape; /// Sums the tensors on one device and returns the result #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensors), fields(shape = ?get_common_shape::(&tensors).unwrap()) ))] pub(crate) fn reduce_sum_centralized( mut tensors: CollectiveTensorMap, central: &PeerId, ) -> B::FloatTensorPrimitive { let mut central_tensor = tensors .remove(central) .expect("Source device id is in the map"); let central_device = B::float_device(¢ral_tensor); for (_, tensor) in tensors { let rhs = B::float_to_device(tensor.clone(), ¢ral_device); central_tensor = B::float_add(central_tensor, rhs); } central_tensor } ================================================ FILE: crates/burn-collective/src/local/reduce/mod.rs ================================================ mod centralized; mod op; mod tree; pub(crate) use centralized::*; pub(crate) use op::*; pub(crate) use tree::*; ================================================ FILE: crates/burn-collective/src/local/reduce/op.rs ================================================ use burn_communication::Protocol; use burn_tensor::{Shape, TensorMetadata, backend::Backend}; use std::sync::mpsc::SyncSender; use crate::{ CollectiveConfig, CollectiveError, PeerId, ReduceOperation, ReduceStrategy, local::{reduce_sum_centralized, reduce_sum_tree}, node::base::Node, }; /// An on-going reduce operation pub struct ReduceOp { /// reduce calls, one for each calling device calls: Vec>, /// The reduce operation, as defined by the first caller op: ReduceOperation, /// The peer that receives the reduce result, as defined by the first caller root: PeerId, /// The shape of the tensor to reduce, as defined by the first caller shape: Shape, } /// Struct for each device that calls an reduce operation pub struct ReduceOpCall { /// Id of the caller of the operation caller: PeerId, /// The tensor primitive passed as input input: B::FloatTensorPrimitive, /// Callback for the result of the reduce result_sender: SyncSender>, } /// Type sent to the collective client upon completion of a reduce aggregation pub(crate) type ReduceResult = Result, CollectiveError>; impl ReduceOp { pub fn new(shape: Shape, reduce_op: ReduceOperation, root: PeerId) -> Self { Self { calls: vec![], op: reduce_op, root, shape, } } fn peers(&self) -> Vec { self.calls.iter().map(|c| c.caller).collect() } /// Register a call to reduce in this operation. /// When the last caller registers a reduce, the operation is executed. pub fn register_call( &mut self, caller: PeerId, input: B::FloatTensorPrimitive, result_sender: SyncSender>, op: ReduceOperation, root: PeerId, peer_count: usize, ) -> Result { if self.shape != input.shape() { return Err(CollectiveError::ReduceShapeMismatch); } if self.op != op { return Err(CollectiveError::ReduceOperationMismatch); } if self.root != root { return Err(CollectiveError::ReduceRootMismatch); } self.calls.push(ReduceOpCall { caller, input, result_sender, }); Ok(self.calls.len() == peer_count) } /// Runs the all-reduce if the operation is ready. Otherwise, do nothing #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(self, config, global_client), fields( ?self.op, ?self.shape, self.peers = ?self.peers(), ) ))] pub async fn execute( mut self, root: PeerId, config: &CollectiveConfig, global_client: &mut Option>, ) { match self.reduce(config, global_client).await { Ok(mut result) => { // Return resulting tensor to root, None to others self.calls.iter().for_each(|op| { let msg = if op.caller == root { Ok(result.take()) } else { Ok(None) }; op.result_sender.send(msg).unwrap(); }); } Err(err) => { self.fail(err); } } } #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(self, config, global_client)) )] async fn reduce( &mut self, config: &CollectiveConfig, global_client: &mut Option>, ) -> Result, CollectiveError> { let tensors = self .calls .iter() .map(|call| (call.caller, call.input.clone())) .collect(); // For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later let mut local_sum = match config.local_reduce_strategy { ReduceStrategy::Centralized => reduce_sum_centralized::(tensors, &self.root), ReduceStrategy::Tree(arity) => reduce_sum_tree::(tensors, &self.root, arity), }; // Do aggregation on a global level with the main tensor let result = if let Some(global_client) = global_client { let strategy = config .global_reduce_strategy .expect("global_reduce_strategy not defined"); global_client .reduce(local_sum, strategy, self.root, self.op) .await .map_err(CollectiveError::Global)? } else { // Mean division locally if self.op == ReduceOperation::Mean { let local_tensor_count = self.calls.len() as f32; local_sum = B::float_div_scalar(local_sum, local_tensor_count.into()) } Some(local_sum) }; Ok(result) } /// Send a collective error as result to operation caller pub fn fail(self, err: CollectiveError) { self.calls.iter().for_each(|op| { op.result_sender.send(Err(err.clone())).unwrap(); }); } } ================================================ FILE: crates/burn-collective/src/local/reduce/tree.rs ================================================ use crate::PeerId; use crate::local::tensor_map::CollectiveTensorMap; use burn_tensor::backend::{Backend, DeviceOps}; /// Performs a reduce on the provided tensors in a b-tree structure with `arity`. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] pub(crate) fn reduce_sum_tree( mut tensors: CollectiveTensorMap, root: &PeerId, arity: u32, ) -> B::FloatTensorPrimitive { // Convert hash map to vector of key-value pairs because order matters let mut input = vec![]; let root_tensor = tensors.remove(root).unwrap(); for (_, tensor) in tensors.drain() { input.push(tensor); } // Sort to put devices of the same type together input.sort_by(|a, b| { let dev_a = B::float_device(a); let dev_b = B::float_device(b); dev_a.id().cmp(&dev_b.id()) }); // put the root first input.insert(0, root_tensor); reduce_sum_tree_inner::(input, arity) } /// Recursive function that sums `tensors` /// /// Traverses `tensors` and reduces in a post-order traversal. The first tensor in the list is /// chosen as the root #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensors)) )] fn reduce_sum_tree_inner( mut tensors: Vec, arity: u32, ) -> B::FloatTensorPrimitive { let mut parents = vec![]; let mut children_groups = vec![]; // Sum tensors in groups of `arity` + 1 while !tensors.is_empty() { let mut children = vec![]; let mut parent_tensor = tensors.remove(0); let parent_device = B::float_device(&parent_tensor); for _ in 0..arity { if tensors.is_empty() { break; } let child_tensor = tensors.remove(0); children.push(B::float_device(&child_tensor)); let rhs = B::float_to_device(child_tensor, &parent_device); parent_tensor = B::float_add(parent_tensor, rhs); } parents.push(parent_tensor); children_groups.push(children); } if parents.len() > 1 { // Parents are not yet at the root, do the upper part of the tree reduce_sum_tree_inner::(parents, arity) } else { // Root of tree parents.remove(0) } } ================================================ FILE: crates/burn-collective/src/local/server.rs ================================================ use crate::{ CollectiveConfig, CollectiveError, PeerId, ReduceOperation, global::node::base::Node, local::{ AllReduceOp, AllReduceResult, BroadcastOp, BroadcastResult, ReduceOp, ReduceResult, client::LocalCollectiveClient, }, }; use burn_communication::websocket::{WebSocket, WsServer}; use burn_tensor::{TensorMetadata, backend::Backend}; use std::sync::{MutexGuard, OnceLock}; use std::{ any::{Any, TypeId}, collections::HashMap, fmt::Debug, sync::{ Arc, Mutex, mpsc::{Receiver, SyncSender}, }, }; use tokio::runtime::{Builder, Runtime}; /// Define the client/server communication on the network type Network = WebSocket; /// Type sent to the collective client upon completion of a register request pub(crate) type RegisterResult = Result<(), CollectiveError>; /// Type sent to the collective client upon completion of a finish request pub(crate) type FinishResult = Result<(), CollectiveError>; /// The local collective server that manages all the collective aggregation operations /// (like all-reduce) between local threads. /// This thread takes in messages from different clients. The clients must register, than they can /// send an aggregate message. They must all use the same parameters for the same aggregate /// operation. pub(crate) struct LocalCollectiveServer { /// Channel receiver for messages from clients message_rec: Receiver>, /// The collective configuration. Must be the same by every peer when calling register config: Option, /// The ids passed to each register so far peers: Vec, /// Callbacks for when all registers are done callbacks_register: Vec>, /// Map of each peer's id and its device devices: HashMap, /// Current uncompleted all-reduce operation all_reduce_op: Option>, /// Current uncompleted reduce call reduce_op: Option>, /// Uncompleted broadcast calls, one for each calling device. broadcast_op: Option>, /// Client for global collective operations global_client: Option>, } #[derive(Debug)] pub(crate) enum Message { /// Register a new peer with the collective. Register { device_id: PeerId, device: B::Device, config: CollectiveConfig, callback: SyncSender, }, /// Perform an all-reduce operation. AllReduce { device_id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, callback: SyncSender>, }, /// Perform a reduce operation. Reduce { device_id: PeerId, tensor: B::FloatTensorPrimitive, op: ReduceOperation, root: PeerId, callback: SyncSender>, }, /// Perform a broadcast operation (one-sender, many-receiver). Broadcast { device_id: PeerId, tensor: Option, callback: SyncSender>, }, /// Reset the collective server. Reset, Finish { id: PeerId, callback: SyncSender, }, } /// The type-erased box type for [`LocalCollectiveClient`]. type LocalClientBox = Box; /// Global state map from [`Backend`] to boxed [`LocalCollectiveClient`]. static BACKEND_CLIENT_MAP: OnceLock>> = OnceLock::new(); /// Gets a locked mutable view of the `STATE_MAP`. pub(crate) fn get_backend_client_map() -> MutexGuard<'static, HashMap> { BACKEND_CLIENT_MAP .get_or_init(Default::default) .lock() .unwrap() } /// Get a [`LocalCollectiveClient`] for the given [`Backend`]. /// /// Will start the local collective client/server pair if necessary. pub(crate) fn get_collective_client() -> LocalCollectiveClient { let typeid = TypeId::of::(); let mut state_map = get_backend_client_map(); match state_map.get(&typeid) { Some(val) => val.downcast_ref().cloned().unwrap(), None => { let client = LocalCollectiveServer::::setup(LocalCollectiveClientConfig::default()); state_map.insert(typeid, Box::new(client.clone())); client } } } /// Global runtime. static SERVER_RUNTIME: OnceLock> = OnceLock::new(); /// Get the global [`Runtime`]. pub(crate) fn get_collective_server_runtime() -> Arc { SERVER_RUNTIME .get_or_init(|| { Builder::new_multi_thread() .enable_all() .build() .expect("Unable to initialize runtime") .into() }) .clone() } /// Configuration for the local collective client/server pair. pub struct LocalCollectiveClientConfig { /// Channel capacity for the messaging queue from client to server. pub channel_capacity: usize, } impl Default for LocalCollectiveClientConfig { fn default() -> Self { Self { channel_capacity: 50, } } } impl From for LocalCollectiveClientConfig { fn from(capacity: usize) -> Self { Self { channel_capacity: capacity, } } } impl LocalCollectiveServer { fn new(rec: Receiver>) -> Self { Self { message_rec: rec, config: None, peers: vec![], devices: HashMap::new(), all_reduce_op: None, reduce_op: None, broadcast_op: None, callbacks_register: vec![], global_client: None, } } /// Setup a client/server pair with the given config. pub(crate) fn setup(cfg: C) -> LocalCollectiveClient where C: Into, { let cfg = cfg.into(); let (tx, rx) = std::sync::mpsc::sync_channel(cfg.channel_capacity); get_collective_server_runtime().spawn(async { let typeid = TypeId::of::(); log::info!("Starting server for backend: {typeid:?}"); let mut server = LocalCollectiveServer::new(rx); loop { match server.message_rec.recv() { Ok(message) => server.process_message(message).await, Err(err) => { log::error!( "Error receiving message from local collective server: {err:?}" ); break; } } } }); LocalCollectiveClient { channel: tx } } async fn process_message(&mut self, message: Message) { match message { Message::Register { device_id, device, config, callback, } => { self.process_register_message(device_id, device, config, &callback) .await } Message::AllReduce { device_id, tensor, op, callback, } => { self.process_all_reduce_message(device_id, tensor, op, callback) .await } Message::Reduce { device_id, tensor, op, root, callback, } => { self.process_reduce_message(device_id, tensor, op, root, callback) .await } Message::Broadcast { device_id, tensor, callback, } => { self.process_broadcast_message(device_id, tensor, callback) .await } Message::Reset => self.reset(), Message::Finish { id, callback } => self.process_finish_message(id, callback).await, } } async fn process_register_message( &mut self, device_id: PeerId, device: B::Device, config: CollectiveConfig, callback: &SyncSender, ) { if !config.is_valid() { callback.send(Err(CollectiveError::InvalidConfig)).unwrap(); return; } if self.peers.contains(&device_id) { callback .send(Err(CollectiveError::MultipleRegister)) .unwrap(); return; } if self.peers.is_empty() || self.config.is_none() { self.config = Some(config); } else if let Some(cfg) = &self.config && *cfg != config { callback .send(Err(CollectiveError::RegisterParamsMismatch)) .unwrap(); return; } self.peers.push(device_id); self.callbacks_register.push(callback.clone()); self.devices.insert(device_id, device); let config = self.config.as_ref().unwrap(); let global_params = config.global_register_params(); if let Some(global_params) = &global_params && self.global_client.is_none() { let server = WsServer::new(global_params.data_service_port); let client = Node::new(&global_params.global_address, server); self.global_client = Some(client) } // All have registered, callback if self.peers.len() == config.num_devices { let mut register_result = Ok(()); // if an error occurs on the global register, it must be passed back to every local peer if let Some(global_params) = global_params { let client = self .global_client .as_mut() .expect("Global client should be initialized"); register_result = client .register(self.peers.clone(), global_params) .await .map_err(CollectiveError::Global); }; // Send results to all callbacks. self.callbacks_register .drain(..) .for_each(|tx| tx.send(register_result.clone()).unwrap()); } } /// Processes an Message::AllReduce. async fn process_all_reduce_message( &mut self, peer_id: PeerId, tensor: ::FloatTensorPrimitive, op: ReduceOperation, callback: SyncSender>, ) { if !self.peers.contains(&peer_id) { callback .send(Err(CollectiveError::RegisterNotFirstOperation)) .unwrap(); return; } if self.all_reduce_op.is_none() { // First call to all-reduce self.all_reduce_op = Some(AllReduceOp::new(tensor.shape(), op)); } // Take the operation, we'll put it back if we're not done let mut all_reduce_op = self.all_reduce_op.take().unwrap(); // On the last caller, the all-reduce is done here let res = all_reduce_op.register_call(peer_id, tensor, callback.clone(), op, self.peers.len()); // Upon an error or the last call, the all_reduce_op is dropped match res { Ok(is_ready) => { if is_ready { all_reduce_op .execute(self.config.as_ref().unwrap(), &mut self.global_client) .await; } else { // Put operation back, we're waiting for more calls self.all_reduce_op = Some(all_reduce_op) } } Err(err) => all_reduce_op.fail(err), } } /// Processes a Message::Reduce. async fn process_reduce_message( &mut self, peer_id: PeerId, tensor: ::FloatTensorPrimitive, op: ReduceOperation, root: PeerId, callback: SyncSender>, ) { if !self.peers.contains(&root) { callback .send(Err(CollectiveError::RegisterNotFirstOperation)) .unwrap(); return; } if self.reduce_op.is_none() { // First call to reduce self.reduce_op = Some(ReduceOp::new(tensor.shape(), op, root)); } let mut reduce_op = self.reduce_op.take().unwrap(); // On the last caller, the all-reduce is done here let res = reduce_op.register_call( peer_id, tensor, callback.clone(), op, root, self.peers.len(), ); // Upon an error or the last call, the all_reduce_op is dropped match res { Ok(is_ready) => { if is_ready { reduce_op .execute(root, self.config.as_ref().unwrap(), &mut self.global_client) .await; } else { // Put operation back, we're waiting for more calls self.reduce_op = Some(reduce_op) } } Err(err) => reduce_op.fail(err), } } /// Processes a Message::Broadcast. async fn process_broadcast_message( &mut self, caller: PeerId, tensor: Option<::FloatTensorPrimitive>, callback: SyncSender>, ) { if self.config.is_none() { callback .send(Err(CollectiveError::RegisterNotFirstOperation)) .unwrap(); return; } if !self.peers.contains(&caller) { callback .send(Err(CollectiveError::RegisterNotFirstOperation)) .unwrap(); return; } if self.broadcast_op.is_none() { // First call to broadcast self.broadcast_op = Some(BroadcastOp::new()); } let device = self.devices.get(&caller).unwrap().clone(); let mut broadcast_op = self.broadcast_op.take().unwrap(); // On the last caller, the all-reduce is done here let res = broadcast_op.register_call(caller, tensor, callback.clone(), device, self.peers.len()); // Upon an error or the last call, the all_reduce_op is dropped match res { Ok(is_ready) => { if is_ready { broadcast_op .execute(self.config.as_ref().unwrap(), &mut self.global_client) .await; } else { // Put operation back, we're waiting for more calls self.broadcast_op = Some(broadcast_op) } } Err(err) => broadcast_op.fail(err), } } /// Reinitializes the collective server fn reset(&mut self) { self.peers.clear(); self.all_reduce_op = None; self.reduce_op = None; self.broadcast_op = None; } /// Processes a Message::Finish. async fn process_finish_message(&mut self, id: PeerId, callback: SyncSender) { if self.config.is_none() { callback .send(Err(CollectiveError::RegisterNotFirstOperation)) .unwrap(); return; } if !self.peers.contains(&id) { callback .send(Err(CollectiveError::MultipleUnregister)) .unwrap(); return; } // Remove registered with id self.peers.retain(|x| *x != id); if self.peers.is_empty() && let Some(mut global_client) = self.global_client.take() { global_client.finish().await; } callback.send(Ok(())).unwrap(); } } ================================================ FILE: crates/burn-collective/src/local/tensor_map.rs ================================================ //! # Common Tensor Map for Local Collective Operations use crate::PeerId; use burn_std::Shape; use burn_tensor::TensorMetadata; use burn_tensor::backend::Backend; use std::collections::HashMap; pub type CollectiveTensorMap = HashMap::FloatTensorPrimitive>; pub type PeerDeviceMap = HashMap::Device>; /// Get the shape of the tensors. They should all have the same shape, otherwise None is returned. pub fn get_common_shape(tensors: &CollectiveTensorMap) -> Option { let mut it = tensors.values(); if let Some(first) = it.next() { let shape = first.shape(); for tensor in it { if tensor.shape() != shape { return None; } } return Some(shape); } None } /// Get the `{ peer_id -> device }` mapping for the given tensors. pub fn get_peer_devices(tensors: &CollectiveTensorMap) -> PeerDeviceMap { tensors .iter() .map(|(id, tensor)| (*id, B::float_device(tensor))) .collect() } ================================================ FILE: crates/burn-collective/src/tests/all_reduce.rs ================================================ mod tests { use std::sync::mpsc::SyncSender; use burn_std::rand::get_seeded_rng; use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend}; use serial_test::serial; #[cfg(feature = "test-ndarray")] pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "test-cuda")] pub type TestBackend = burn_cuda::Cuda; #[cfg(feature = "test-wgpu")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-metal")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-vulkan")] pub type TestBackend = burn_wgpu::Wgpu; use crate::{ AllReduceStrategy, CollectiveConfig, PeerId, ReduceOperation, all_reduce, register, reset_collective, }; pub fn run_peer( id: PeerId, config: CollectiveConfig, input: TensorData, op: ReduceOperation, output: SyncSender>, ) { let device = B::Device::default(); register::(id, device.clone(), config).unwrap(); let tensor = Tensor::::from_data(input, &device); let tensor = Tensor::from_primitive(TensorPrimitive::Float( all_reduce::(id, tensor.into_primitive().tensor(), op).unwrap(), )); output.send(tensor).unwrap(); } fn generate_random_input( shape: Shape, op: ReduceOperation, thread_count: usize, ) -> (Vec, TensorData) { let input: Vec = (0..thread_count) .map(|_| { TensorData::random::( shape.clone(), burn_tensor::Distribution::Default, &mut get_seeded_rng(), ) }) .collect(); let device = ::Device::default(); let mut expected_tensor = Tensor::::zeros(shape, &device); for item in input.iter().take(thread_count as usize) { let input_tensor = Tensor::::from_data(item.clone(), &device); expected_tensor = expected_tensor.add(input_tensor); } if op == ReduceOperation::Mean { expected_tensor = expected_tensor.div_scalar(thread_count as u32); } let expected = expected_tensor.to_data(); (input, expected) } fn test_all_reduce( device_count: usize, op: ReduceOperation, strategy: AllReduceStrategy, tensor_size: usize, ) { reset_collective::(); let (send, recv) = std::sync::mpsc::sync_channel(32); let shape = Shape { dims: vec![tensor_size], }; let (input, expected) = generate_random_input(shape, op, device_count); let config = CollectiveConfig::default() .with_num_devices(device_count) .with_local_all_reduce_strategy(strategy); for id in 0..device_count { let send = send.clone(); let input = input[id as usize].clone(); std::thread::spawn({ let config = config.clone(); move || run_peer::(id.into(), config, input, op, send) }); } let first = recv.recv().unwrap().to_data(); for _ in 1..device_count { let tensor = recv.recv().unwrap(); tensor.to_data().assert_eq(&first, true); } let tol: Tolerance = Tolerance::balanced(); expected.assert_approx_eq(&first, tol); } #[test] #[serial] pub fn test_all_reduce_centralized_sum() { test_all_reduce::(4, ReduceOperation::Sum, AllReduceStrategy::Centralized, 4); } #[test] #[serial] pub fn test_all_reduce_centralized_mean() { test_all_reduce::(4, ReduceOperation::Mean, AllReduceStrategy::Centralized, 4); } #[test] #[serial] pub fn test_all_reduce_binary_tree_sum() { test_all_reduce::(4, ReduceOperation::Sum, AllReduceStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_all_reduce_binary_tree_mean() { test_all_reduce::(4, ReduceOperation::Mean, AllReduceStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_all_reduce_5_tree_sum() { test_all_reduce::(4, ReduceOperation::Sum, AllReduceStrategy::Tree(5), 4); } #[test] #[serial] pub fn test_all_reduce_5_tree_mean() { test_all_reduce::(4, ReduceOperation::Mean, AllReduceStrategy::Tree(5), 4); } #[test] #[serial] pub fn test_all_reduce_ring_sum() { test_all_reduce::(3, ReduceOperation::Sum, AllReduceStrategy::Ring, 3); } #[test] #[serial] pub fn test_all_reduce_ring_mean() { test_all_reduce::(3, ReduceOperation::Mean, AllReduceStrategy::Ring, 3); } #[test] #[serial] pub fn test_all_reduce_ring_irregular_sum() { // this should trigger the fallback algorithm when the tensor is too small. test_all_reduce::(4, ReduceOperation::Sum, AllReduceStrategy::Ring, 3); } } ================================================ FILE: crates/burn-collective/src/tests/broadcast.rs ================================================ mod tests { use std::sync::mpsc::SyncSender; use burn_std::rand::get_seeded_rng; use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend}; use serial_test::serial; #[cfg(feature = "test-ndarray")] pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "test-cuda")] pub type TestBackend = burn_cuda::Cuda; #[cfg(feature = "test-wgpu")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-metal")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-vulkan")] pub type TestBackend = burn_wgpu::Wgpu; use crate::{ BroadcastStrategy, CollectiveConfig, PeerId, broadcast, register, reset_collective, }; pub fn run_peer( id: PeerId, config: CollectiveConfig, input: Option, output: SyncSender>, ) { let device = B::Device::default(); register::(id, device.clone(), config).unwrap(); let tensor = input.map(|data| B::float_from_data(data, &device)); let tensor = broadcast::(id, tensor).unwrap(); let tensor = Tensor::::from_primitive(TensorPrimitive::Float(tensor)); output.send(tensor).unwrap(); } fn generate_random_input(shape: Shape) -> TensorData { TensorData::random::( shape.clone(), burn_tensor::Distribution::Default, &mut get_seeded_rng(), ) } fn test_broadcast( device_count: usize, strategy: BroadcastStrategy, tensor_size: usize, ) { reset_collective::(); let (send, recv) = std::sync::mpsc::sync_channel(32); let shape = Shape { dims: vec![tensor_size], }; let input = generate_random_input(shape); let config = CollectiveConfig::default() .with_num_devices(device_count) .with_local_broadcast_strategy(strategy); for id in 0..device_count { // The peer #0 is the root: it sends the tensor let input = if id == 0 { Some(input.clone()) } else { None }; std::thread::spawn({ let config = config.clone(); let send = send.clone(); move || run_peer::(id.into(), config, input, send) }); } // Expect all peers to receive the input tensor let tol: Tolerance = Tolerance::balanced(); for _ in 0..device_count { let tensor = recv.recv().unwrap().to_data(); input.assert_approx_eq(&tensor, tol); } } #[test] #[serial] pub fn test_broadcast_centralized_sum() { test_broadcast::(4, BroadcastStrategy::Centralized, 4); } #[test] #[serial] pub fn test_broadcast_centralized_mean() { test_broadcast::(4, BroadcastStrategy::Centralized, 4); } #[test] #[serial] pub fn test_broadcast_binary_tree_sum() { test_broadcast::(4, BroadcastStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_broadcast_binary_tree_mean() { test_broadcast::(4, BroadcastStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_broadcast_5_tree_sum() { test_broadcast::(4, BroadcastStrategy::Tree(5), 4); } #[test] #[serial] pub fn test_broadcast_5_tree_mean() { test_broadcast::(4, BroadcastStrategy::Tree(5), 4); } } ================================================ FILE: crates/burn-collective/src/tests/mod.rs ================================================ mod all_reduce; mod broadcast; mod reduce; ================================================ FILE: crates/burn-collective/src/tests/reduce.rs ================================================ mod tests { use std::sync::mpsc::SyncSender; use burn_std::rand::get_seeded_rng; use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend}; use serial_test::serial; #[cfg(feature = "test-ndarray")] pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "test-cuda")] pub type TestBackend = burn_cuda::Cuda; #[cfg(feature = "test-wgpu")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-metal")] pub type TestBackend = burn_wgpu::Wgpu; #[cfg(feature = "test-vulkan")] pub type TestBackend = burn_wgpu::Wgpu; use crate::{ CollectiveConfig, PeerId, ReduceOperation, ReduceStrategy, reduce, register, reset_collective, }; pub fn run_peer( id: PeerId, config: CollectiveConfig, input: TensorData, op: ReduceOperation, root: PeerId, output: SyncSender>>, ) { let device = B::Device::default(); register::(id, device.clone(), config).unwrap(); let tensor = Tensor::::from_data(input, &device); let tensor = tensor.into_primitive().tensor(); let tensor = reduce::(id, tensor, op, root).unwrap(); let tensor = tensor.map(|t| Tensor::::from_primitive(TensorPrimitive::Float(t))); output.send(tensor).unwrap(); } fn generate_random_input( shape: Shape, op: ReduceOperation, thread_count: usize, ) -> (Vec, TensorData) { let input: Vec = (0..thread_count) .map(|_| { TensorData::random::( shape.clone(), burn_tensor::Distribution::Default, &mut get_seeded_rng(), ) }) .collect(); let device = ::Device::default(); let mut expected_tensor = Tensor::::zeros(shape, &device); for item in input.iter().take(thread_count) { let input_tensor = Tensor::::from_data(item.clone(), &device); expected_tensor = expected_tensor.add(input_tensor); } if op == ReduceOperation::Mean { expected_tensor = expected_tensor.div_scalar(thread_count as u32); } let expected = expected_tensor.to_data(); (input, expected) } fn test_reduce( device_count: usize, op: ReduceOperation, strategy: ReduceStrategy, tensor_size: usize, ) { reset_collective::(); let (send, recv) = std::sync::mpsc::sync_channel(32); let shape = Shape { dims: vec![tensor_size], }; let (input, expected) = generate_random_input(shape, op, device_count); let config = CollectiveConfig::default() .with_num_devices(device_count) .with_local_reduce_strategy(strategy); let root: PeerId = 0.into(); for id in 0..device_count { let send = send.clone(); let input = input[id as usize].clone(); std::thread::spawn({ let config = config.clone(); move || run_peer::(id.into(), config, input, op, root, send) }); } let mut result = None; for _ in 0..device_count { let tensor = recv.recv().unwrap(); if tensor.is_some() { if result.is_some() { panic!("Two peers received the result of an reduce!"); } result = tensor.map(|t| t.to_data()); } } let tol: Tolerance = Tolerance::balanced(); expected.assert_approx_eq(&result.expect("One peer has received the result"), tol); } #[test] #[serial] pub fn test_reduce_centralized_sum() { test_reduce::(4, ReduceOperation::Sum, ReduceStrategy::Centralized, 4); } #[test] #[serial] pub fn test_reduce_centralized_mean() { test_reduce::(4, ReduceOperation::Mean, ReduceStrategy::Centralized, 4); } #[test] #[serial] pub fn test_reduce_binary_tree_sum() { test_reduce::(4, ReduceOperation::Sum, ReduceStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_reduce_binary_tree_mean() { test_reduce::(4, ReduceOperation::Mean, ReduceStrategy::Tree(2), 4); } #[test] #[serial] pub fn test_reduce_5_tree_sum() { test_reduce::(4, ReduceOperation::Sum, ReduceStrategy::Tree(5), 4); } #[test] #[serial] pub fn test_reduce_5_tree_mean() { test_reduce::(4, ReduceOperation::Mean, ReduceStrategy::Tree(5), 4); } } ================================================ FILE: crates/burn-communication/Cargo.toml ================================================ [package] authors = ["Guilhem Ané (@Cielbird)", "Nathaniel Simard (@nathanielsimard)"] description = "Abstractions for network communication for Burn" edition.workspace = true license.workspace = true name = "burn-communication" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-communication" version.workspace = true [lints] workspace = true [features] tracing = [ "burn-std/tracing", "burn-tensor?/tracing", ] data-service = ["burn-tensor"] websocket = ["axum", "tokio-tungstenite", "futures"] [dependencies] burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true } bytes = { workspace = true } derive-new = { workspace = true } futures-util = { workspace = true } log = { workspace = true } rmp-serde = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_bytes = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "sync", "signal", "tracing"] } tokio-util = { workspace = true } tracing = { workspace = true, features = ["default"] } tracing-core = { workspace = true, features = ["default"] } tracing-subscriber = { workspace = true, features = ["default", "fmt", "env-filter"] } # Tensor Data Service burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", optional = true } # Websocket axum = { workspace = true, features = ["ws"], optional = true } tokio-tungstenite = { workspace = true, optional = true } futures = { workspace = true, optional = true } ================================================ FILE: crates/burn-communication/README.md ================================================ # Burn Communication Abstractions for network communication The Protocol trait defines how to communicate in a server/client style. The server can set up routes with callbacks upon connection. ## WebSocket Communication with WebSockets is implemented with the `websocket` feature. ## Tensor Data Service The tensor data service provides easy utilities to share tensors peer-to-peer. One peer can expose a tensor, and another can download it. Each peer is both a client and a server. ================================================ FILE: crates/burn-communication/src/base.rs ================================================ use burn_std::future::DynFut; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display}; use std::hash::Hash; use std::str::FromStr; /// Allows nodes to find each other #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)] pub struct Address { pub(crate) inner: String, } impl FromStr for Address { type Err = String; fn from_str(s: &str) -> Result { Ok(Self { inner: s.to_string(), }) } } impl Display for Address { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.inner) } } /// The protocol used for the communications. pub trait Protocol: Clone + Send + Sync + 'static { /// The client implementation for the current protocol. type Client: ProtocolClient; /// The server implementation for the current protocol. type Server: ProtocolServer; } /// Error that happens during a communication. pub trait CommunicationError: Debug + Send + 'static {} /// The client is only used to create a [channel](CommunicationChannel), which should be use to /// transmit information with the [server](ProtocolServer). pub trait ProtocolClient: Send + Sync + 'static { /// Channel used by this protocol. type Channel: CommunicationChannel; /// The error type. type Error: CommunicationError; /// Opens a new [channel](CommunicationChannel) with the current protocol at the given /// [address](Address) and route. /// /// * `address` - Address to connect to /// * `route` - The name of the route (no slashes) /// /// Returns None if the connection can't be done. fn connect(address: Address, route: &str) -> DynFut>; } /// Data sent and received by the client and server. #[derive(new)] pub struct Message { /// The data is always encoded as bytes. pub data: bytes::Bytes, } /// Defines how to create a server that respond to a [channel](CommunicationChannel). pub trait ProtocolServer: Sized + Send + Sync + 'static { /// Channel used by this protocol. type Channel: CommunicationChannel; /// The error type. type Error: CommunicationError; /// Defines an endpoint with the function that responds. /// TODO Docs: does it need a slash? fn route(self, path: &str, callback: C) -> Self where C: FnOnce(Self::Channel) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send + 'static; /// Start the server. fn serve( self, shutdown: F, ) -> impl Future> + Send + 'static where F: Future + Send + 'static; } /// Handles communications. pub trait CommunicationChannel: Send + 'static { type Error: CommunicationError; /// Send a [message](Message) on the channel. fn send( &mut self, message: Message, ) -> impl std::future::Future> + Send; /// Receive a [message](Message) on the channel and returns a new [response message](Message). fn recv( &mut self, ) -> impl std::future::Future, Self::Error>> + Send; fn close(&mut self) -> impl std::future::Future> + Send; } ================================================ FILE: crates/burn-communication/src/data_service.rs ================================================ //! This module enables direct data transfer between servers without blocking the client or any server. //! //! It eliminates the need for intermediate data transfer through the client, avoiding the process of downloading data from one server and reuploading it to another. //! //! The module provides an optimized mechanism for servers to communicate directly, streamlining data movement between them without involving the client. use crate::Message; use crate::base::Protocol; use crate::base::{Address, CommunicationChannel, ProtocolClient, ProtocolServer}; use burn_tensor::{TensorData, backend::Backend}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, marker::PhantomData, sync::Arc}; use tokio::sync::Mutex; use tokio::sync::Notify; use tokio_util::sync::CancellationToken; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct TensorTransferId(u64); impl From for TensorTransferId { fn from(value: u64) -> Self { Self(value) } } impl TensorTransferId { pub fn next(&mut self) { self.0 += 1; } } #[derive(Debug, Serialize, Deserialize)] enum DataServiceMessage { TensorRequest(TensorTransferId), Tensor(TensorData), } type ClientChannelRef = Arc::Channel>>; pub struct TensorDataService> { /// Maps tensor transfer IDs to their exposed state. pub exposed_tensors: Mutex>, /// Maps node addresses to their channels. pub channels: Mutex>>, /// Notify when a new tensor is exposed. pub new_tensor_notify: Arc, cancel_token: CancellationToken, _phantom_data: PhantomData, } pub struct TensorExposeState { /// The bytes of the tensor data message. Message::Data(...) serialized with rmp_serde pub bytes: bytes::Bytes, /// How many times the tensor will be downloaded pub max_downloads: u32, /// How man times the tensor has been downloaded pub cur_download_count: u32, } /// Provides a routing function for a tensor data service for a communications server pub trait TensorDataServer { /// Routes the tensor data service to the "/data" route fn route_tensor_data_service(self, state: Arc>) -> Self; } impl + 'static> TensorDataServer for S { fn route_tensor_data_service(self, state: Arc>) -> Self { self.route("/data", async move |stream: S::Channel| { state.handle_data_channel(stream).await; }) } } impl TensorDataService { pub fn new(cancel_token: CancellationToken) -> Self { Self { exposed_tensors: Mutex::new(HashMap::new()), channels: Mutex::new(HashMap::new()), new_tensor_notify: Arc::new(Notify::new()), cancel_token, _phantom_data: PhantomData::, } } /// Exposes a tensor to the data server, allowing it to be downloaded by other nodes. pub async fn expose( &self, tensor: B::FloatTensorPrimitive, max_downloads: u32, transfer_id: TensorTransferId, ) { let data = B::float_into_data(tensor).await.unwrap(); self.expose_data(data, max_downloads, transfer_id).await } /// Exposes a tensor data to the data server, allowing it to be downloaded by other nodes. pub async fn expose_data( &self, tensor_data: TensorData, max_downloads: u32, transfer_id: TensorTransferId, ) { let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::Tensor(tensor_data)) .unwrap() .into(); let mut exposed_tensors = self.exposed_tensors.lock().await; exposed_tensors.insert( transfer_id, TensorExposeState { bytes, max_downloads, cur_download_count: 0, }, ); core::mem::drop(exposed_tensors); self.new_tensor_notify.notify_waiters(); } pub async fn close(&self) { // Send a closing message to every open WebSocket stream let mut streams = self.channels.lock().await; for (_, stream) in streams.drain() { let mut stream = stream.lock().await; stream .close() .await .expect("Failed to close WebSocket stream"); } } /// Downloads a tensor that is exposed on another server. Requires a Tokio 1.x runtime /// /// Returns None if the peer closes the connection pub async fn download_tensor( &self, remote: Address, transfer_id: TensorTransferId, ) -> Option { log::info!("Downloading tensor from {remote:?}"); let stream = self.get_data_stream(remote).await; let mut stream = stream.lock().await; // Send the download request with the download id let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::TensorRequest(transfer_id)) .unwrap() .into(); stream .send(Message::new(bytes)) .await .expect("Failed to send download id"); if let Ok(msg) = stream.recv().await { let Some(msg) = msg else { log::warn!("Received None message from the websocket, closing connection."); return None; }; let DataServiceMessage::Tensor(data) = rmp_serde::from_slice(&msg.data) .expect("Can deserialize messages from the websocket.") else { panic!("Message should have been TensorData") }; return Some(data); } log::warn!("Closed connection"); None } /// Get the WebSocket stream for the given address, or create a new one if it doesn't exist. async fn get_data_stream( &self, address: Address, ) -> Arc::Channel>> { let mut streams = self.channels.lock().await; match streams.get(&address) { Some(stream) => stream.clone(), None => { // Open a new WebSocket connection to the address let stream = P::Client::connect(address.clone(), "data").await; let Some(stream) = stream else { panic!("Failed to connect to data server at {address:?}"); }; let stream = Arc::new(Mutex::new(stream)); streams.insert(address.clone(), stream.clone()); stream } } } /// Get the requested exposed tensor data, and update download counter async fn get_exposed_tensor_bytes( &self, transfer_id: TensorTransferId, ) -> Option { loop { { let mut exposed_tensors = self.exposed_tensors.lock().await; // take the tensor out of the hashmap while we download if let Some(mut exposed_state) = exposed_tensors.remove(&transfer_id) { exposed_state.cur_download_count += 1; let bytes = if exposed_state.cur_download_count == exposed_state.max_downloads { exposed_state.bytes } else { let bytes = exposed_state.bytes.clone(); exposed_tensors.insert(transfer_id, exposed_state); bytes }; return Some(bytes); } } // No matching tensor, wait for a new one to come in. self.new_tensor_notify.notified().await; } } /// Handle incoming connections for downloading tensors. pub(crate) async fn handle_data_channel( &self, mut channel: ::Channel, ) { log::info!("[Data Handler] New connection for download."); while !self.cancel_token.is_cancelled() { match channel.recv().await { Ok(message) => { if let Some(msg) = message { let bytes = msg.data; let msg: DataServiceMessage = rmp_serde::from_slice(&bytes) .expect("Can deserialize messages from the websocket."); let DataServiceMessage::TensorRequest(transfer_id) = msg else { panic!("Received a message that wasn't a tensor request! {msg:?}"); }; let bytes = self.get_exposed_tensor_bytes(transfer_id).await.unwrap(); channel.send(Message::new(bytes)).await.unwrap(); } else { log::info!("Closed connection"); return; } } Err(err) => panic!("Failed to receive message from websocket: {err:?}"), }; } log::info!("[Data Service] Closing connection for download."); } } ================================================ FILE: crates/burn-communication/src/lib.rs ================================================ #[macro_use] extern crate derive_new; mod base; pub use base::*; pub mod util; #[cfg(feature = "websocket")] pub mod websocket; #[cfg(feature = "data-service")] pub mod data_service; ================================================ FILE: crates/burn-communication/src/util.rs ================================================ use tracing_core::{Level, LevelFilter}; use tracing_subscriber::{ Layer, filter::filter_fn, layer::SubscriberExt, registry, util::SubscriberInitExt, }; /// Utilities to help handle communication termination. pub async fn os_shutdown_signal() { let ctrl_c = async { tokio::signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } } pub(crate) fn init_logging() { let layer = tracing_subscriber::fmt::layer() .with_filter(LevelFilter::INFO) .with_filter(filter_fn(|m| { if let Some(path) = m.module_path() { // The wgpu crate is logging too much, so we skip `info` level. if path.starts_with("wgpu") && *m.level() >= Level::INFO { return false; } } true })); // If we start multiple servers in the same process, this will fail, it's ok let _ = registry().with(layer).try_init(); } ================================================ FILE: crates/burn-communication/src/websocket/base.rs ================================================ use crate::{ base::{Address, Protocol}, websocket::{client::WsClient, server::WsServer}, }; #[derive(Clone)] /// A websocket implements a [communication protocol](Protocol) that can be used to communicate /// over the internet. pub struct WebSocket {} impl Protocol for WebSocket { type Client = WsClient; type Server = WsServer; } /// Parse an address, add the ws:// prefix if needed, and return an error if the address is invalid pub(crate) fn parse_ws_address(mut address: Address) -> Result { let s = &address.inner; let parts = s.split("://").collect::>(); let num_parts = parts.len(); let url = if num_parts == 2 { if parts[0] == "ws" { s.to_owned() } else { return Err(format!("Invalid prefix: {}", parts[0])); } } else if num_parts == 1 { return Err(format!("ws://{s}")); } else { return Err(format!("Invalid url: {s}")); }; address.inner = url; Ok(address) } ================================================ FILE: crates/burn-communication/src/websocket/client.rs ================================================ use crate::{ base::{Address, CommunicationChannel, CommunicationError, Message, ProtocolClient}, websocket::base::parse_ws_address, }; use burn_std::future::DynFut; use futures::{SinkExt, StreamExt}; use tokio::net::TcpStream; use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, connect_async_with_config, tungstenite::{self, protocol::WebSocketConfig}, }; #[derive(Clone)] pub struct WsClient; impl ProtocolClient for WsClient { type Channel = WsClientChannel; type Error = WsClientError; fn connect(address: Address, route: &str) -> DynFut> { Box::pin(connect_ws(address, route.to_owned())) } } /// Open a new WebSocket connection to the address async fn connect_ws(address: Address, route: String) -> Option { let address = parse_ws_address(address).ok()?; let address = format!("{address}/{route}"); const MB: usize = 1024 * 1024; let (stream, _) = connect_async_with_config( address.clone(), Some( WebSocketConfig::default() .write_buffer_size(0) .max_message_size(None) .max_frame_size(Some(MB * 512)) .accept_unmasked_frames(true) .read_buffer_size(64 * 1024), // 64 KiB (previous default) ), true, ) .await .ok()?; Some(WsClientChannel { inner: stream }) } pub struct WsClientChannel { inner: WebSocketStream>, } impl CommunicationChannel for WsClientChannel { type Error = WsClientError; async fn send(&mut self, msg: Message) -> Result<(), WsClientError> { self.inner .send(tungstenite::Message::Binary(msg.data)) .await?; Ok(()) } async fn recv(&mut self) -> Result, WsClientError> { match self.inner.next().await { Some(next) => match next { Ok(tungstenite::Message::Binary(data)) => Ok(Some(Message { data })), Ok(tungstenite::Message::Close(_close_frame)) => Ok(None), Err(err) => Err(WsClientError::Tungstenite(err)), msg => Err(WsClientError::UnknownMessage(format!("{msg:?}"))), }, None => todo!(), } } async fn close(&mut self) -> Result<(), WsClientError> { let reason = "Peer is closing".to_string(); self.inner .send(tungstenite::Message::Close(Some( tungstenite::protocol::CloseFrame { code: tungstenite::protocol::frame::coding::CloseCode::Normal, reason: reason.clone().into(), }, ))) .await?; Ok(()) } } #[derive(Debug)] pub enum WsClientError { Io(std::io::Error), Tungstenite(tungstenite::Error), UnknownMessage(String), Other(String), } impl CommunicationError for WsClientError {} impl From for WsClientError { fn from(err: std::io::Error) -> Self { Self::Io(err) } } impl From for WsClientError { fn from(err: tungstenite::Error) -> Self { Self::Tungstenite(err) } } ================================================ FILE: crates/burn-communication/src/websocket/mod.rs ================================================ mod base; mod client; mod server; pub use base::*; pub use client::*; pub use server::*; ================================================ FILE: crates/burn-communication/src/websocket/server.rs ================================================ use std::net::SocketAddr; use crate::{ base::{CommunicationChannel, CommunicationError, Message, ProtocolServer}, util::init_logging, }; use axum::{ Router, extract::{ State, WebSocketUpgrade, ws::{self, WebSocket}, }, routing::get, }; use futures::StreamExt; #[derive(Clone, Debug)] pub struct WsServer { port: u16, router: Router<()>, } pub struct WsServerChannel { inner: WebSocket, } impl WsServer { pub fn new(port: u16) -> Self { Self { port, router: Router::new(), } } } impl ProtocolServer for WsServer { type Channel = WsServerChannel; type Error = WsServerError; async fn serve(self, shutdown: F) -> Result<(), Self::Error> where F: Future + Send + 'static, { init_logging(); let address = format!("0.0.0.0:{}", self.port); log::info!("Starting server {address}"); let listener = tokio::net::TcpListener::bind(address).await?; axum::serve( listener, self.router .into_make_service_with_connect_info::(), ) .with_graceful_shutdown(shutdown) .await?; Ok(()) } fn route(mut self, path: &str, callback: C) -> Self where C: FnOnce(WsServerChannel) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send + 'static, { // Format path: should start with a / let path = if path.starts_with("/") { path.to_owned() } else { format!("/{path}") }; let method = get(|ws: WebSocketUpgrade, _: State<()>| async { ws.on_upgrade(async move |socket| { callback(WsServerChannel { inner: socket }).await; }) }); self.router = self.router.route(&path, method); self } } impl CommunicationChannel for WsServerChannel { type Error = WsServerError; async fn send(&mut self, message: Message) -> Result<(), WsServerError> { self.inner.send(ws::Message::Binary(message.data)).await?; Ok(()) } async fn recv(&mut self) -> Result, WsServerError> { match self.inner.next().await { Some(next) => match next { Ok(ws::Message::Binary(data)) => Ok(Some(Message { data })), Ok(ws::Message::Close(_close_frame)) => Ok(None), Err(err) => Err(WsServerError::Axum(err)), msg => Err(WsServerError::UnknownMessage(format!("{msg:?}"))), }, None => todo!(), } } async fn close(&mut self) -> Result<(), WsServerError> { let reason = "Peer is closing".to_string(); self.inner .send(ws::Message::Close(Some(ws::CloseFrame { code: 1000, // code: Normal reason: reason.clone().into(), }))) .await?; Ok(()) } } #[derive(Debug)] pub enum WsServerError { Io(std::io::Error), Axum(axum::Error), UnknownMessage(String), Other(String), } impl CommunicationError for WsServerError {} impl From for WsServerError { fn from(err: std::io::Error) -> Self { Self::Io(err) } } impl From for WsServerError { fn from(err: axum::Error) -> Self { Self::Axum(err) } } ================================================ FILE: crates/burn-core/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Flexible and Comprehensive Deep Learning Framework in Rust" documentation = "https://docs.rs/burn-core" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-core" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-core" version.workspace = true [lints] workspace = true [features] default = [ "std", "burn-std/default", "burn-dataset?/default", "burn-tensor/default", ] doc = [ "std", "dataset", "audio", # Doc features "burn-std/doc", "burn-dataset/doc", "burn-tensor/doc", ] tracing = [ "burn-std/tracing", "burn-tensor/tracing", "burn-dataset?/tracing", "burn-vision?/tracing", ] dataset = ["burn-dataset"] network = ["burn-std/network"] sqlite = ["burn-dataset?/sqlite"] sqlite-bundled = ["burn-dataset?/sqlite-bundled"] std = [ "bincode/std", "burn-std/std", "burn-tensor/std", "flate2", "half/std", "log", "rand/std", "rmp-serde", "serde/std", "serde_json/std", "num-traits/std", ] vision = ["burn-vision", "burn-dataset?/vision"] audio = ["burn-dataset?/audio"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror"] test-cuda = [ "burn-cuda/default", ] # To use cuda during testing, default uses ndarray. test-rocm = [ "burn-rocm/default", ] # To use hip during testing, default uses ndarray. test-tch = [ "burn-tch/default", ] # To use tch during testing, default uses ndarray. test-wgpu = [ "burn-wgpu/default", ] # To use wgpu during testing, default uses ndarray. test-vulkan = [ "test-wgpu", "burn-wgpu/vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. test-metal = [ "test-wgpu", "burn-wgpu/metal", ] # To use wgpu-spirv during testing, default uses ndarray. # Memory checks are disabled by default test-memory-checks = ["burn-fusion/memory-checks"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-derive = { path = "../burn-derive", version = "=0.21.0-pre.2" } burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false } burn-vision = { path = "../burn-vision", version = "=0.21.0-pre.2", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } derive-new = { workspace = true } log = { workspace = true, optional = true } rand = { workspace = true } # The same implementation of HashMap in std but with no_std support (only alloc crate is needed) hashbrown = { workspace = true, features = ["serde"] } # no_std compatible # Serialize Deserialize flate2 = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } ahash = { workspace = true } bincode = { workspace = true } half = { workspace = true } num-traits = { workspace = true } rmp-serde = { workspace = true, optional = true } serde_json = { workspace = true, features = ["alloc"] } #Default enables std spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled thiserror = { workspace = true, optional = true } [target.'cfg(target_has_atomic = "ptr")'.dependencies] regex = { workspace = true } # FOR TESTING burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic-util = { workspace = true } portable-atomic = { workspace = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } burn-dataset = { path = "../burn-dataset", version = "=0.21.0-pre.2", features = [ "fake", ] } rstest = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-core/README.md ================================================ # Burn Core This crate should be used with [burn](https://github.com/tracel-ai/burn). It contains the core traits and components for building and training deep learning models with Burn. [![Current Crates.io Version](https://img.shields.io/crates/v/burn-core.svg)](https://crates.io/crates/burn-core) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-core/blob/master/README.md) ## Feature Flags This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the default `std` feature. - `std` - enables the standard library. Enabled by default. ================================================ FILE: crates/burn-core/src/config.rs ================================================ use alloc::{format, string::String, string::ToString}; pub use burn_derive::Config; use core::fmt::Debug; /// Configuration IO error. #[derive(Debug)] pub enum ConfigError { /// Invalid format. InvalidFormat(String), /// File not found. FileNotFound(String), } impl core::fmt::Display for ConfigError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut message = "Config error => ".to_string(); match self { Self::InvalidFormat(err) => { message += format!("Invalid format: {err}").as_str(); } Self::FileNotFound(err) => { message += format!("File not found: {err}").as_str(); } }; f.write_str(message.as_str()) } } impl core::error::Error for ConfigError {} /// Configuration trait. pub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned { /// Saves the configuration to a file. /// /// # Arguments /// /// * `file` - File to save the configuration to. /// /// # Returns /// /// The output of the save operation. #[cfg(feature = "std")] fn save>(&self, file: P) -> std::io::Result<()> { std::fs::write(file, config_to_json(self)) } /// Loads the configuration from a file. /// /// # Arguments /// /// * `file` - File to load the configuration from. /// /// # Returns /// /// The loaded configuration. #[cfg(feature = "std")] fn load>(file: P) -> Result { let content = std::fs::read_to_string(file.as_ref()) .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?; config_from_str(&content) } /// Loads the configuration from a binary buffer. /// /// # Arguments /// /// * `data` - Binary buffer to load the configuration from. /// /// # Returns /// /// The loaded configuration. fn load_binary(data: &[u8]) -> Result { let content = core::str::from_utf8(data).map_err(|_| { ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()) })?; config_from_str(content) } } /// Converts a configuration to a JSON string. /// /// # Arguments /// /// * `config` - Configuration to convert. /// /// # Returns /// /// The JSON string. pub fn config_to_json(config: &C) -> String { serde_json::to_string_pretty(config).unwrap() } fn config_from_str(content: &str) -> Result { serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}"))) } ================================================ FILE: crates/burn-core/src/data/dataloader/base.rs ================================================ use burn_tensor::backend::Backend; pub use crate::data::dataset::{Dataset, DatasetIterator}; use core::iter::Iterator; use std::sync::Arc; /// A progress struct that can be used to track the progress of a data loader. #[derive(new, Clone, Debug)] pub struct Progress { /// The number of items that have been processed. pub items_processed: usize, /// The total number of items that need to be processed. pub items_total: usize, } /// A data loader iterator that can be used to iterate over a data loader. pub trait DataLoaderIterator: Iterator { /// Returns the progress of the data loader. fn progress(&self) -> Progress; } /// A data loader that can be used to iterate over a dataset. pub trait DataLoader: Send + Sync { /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. fn iter<'a>(&'a self) -> Box + 'a>; /// The number of items (not the number of batches nor the number of iterations), /// corresponding to the items_total of the progress returned by the iterator. fn num_items(&self) -> usize; /// Move the data loader to the given device, ensuring the batches are assigned to the correct device. fn to_device(&self, device: &B::Device) -> Arc>; /// Returns a new data loader containing a subset of the data. /// /// The subset includes items from `start` (inclusive) to `end` (exclusive), /// preserving the batch size and ordering of the original data loader. /// /// # Arguments /// /// * `start` - The starting index of the subset (inclusive). /// * `end` - The ending index of the subset (exclusive). /// /// # Returns /// /// A boxed [`DataLoader`] instance containing only the specified range. fn slice(&self, start: usize, end: usize) -> Arc>; } ================================================ FILE: crates/burn-core/src/data/dataloader/batch.rs ================================================ use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher}; use burn_dataset::{ Dataset, transform::{PartialDataset, ShuffledDataset}, }; use burn_tensor::backend::Backend; use rand::SeedableRng; use std::ops::DerefMut; use std::sync::Arc; /// A data loader that can be used to iterate over a dataset in batches. pub struct BatchDataLoader { strategy: Box>, dataset: Arc>, batcher: Arc>, device: B::Device, rng: Option>>, } impl Clone for BatchDataLoader { fn clone(&self) -> Self { Self { strategy: self.strategy.clone_dyn(), dataset: self.dataset.clone(), batcher: self.batcher.clone(), device: self.device.clone(), rng: self.rng.clone(), } } } impl BatchDataLoader { /// Creates a new batch data loader. /// /// # Arguments /// /// * `strategy` - The batch strategy. /// * `dataset` - The dataset. /// * `batcher` - The batcher. /// * `device` - The device to use when loading a batch. /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader /// iterator is created. /// /// # Returns /// /// The batch data loader. pub fn new( strategy: Box>, dataset: Arc>, batcher: Arc>, device: B::Device, rng: Option, ) -> Self { Self { strategy, dataset, batcher, device, rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))), } } } /// A data loader iterator that can be used to iterate over a data loader. struct BatchDataloaderIterator { current_index: usize, strategy: Box>, dataset: Arc>, batcher: Arc>, device: B::Device, } impl DataLoader for BatchDataLoader where B: Backend, I: Send + Sync + Clone + 'static, O: Send + 'static, { fn iter<'a>(&'a self) -> Box + 'a> { // When starting a new iteration, we first check if the dataloader was created with an rng, // implying that we should shuffle the dataset beforehand, while advancing the current // rng to ensure that each new iteration shuffles the dataset differently. let dataset = match &self.rng { Some(rng) => Arc::new(ShuffledDataset::new( self.dataset.clone(), rng.lock().deref_mut(), )), None => self.dataset.clone(), }; Box::new(BatchDataloaderIterator::new( self.strategy.clone_dyn(), dataset, self.batcher.clone(), self.device.clone(), )) } fn num_items(&self) -> usize { self.dataset.len() } fn to_device(&self, device: &B::Device) -> Arc> { let rng = self.rng.as_ref().map(|rng| { let mut rng = rng.lock(); rng.fork() }); Arc::new(Self::new( self.strategy.clone_dyn(), self.dataset.clone(), self.batcher.clone(), device.clone(), rng, )) } fn slice(&self, start: usize, end: usize) -> Arc> { let rng = self.rng.as_ref().map(|rng| { let mut rng = rng.lock(); rng.fork() }); let dataloader = Self::new( self.strategy.clone_dyn(), Arc::new(PartialDataset::new(self.dataset.clone(), start, end)), self.batcher.clone(), self.device.clone(), rng, ); Arc::new(dataloader) } } impl BatchDataloaderIterator { /// Creates a new batch data loader iterator. /// /// # Arguments /// /// * `strategy` - The batch strategy. /// * `dataset` - The dataset. /// * `batcher` - The batcher. /// * `device` - The device to use when loading a batch. /// /// # Returns /// /// The batch data loader iterator. pub fn new( strategy: Box>, dataset: Arc>, batcher: Arc>, device: B::Device, ) -> Self { BatchDataloaderIterator { current_index: 0, strategy, dataset, batcher, device, } } } impl Iterator for BatchDataloaderIterator { type Item = O; fn next(&mut self) -> Option { while let Some(item) = self.dataset.get(self.current_index) { self.current_index += 1; self.strategy.add(item); if let Some(items) = self.strategy.batch(false) { return Some(self.batcher.batch(items, &self.device)); } } if let Some(items) = self.strategy.batch(true) { return Some(self.batcher.batch(items, &self.device)); } None } } impl DataLoaderIterator for BatchDataloaderIterator { fn progress(&self) -> Progress { Progress::new(self.current_index, self.dataset.len()) } } #[cfg(test)] mod tests { use std::collections::HashSet; use super::*; use crate::data::dataloader::FixBatchStrategy; use crate::data::dataloader::batcher::TestBatcher; use crate::data::dataset::FakeDataset; #[test] fn test_batch_dataloader() { let batcher = Arc::new(TestBatcher::new()); let dataset = Arc::new(FakeDataset::::new(27)); let dataloader = BatchDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset.clone(), batcher, Default::default(), None, ); let mut items_dataset = HashSet::new(); let mut items_dataloader = HashSet::new(); for item in dataset.iter() { items_dataset.insert(item); } for items in dataloader.iter() { for item in items { items_dataloader.insert(item); } } assert_eq!(items_dataset, items_dataloader); } #[test] fn test_batch_dataloader_slice() { let batcher = Arc::new(TestBatcher::new()); let dataset = Arc::new(FakeDataset::::new(27)); let dataloader = BatchDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset.clone(), batcher, Default::default(), None, ); let dataloader_slice = dataloader.slice(5, 15); let mut items_dataloader = HashSet::new(); let mut items_dataloader_slice = HashSet::new(); let mut idx = 0; for items in dataloader.iter() { for item in items { if (5..15).contains(&idx) { items_dataloader.insert(item); } idx += 1; } } for items in dataloader_slice.iter() { for item in items { items_dataloader_slice.insert(item); } } assert_eq!(items_dataloader, items_dataloader_slice); } } ================================================ FILE: crates/burn-core/src/data/dataloader/batcher.rs ================================================ use burn_tensor::backend::Backend; #[cfg(test)] use crate::TestBackend; /// A trait for batching items of type `I` into items of type `O`. pub trait Batcher: Send + Sync { /// Batches the given items on the specified device. /// /// # Arguments /// /// * `items` - The items to batch. /// * `device` - The backend device to use. /// /// # Returns /// /// The batched items. fn batch(&self, items: Vec, device: &B::Device) -> O; } /// Test batcher #[cfg(test)] #[derive(new, Clone)] pub struct TestBatcher; #[cfg(test)] impl Batcher> for TestBatcher { fn batch(&self, items: Vec, _device: &::Device) -> Vec { items } } ================================================ FILE: crates/burn-core/src/data/dataloader/builder.rs ================================================ use super::{ BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader, batcher::Batcher, }; use burn_dataset::Dataset; use burn_tensor::backend::Backend; use rand::{SeedableRng, rngs::StdRng}; use std::sync::Arc; /// A builder for data loaders. pub struct DataLoaderBuilder { strategy: Option>>, batcher: Arc>, num_threads: Option, shuffle: Option, device: Option, } impl DataLoaderBuilder where B: Backend, I: Send + Sync + Clone + std::fmt::Debug + 'static, O: Send + Clone + std::fmt::Debug + 'static, { /// Creates a new data loader builder. /// /// # Arguments /// /// * `batcher` - The batcher. /// /// # Returns /// /// The data loader builder. pub fn new(batcher: Bt) -> Self where Bt: Batcher + 'static, { Self { batcher: Arc::new(batcher), strategy: None, num_threads: None, shuffle: None, device: None, } } /// Sets the batch size to a fix number. /// /// The [fix batch strategy](FixBatchStrategy) will be used. /// /// # Arguments /// /// * `batch_size` - The batch size. /// /// # Returns /// /// The data loader builder. pub fn batch_size(mut self, batch_size: usize) -> Self { self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); self } /// Sets the seed for shuffling. /// /// Each time the dataloader starts a new iteration, the dataset will be shuffled. /// /// # Arguments /// /// * `seed` - The seed. /// /// # Returns /// /// The data loader builder. pub fn shuffle(mut self, seed: u64) -> Self { self.shuffle = Some(seed); self } /// Sets the number of workers. /// /// - `Some(0)` or `None`: the dataloader will run without work threads. /// - `Some(n); n > 0`: the dataloader will run with `n` background threads. /// /// A 1-worker threaded dataloader will run loads in a background thread, /// while a 0-worker threaded dataloader will run loads in the main thread. /// /// # Arguments /// /// * `num_workers` - The number of workers. /// /// # Returns /// /// The data loader builder. pub fn num_workers(mut self, num_workers: usize) -> Self { self.num_threads = Some(num_workers); self } /// Sets the data loader device. /// /// # Arguments /// /// * `device` - The device to use when loading a batch. /// /// # Returns /// /// The data loader builder. pub fn set_device(mut self, device: B::Device) -> Self { self.device = Some(device); self } /// Builds the data loader. /// /// # Arguments /// /// * `dataset` - The dataset. /// /// # Returns /// /// The data loader. pub fn build(self, dataset: D) -> Arc> where D: Dataset + 'static, { let dataset = Arc::new(dataset); let device = self.device.unwrap_or_default(); let rng = self.shuffle.map(StdRng::seed_from_u64); let strategy = match self.strategy { Some(strategy) => strategy, None => Box::new(FixBatchStrategy::new(1)), }; if let Some(num_threads) = self.num_threads && num_threads > 0 { return Arc::new(MultiThreadDataLoader::new( strategy, dataset, self.batcher, num_threads, device, rng, )); } Arc::new(BatchDataLoader::new( strategy, dataset, self.batcher, device, rng, )) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use crate::data::dataset::FakeDataset; #[derive(new, Clone)] struct TestBatcherDevice; #[cfg(test)] impl Batcher for TestBatcherDevice { fn batch(&self, _items: Vec, device: &TestDevice) -> TestDevice { *device } } type TestDevice = ::Device; #[test] fn test_dataloader_no_workers() { type TestDevice = ::Device; let default_device = TestDevice::default(); let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new()) .batch_size(1) .build(FakeDataset::::new(9)); assert_eq!(dataloader.num_items(), 9); for device in dataloader.iter() { assert_eq!(device, default_device) } } #[test] fn test_dataloader_default_device() { let default_device = TestDevice::default(); let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new()) .batch_size(1) .num_workers(1) .build(FakeDataset::::new(9)); assert_eq!(dataloader.num_items(), 9); for device in dataloader.iter() { assert_eq!(device, default_device) } } #[test] fn test_dataloader_slice_multi_device() { let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new()) .batch_size(1) .num_workers(1) .build(FakeDataset::::new(11)); #[cfg(all( test, not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda") ))] // Only one device exists... let (device1, device2) = ( burn_ndarray::NdArrayDevice::Cpu, burn_ndarray::NdArrayDevice::Cpu, ); #[cfg(all(test, feature = "test-tch"))] let (device1, device2) = ( burn_tch::LibTorchDevice::Cuda(0), burn_tch::LibTorchDevice::Cuda(1), ); #[cfg(all(test, feature = "test-wgpu"))] let (device1, device2) = ( burn_wgpu::WgpuDevice::DiscreteGpu(0), burn_wgpu::WgpuDevice::DiscreteGpu(1), ); #[cfg(all(test, feature = "test-cuda"))] let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1)); assert_eq!(dataloader.num_items(), 11); let dataloader_1 = dataloader.slice(0, 5).to_device(&device1); let dataloader_2 = dataloader.slice(5, 11).to_device(&device2); assert_eq!(dataloader_1.num_items(), 5); assert_eq!(dataloader_2.num_items(), 6); let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter()); for _ in 0..5 { assert_eq!(iterator_1.next(), Some(device1)); assert_eq!(iterator_2.next(), Some(device2)); } assert_eq!(iterator_1.next(), None); // For uneven split, the last dataloader (partial dataset) will have the remaining item assert_eq!(iterator_2.next(), Some(device2)); assert_eq!(iterator_2.next(), None); } } ================================================ FILE: crates/burn-core/src/data/dataloader/mod.rs ================================================ mod base; mod batch; mod builder; mod multithread; mod strategy; /// Module for batching items. pub mod batcher; /// Module to split a dataloader. pub mod split; pub use base::*; pub use batch::*; pub use builder::*; pub use multithread::*; pub use strategy::*; ================================================ FILE: crates/burn-core/src/data/dataloader/multithread.rs ================================================ use burn_dataset::Dataset; use burn_dataset::transform::PartialDataset; use burn_tensor::backend::Backend; use rand::distr::{Distribution, StandardUniform}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use super::batcher::Batcher; use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress}; use std::sync::{Arc, OnceLock, mpsc}; use std::thread; const MAX_QUEUED_ITEMS: usize = 100; type RngSeed = ::Seed; /// A multi-threaded data loader that can be used to iterate over a dataset. pub struct MultiThreadDataLoader { // Configuration parameters needed for initialization strategy: Box>, dataset: Arc>, batcher: Arc>, device: B::Device, seed: Option, num_threads: usize, // The lazily initialized data loaders dataloaders: OnceLock>>, } /// A message that can be sent between threads. #[derive(Debug)] pub enum Message { /// A batch of items. Batch(usize, O, Progress), /// The thread is done. Done, } struct MultiThreadsDataloaderIterator { num_done: usize, workers: Vec>, receiver: mpsc::Receiver>, progresses: Vec, } impl MultiThreadDataLoader where I: Send + Sync + Clone + 'static, O: Send + 'static, { /// Creates a new multi-threaded batch data loader. /// /// # Arguments /// /// * `strategy` - The batch strategy. /// * `dataset` - The dataset. /// * `batcher` - The batcher. /// * `num_threads` - The number of threads. /// * `device` - The device to use when loading a batch. /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader /// iterator is created. /// /// # Returns /// /// The multi-threaded batch data loader. pub fn new( strategy: Box>, dataset: Arc>, batcher: Arc>, num_threads: usize, device: B::Device, rng: Option, ) -> Self { let mut seed = None; if let Some(mut rng) = rng { // RNG stream splitting (not state cloning): derive a new seed from the RNG's output. // This is exactly what `rng.fork()` does. let mut s = RngSeed::default(); rng.fill_bytes(&mut s); seed = Some(s); } Self::from_seed(strategy, dataset, batcher, num_threads, device, seed) } fn from_seed( strategy: Box>, dataset: Arc>, batcher: Arc>, num_threads: usize, device: B::Device, seed: Option, ) -> Self { Self { strategy, dataset, batcher, num_threads, device, seed, dataloaders: OnceLock::new(), } } /// Force initialization if needed. fn initialize(&self) -> &[BatchDataLoader] { self.dataloaders .get_or_init(|| { let mut dataset = self.dataset.clone(); if let Some(seed) = self.seed.as_ref() { // Pre-shuffle the dataset before split if shuffle is enabled. // This ensures that each thread gets a uniform random sample of the dataset. let mut rng = StdRng::from_seed(*seed); dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new( dataset, &mut rng, )); } let datasets = match self.strategy.batch_size() { Some(batch_size) => { PartialDataset::split_chunks(dataset, self.num_threads, batch_size) } None => PartialDataset::split(dataset, self.num_threads), }; // Create more rngs from the first one, one for each new dataloader. let mut rng = self.seed.map(StdRng::from_seed); let rngs = (0..self.num_threads).map(|_| { rng.as_mut().map(|rng| { StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng)) }) }); datasets .into_iter() .zip(rngs) .map(|(dataset, rng)| { let strategy = self.strategy.clone_dyn(); BatchDataLoader::new( strategy, Arc::new(dataset), self.batcher.clone(), self.device.clone(), rng, ) }) .collect() }) .as_ref() } } impl DataLoader for MultiThreadDataLoader where I: Send + Sync + Clone + 'static, O: Send + 'static + std::fmt::Debug, { fn iter<'a>(&'a self) -> Box + 'a> { // This will initialize the loader if it hasn't been initialized yet let dataloaders = self.initialize(); let (sender, receiver) = mpsc::sync_channel::>(MAX_QUEUED_ITEMS); let mut progresses = Vec::with_capacity(dataloaders.len()); let handlers: Vec<_> = dataloaders .iter() .enumerate() .map(|(index, dataloader)| { let dataloader_cloned = dataloader.clone(); let sender_cloned = sender.clone(); progresses.push(Progress::new(0, dataloader_cloned.num_items())); std::thread::Builder::new() .name(std::format!("dataloader-{index}")) .spawn(move || { let mut iterator = dataloader_cloned.iter(); while let Some(item) = iterator.next() { let progress = iterator.progress(); match sender_cloned.send(Message::Batch(index, item, progress)) { Ok(_) => {} // The receiver is probably gone, no need to panic, just need to stop // iterating. Err(_) => return, }; } // Same thing. sender_cloned.send(Message::Done).ok(); }) .unwrap() }) .collect(); Box::new(MultiThreadsDataloaderIterator::new( receiver, handlers, progresses, )) } fn num_items(&self) -> usize { // For num_items, we can directly use the dataset size without // necessarily initializing the full loader self.dataset.len() } fn to_device(&self, device: &B::Device) -> Arc> { Arc::new(Self::from_seed( self.strategy.clone_dyn(), self.dataset.clone(), self.batcher.clone(), self.num_threads, device.clone(), self.seed, )) } fn slice(&self, start: usize, end: usize) -> Arc> { let dataloader = Self::from_seed( self.strategy.clone_dyn(), Arc::new(PartialDataset::new(self.dataset.clone(), start, end)), self.batcher.clone(), self.num_threads, self.device.clone(), self.seed, ); Arc::new(dataloader) } } impl MultiThreadsDataloaderIterator { pub fn new( receiver: mpsc::Receiver>, workers: Vec>, progresses: Vec, ) -> Self { MultiThreadsDataloaderIterator { num_done: 0, workers, receiver, progresses, } } } impl DataLoaderIterator for MultiThreadsDataloaderIterator { fn progress(&self) -> Progress { let mut items_total = 0; let mut items_processed = 0; for progress in self.progresses.iter() { items_total += progress.items_total; items_processed += progress.items_processed; } Progress::new(items_processed, items_total) } } impl Iterator for MultiThreadsDataloaderIterator { type Item = O; fn next(&mut self) -> Option { if self.workers.is_empty() { return None; } loop { let item = self.receiver.recv(); let item = item.unwrap(); match item { Message::Batch(index, item, progress) => { if let Some(current) = self.progresses.get_mut(index) { *current = progress; } return Some(item); } Message::Done => { self.num_done += 1; } }; if self.num_done == self.workers.len() { while let Some(worker) = self.workers.pop() { worker.join().unwrap(); } return None; } } } } #[cfg(test)] mod tests { use super::*; use crate::data::dataloader::FixBatchStrategy; use crate::data::dataloader::batcher::TestBatcher; use crate::data::dataset::FakeDataset; use burn_dataset::InMemDataset; use std::collections::HashSet; #[test] fn test_multi_thread_batch_dataloader() { let batcher = Arc::new(TestBatcher::new()); let dataset = Arc::new(FakeDataset::::new(27)); let dataloader_single_thread = BatchDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset.clone(), batcher.clone(), Default::default(), None, ); let dataloader_multi_thread = MultiThreadDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset, batcher, 4, Default::default(), None, ); let mut items_single_thread = HashSet::new(); let mut items_multi_thread = HashSet::new(); for items in dataloader_single_thread.iter() { for item in items { items_single_thread.insert(item); } } for items in dataloader_multi_thread.iter() { for item in items { items_multi_thread.insert(item); } } assert_eq!(items_single_thread, items_multi_thread); } #[test] fn test_multi_thread_batch_dataloader_shuffle() { let num_classes = 2; let class_size = 100; let batch_size = 10; // Items is a deliberately ordered dataset. let mut items = Vec::new(); for class in 0..num_classes { items.extend(vec![class; class_size]); } { // Unshuffled multithreaded loader let dataset = Arc::new(InMemDataset::new(items.clone())); let batcher = Arc::new(TestBatcher::new()); let loader = MultiThreadDataLoader::new( Box::new(FixBatchStrategy::new(batch_size)), dataset, batcher, num_classes, Default::default(), // No rng means no shuffling. None, ); for batch in loader.iter() { let mut batch_items = HashSet::new(); for item in batch { batch_items.insert(item); } // Since the dataset is not shuffled, we expect each batch to contain the same item. assert_eq!(batch_items.len(), 1); } } { // Shuffled multithreaded loader let dataset = Arc::new(InMemDataset::new(items.clone())); let batcher = Arc::new(TestBatcher::new()); let loader = MultiThreadDataLoader::new( Box::new(FixBatchStrategy::new(batch_size)), dataset.clone(), batcher.clone(), num_classes, Default::default(), // The rng enables shuffling. Some(StdRng::seed_from_u64(42)), ); for batch in loader.iter() { let mut batch_items = HashSet::new(); for item in batch { batch_items.insert(item); } // Since the dataset is shuffled, we expect to see all items. assert_eq!(batch_items.len(), num_classes); } } } #[test] fn test_multi_thread_batch_dataloader_incomplete_batches() { let batcher = Arc::new(TestBatcher::new()); let dataset = Arc::new(FakeDataset::::new(27)); let dataloader_single_thread = BatchDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset.clone(), batcher.clone(), Default::default(), None, ); let dataloader_multi_thread = MultiThreadDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset, batcher, 4, Default::default(), None, ); let mut items_single_thread = HashSet::new(); let mut items_multi_thread = HashSet::new(); let mut single_thread_cnt = 0; let mut multi_thread_cnt = 0; for items in dataloader_single_thread.iter() { items_single_thread.insert(items); single_thread_cnt += 1; } for items in dataloader_multi_thread.iter() { items_multi_thread.insert(items); multi_thread_cnt += 1; } assert_eq!(single_thread_cnt, multi_thread_cnt); assert_eq!(items_single_thread, items_multi_thread); } } ================================================ FILE: crates/burn-core/src/data/dataloader/split.rs ================================================ use std::sync::Arc; use burn_tensor::backend::Backend; use super::DataLoader; /// Splits a dataloader into multiple partial dataloaders (one per device). pub fn split_dataloader( dataloader: Arc>, devices: &[B::Device], ) -> Vec>> { let num_splits = devices.len(); if num_splits > 1 { let num_items = dataloader.num_items(); let mut dataloaders = Vec::with_capacity(num_splits); let mut start = 0; let step = num_items / num_splits; for (i, device) in devices.iter().enumerate() { let end = if i == (num_splits - 1) { num_items } else { start + step }; let dataloader = dataloader.slice(start, end).to_device(device); dataloaders.push(dataloader); start = end; } dataloaders } else { vec![dataloader] } } #[cfg(test)] mod tests { use std::collections::HashSet; use super::*; use crate::TestBackend; use crate::data::dataloader::batcher::Batcher; use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy}; use crate::data::dataset::FakeDataset; #[test] fn test_split_batch_dataloader() { type TestDevice = ::Device; #[derive(new, Clone)] pub struct TestBatcher; #[cfg(test)] impl Batcher, TestDevice)> for TestBatcher { fn batch(&self, items: Vec, device: &TestDevice) -> (Vec, TestDevice) { (items, *device) } } let batcher = Arc::new(TestBatcher::new()); let dataset = Arc::new(FakeDataset::::new(11)); #[allow(clippy::arc_with_non_send_sync)] let dataloader = Arc::new(BatchDataLoader::new( Box::new(FixBatchStrategy::new(5)), dataset.clone(), batcher, Default::default(), None, )); #[cfg(all( test, not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda") ))] // Only one device exists... let (device1, device2) = ( burn_ndarray::NdArrayDevice::Cpu, burn_ndarray::NdArrayDevice::Cpu, ); #[cfg(all(test, feature = "test-tch"))] let (device1, device2) = ( burn_tch::LibTorchDevice::Cuda(0), burn_tch::LibTorchDevice::Cuda(1), ); #[cfg(all(test, feature = "test-wgpu"))] let (device1, device2) = ( burn_wgpu::WgpuDevice::DiscreteGpu(0), burn_wgpu::WgpuDevice::DiscreteGpu(1), ); #[cfg(all(test, feature = "test-cuda"))] let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1)); let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]); assert_eq!(dataloaders.len(), 2); let [dataloader_1, dataloader_2] = match dataloaders.try_into() { Ok(arr) => arr, Err(_) => unreachable!(), }; assert_eq!(dataloader_1.num_items(), 5); assert_eq!(dataloader_2.num_items(), 6); let mut items_dataloader = HashSet::new(); let mut items_dataloader_split = HashSet::new(); for (items, _device) in dataloader.iter() { for item in items { items_dataloader.insert(item); } } for (items, device) in dataloader_1.iter() { assert_eq!(device, device1); for item in items { items_dataloader_split.insert(item); } } for (items, device) in dataloader_2.iter() { assert_eq!(device, device2); for item in items { items_dataloader_split.insert(item); } } assert_eq!(items_dataloader, items_dataloader_split); } } ================================================ FILE: crates/burn-core/src/data/dataloader/strategy.rs ================================================ /// A strategy to batch items. pub trait BatchStrategy: Send + Sync { /// Adds an item to the strategy. /// /// # Arguments /// /// * `item` - The item to add. fn add(&mut self, item: I); /// Batches the items. /// /// # Arguments /// /// * `force` - Whether to force batching. /// /// # Returns /// /// The batched items. fn batch(&mut self, force: bool) -> Option>; /// Creates a new strategy of the same type. /// /// # Returns /// /// The new strategy. fn clone_dyn(&self) -> Box>; /// Returns the expected batch size for this strategy. /// /// # Returns /// /// The batch size, or None if the strategy doesn't have a fixed batch size. fn batch_size(&self) -> Option; } /// A strategy to batch items with a fixed batch size. pub struct FixBatchStrategy { items: Vec, batch_size: usize, } impl FixBatchStrategy { /// Creates a new strategy to batch items with a fixed batch size. /// /// # Arguments /// /// * `batch_size` - The batch size. /// /// # Returns /// /// The strategy. pub fn new(batch_size: usize) -> Self { FixBatchStrategy { items: Vec::with_capacity(batch_size), batch_size, } } } impl BatchStrategy for FixBatchStrategy { fn add(&mut self, item: I) { self.items.push(item); } fn batch(&mut self, force: bool) -> Option> { if self.items.len() < self.batch_size && !force { return None; } let mut items = Vec::with_capacity(self.batch_size); std::mem::swap(&mut items, &mut self.items); if items.is_empty() { return None; } Some(items) } fn clone_dyn(&self) -> Box> { Box::new(Self::new(self.batch_size)) } fn batch_size(&self) -> Option { Some(self.batch_size) } } ================================================ FILE: crates/burn-core/src/data/mod.rs ================================================ /// Dataloader module. #[cfg(feature = "dataset")] pub mod dataloader; /// Dataset module. #[cfg(feature = "dataset")] pub mod dataset { pub use burn_dataset::*; } /// Network module. #[cfg(feature = "network")] pub mod network { pub use burn_std::network::*; } ================================================ FILE: crates/burn-core/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![recursion_limit = "135"] //! The core crate of Burn. #[macro_use] extern crate derive_new; /// Re-export serde for proc macros. pub use serde; /// The configuration module. pub mod config; /// Data module. #[cfg(feature = "std")] pub mod data; /// Module for the neural network module. pub mod module; /// Module for the recorder. pub mod record; /// Module for the tensor. pub mod tensor; // Tensor at root: `burn::Tensor` pub use tensor::Tensor; /// Module for visual operations #[cfg(feature = "vision")] pub mod vision; extern crate alloc; /// Backend for test cases #[cfg(all( test, not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda"), not(feature = "test-rocm") ))] pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] /// Backend for test cases pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] /// Backend for test cases pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] /// Backend for test cases pub type TestBackend = burn_cuda::Cuda; #[cfg(all(test, feature = "test-rocm"))] /// Backend for test cases pub type TestBackend = burn_rocm::Rocm; /// Backend for autodiff test cases #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[cfg(all(test, feature = "test-memory-checks"))] mod tests { burn_fusion::memory_checks!(); } #[cfg(test)] mod test_utils { use crate as burn; use crate::module::Module; use crate::module::Param; use burn_tensor::Tensor; use burn_tensor::backend::Backend; /// Simple linear module. #[derive(Module, Debug)] pub struct SimpleLinear { pub weight: Param>, pub bias: Option>>, } impl SimpleLinear { pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self { let weight = Tensor::random( [out_features, in_features], burn_tensor::Distribution::Default, device, ); let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device); Self { weight: Param::from_tensor(weight), bias: Some(Param::from_tensor(bias)), } } } } pub mod prelude { //! Structs and macros used by most projects. Add `use //! burn::prelude::*` to your code to quickly get started with //! Burn. pub use crate::{ config::Config, module::Module, tensor::{ Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData, backend::Backend, cast::ToElement, s, }, }; pub use burn_std::device::Device as DeviceOps; } ================================================ FILE: crates/burn-core/src/module/base.rs ================================================ use super::{Param, ParamId, Quantizer}; use crate::{ record::Record, tensor::backend::{AutodiffBackend, Backend}, }; use alloc::{string::String, vec::Vec}; pub use burn_derive::Module; use burn_tensor::{Bool, Int, Tensor, ops::Device}; /// Type alias to `Vec` which supports `no_std` environments, but automatically using /// the `alloc` crate. pub type Devices = Vec>; // At the moment, our plan is to continue experimenting with the macro internally and monitor its development. // We may consider making it public in the future. macro_rules! module { (map=$module:ident, ops=$item:expr) => {{ struct Mapper; impl ModuleMapper for Mapper { fn map_float( &mut self, param: Param>, ) -> Param> { let (id, tensor, mapper) = param.consume(); let func = $item; let tensor = func(tensor); Param::from_mapped_value(id, tensor, mapper) } } let mut mapper = Mapper; $module.map(&mut mapper) }}; (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ struct Visitor<'a, B: Backend> { state: &'a mut $state_ty, backend: core::marker::PhantomData, } impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { fn visit_float(&mut self, param: &Param>) { let func = $item; func(¶m.val(), &mut self.state) } } #[allow(clippy::redundant_closure_call)] let mut state = $init(); let mut visitor = Visitor { state: &mut state, backend: core::marker::PhantomData, }; $module.visit(&mut visitor); state }}; } /// Trait for all neural network modules. /// /// Modules should be created using the [derive](burn_derive::Module) attribute. /// This will make your module trainable, savable and loadable via /// `state` and `load`. /// /// # Example /// /// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic /// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code /// necessary to optimize and train the module on any backend. /// /// ```rust, ignore /// // Not necessary when using the burn crate directly. /// use burn_core as burn; /// /// use burn::{ /// module::Module, /// nn::Linear, /// tensor::Tensor, /// tensor::backend::Backend, /// }; /// /// #[derive(Module, Debug)] /// struct MyModule { /// my_param: Linear, /// my_other_field: usize, /// } /// ``` pub trait Module: Clone + Send + core::fmt::Debug { /// Type to save and load the module. type Record: Record; /// Return all the devices found in the underneath module tree added to the given vector /// without duplicates. fn collect_devices(&self, devices: Devices) -> Devices; /// Return all the devices found in the underneath module tree without duplicates. fn devices(&self) -> Devices { self.collect_devices(Devices::::new()) } /// Fork the module and all of its sub-modules to the given device. /// /// # Notes /// /// This is similar to [to_device](Module::to_device), but it ensures the output module on the /// new device will have its own autodiff graph. fn fork(self, device: &B::Device) -> Self; /// Move the module and all of its sub-modules to the given device. /// /// # Warnings /// /// The operation supports autodiff and it will be registered when activated. However, this may /// not be what you want. The output model will be an intermediary model, meaning that you /// can't optimize it with gradient descent. If you want to optimize the output network on the /// target device, use [fork](Module::fork) instead. fn to_device(self, device: &B::Device) -> Self; /// Each tensor in the module tree will not require grad. /// /// # Warnings /// /// This should not be used for inference, use [valid](AutodiffModule::valid) when using /// AD modules. This is mostly useful when performing partial finetuning, which is updating only /// a small fraction of the parameters instead of finetuning all of them. fn no_grad(self) -> Self { module!( map = self, ops = |tensor: Tensor| tensor.set_require_grad(false) ) } /// Move the module and all of its sub-modules to the autodiff backend. /// /// # Notes /// /// * Only plain modules (not already on an autodiff backend) can be moved. /// * Calling `train()` on a module that is already on an autodiff backend /// will result in a type error, because the module's inner backend does not match. fn train(self) -> >::TrainModule where AB: AutodiffBackend, Self: HasAutodiffModule, { >::TrainModule::from_inner(self) } /// Get the number of parameters the module has, including all of its sub-modules. fn num_params(&self) -> usize { module!( visit_float = self, ops = |tensor: &Tensor, state: &mut usize| { *state += tensor.shape().num_elements(); }, state = usize, init = || 0 ) } /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor). fn visit>(&self, visitor: &mut Visitor); /// Map each tensor parameter in the module with a [mapper](ModuleMapper). fn map>(self, mapper: &mut Mapper) -> Self; /// Load the module state from a record. fn load_record(self, record: Self::Record) -> Self; /// Convert the module into a record containing the state. fn into_record(self) -> Self::Record; #[cfg(feature = "std")] /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder). /// /// List of supported file recorders: /// /// * [default](crate::record::DefaultFileRecorder) /// * [bincode](crate::record::BinFileRecorder) /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder) /// * [json pretty](crate::record::PrettyJsonFileRecorder) /// * [json compressed with gzip](crate::record::JsonGzFileRecorder) /// * [named mpk](crate::record::NamedMpkFileRecorder) /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder) /// /// ## Notes /// /// The file extension is automatically added depending on the file recorder provided, you /// don't have to specify it. fn save_file( self, file_path: PB, recorder: &FR, ) -> Result<(), crate::record::RecorderError> where FR: crate::record::FileRecorder, PB: Into, { let record = Self::into_record(self); recorder.record(record, file_path.into()) } #[cfg(feature = "std")] /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). /// /// The recorder should be the same as the one used to save the module, see /// [save_file](Self::save_file). /// /// ## Notes /// /// The file extension is automatically added depending on the file recorder provided, you /// don't have to specify it. fn load_file( self, file_path: PB, recorder: &FR, device: &B::Device, ) -> Result where FR: crate::record::FileRecorder, PB: Into, { let record = recorder.load(file_path.into(), device)?; Ok(self.load_record(record)) } /// Quantize the weights of the module. fn quantize_weights(self, quantizer: &mut Quantizer) -> Self { self.map(quantizer) } } /// Module visitor trait for traversing and inspecting module parameters. pub trait ModuleVisitor { /// Visit a float parameter in the module. /// /// # Parameters /// - `param`: The float parameter to visit #[allow(unused_variables)] fn visit_float(&mut self, param: &Param>) {} /// Visit an int parameter in the module. /// /// # Parameters /// - `param`: The integer parameter to visit #[allow(unused_variables)] fn visit_int(&mut self, param: &Param>) {} /// Visit a bool parameter in the module. /// /// # Parameters /// - `param`: The boolean parameter to visit #[allow(unused_variables)] fn visit_bool(&mut self, param: &Param>) {} /// Called when entering a submodule. /// /// # Parameters /// - `name`: The name of the submodule being entered /// - `container_type`: The type of the container with format: /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear") /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum") /// - For Vec containers: "Vec" (name is the index) /// - For Tuple containers: "Tuple" (name is the index) /// - For Array containers: "Array" (name is the index) /// /// Note: Option containers do not call enter_module/exit_module to preserve /// the field name in the path (e.g., "bias" instead of "bias.Some") #[allow(unused_variables)] fn enter_module(&mut self, name: &str, container_type: &str) {} /// Called when exiting a submodule. /// /// # Parameters /// - `name`: The name of the submodule being exited /// - `container_type`: The type of the container with format: /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear") /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum") /// - For Vec containers: "Vec" (name is the index) /// - For Tuple containers: "Tuple" (name is the index) /// - For Array containers: "Array" (name is the index) /// /// Note: Option containers do not call enter_module/exit_module to preserve /// the field name in the path (e.g., "bias" instead of "bias.Some") #[allow(unused_variables)] fn exit_module(&mut self, name: &str, container_type: &str) {} /// Visit a float tensor with its full module path. /// /// # Parameters /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]). /// Each element represents a module name in the hierarchy, with the final element /// being the parameter name. This allows efficient reuse of the path stack. /// - `id`: The unique identifier of the parameter /// - `tensor`: The float tensor to visit #[allow(unused_variables)] fn visit_float_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { } /// Visit an int tensor with its full module path. /// /// # Parameters /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]). /// Each element represents a module name in the hierarchy, with the final element /// being the parameter name. This allows efficient reuse of the path stack. /// - `id`: The unique identifier of the parameter /// - `tensor`: The integer tensor to visit #[allow(unused_variables)] fn visit_int_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { } /// Visit a bool tensor with its full module path. /// /// # Parameters /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]). /// Each element represents a module name in the hierarchy, with the final element /// being the parameter name. This allows efficient reuse of the path stack. /// - `id`: The unique identifier of the parameter /// - `tensor`: The boolean tensor to visit #[allow(unused_variables)] fn visit_bool_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { } } /// Module mapper trait for transforming module parameters. pub trait ModuleMapper { /// Called when entering a submodule. /// /// # Parameters /// - `name`: The name of the submodule being entered /// - `container_type`: The type of the container with format: /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear") /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum") /// - For Vec containers: "Vec" (name is the index) /// - For Tuple containers: "Tuple" (name is the index) /// - For Array containers: "Array" (name is the index) /// /// Note: Option containers do not call enter_module/exit_module to preserve /// the field name in the path (e.g., "bias" instead of "bias.Some") #[allow(unused_variables)] fn enter_module(&mut self, name: &str, container_type: &str) {} /// Called when exiting a submodule. /// /// # Parameters /// - `name`: The name of the submodule being exited /// - `container_type`: The type of the container with format: /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear") /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum") /// - For Vec containers: "Vec" (name is the index) /// - For Tuple containers: "Tuple" (name is the index) /// - For Array containers: "Array" (name is the index) /// /// Note: Option containers do not call enter_module/exit_module to preserve /// the field name in the path (e.g., "bias" instead of "bias.Some") #[allow(unused_variables)] fn exit_module(&mut self, name: &str, container_type: &str) {} /// Map a float parameter in the module. /// /// # Parameters /// - `param`: The float parameter to transform /// /// # Returns /// The transformed parameter #[allow(unused_variables)] fn map_float(&mut self, param: Param>) -> Param> { let (id, tensor, mapper) = param.consume(); Param::from_mapped_value(id, tensor, mapper) } /// Map an int parameter in the module. /// /// # Parameters /// - `param`: The integer parameter to transform /// /// # Returns /// The transformed parameter #[allow(unused_variables)] fn map_int( &mut self, param: Param>, ) -> Param> { let (id, tensor, mapper) = param.consume(); Param::from_mapped_value(id, tensor, mapper) } /// Map a bool parameter in the module. /// /// # Parameters /// - `param`: The boolean parameter to transform /// /// # Returns /// The transformed parameter #[allow(unused_variables)] fn map_bool( &mut self, param: Param>, ) -> Param> { let (id, tensor, mapper) = param.consume(); Param::from_mapped_value(id, tensor, mapper) } } /// Module with auto-differentiation backend. pub trait AutodiffModule: Module + Send + core::fmt::Debug { /// Inner module without auto-differentiation. type InnerModule: Module; /// Returns the same module, but on the inner backend without auto-differentiation. fn valid(&self) -> Self::InnerModule; /// Wraps an inner module back into an auto-diff module. fn from_inner(module: Self::InnerModule) -> Self; } /// Helper trait to associate a module with its autodiff version. pub trait HasAutodiffModule { /// The module with auto-differentiation. type TrainModule: AutodiffModule; } #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use crate::test_utils::SimpleLinear; #[test] fn test_module_val_train_stateful() { let device = Default::default(); let module = SimpleLinear::::new(4, 4, &device); assert!(module.weight.is_require_grad()); assert!(module.weight.require_grad); let module = module.valid(); assert!(!module.weight.is_require_grad()); assert!(module.weight.require_grad); // stateful // Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying // let module: SimpleLinear = module.train(); let module = module.train::(); assert!(module.weight.is_require_grad()); assert!(module.weight.require_grad); // stateful let module = module.no_grad(); assert!(!module.weight.is_require_grad()); assert!(!module.weight.require_grad); // stateful let module = module.valid(); assert!(!module.weight.is_require_grad()); // always assert!(!module.weight.require_grad); // stateful let module = module.train::(); assert!(!module.weight.is_require_grad()); assert!(!module.weight.require_grad); // stateful } } ================================================ FILE: crates/burn-core/src/module/display.rs ================================================ use alloc::{ borrow::ToOwned, format, string::{String, ToString}, vec::Vec, }; use core::any; use core::fmt::{Debug, Display, Write}; /// Default display settings for a module. pub trait ModuleDisplayDefault { /// Attributes of the module used for display purposes. /// /// # Arguments /// /// * `_content` - The content object that contains display settings and attributes. /// /// # Returns /// /// An optional content object containing the display attributes. fn content(&self, _content: Content) -> Option; /// Gets the number of the parameters of the module. fn num_params(&self) -> usize { 0 } } /// Trait to implement custom display settings for a module. /// /// In order to implement custom display settings for a module, /// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)] /// 2. Implement ModuleDisplay trait for the module pub trait ModuleDisplay: ModuleDisplayDefault { /// Formats the module with provided display settings. /// /// # Arguments /// /// * `passed_settings` - Display settings passed to the module. /// /// # Returns /// /// A string representation of the formatted module. fn format(&self, passed_settings: DisplaySettings) -> String { let settings = if let Some(custom_settings) = self.custom_settings() { custom_settings.inherit(passed_settings) } else { passed_settings }; let indent = " ".repeat(settings.level * settings.indentation_size()); let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size()); let settings = settings.level_up(); let self_type = extract_type_name::(); // Use custom content if it is implemented and show_all_attributes is false, // otherwise use default content let content = if !settings.show_all_attributes() { self.custom_content(Content::new(settings.clone())) .unwrap_or_else(|| { self.content(Content::new(settings.clone())) .unwrap_or_else(|| { panic!("Default content should be implemented for {self_type}.") }) }) } else { self.content(Content::new(settings.clone())) .unwrap_or_else(|| panic!("Default content should be implemented for {self_type}.")) }; let top_level_type = if let Some(top_level_type) = content.top_level_type { top_level_type.to_owned() } else { self_type.to_owned() }; // If there is only one item in the content, return it or no attributes if let Some(item) = content.single_item { return item; } else if content.attributes.is_empty() { return top_level_type.to_string(); } let mut result = String::new(); // Print the struct name if settings.new_line_after_attribute() { writeln!(result, "{top_level_type} {{").unwrap(); } else { write!(result, "{top_level_type} {{").unwrap(); } for (i, attribute) in content.attributes.iter().enumerate() { if settings.new_line_after_attribute() { writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap(); } else if i == 0 { write!(result, "{}: {}", attribute.name, attribute.value).unwrap(); } else { write!(result, ", {}: {}", attribute.name, attribute.value).unwrap(); } } if settings.show_num_parameters() { let num_params = self.num_params(); if num_params > 0 { if settings.new_line_after_attribute() { writeln!(result, "{indent}params: {num_params}").unwrap(); } else { write!(result, ", params: {num_params}").unwrap(); } } } if settings.new_line_after_attribute() { write!(result, "{indent_close_braces}}}").unwrap(); } else { write!(result, "}}").unwrap(); } result } /// Custom display settings for the module. /// /// # Returns /// /// An optional display settings object. fn custom_settings(&self) -> Option { None } /// Custom attributes for the module. /// /// # Arguments /// /// * `_content` - The content object that contains display settings and attributes. /// /// # Returns /// /// An optional content object containing the custom attributes. fn custom_content(&self, _content: Content) -> Option { None } } /// Custom module display settings. #[derive(Debug, Clone)] pub struct DisplaySettings { /// Whether to print the module parameter ids. show_param_id: Option, /// Whether to print the module attributes. show_all_attributes: Option, /// Whether to print the module number of parameters. show_num_parameters: Option, /// Print new line after an attribute. new_line_after_attribute: Option, /// Indentation size. indentation_size: Option, /// Level of indentation. level: usize, } impl Default for DisplaySettings { fn default() -> Self { DisplaySettings { show_param_id: None, show_all_attributes: None, show_num_parameters: None, new_line_after_attribute: None, indentation_size: None, level: 1, } } } impl DisplaySettings { /// Create a new format settings. /// /// # Returns /// /// A new instance of `DisplaySettings`. pub fn new() -> Self { Default::default() } /// Sets a flag to show module parameters. /// /// # Arguments /// /// * `flag` - Boolean flag to show module parameters. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn with_show_param_id(mut self, flag: bool) -> Self { self.show_param_id = Some(flag); self } /// Sets a flag to show module attributes. /// /// # Arguments /// /// * `flag` - Boolean flag to show all module attributes. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn with_show_all_attributes(mut self, flag: bool) -> Self { self.show_all_attributes = Some(flag); self } /// Sets a flag to show the number of module parameters. /// /// # Arguments /// /// * `flag` - Boolean flag to show the number of module parameters. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn with_show_num_parameters(mut self, flag: bool) -> Self { self.show_num_parameters = Some(flag); self } /// Sets a flag to print a new line after an attribute. /// /// # Arguments /// /// * `flag` - Boolean flag to print a new line after an attribute. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self { self.new_line_after_attribute = Some(flag); self } /// Sets the indentation size. /// /// # Arguments /// /// * `size` - The size of the indentation. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn with_indentation_size(mut self, size: usize) -> Self { self.indentation_size = Some(size); self } /// Inherits settings from the provided settings and return a new settings object. /// /// # Arguments /// /// * `top` - The top level `DisplaySettings` to inherit from. /// /// # Returns /// /// Updated `DisplaySettings` instance. pub fn inherit(self, top: Self) -> Self { let mut updated = self.clone(); if let Some(show_param_id) = top.show_param_id { updated.show_param_id = Some(show_param_id); }; if let Some(show_all_attributes) = top.show_all_attributes { updated.show_all_attributes = Some(show_all_attributes); } if let Some(show_num_parameters) = top.show_num_parameters { updated.show_num_parameters = Some(show_num_parameters); } if let Some(new_line_after_attribute) = top.new_line_after_attribute { updated.new_line_after_attribute = Some(new_line_after_attribute); } if let Some(indentation_size) = top.indentation_size { updated.indentation_size = Some(indentation_size); } updated.level = top.level; updated } /// A convenience method to wrap the DisplaySettings struct in an option. /// /// # Returns /// /// An optional `DisplaySettings`. pub fn optional(self) -> Option { Some(self) } /// Increases the level of indentation. /// /// # Returns /// /// Updated `DisplaySettings` instance with increased indentation level. pub fn level_up(mut self) -> Self { self.level += 1; self } /// Gets `show_param_id` flag, substitutes false if not set. /// /// This flag is used to print the module parameter ids. /// /// # Returns /// /// A boolean value indicating whether to show parameter ids. pub fn show_param_id(&self) -> bool { self.show_param_id.unwrap_or(false) } /// Gets `show_all_attributes`, substitutes false if not set. /// /// This flag is used to force to print all module attributes, overriding custom attributes. /// /// # Returns /// /// A boolean value indicating whether to show all attributes. pub fn show_all_attributes(&self) -> bool { self.show_all_attributes.unwrap_or(false) } /// Gets `show_num_parameters`, substitutes true if not set. /// /// This flag is used to print the number of module parameters. /// /// # Returns /// /// A boolean value indicating whether to show the number of parameters. pub fn show_num_parameters(&self) -> bool { self.show_num_parameters.unwrap_or(true) } /// Gets `new_line_after_attribute`, substitutes true if not set. /// /// This flag is used to print a new line after an attribute. /// /// # Returns /// /// A boolean value indicating whether to print a new line after an attribute. pub fn new_line_after_attribute(&self) -> bool { self.new_line_after_attribute.unwrap_or(true) } /// Gets `indentation_size`, substitutes 2 if not set. /// /// This flag is used to set the size of indentation. /// /// # Returns /// /// An integer value indicating the size of indentation. pub fn indentation_size(&self) -> usize { self.indentation_size.unwrap_or(2) } } /// Struct to store the attributes of a module for formatting. #[derive(Clone, Debug)] pub struct Content { /// List of attributes. pub attributes: Vec, /// Single item content. pub single_item: Option, /// Display settings. pub display_settings: DisplaySettings, /// Top level type name. pub top_level_type: Option, } impl Content { /// Creates a new attributes struct. /// /// # Arguments /// /// * `display_settings` - Display settings for the content. /// /// # Returns /// /// A new instance of `Content`. pub fn new(display_settings: DisplaySettings) -> Self { Content { attributes: Vec::new(), single_item: None, display_settings, top_level_type: None, } } /// Adds an attribute to the format settings. The value will be formatted and stored as a string. /// /// # Arguments /// /// * `name` - Name of the attribute. /// * `value` - Value of the attribute. /// /// # Returns /// /// Updated `Content` instance with the new attribute added. pub fn add(mut self, name: &str, value: &T) -> Self { if self.single_item.is_some() { panic!("Cannot add multiple attributes when single item is set."); } let attribute = Attribute { name: name.to_owned(), value: value.format(self.display_settings.clone()), // TODO level + 1 ty: any::type_name::().to_string(), }; self.attributes.push(attribute); self } /// Adds an attribute using its `Debug` representation. /// /// This is intended for fields that do not implement [`ModuleDisplay`]. /// /// # Arguments /// /// * `name` - Name of the attribute. /// * `value` - Value of the attribute. /// /// # Returns /// /// Updated `Content` instance with the new attribute added. pub fn add_debug_attribute(mut self, name: &str, value: &T) -> Self { if self.single_item.is_some() { panic!("Cannot add multiple attributes when single item is set."); } self.attributes.push(Attribute { name: name.to_owned(), value: DisplayAdapter(value).format(self.display_settings.clone()), ty: any::type_name::().to_string(), }); self } /// Adds a single item. /// /// # Arguments /// /// * `value` - Rendered string of the single item. /// /// # Returns /// /// Updated `Content` instance with the single item added. pub fn add_single(mut self, value: &T) -> Self { if !self.attributes.is_empty() { panic!("Cannot add single item when attributes are set."); } self.single_item = Some(value.format(self.display_settings.clone())); self } /// Adds a single item. /// /// # Arguments /// /// * `value` - Formatted display value. /// /// # Returns /// /// Updated `Content` instance with the formatted single item added. pub fn add_formatted(mut self, value: &T) -> Self { if !self.attributes.is_empty() { panic!("Cannot add single item when attributes are set."); } self.single_item = Some(format!("{value}")); self } /// A convenience method to wrap the Attributes struct in an option /// because it is often used as an optional field. /// /// # Returns /// /// An optional `Content`. pub fn optional(self) -> Option { if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none() { None } else { Some(self) } } /// Sets the top level type name. /// /// # Arguments /// /// * `ty` - The type name to set. /// /// # Returns /// /// Updated `Content` instance with the top level type name set. pub fn set_top_level_type(mut self, ty: &str) -> Self { self.top_level_type = Some(ty.to_owned()); self } } /// Minimal display adapter for non-module types. struct DisplayAdapter<'a, T: Debug>(&'a T); impl<'a, T: Debug> ModuleDisplayDefault for DisplayAdapter<'a, T> { fn content(&self, content: Content) -> Option { content.add_single(&format!("{:?}", self.0)).optional() } } impl<'a, T: Debug> ModuleDisplay for DisplayAdapter<'a, T> {} /// Attribute to print in the display method. #[derive(Clone, Debug)] pub struct Attribute { /// Name of the attribute. pub name: String, /// Value of the attribute. pub value: String, /// Type of the attribute. pub ty: String, } /// Extracts the short name of a type T /// /// # Returns /// /// A string slice representing the short name of the type. pub fn extract_type_name() -> &'static str { // Get the full type name of T, including module path and generic parameters let ty = any::type_name::(); // Find the first occurrence of '<' in the full type name // If not found, use the length of the type name let end = ty.find('<').unwrap_or(ty.len()); // Slice the type name up to the first '<' or the end let ty = &ty[0..end]; // Find the last occurrence of "::" in the sliced type name // If found, add 2 to skip the "::" itself // If not found, start from the beginning of the type name let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0); // Find the last occurrence of '<' in the sliced type name // If not found, use the length of the type name let end = ty.rfind('<').unwrap_or(ty.len()); // If the start index is less than the end index, // return the slice of the type name from start to end // Otherwise, return the entire sliced type name if start < end { &ty[start..end] } else { ty } } ================================================ FILE: crates/burn-core/src/module/initializer.rs ================================================ use crate::tensor::Shape; use crate::config::Config; use crate::module::{Param, ParamId}; use crate::tensor::backend::Backend; use crate::tensor::{Distribution, Tensor, s}; use crate as burn; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Enum specifying with what values a tensor should be initialized #[derive(Config, Debug, PartialEq)] pub enum Initializer { /// Fills tensor with specified value everywhere Constant { /// The value to fill the tensor with value: f64, }, /// Fills tensor with 1s everywhere Ones, /// Fills tensor with 0s everywhere Zeros, /// Fills tensor with values drawn uniformly between specified values Uniform { /// The minimum value to draw from min: f64, /// The maximum value to draw from max: f64, }, /// Fills tensor with values drawn from normal distribution with specified mean and std Normal { /// The mean of the normal distribution mean: f64, /// The standard deviation of the normal distribution std: f64, }, /// Fills tensor with values according to the uniform version of Kaiming initialization KaimingUniform { /// The gain to use in initialization formula gain: f64, /// Whether to use fan out only in initialization formula fan_out_only: bool, }, /// Fills tensor with values according to the uniform version of Kaiming initialization KaimingNormal { /// The gain to use in initialization formula gain: f64, /// Whether to use fan out only in initialization formula fan_out_only: bool, }, /// Fills tensor with values according to the uniform version of Xavier Glorot initialization /// described in [Understanding the difficulty of training deep feedforward neural networks /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) XavierUniform { /// The gain to use in initialization formula gain: f64, }, /// Fills tensor with values according to the normal version of Xavier Glorot initialization /// described in [Understanding the difficulty of training deep feedforward neural networks /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) XavierNormal { /// The gain to use in initialization formula gain: f64, }, /// Fills tensor with values according to the (semi) orthogonal initialization /// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks` /// - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120) Orthogonal { /// The gain to use in initialization formula gain: f64, }, } impl Initializer { /// Inits a tensor parameter of given shape with values depending on initializer kind. /// /// # Params /// /// - shape: Shape of the initiated tensor. pub fn init>( &self, shape: S, device: &B::Device, ) -> Param> { self.init_with(shape, None, None, device) } /// Inits a tensor parameter of given shape with values depending on initializer kind. /// /// # Params /// /// - shape: Shape of the initiated tensor. pub fn init_with>( &self, shape: S, fan_in: Option, fan_out: Option, device: &B::Device, ) -> Param> { let device = device.clone(); let shape: Shape = shape.into(); let config = self.clone(); let shape_for_closure = shape.clone(); Param::uninitialized( ParamId::new(), move |device, require_grad| { B::memory_persistent_allocations(device, (), move |_| { let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device); if require_grad { tensor = tensor.require_grad(); } tensor }) }, device, true, shape_for_closure, ) } fn init_tensor>( &self, shape: S, fan_in: Option, fan_out: Option, device: &B::Device, ) -> Tensor { let shape = shape.into(); match self { Initializer::Constant { value } => Tensor::::full(shape, *value, device), Initializer::Ones => Tensor::::ones(shape, device), Initializer::Zeros => Tensor::::zeros(shape, device), Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device), Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device), Initializer::KaimingUniform { gain, fan_out_only } => { let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); uniform_draw(shape, -a, a, device) } Initializer::KaimingNormal { gain, fan_out_only } => { let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out); normal_draw(shape, 0.0, std, device) } Initializer::XavierUniform { gain } => { let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out); uniform_draw(shape, -a, a, device) } Initializer::XavierNormal { gain } => { let std = *gain * self.xavier_std(fan_in, fan_out); normal_draw(shape, 0.0, std, device) } Initializer::Orthogonal { gain } => { // following the implementation in pytorch: // https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574 assert!( D >= 2, "Expected D (in Tensor) to be greater or equal 2; (D >= 2)" ); let rows: usize = shape.dims::()[0]; let cols: usize = shape.num_elements() / rows; let mut t: Tensor = normal_draw([rows, cols], 0.0, 1.0, device); if rows < cols { t = t.transpose(); } let (q, r) = qr_decomposition(t, device); let [r_rows, r_cols] = r.clone().dims(); let diag_r = Tensor::::ones([1, r_rows], device) .matmul(Tensor::::eye(r_cols, device).mul(r.clone())); let ph = diag_r.clone().sign(); let mut q = q.mul(ph); if rows < cols { q = q.transpose(); } q.reshape(shape).mul_scalar(*gain) } } } fn kaiming_std( &self, fan_out_only: bool, fan_in: Option, fan_out: Option, ) -> f64 { let fan = if fan_out_only { fan_out } else { fan_in }; let fan = fan.expect( "Can't use Kaiming initialization without specifying fan. Use init_with method.", ); 1.0 / (fan as f64).sqrt() } fn xavier_std(&self, fan_in: Option, fan_out: Option) -> f64 { let fan_in = fan_in.expect( "Can't use Xavier initialization without specifying fan in. Use init_with method and \ provide fan_in.", ); let fan_out = fan_out.expect( "Can't use Xavier initialization without specifying fan out. Use init_with method and \ provide fan_out.", ); (2.0 / (fan_in + fan_out) as f64).sqrt() } } fn uniform_draw>( shape: S, low: f64, high: f64, device: &B::Device, ) -> Tensor { let distribution = Distribution::Uniform(low, high); Tensor::::random(shape, distribution, device) } fn normal_draw>( shape: S, mean: f64, std: f64, device: &B::Device, ) -> Tensor { let distribution = Distribution::Normal(mean, std); Tensor::::random(shape, distribution, device) } fn qr_decomposition( a: Tensor, device: &B::Device, ) -> (Tensor, Tensor) { // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process let [m, n] = a.clone().dims(); let mut q = Tensor::::zeros([m, n], device); let mut r = Tensor::::zeros([n, n], device); for j in 0..n { let mut v: Tensor = a.clone().slice(s![.., j..=j]).squeeze_dim(1); for i in 0..j { let q_i: Tensor = q.clone().slice(s![.., i..=i]).squeeze_dim(1); let r_ij = q_i.clone().mul(v.clone()).sum(); r = r .clone() .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze()); v = v - q_i.mul(r_ij); } // norm of v let r_jj = v .clone() .powf(Tensor::from_floats([2.0], device)) .sum() .sqrt(); r = r .clone() .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze()); let q_j = v / r_jj; q = q .clone() .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1)); } (q, r) } #[cfg(test)] mod tests { use super::*; use burn_tensor::{ElementConversion, TensorData}; use num_traits::Pow; pub type TB = burn_ndarray::NdArray; use burn_tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor) { let (actual_vars, actual_means) = tensor.clone().var_mean(0); let actual_vars = actual_vars.to_data(); let actual_vars = actual_vars.as_slice::().unwrap(); let actual_means = actual_means.to_data(); let actual_means = actual_means.as_slice::().unwrap(); for i in 0..tensor.shape()[0] { let actual_var = actual_vars[i] as f64; let actual_mean = actual_means[i] as f64; assert!( (expected_var - actual_var).abs() <= 0.1, "Expected variance to be between {expected_var} += 0.1, but got {actual_var}" ); assert!( (expected_mean - actual_mean).abs() <= 0.1, "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}" ); } } #[test] fn initializer_uniform_init() { let device = Default::default(); TB::seed(&device, 0); let (min, max) = (0.0, 1.0); let uniform = Initializer::Uniform { min, max }; let tensor: Tensor = uniform.init([2, 2, 2, 2], &Default::default()).into_value(); tensor .into_data() .assert_within_range::(min.elem()..max.elem()); } #[test] fn initializer_normal_init() { // seed random generator let device = Default::default(); TB::seed(&device, 0); let (mean, std) = (0.0, 1.0); let normal: Tensor = Initializer::Normal { mean, std } .init([1000], &Default::default()) .into_value(); let (var_act, mean_act) = normal.var_mean(0); let var_act: f32 = var_act.into_scalar().elem(); let mean_act: f32 = mean_act.into_scalar().elem(); assert!( var_act > 0.9 && var_act < 1.1, "Expected variance to be between 1.0 += 0.1, but got {var_act}" ); assert!( mean_act > -0.1 && mean_act < 0.1, "Expected mean to be between 0.0 += 0.1, but got {mean_act}" ); } #[test] fn initializer_constant_init() { let value = 5.0; let constants: Tensor = Initializer::Constant { value } .init([2, 2, 2, 2], &Default::default()) .into_value(); constants.sum().to_data().assert_approx_eq::( &TensorData::from([value as f32 * 16.0]), Tolerance::default(), ); } #[test] fn initializer_zeros_init() { let zeros: Tensor = Initializer::Zeros .init([2, 2, 2, 2], &Default::default()) .into_value(); zeros .sum() .to_data() .assert_approx_eq::(&TensorData::from([0.0]), Tolerance::default()); } #[test] fn initializer_ones_init() { let ones: Tensor = Initializer::Ones .init([2, 2, 2, 2], &Default::default()) .into_value(); ones.sum() .to_data() .assert_approx_eq::(&TensorData::from([16.0]), Tolerance::default()); } #[test] fn initializer_kaiming_uniform_init() { let device = Default::default(); TB::seed(&device, 0); let gain = 2_f64; let (fan_in, fan_out) = (5, 6); let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::(); let tensor: Tensor = Initializer::KaimingUniform { gain, fan_out_only: false, } .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default()) .into_value(); tensor.into_data().assert_within_range(-k..k); } #[test] fn initializer_kaiming_normal_init() { let device = Default::default(); TB::seed(&device, 0); let gain = 2.; let (fan_in, fan_out) = (1000, 10); let expected_mean = 0_f64; let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.); let tensor: Tensor = Initializer::KaimingNormal { gain, fan_out_only: false, } .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default()) .into_value(); assert_normal_init(expected_mean, expected_var, &tensor) } #[test] fn initializer_kaiming_uniform_init_bias() { let device = Default::default(); TB::seed(&device, 0); let gain = 2_f64; let shape = [3]; let fan_in = 5; let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::(); let tensor: Tensor = Initializer::KaimingUniform { gain, fan_out_only: false, } .init_with(shape, Some(fan_in), None, &Default::default()) .into_value(); tensor.into_data().assert_within_range(-k..k); } #[test] fn initializer_kaiming_uniform_init_fan_out() { let device = Default::default(); TB::seed(&device, 0); let gain = 2_f64; let (fan_in, fan_out) = (5, 6); let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::(); let tensor: Tensor = Initializer::KaimingUniform { gain, fan_out_only: true, } .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default()) .into_value(); tensor.into_data().assert_within_range(-k..k); } #[test] #[should_panic] fn initializer_kaiming_uniform_no_fan() { let device = Default::default(); TB::seed(&device, 0); let gain = 2_f64; let (fan_in, fan_out) = (5, 6); let _: Tensor = Initializer::KaimingUniform { gain, fan_out_only: false, } .init([fan_out, fan_in], &Default::default()) .into_value(); } #[test] fn initializer_xavier_uniform_init() { let device = Default::default(); TB::seed(&device, 0); let gain = 2.; let (fan_in, fan_out) = (5, 6); let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::(); let tensor: Tensor = Initializer::XavierUniform { gain } .init_with( [fan_out, fan_in], Some(fan_in), Some(fan_out), &Default::default(), ) .into_value(); tensor.into_data().assert_within_range(-bound..bound); } #[test] fn initializer_xavier_normal_init() { let device = Default::default(); TB::seed(&device, 0); let gain = 2.; let (fan_in, fan_out) = (1000, 10); let expected_mean = 0_f64; let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.); let tensor: Tensor = Initializer::XavierNormal { gain } .init_with( [fan_out, fan_in], Some(fan_in), Some(fan_out), &Default::default(), ) .into_value(); assert_normal_init(expected_mean, expected_var, &tensor) } #[test] #[should_panic] fn initializer_xavier_uniform_no_fan() { let device = Default::default(); TB::seed(&device, 0); let gain = 2.; let (fan_in, fan_out) = (5, 6); let _: Tensor = Initializer::XavierUniform { gain } .init([fan_out, fan_in], &Default::default()) .into_value(); } #[test] fn test_qr_decomposition() { let device = Default::default(); TB::seed(&device, 0); // test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr let a = Tensor::::from_floats( [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]], &Default::default(), ); let qr = qr_decomposition(a.clone(), &Default::default()); // Q @ R should reconstruct input `a` let q_matmul_r = qr.0.clone().matmul(qr.1.clone()); // assert that the difference between input (`a`) and Q @ R is (almost) zero q_matmul_r .into_data() .assert_approx_eq::(&a.into_data(), Tolerance::rel_abs(0.1, 0.1)); } #[test] fn initializer_orthogonal_correct() { let device = Default::default(); TB::seed(&device, 0); let gain = 1.; // test 2D tensor let size = 10; let q: Tensor = Initializer::Orthogonal { gain } .init([size, size], &Default::default()) .into_value(); let eye = Tensor::::eye(size, &Default::default()); // Q.T @ Q should be close to identity matrix q.clone() .transpose() .matmul(q) .into_data() .assert_approx_eq::(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1)); } #[test] fn initializer_orthogonal_init() { let device = Default::default(); TB::seed(&device, 0); let gain = 1.; // test 2D tensor let shape = [25, 30]; let t: Tensor = Initializer::Orthogonal { gain } .init(shape, &Default::default()) .into_value(); let dims = t.dims(); assert_eq!( shape, dims, "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})" ); // test 3D tensor let shape = [24, 6, 85]; let t: Tensor = Initializer::Orthogonal { gain } .init(shape, &Default::default()) .into_value(); let dims = t.dims(); assert_eq!( shape, dims, "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})" ); } #[test] #[should_panic] fn initializer_orthogonal_init_1d() { let device = Default::default(); TB::seed(&device, 0); let gain = 1.; // test 1D tensor let shape = [3]; let _: Tensor = Initializer::Orthogonal { gain } .init(shape, &Default::default()) .into_value(); } } ================================================ FILE: crates/burn-core/src/module/mod.rs ================================================ mod base; mod display; mod initializer; mod param; mod quantize; #[cfg(feature = "std")] mod reinit; pub use base::*; pub use display::*; pub use initializer::*; pub use param::*; pub use quantize::*; #[cfg(feature = "std")] pub use reinit::*; ================================================ FILE: crates/burn-core/src/module/param/base.rs ================================================ use super::ParamId; use alloc::{boxed::Box, format}; use burn_std::stub::RwLock; use burn_tensor::Shape; use core::cell::OnceCell; use core::ops::Deref; #[cfg(target_has_atomic = "ptr")] use alloc::sync::Arc; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; #[cfg(target_has_atomic = "ptr")] type Mapper = Arc T + Send + Sync>; #[cfg(not(target_has_atomic = "ptr"))] type Mapper = Arc T + Send + Sync>>; #[cfg(target_has_atomic = "ptr")] fn new_mapper T + Send + Sync + 'static>(func: F) -> Mapper { Arc::new(func) } #[cfg(not(target_has_atomic = "ptr"))] fn new_mapper T + Send + Sync + 'static>(func: F) -> Mapper { Arc::new(Box::new(func)) } /// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they /// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during /// training, and loaded during inference. If you don't want to save the tensors /// and/or don't want to update it during training, you don't need this type to wrap your tensor. /// /// # Core Lazy Initialization Architecture /// /// `Param` has a dual-state design using `OnceCell`: /// /// ## State Management /// /// **Two possible states:** /// /// 1. **Initialized**: `state: OnceCell` contains value, `initialization: None` /// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock>>)` pub struct Param { /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter. pub id: ParamId, /// The OnceCell holding the initialized parameter value. /// Empty for uninitialized parameters, populated after first access or explicit initialization. pub(crate) state: OnceCell, /// The deferred initialization state for lazy parameters. /// /// **State Transitions:** /// - Initialized params: `None` /// - Uninitialized params: `Some(RwLock)>)` /// - After lazy init triggers: `Some(RwLock)` (inner Option is taken) pub(crate) initialization: Option>>>, pub(crate) param_mapper: ParamMapper, // For stateful `module.valid()` <> `module.train()` pub(crate) require_grad: bool, } #[derive(Clone)] /// Applies transformations when loading and saving parameters. /// /// # Mapper System /// /// `ParamMapper` allows applying transformations during serialization and deserialization: /// - `load: Option>` - transformation during deserialization (applied in `transform_for_load()`) /// - `save: Option>` - transformation during serialization (applied in `transform_for_save()`) /// /// These are commonly used for: /// - Quantization/dequantization /// - Precision conversion (e.g., FP32 ↔ FP16) /// - Custom parameter transformations pub struct ParamMapper { load: Option>, save: Option>, } impl core::fmt::Debug for ParamMapper { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!( "ParamMapper {{ load: {}, save: {} }}", self.load.is_some(), self.save.is_some() )) } } impl ParamMapper { /// Applies the transformation when loading the given parameter. pub fn on_load(&self, param: T) -> T { match &self.load { Some(mapper) => mapper(param), None => param, } } /// Applies the transformation when saving the given parameter. pub fn on_save(&self, param: T) -> T { match &self.save { Some(mapper) => mapper(param), None => param, } } } impl Default for ParamMapper { fn default() -> Self { Self { load: None, save: None, } } } impl core::fmt::Display for Param { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(format!("Param: {}", self.id).as_str()) } } impl core::fmt::Debug for Param { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str()) } } /// Trait that defines what is necessary for a type to be a parameter. pub trait Parameter: Clone + core::fmt::Debug + Send { /// The device type to be used. type Device: Clone; /// Fetch the device. fn device(&self) -> Self::Device; /// Fetch the gradient requirement. fn is_require_grad(&self) -> bool; /// Set the gradient requirement. fn set_require_grad(self, require_grad: bool) -> Self; } /// The deferred initialization state for lazy parameters. #[allow(clippy::type_complexity)] pub(crate) struct Uninitialized { /// The initialization function. Called with `(device, is_require_grad) -> Parameter`. /// This function is consumed during initialization via `FnOnce`. init: Box P + Send>, /// The target device on which the parameter should be initialized. /// Used by `lazy_device()` to provide device information without triggering initialization. pub(crate) device: P::Device, /// The gradient requirement for the parameter. /// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization. pub(crate) is_require_grad: bool, /// The shape of the tensor parameter. /// Used by `lazy_shape()` to provide shape information without triggering initialization. pub(crate) shape: Shape, } impl Uninitialized

{ /// Consumes the uninitialized state and runs the initialization function. /// /// This is called by [Param::val] when accessing an uninitialized parameter for the first time. /// The function is given the stored device and gradient requirement, and returns the initialized parameter. fn initialize(self) -> P { let init = self.init; init(&self.device, self.is_require_grad) } } impl Param { /// Create a new parameter that is already initialized. pub fn initialized(id: ParamId, value: T) -> Self { let require_grad = value.is_require_grad(); Self { id, state: OnceCell::from(value), initialization: None, param_mapper: Default::default(), require_grad, } } /// Create a new parameter that is not already initialized. pub fn uninitialized( id: ParamId, init: F, device: T::Device, is_require_grad: bool, shape: Shape, ) -> Self where F: FnOnce(&T::Device, bool) -> T + Send + 'static, { Self { id, state: OnceCell::new(), initialization: Some(RwLock::new(Some(Uninitialized { init: Box::new(init), device, is_require_grad, shape, }))), param_mapper: Default::default(), require_grad: is_require_grad, } } /// Gets the parameter value, initializing it lazily if needed. /// /// For initialized parameters, this returns a clone of the cached value. /// For uninitialized parameters, this triggers initialization: pub fn val(&self) -> T { self.state .get_or_init(|| { let mut result = self .initialization .as_ref() .expect("Should have an initialization when no state provided.") .write() .unwrap(); let state = result.take().expect("Should exist when not initialized"); state.initialize() }) .clone() } /// Check if the parameter has been initialized. /// /// Returns `true` if the parameter's value has been computed and cached, /// `false` if it's still lazy and will be initialized on first access. pub fn is_initialized(&self) -> bool { self.state.get().is_some() } /// Gets the parameter's value while consuming the parameter. pub fn into_value(self) -> T { self.consume().1 } /// Gets the parameter id and value while consuming the parameter. pub fn consume(self) -> (ParamId, T, ParamMapper) { let tensor = self.val(); core::mem::drop(self.state); (self.id, tensor, self.param_mapper) } /// Execute the given function on the inner value. pub fn map T>(self, func: F) -> Self { let (id, tensor, param_mapper) = self.consume(); let tensor = func(tensor); let require_grad = tensor.is_require_grad(); Self { id, state: OnceCell::from(tensor), initialization: None, param_mapper, require_grad, } } /// Create an initialized parameter with the given id, value, and param mapper. /// /// This is a helper method for creating parameters while preserving the param mapper, /// typically used in ModuleMapper implementations. pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper) -> Self { let require_grad = value.is_require_grad(); Self { id, state: OnceCell::from(value), initialization: None, param_mapper, require_grad, } } /// Runs a transformation on the parameter when loading. pub fn load_mapper T + Send + Sync + 'static>(mut self, func: F) -> Self { self.param_mapper.load = Some(new_mapper(func)); self } /// Runs a transformation on the parameter when saving. pub fn save_mapper T + Send + Sync + 'static>(mut self, func: F) -> Self { self.param_mapper.save = Some(new_mapper(func)); self } /// Execute the given function on the inner value. pub fn init_mapper T + Send + 'static>(self, func: F) -> Self where T: 'static, { let initialization = match &self.initialization { Some(init) => init, None => return self.map(func), }; let mut init = initialization.write().unwrap(); match init.as_mut() { Some(value) => { #[allow(clippy::type_complexity)] let mut prev: Box T + Send> = Box::new(|_, _| panic!("Fake func to not have null ref.")); core::mem::swap(&mut prev, &mut value.init); value.init = Box::new(|a, b| { let tensor = prev(a, b); func(tensor) }); core::mem::drop(init); self } None => { core::mem::drop(init); self.map(func) } } } /// The device on which the parameter is or will be initialized, **without triggering initialization**. /// /// This is critical for the load optimization: when loading tensors into an uninitialized parameter, /// we need to know the target device to move the loaded tensor appropriately, but we don't want to /// trigger the initialization function (which would allocate an unnecessary tensor). /// /// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to /// preserve lazy initialization. pub fn lazy_device(&self) -> T::Device { let initialization = match &self.initialization { Some(init) => init, None => return self.device(), }; let init = initialization.read().unwrap(); match init.as_ref() { Some(value) => value.device.clone(), None => self.device(), } } /// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**. /// /// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization. /// When loading tensors into an uninitialized parameter, we need to apply the correct gradient /// setting to the loaded tensor without triggering the initialization function. /// /// # Notes /// /// This is a crate-private function, since users are not expected to use `is_require_grad` of an /// uninitialized module to then override its value. All low-level functions should be provided /// by `burn` and should handle those details. pub(crate) fn lazy_is_require_grad(&self) -> bool { let initialization = match &self.initialization { Some(init) => init, None => return self.is_require_grad(), }; let init = initialization.read().unwrap(); match init.as_ref() { Some(value) => value.is_require_grad, None => self.is_require_grad(), } } /// Override the gradient requirement for the current parameter. pub fn set_require_grad(self, require_grad: bool) -> Self { let initialization = match &self.initialization { Some(init) => init, None => return self.map(|tensor| tensor.set_require_grad(require_grad)), }; let mut init = initialization.write().unwrap(); let mut is_lazy = false; if let Some(value) = init.as_mut() { is_lazy = true; value.is_require_grad = require_grad; }; core::mem::drop(init); if is_lazy { return self; } self.map(|tensor| tensor.set_require_grad(require_grad)) } } impl Clone for Param { fn clone(&self) -> Self { let mut param = Param::initialized(self.id, self.val()); param.param_mapper = self.param_mapper.clone(); param } } impl Deref for Param { type Target = T; fn deref(&self) -> &Self::Target { self.state.get_or_init(|| { let mut result = self .initialization .as_ref() .expect("Should have an initialization when no state provided.") .write() .unwrap(); let state = result.take().expect("Should exist when not initialized"); state.initialize() }) } } ================================================ FILE: crates/burn-core/src/module/param/constant.rs ================================================ use alloc::{format, string::ToString}; use core::{fmt::Display, marker::PhantomData}; use crate as burn; use crate::{ module::{ AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, ModuleVisitor, }, record::{PrecisionSettings, Record}, }; use burn_tensor::{ BasicAutodiffOps, BasicOps, Tensor, backend::{AutodiffBackend, Backend}, ops::Device, }; #[deprecated( since = "0.21.0", note = "ConstantRecord is misleading as it doesn't persist data. Use EmptyRecord instead." )] /// A record representing the absence of persistent module state. pub type ConstantRecord = EmptyRecord; /// A record representing the absence of persistent module state. /// /// `EmptyRecord` is used for modules that do not store any data to be /// serialized or restored (e.g., modules marked with `#[module(skip)]` /// or modules without parameters). /// /// This record contains no fields and serializes to `None`. #[derive(Debug, Clone, Copy, new, Default, PartialEq, Eq)] pub struct EmptyRecord; impl serde::Serialize for EmptyRecord { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { // nothing to serialize S::serialize_none(serializer) } } impl<'de> serde::Deserialize<'de> for EmptyRecord { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { deserializer.deserialize_option(serde::de::IgnoredAny).ok(); Ok(EmptyRecord::new()) } } impl Record for EmptyRecord { type Item = EmptyRecord; fn into_item(self) -> Self::Item { self } fn from_item(item: Self::Item, _device: &B::Device) -> Self { item } } /// Constant macro. #[macro_export] macro_rules! empty { (module) => { type Record = burn::module::EmptyRecord; fn visit>(&self, _visitor: &mut V) { // Nothing to do } fn map>(self, _mapper: &mut M) -> Self { self } fn load_record(self, _record: Self::Record) -> Self { self } fn into_record(self) -> Self::Record { burn::module::EmptyRecord::new() } fn to_device(self, _: &B::Device) -> Self { self } fn fork(self, _: &B::Device) -> Self { self } fn collect_devices(&self, devices: burn::module::Devices) -> burn::module::Devices { devices } }; (ad_module, $type:ty) => { type InnerModule = $type; fn valid(&self) -> Self::InnerModule { self.clone() } fn from_inner(module: Self::InnerModule) -> Self { module } }; ($type:ty) => { impl burn::module::Module for $type { empty!(module); } impl burn::module::AutodiffModule for $type { empty!(ad_module, $type); } impl burn::module::ModuleDisplayDefault for $type { fn content(&self, content: burn::module::Content) -> Option { let string = format!("{}", self); content.add_formatted(&string).optional() } } impl burn::module::ModuleDisplay for $type {} }; } // TODO: breaking change for these constant types (currently empty record, non-persistent)? // General Types empty!(alloc::string::String); empty!(bool); // Float Types empty!(f64); empty!(f32); empty!(half::bf16); empty!(half::f16); // Unsigned Integer Types empty!(usize); empty!(u64); empty!(u32); empty!(u16); empty!(u8); // Signed Integer Types empty!(isize); empty!(i64); empty!(i32); empty!(i16); empty!(i8); impl burn::module::ModuleDisplay for str {} impl burn::module::ModuleDisplayDefault for str { fn content(&self, content: burn::module::Content) -> Option { content.add_formatted(&self).optional() } } // TODO: tensor record should persist impl> Module for Tensor { type Record = EmptyRecord; fn visit>(&self, _visitor: &mut V) {} fn map>(self, _mapper: &mut M) -> Self { self } fn into_record(self) -> Self::Record { EmptyRecord } fn load_record(self, _record: Self::Record) -> Self { self } fn to_device(self, device: &B::Device) -> Self { self.to_device(device) } fn fork(self, device: &B::Device) -> Self { self.to_device(device) } fn collect_devices(&self, mut devices: Devices) -> Devices { let device = self.device(); if !devices.contains(&device) { devices.push(device) } devices } } impl> ModuleDisplayDefault for Tensor { fn content(&self, content: Content) -> Option { let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().as_slice()); content.add_single(&string).optional() } } impl> ModuleDisplay for Tensor {} impl> AutodiffModule for Tensor { type InnerModule = Tensor; fn valid(&self) -> Self::InnerModule { self.clone().inner() } fn from_inner(tensor: Self::InnerModule) -> Self { Tensor::from_inner(tensor) } } impl Module for PhantomData { type Record = EmptyRecord; fn visit>(&self, _visitor: &mut V) { // Nothing to do } fn map>(self, _mapper: &mut M) -> Self { self } fn load_record(self, _record: Self::Record) -> Self { self } fn into_record(self) -> Self::Record { EmptyRecord::new() } fn to_device(self, _: &Device) -> Self { self } fn fork(self, _: &Device) -> Self { self } fn collect_devices(&self, devices: Devices) -> Devices { devices } } impl ModuleDisplayDefault for PhantomData { fn content(&self, content: Content) -> Option { content.add_single(&"PhantomData".to_string()).optional() } } impl ModuleDisplay for PhantomData {} impl AutodiffModule for PhantomData { type InnerModule = PhantomData; fn valid(&self) -> Self::InnerModule { PhantomData } fn from_inner(_module: Self::InnerModule) -> Self { PhantomData } } /// Container to satisfy the Module trait for types that are not modules. #[derive(Clone, Debug)] #[deprecated( since = "0.21.0", note = "Ignored is deprecated. Use #[module(skip)] for non-persistent fields (same behavior)." )] pub struct Ignored(pub T); #[allow(deprecated)] impl Module for Ignored where B: Backend, T: Sync + Send + core::fmt::Debug + Clone, { type Record = EmptyRecord; fn visit>(&self, _visitor: &mut V) { // Nothing to do } fn map>(self, _mapper: &mut M) -> Self { self } fn load_record(self, _record: Self::Record) -> Self { self } fn into_record(self) -> Self::Record { EmptyRecord::new() } fn to_device(self, _: &Device) -> Self { self } fn fork(self, _: &Device) -> Self { self } fn collect_devices(&self, devices: Devices) -> Devices { devices } } #[allow(deprecated)] impl ModuleDisplayDefault for Ignored where T: Sync + Send + core::fmt::Debug + Clone, { fn content(&self, content: Content) -> Option { // For now, just print the debug representation of the ignored value content.add_single(&format!("{:?}", self.0)).optional() } } #[allow(deprecated)] impl ModuleDisplay for Ignored where T: Sync + Send + core::fmt::Debug + Clone {} #[allow(deprecated)] impl Display for Ignored where T: Sync + Send + core::fmt::Debug + Clone, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{:?}", self.0) } } #[allow(deprecated)] impl AutodiffModule for Ignored where B: AutodiffBackend, T: Sync + Send + core::fmt::Debug + Clone, { type InnerModule = Ignored; fn valid(&self) -> Self::InnerModule { self.clone() } fn from_inner(module: Self::InnerModule) -> Self { module } } #[allow(deprecated)] // Implement deref for Ignored impl core::ops::Deref for Ignored { type Target = T; fn deref(&self) -> &Self::Target { &self.0 } } #[cfg(all(test, feature = "std"))] mod tests { use core::marker::PhantomData; use burn_tensor::backend::Backend; use burn_tensor::{Device, Tensor}; use crate::TestBackend; use crate::{ TestAutodiffBackend, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, }; use burn::module::Module; use crate as burn; #[test] fn tensor_load_record_setting() { let device: &Device = &Default::default(); let tensor = Tensor::::ones([3, 3], device); let byte_recorder = BinBytesRecorder::::default(); let bytes = Recorder::::record( &byte_recorder, tensor.clone().into_record(), (), ) .unwrap(); let no_grad_is_require_grad = tensor .clone() .no_grad() .load_record( Recorder::::load(&byte_recorder, bytes.clone(), device) .unwrap(), ) .is_require_grad(); let with_default_is_require_grad = tensor .load_record( Recorder::::load(&byte_recorder, bytes.clone(), device) .unwrap(), ) .is_require_grad(); assert!(!no_grad_is_require_grad); assert!(!with_default_is_require_grad); } #[test] fn empty_module_with_phantom() { #[derive(Module, Debug, new)] struct EmptyModule { _phantom: PhantomData, } let _module = EmptyModule::::new(); assert_eq!(core::mem::size_of::>(), 0); } } ================================================ FILE: crates/burn-core/src/module/param/id.rs ================================================ use core::hash::{BuildHasher, Hasher}; use alloc::string::String; use burn_std::id::IdGenerator; use data_encoding::BASE32_DNSSEC; // Hashbrown changed its default hasher in 0.15, but there are some issues // https://github.com/rust-lang/hashbrown/issues/577 // Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher. type DefaultHashBuilder = core::hash::BuildHasherDefault; /// Parameter ID. #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] pub struct ParamId { value: u64, } impl From for ParamId { fn from(value: u64) -> Self { Self { value } } } impl Default for ParamId { fn default() -> Self { Self::new() } } impl ParamId { /// Create a new parameter ID. pub fn new() -> Self { Self { value: IdGenerator::generate(), } } /// Gets the internal value of the id. pub fn val(&self) -> u64 { self.value } /// Convert the parameter ID into a string. pub fn serialize(self) -> String { BASE32_DNSSEC.encode(&self.value.to_le_bytes()) } /// Deserialize a param id. /// /// Preserves compatibility with previous formats (6 bytes, 16-byte uuid). pub fn deserialize(encoded: &str) -> ParamId { let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) { Ok(bytes) => { let mut buffer = [0u8; 8]; buffer[..bytes.len()].copy_from_slice(&bytes); u64::from_le_bytes(buffer) } Err(err) => match uuid::Uuid::try_parse(encoded) { // Backward compatibility with uuid parameter identifiers Ok(id) => { // Hash the 128-bit uuid to 64-bit // Though not *theoretically* unique, the probability of a collision should be extremely low let mut hasher = DefaultHashBuilder::default().build_hasher(); // let mut hasher = DefaultHasher::new(); hasher.write(id.as_bytes()); hasher.finish() } Err(_) => panic!("Invalid id. {err}"), }, }; ParamId::from(u64_id) } } impl core::fmt::Display for ParamId { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(&self.serialize()) } } #[cfg(test)] mod tests { use super::*; #[test] fn param_serde_deserialize() { let val = ParamId::from(123456u64); let deserialized = ParamId::deserialize(&val.serialize()); assert_eq!(val, deserialized); } #[test] fn param_serde_deserialize_legacy() { let legacy_val = [45u8; 6]; let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val)); assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val); assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]); } #[test] fn param_serde_deserialize_legacy_uuid() { // Ensure support for legacy uuid deserialization and make sure it results in the same output let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c"; let param_id1 = ParamId::deserialize(legacy_id); let param_id2 = ParamId::deserialize(legacy_id); assert_eq!(param_id1, param_id2); } #[test] #[should_panic = "Invalid id."] fn param_serde_deserialize_invalid_id() { let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c"; let _ = ParamId::deserialize(invalid_uuid); } } ================================================ FILE: crates/burn-core/src/module/param/mod.rs ================================================ mod base; mod constant; mod id; mod primitive; mod running; mod tensor; mod visitor; pub use base::*; pub use constant::*; pub use id::*; pub use running::*; pub use visitor::*; ================================================ FILE: crates/burn-core/src/module/param/primitive.rs ================================================ use crate::module::{ AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, ModuleVisitor, }; use alloc::{format, string::ToString, vec::Vec}; use burn_tensor::{ backend::{AutodiffBackend, Backend}, ops::Device, }; use core::fmt::Debug; impl Module for Option where T: Module + Debug + Send + Clone, B: Backend, { type Record = Option; fn visit>(&self, visitor: &mut V) { if let Some(module) = self { module.visit(visitor) } } fn map>(self, mapper: &mut M) -> Self { self.map(|module| module.map(mapper)) } fn load_record(self, record: Self::Record) -> Self { let is_constant = self.num_params() == 0; if is_constant { return self; } self.zip(record) .map(|(module, record)| module.load_record(record)) } fn into_record(self) -> Self::Record { self.map(Module::into_record) } fn to_device(self, device: &Device) -> Self { self.map(|module| module.to_device(device)) } fn fork(self, device: &Device) -> Self { self.map(|module| module.fork(device)) } fn collect_devices(&self, mut devices: Vec) -> Vec { if let Some(module) = self.as_ref() { devices = module.collect_devices(devices); } devices } } impl ModuleDisplayDefault for Option { fn content(&self, content: Content) -> Option { match self { Some(module) => content.add_single(module).optional(), None => content.add_single("None").optional(), } } } impl ModuleDisplay for Option {} impl AutodiffModule for Option where T: AutodiffModule + Debug + Send + Clone, B: AutodiffBackend, { type InnerModule = Option; fn valid(&self) -> Self::InnerModule { self.as_ref().map(|module| module.valid()) } fn from_inner(module: Self::InnerModule) -> Self { module.map(|module| T::from_inner(module)) } } impl Module for Vec where T: Module + Debug + Send + Clone, B: Backend, { type Record = Vec; fn num_params(&self) -> usize { let mut num_params = 0; for module in self.iter() { num_params += module.num_params(); } num_params } fn visit>(&self, visitor: &mut V) { for (i, module) in self.iter().enumerate() { let index_str = alloc::format!("{}", i); visitor.enter_module(&index_str, "Vec"); module.visit(visitor); visitor.exit_module(&index_str, "Vec"); } } fn map>(self, mapper: &mut M) -> Self { self.into_iter() .enumerate() .map(|(i, module)| { let index_str = alloc::format!("{}", i); mapper.enter_module(&index_str, "Vec"); let mapped = module.map(mapper); mapper.exit_module(&index_str, "Vec"); mapped }) .collect() } fn into_record(self) -> Self::Record { self.into_iter().map(Module::into_record).collect() } fn load_record(self, record: Self::Record) -> Self { assert_eq!( self.len(), record.len(), r#"[Load Record Error] The vec record does not the same length as the module. Make sure you module initialization is compatible with the record being loaded. "#, ); self.into_iter() .zip(record) .map(|(module, record)| module.load_record(record)) .collect() } fn to_device(self, device: &Device) -> Self { self.into_iter() .map(|module| module.to_device(device)) .collect() } fn fork(self, device: &Device) -> Self { self.into_iter().map(|module| module.fork(device)).collect() } fn collect_devices(&self, mut devices: Vec) -> Vec { for module in self.iter() { devices = module.collect_devices(devices); } devices } } impl ModuleDisplayDefault for Vec { fn content(&self, content: Content) -> Option { self.iter() .enumerate() .fold(content, |acc, (i, module)| { let index = format!("{i}"); acc.add(&index, module) }) .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str()) .optional() } } impl ModuleDisplay for Vec {} impl AutodiffModule for Vec where T: AutodiffModule + Debug + Send + Clone, B: AutodiffBackend, { type InnerModule = Vec; fn valid(&self) -> Self::InnerModule { self.iter().map(|module| module.valid()).collect() } fn from_inner(module: Self::InnerModule) -> Self { module .into_iter() .map(|module| T::from_inner(module)) .collect() } } impl Module for [T; N] where T: Module + Debug + Send + Clone, B: Backend, { type Record = [T::Record; N]; fn collect_devices(&self, mut devices: Vec) -> Vec { for module in self.iter() { devices = module.collect_devices(devices); } devices } fn num_params(&self) -> usize { let mut num_params = 0; for module in self.iter() { num_params += module.num_params(); } num_params } fn visit>(&self, visitor: &mut V) { for (i, module) in self.iter().enumerate() { let index_str = alloc::format!("{}", i); visitor.enter_module(&index_str, "Array"); module.visit(visitor); visitor.exit_module(&index_str, "Array"); } } fn map>(self, mapper: &mut M) -> Self { let mut result = Vec::with_capacity(N); for (i, module) in IntoIterator::into_iter(self).enumerate() { let index_str = alloc::format!("{}", i); mapper.enter_module(&index_str, "Array"); let mapped = module.map(mapper); mapper.exit_module(&index_str, "Array"); result.push(mapped); } result .try_into() .unwrap_or_else(|v: Vec| panic!("Expected array of length {}, got {}", N, v.len())) } fn load_record(self, record: Self::Record) -> Self { self.into_iter() .zip(record) .map(|(module, record)| module.load_record(record)) .collect::>() .try_into() .unwrap() } fn into_record(self) -> Self::Record { self.map(Module::into_record) } fn to_device(self, device: &Device) -> Self { self.map(|module| module.to_device(device)) } fn fork(self, device: &Device) -> Self { self.map(|module| module.fork(device)) } } impl ModuleDisplayDefault for [T; N] { fn content(&self, content: Content) -> Option { self.iter() .enumerate() .fold(content, |acc, (i, module)| { let index = format!("{i}"); acc.add(&index, module) }) .set_top_level_type(format!("[0..{}]", self.len()).as_str()) .optional() } } impl ModuleDisplay for [T; N] {} impl AutodiffModule for [T; N] where T: AutodiffModule + Debug + Send + Clone, T::InnerModule: Debug, B: AutodiffBackend, { type InnerModule = [T::InnerModule; N]; fn valid(&self) -> Self::InnerModule { self.clone().map(|module| module.valid()) } fn from_inner(module: Self::InnerModule) -> Self { module.map(|module| T::from_inner(module)) } } /// A macro for generating implementations for tuple modules of different sizes. /// For example: `impl_module_tuple!([L0, L1][0, 1])`. /// Would generate an implementation for a tuple of size 2. /// For this macro to work properly, please adhere to the convention: /// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`. macro_rules! impl_module_tuple { // `$l` represents the generic modules. // `$i` represents the indices of the modules in the tuple. ([$($l:ident),*][$($i:tt),*]) => { impl Module for ($($l,)*) where B: Backend, $($l: Module + Debug + Send + Clone,)* { type Record = ($($l::Record),*); fn collect_devices(&self, mut devices: Vec) -> Vec { $(devices = self.$i.collect_devices(devices);)* devices } fn fork(self, device: &Device) -> Self { ($(self.$i.fork(device),)*) } fn to_device(self, device: &Device) -> Self { ($(self.$i.to_device(device),)*) } fn visit>(&self, visitor: &mut V) { $( let index_str = $i.to_string(); visitor.enter_module(&index_str, "Tuple"); self.$i.visit(visitor); visitor.exit_module(&index_str, "Tuple"); )* } fn map>(self, mapper: &mut M) -> Self { ($( { let index_str = $i.to_string(); mapper.enter_module(&index_str, "Tuple"); let mapped = self.$i.map(mapper); mapper.exit_module(&index_str, "Tuple"); mapped } ,)*) } fn load_record(self, record: Self::Record) -> Self { ($(self.$i.load_record(record.$i),)*) } fn into_record(self) -> Self::Record { ($(self.$i.into_record(),)*) } } impl AutodiffModule for ($($l,)*) where B: AutodiffBackend, $($l: AutodiffModule + Debug + Send + Clone,)* { type InnerModule = ($($l::InnerModule,)*); fn valid(&self) -> Self::InnerModule { ($(self.$i.valid(),)*) } fn from_inner(module: Self::InnerModule) -> Self { ($($l::from_inner(module.$i),)*) } } impl<$($l,)*> ModuleDisplayDefault for ($($l,)*) where $($l: ModuleDisplay,)* { fn content(&self, content: Content) -> Option { let content = content $(.add(&format!("{}", $i), &self.$i))* .set_top_level_type(format!("({})", stringify!($($l),*)).as_str()); content.optional() } } impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {} }; } impl_module_tuple!([L0, L1][0, 1]); impl_module_tuple!([L0, L1, L2][0, 1, 2]); impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]); impl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]); impl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]); impl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn dont_override_constant_module_when_loading_record() { let module = Some(42); let record = Module::::into_record(module); let loaded = Module::::load_record(module, record); assert_eq!(loaded, module); } #[test] fn dont_override_constant_module_when_loading_none_record() { let module = Some(42); let record = None; let loaded = Module::::load_record(module, record); assert_eq!(loaded, module); } } ================================================ FILE: crates/burn-core/src/module/param/running.rs ================================================ use super::ParamId; use crate::module::{ AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, ModuleVisitor, Param, }; use alloc::string::ToString; use alloc::vec::Vec; #[cfg(target_has_atomic = "ptr")] use alloc::sync::Arc; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; use burn_std::stub::Mutex; use burn_tensor::{ Tensor, backend::{AutodiffBackend, Backend}, ops::Device, }; #[cfg(feature = "std")] mod threading { pub(super) use std::collections::HashMap; pub(super) use std::thread::ThreadId; #[inline(always)] pub(super) fn get_thread_current_id() -> ThreadId { std::thread::current().id() } } #[cfg(not(feature = "std"))] mod threading { pub(super) use burn_std::stub::ThreadId; pub(super) use hashbrown::HashMap; #[inline(always)] pub(super) fn get_thread_current_id() -> ThreadId { panic!("Current thread id is not available") } } // Re-export items from the disabled/enabled blocks use threading::*; /// A state that can be updated during the forward pass while being thread safe. /// /// # Note /// /// The state value is the average of all updates on all threads. #[derive(Clone, Debug)] pub struct RunningState { id: ParamId, values: Arc>>, value: Arc>, } // Implement display for the module impl core::fmt::Display for RunningState { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { write!(f, "RunningState(id={})", self.id) } } impl ModuleDisplayDefault for RunningState { fn content(&self, content: Content) -> Option { content .add_formatted(&"RunningState".to_string()) .optional() } } impl ModuleDisplay for RunningState {} impl Module for RunningState> { type Record = Param>; fn visit>(&self, visitor: &mut V) { let tensor = self.value.lock().unwrap(); let param = Param::initialized(self.id, tensor.clone()); visitor.visit_float(¶m) } fn map>(self, mapper: &mut M) -> Self { let mut tensor = self.value.lock().unwrap(); let param = Param::initialized(self.id, tensor.clone()); let param_out = mapper.map_float(param); let (_, tensor_out, _) = param_out.consume(); *tensor = tensor_out; core::mem::drop(tensor); self } fn into_record(self) -> Self::Record { self.sync(); let tensor = self.value.lock().unwrap(); Param::initialized(self.id, tensor.clone()) } fn load_record(mut self, record: Self::Record) -> Self { let mut tensor = self.value.lock().unwrap(); *tensor = record.val().to_device(&tensor.device()); self.id = record.id; core::mem::drop(tensor); self } fn to_device(self, device: &Device) -> Self { let mut tensor = self.value.lock().unwrap(); let tensor_out = tensor.clone().to_device(device); *tensor = tensor_out; core::mem::drop(tensor); self } fn fork(self, device: &Device) -> Self { self.to_device(device) // Same thing here since no grad. } fn collect_devices(&self, mut devices: Vec>) -> Vec> { let device = self.value.lock().unwrap().device(); if !devices.contains(&device) { devices.push(device) } devices } } impl RunningState> { /// Create a new running state. pub fn new(value: Tensor) -> Self { Self { id: ParamId::new(), values: Arc::new(Mutex::new(HashMap::new())), value: Arc::new(Mutex::new(value)), } } /// Create a new running state. pub fn with_id(id: ParamId, value: Tensor) -> Self { Self { id, values: Arc::new(Mutex::new(HashMap::new())), value: Arc::new(Mutex::new(value)), } } /// Create a new running state from a record. pub fn from_record(record: Param>) -> Self { let tensor = record.val(); Self { id: record.id, values: Arc::new(Mutex::new(HashMap::new())), value: Arc::new(Mutex::new(tensor)), } } /// Update the value on the current thread. pub fn update(&self, value: Tensor) { let thread_id = get_thread_current_id(); let mut map = self.values.lock().unwrap(); if map.contains_key(&thread_id) { self.update_value(&mut map); } map.insert(thread_id, value); } /// Get the current value, /// /// # Note /// /// The current value might be outdated by one update. pub fn value(&self) -> Tensor { let value = self.value.lock().unwrap(); value.clone() } /// Get the current value and make sure it is sync. /// /// # Note /// /// Don't use this function after an update on the same thread where other threads might have to /// register their update before the actual synchronization needs to happen. pub fn value_sync(&self) -> Tensor { let thread_id = get_thread_current_id(); let mut map = self.values.lock().unwrap(); if map.contains_key(&thread_id) { self.update_value(&mut map); } let value = self.value.lock().unwrap(); value.clone() } fn sync(&self) { let mut map = self.values.lock().unwrap(); if !map.is_empty() { self.update_value(&mut map); } } fn update_value(&self, map: &mut HashMap>) { let mut value_updated: Option> = None; let mut counter = 0; for (_key, tensor) in map.drain() { counter += 1; value_updated = match value_updated { Some(current) => { let device = current.device(); Some(tensor.to_device(&device).add(current)) } None => Some(tensor), }; } if let Some(value) = value_updated { let value = value.div_scalar(counter); let mut value_old = self.value.lock().unwrap(); *value_old = value; } } } impl AutodiffModule for RunningState> { type InnerModule = RunningState>; fn valid(&self) -> Self::InnerModule { self.sync(); let value = self.value(); RunningState::with_id(self.id, value.inner()) } fn from_inner(module: Self::InnerModule) -> Self { module.sync(); let value = module.value(); RunningState::with_id(module.id, Tensor::from_inner(value)) } } ================================================ FILE: crates/burn-core/src/module/param/tensor.rs ================================================ use super::{Param, ParamId, Parameter}; use crate::module::{ AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, ModuleVisitor, }; use crate::tensor::{ Tensor, backend::{AutodiffBackend, Backend}, }; use alloc::{format, string::ToString, vec::Vec}; use burn_tensor::{Bool, Float, Int, TensorData, ops::Device}; impl Parameter for Tensor { type Device = B::Device; fn device(&self) -> Self::Device { Tensor::device(self) } fn is_require_grad(&self) -> bool { Tensor::is_require_grad(self) } fn set_require_grad(self, require_grad: bool) -> Self { Tensor::set_require_grad(self, require_grad) } } impl Parameter for Tensor { type Device = B::Device; fn device(&self) -> Self::Device { Tensor::device(self) } fn is_require_grad(&self) -> bool { false } fn set_require_grad(self, _require_grad: bool) -> Self { self } } impl Parameter for Tensor { type Device = B::Device; fn device(&self) -> Self::Device { Tensor::device(self) } fn is_require_grad(&self) -> bool { false } fn set_require_grad(self, _require_grad: bool) -> Self { self } } impl Param> { /// Create a new parameter from a float tensor. /// /// # Warnings /// /// We strongly recommend using [Param::uninitialized] if you are using this method to /// initialize parameters inside a module, since the tensor initialization will be lazy, /// making the loading of weights more performant. pub fn from_tensor(value: Tensor) -> Self { // When creating a parameter from a float tensor, we automatically mark it as requiring // gradients, so that it can be updated by an optimizer. Param::initialized(ParamId::new(), value.require_grad()) } /// The shape of the parameter, **without triggering initialization**. /// /// This is critical for shape validation during loading: when applying tensors to an /// uninitialized parameter, we need to validate the shape without triggering the /// initialization function (which would allocate an unnecessary tensor). /// /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to /// preserve lazy initialization. pub fn lazy_shape(&self) -> burn_tensor::Shape { let initialization = match &self.initialization { Some(init) => init, None => return self.shape(), }; let init = initialization.read().unwrap(); match init.as_ref() { Some(value) => value.shape.clone(), None => self.shape(), } } /// Create a new parameter from data. pub fn from_data(data: T, device: &B::Device) -> Self where T: Into, { let data: TensorData = data.into(); // When creating a parameter from a float tensor, we automatically mark it as requiring // gradients, so that it can be updated by an optimizer. B::memory_persistent_allocations(device, data, |data| { let value = Tensor::from_data(data, device); Param::initialized(ParamId::new(), value.require_grad()) }) } /// Transform a parameter for loading by applying load transformations. /// /// This method is used to restore a parameter from a tensor (typically during deserialization). /// It ensures the tensor is moved to the expected device, applies the param mapper's /// `on_load` transformation, and preserves the autodiff settings (require_grad). pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { let mut new_tensor = tensor; let mapper = self.param_mapper.clone(); let expected_device = self.lazy_device(); let expected_require_grad = self.lazy_is_require_grad(); // Make sure we load the tensor into the same module device. if new_tensor.device() != expected_device { new_tensor = new_tensor.to_device(&expected_device).detach(); } new_tensor = mapper.on_load(new_tensor); // Make sure we load the tensor with the same autodiff setting. new_tensor = new_tensor.set_require_grad(expected_require_grad); let mut loaded = Self::initialized(param_id, new_tensor); loaded.param_mapper = mapper; loaded } /// Transform a parameter for saving by applying save transformations. /// /// This method is used to prepare a parameter for saving (typically during serialization). /// It applies the param mapper's `on_save` transformation, which can be used /// to modify the tensor before serialization (e.g., quantization, precision conversion). pub fn transform_for_save(&self) -> Self { let mut tensor = self.val(); let mapper = self.param_mapper.clone(); tensor = mapper.on_save(tensor); Self::initialized(self.id, tensor) } } impl Param> { /// The shape of the parameter, **without triggering initialization**. /// /// This is critical for shape validation during loading: when applying tensors to an /// uninitialized parameter, we need to validate the shape without triggering the /// initialization function (which would allocate an unnecessary tensor). /// /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to /// preserve lazy initialization. pub fn lazy_shape(&self) -> burn_tensor::Shape { let initialization = match &self.initialization { Some(init) => init, None => return self.shape(), }; let init = initialization.read().unwrap(); match init.as_ref() { Some(value) => value.shape.clone(), None => self.shape(), } } /// Transform a parameter for loading by applying load transformations. /// /// This method is used to restore a parameter from a tensor (typically during deserialization). /// It ensures the tensor is moved to the expected device and applies the param mapper's /// `on_load` transformation. pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { let mut new_tensor = tensor; let mapper = self.param_mapper.clone(); let expected_device = self.lazy_device(); // Make sure we load the tensor into the same module device. if new_tensor.device() != expected_device { new_tensor = new_tensor.to_device(&expected_device); } new_tensor = mapper.on_load(new_tensor); let mut loaded = Self::initialized(param_id, new_tensor); loaded.param_mapper = mapper; loaded } /// Transform a parameter for saving by applying save transformations. /// /// This method is used to prepare a parameter for saving (typically during serialization). /// It applies the param mapper's `on_save` transformation, which can be used /// to modify the tensor before serialization (e.g., quantization, precision conversion). pub fn transform_for_save(&self) -> Self { let mut tensor = self.val(); let mapper = self.param_mapper.clone(); tensor = mapper.on_save(tensor); Self::initialized(self.id, tensor) } } impl Param> { /// The shape of the parameter, **without triggering initialization**. /// /// This is critical for shape validation during loading: when applying tensors to an /// uninitialized parameter, we need to validate the shape without triggering the /// initialization function (which would allocate an unnecessary tensor). /// /// **Returns:** /// - For uninitialized params: the shape from the `Uninitialized` struct /// - For initialized params: the actual shape from the tensor /// /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to /// preserve lazy initialization. pub fn lazy_shape(&self) -> burn_tensor::Shape { let initialization = match &self.initialization { Some(init) => init, None => return self.shape(), }; let init = initialization.read().unwrap(); match init.as_ref() { Some(value) => value.shape.clone(), None => self.shape(), } } /// Transform a parameter for loading by applying load transformations. /// /// This method is used to restore a parameter from a tensor (typically during deserialization). /// It ensures the tensor is moved to the expected device and applies the param mapper's /// `on_load` transformation. pub fn transform_for_load(self, tensor: Tensor, param_id: ParamId) -> Self { let mut new_tensor = tensor; let mapper = self.param_mapper.clone(); let expected_device = self.lazy_device(); // Make sure we load the tensor into the same module device. if new_tensor.device() != expected_device { new_tensor = new_tensor.to_device(&expected_device); } new_tensor = mapper.on_load(new_tensor); let mut loaded = Self::initialized(param_id, new_tensor); loaded.param_mapper = mapper; loaded } /// Transform a parameter for saving by applying save transformations. /// /// This method is used to prepare a parameter for saving (typically during serialization). /// It applies the param mapper's `on_save` transformation, which can be used /// to modify the tensor before serialization (e.g., quantization, precision conversion). pub fn transform_for_save(&self) -> Self { let mut tensor = self.val(); let mapper = self.param_mapper.clone(); tensor = mapper.on_save(tensor); Self::initialized(self.id, tensor) } } impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { visitor.visit_float(self) } fn map>(self, mapper: &mut M) -> Self { mapper.map_float(self) } fn into_record(self) -> Self::Record { self.transform_for_save() } fn load_record(self, record: Self::Record) -> Self { let (record_param_id, record_tensor, _) = record.consume(); self.transform_for_load(record_tensor, record_param_id) } fn to_device(self, device: &Device) -> Self { self.map(|tensor| tensor.to_device(device)) } fn fork(self, device: &Device) -> Self { self.map(|tensor| { let is_require_grad = tensor.is_require_grad(); let mut tensor = tensor.to_device(device).detach(); if is_require_grad { tensor = tensor.require_grad(); } tensor }) } fn collect_devices(&self, mut devices: Vec>) -> Vec> { let device = self.val().device(); if !devices.contains(&device) { devices.push(device) } devices } } impl ModuleDisplayDefault for Param> { fn content(&self, content: Content) -> Option { let id = if content.display_settings.show_param_id() { format!(", id: {}", self.id) } else { "".to_string() }; let string = format!( "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}", self.shape().as_slice() ); content.add_formatted(&string).optional() } } impl ModuleDisplay for Param> {} impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { visitor.visit_int(self) } fn map>(self, mapper: &mut M) -> Self { mapper.map_int(self) } fn into_record(self) -> Self::Record { self.transform_for_save() } fn load_record(self, record: Self::Record) -> Self { let (record_param_id, record_tensor, _) = record.consume(); self.transform_for_load(record_tensor, record_param_id) } fn to_device(self, device: &Device) -> Self { self.map(|tensor| tensor.to_device(device)) } fn fork(self, device: &Device) -> Self { self.to_device(device) // Don't support autodiff. } fn collect_devices(&self, mut devices: Vec>) -> Vec> { let device = self.val().device(); if !devices.contains(&device) { devices.push(device) } devices } } impl ModuleDisplayDefault for Param> { fn content(&self, content: Content) -> Option { let id = if content.display_settings.show_param_id() { format!(", id: {}", self.id) } else { "".to_string() }; let string = format!( "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}", self.shape().as_slice() ); content.add_formatted(&string).optional() } } impl ModuleDisplay for Param> {} impl Module for Param> { type Record = Param>; fn visit>(&self, visitor: &mut V) { visitor.visit_bool(self) } fn map>(self, mapper: &mut M) -> Self { mapper.map_bool(self) } fn into_record(self) -> Self::Record { self.transform_for_save() } fn load_record(self, record: Self::Record) -> Self { let (record_param_id, record_tensor, _) = record.consume(); self.transform_for_load(record_tensor, record_param_id) } fn to_device(self, device: &Device) -> Self { self.map(|tensor| tensor.to_device(device)) } fn fork(self, device: &Device) -> Self { self.to_device(device) // Don't support autodiff. } fn collect_devices(&self, mut devices: Vec>) -> Vec> { let device = self.val().device(); if !devices.contains(&device) { devices.push(device) } devices } } impl ModuleDisplayDefault for Param> { fn content(&self, content: Content) -> Option { let id = if content.display_settings.show_param_id() { format!(", id: {}", self.id) } else { "".to_string() }; let string = format!( "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}", self.shape().as_slice() ); content.add_formatted(&string).optional() } } impl ModuleDisplay for Param> {} impl AutodiffModule for Param> { type InnerModule = Param>; fn valid(&self) -> Self::InnerModule { // Preserve initialized param `require_grad` state, but reset the inner value's let require_grad = self.require_grad; let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false)); param.require_grad = require_grad; param } fn from_inner(module: Self::InnerModule) -> Self { // Reinstate the param's `require_grad` state let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad); Param::initialized(module.id, tensor) } } impl HasAutodiffModule for Param> { type TrainModule = Param>; } impl AutodiffModule for Param> { type InnerModule = Param>; fn valid(&self) -> Self::InnerModule { Param::initialized(self.id, self.val().inner()) } fn from_inner(module: Self::InnerModule) -> Self { Param::initialized(module.id, Tensor::from_inner(module.val())) } } impl AutodiffModule for Param> { type InnerModule = Param>; fn valid(&self) -> Self::InnerModule { Param::initialized(self.id, self.val().inner()) } fn from_inner(module: Self::InnerModule) -> Self { Param::initialized(module.id, Tensor::from_inner(module.val())) } } #[cfg(all(test, feature = "std"))] mod tests { use super::*; use crate::{ TestAutodiffBackend, module::Module, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, }; #[test] fn test_load_record_setting() { let device = Default::default(); let tensor = Tensor::::ones([3, 3], &device).require_grad(); let byte_recorder = BinBytesRecorder::::default(); let bytes = byte_recorder .record( Param::initialized(ParamId::new(), tensor.clone()).into_record(), (), ) .unwrap(); let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone()) .no_grad() .load_record(byte_recorder.load(bytes.clone(), &device).unwrap()) .is_require_grad(); let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor) .load_record(byte_recorder.load(bytes, &device).unwrap()) .is_require_grad(); assert!(!no_grad_is_require_grad); assert!(with_default_is_require_grad); } #[test] fn test_param_require_grad_stateful() { let device = Default::default(); let tensor = Tensor::::ones([3, 3], &device).require_grad(); let param = Param::initialized(ParamId::new(), tensor); assert!(param.is_require_grad()); assert!(param.require_grad); let param = param.valid(); assert!(!param.is_require_grad()); assert!(param.require_grad); // stateful // Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying: // let param: Param> = param.train(); let param = param.train::(); assert!(param.is_require_grad()); assert!(param.require_grad); // stateful let param = param.no_grad(); assert!(!param.is_require_grad()); assert!(!param.require_grad); // stateful let param = param.valid(); assert!(!param.is_require_grad()); // always assert!(!param.require_grad); // stateful let param = param.train::(); assert!(!param.is_require_grad()); assert!(!param.require_grad); // stateful } } ================================================ FILE: crates/burn-core/src/module/param/visitor.rs ================================================ use super::{Param, ParamId}; use crate::module::{Module, ModuleVisitor}; use alloc::vec::Vec; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; use core::marker::PhantomData; struct ParamIdCollector<'a, M> { param_ids: &'a mut Vec, phantom: PhantomData, } impl ModuleVisitor for ParamIdCollector<'_, M> where B: Backend, M: Module, { fn visit_float(&mut self, param: &Param>) { self.param_ids.push(param.id); } fn visit_int(&mut self, param: &Param>) { self.param_ids.push(param.id); } fn visit_bool(&mut self, param: &Param>) { self.param_ids.push(param.id); } } /// List all the parameter ids in a module. pub fn list_param_ids, B: Backend>(module: &M) -> Vec { let mut params_ids = Vec::new(); let mut visitor = ParamIdCollector { param_ids: &mut params_ids, phantom: PhantomData::, }; module.visit(&mut visitor); params_ids } ================================================ FILE: crates/burn-core/src/module/quantize.rs ================================================ use burn_tensor::{ Tensor, backend::Backend, quantization::{Calibration, QuantScheme, compute_q_params, compute_range}, }; use crate::module::{ModuleMapper, Param}; /// Describes how to quantize a module. pub struct Quantizer { /// The calibration method used in quantization. pub calibration: Calibration, /// The quantization scheme. pub scheme: QuantScheme, } impl ModuleMapper for Quantizer { fn map_float(&mut self, param: Param>) -> Param> { let (id, tensor, mapper) = param.consume(); let range = compute_range(&self.scheme, &tensor, &self.calibration); let qparams = compute_q_params(&self.scheme, range); let tensor = tensor.quantize(&self.scheme, qparams); Param::from_mapped_value(id, tensor, mapper) } } #[cfg(all(test, not(feature = "test-tch")))] mod tests { use crate::test_utils::SimpleLinear; use crate::{ TestBackend, module::{Module, Quantizer}, }; use burn_tensor::{ Device, Tolerance, ops::QuantizedTensor, quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue}, }; type B = TestBackend; #[test] fn should_quantize_module() { let device: Device = Default::default(); let module = SimpleLinear::::new(32, 32, &device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::Tensor) .with_param(QuantParam::F32); let result = module.weight.val(); let calibration = Calibration::MinMax; let mut quantizer = Quantizer { calibration, scheme, }; let q_module = module.quantize_weights(&mut quantizer); let q_result = q_module.weight.val().dequantize(); result .into_data() .assert_approx_eq::(&q_result.into_data(), Tolerance::permissive()); } } ================================================ FILE: crates/burn-core/src/module/reinit.rs ================================================ use super::{Module, ModuleMapper}; use burn_tensor::{ Element, ElementConversion, Tensor, TensorData, backend::Backend, ops::{FloatElem, IntElem}, }; use rand::{RngExt, SeedableRng}; #[derive(Debug)] /// Overrides float and int tensors of [burn modules](super::Module). /// /// This is useful for testing. pub struct Reinitializer { float: ReinitStrategy>, int: ReinitStrategy>, } #[derive(Debug)] #[allow(missing_docs)] enum ReinitStrategy { Range { min: E, max: E }, Constant { value: E }, Random { seed: u64, min: E, max: E }, } impl Default for Reinitializer { fn default() -> Self { Self::new() } } impl Reinitializer { /// Create a new [reinitializer](Reinitializer). pub fn new() -> Self { Self { float: ReinitStrategy::Constant { value: 0.elem::>(), }, int: ReinitStrategy::Constant { value: 0.elem::>(), }, } } /// Apply the reinitialization to the given [module](Module). pub fn apply>(mut self, module: M) -> M { module.map(&mut self) } /// Set the reinitialization strategy to constant for all tensors. pub fn constant(self, constant: f64) -> Self { self.constant_float(constant).constant_int(constant as i64) } /// Set the reinitialization strategy to constant for float tensors. pub fn constant_float(mut self, constant: f64) -> Self { self.float = ReinitStrategy::Constant { value: constant.elem(), }; self } /// Set the reinitialization strategy to constant for int tensors. pub fn constant_int(mut self, constant: i64) -> Self { self.int = ReinitStrategy::Constant { value: constant.elem(), }; self } /// Set the reinitialization strategy to random for all tensors. pub fn random(self, seed: u64, min: f64, max: f64) -> Self { self.random_float(seed, min, max) .random_int(seed, min as i64, max as i64) } /// Set the reinitialization strategy to random for float tensors. pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self { self.float = ReinitStrategy::Random { seed, min: min.elem(), max: max.elem(), }; self } /// Set the reinitialization strategy to random for int tensors. pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self { self.int = ReinitStrategy::Random { seed, min: min.elem(), max: max.elem(), }; self } /// Set the reinitialization strategy to range for all tensors. pub fn range(self, min: f64, max: f64) -> Self { self.range_float(min, max).range_int(min as i64, max as i64) } /// Set the reinitialization strategy to range for float tensors. pub fn range_float(mut self, min: f64, max: f64) -> Self { self.float = ReinitStrategy::Range { min: min.elem(), max: max.elem(), }; self } /// Set the reinitialization strategy to range for int tensors. pub fn range_int(mut self, min: i64, max: i64) -> Self { self.int = ReinitStrategy::Range { min: min.elem(), max: max.elem(), }; self } } impl ModuleMapper for Reinitializer { fn map_float( &mut self, param: super::Param>, ) -> super::Param> { let (id, tensor, mapper) = param.consume(); let device = tensor.device(); let shape = tensor.shape(); let num_elements = shape.num_elements(); let tensor = match &self.float { ReinitStrategy::Range { min, max } => { let tensor = Tensor::arange(0..num_elements as i64, &device) .reshape(shape) .float(); let (factor, bias) = resolve::>(*min, *max, num_elements); tensor * factor + bias } ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device), ReinitStrategy::Random { seed, min, max } => { let data = TensorData::new( random_vector::>(*seed, min.elem(), max.elem(), num_elements), shape, ); Tensor::from_data(data, &device) } }; super::Param::from_mapped_value(id, tensor, mapper) } fn map_int( &mut self, param: super::Param>, ) -> super::Param> { let (id, tensor, mapper) = param.consume(); let device = tensor.device(); let shape = tensor.shape(); let num_elements = shape.num_elements(); let tensor = match &self.int { ReinitStrategy::Range { min, max } => { let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape); let (factor, bias) = resolve::>(*min, *max, num_elements); tensor * factor + bias } ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device), ReinitStrategy::Random { seed, min, max } => { let data = TensorData::new( random_vector::>(*seed, min.elem(), max.elem(), num_elements), shape, ); Tensor::from_data(data, &device) } }; super::Param::from_mapped_value(id, tensor, mapper) } fn map_bool( &mut self, param: super::Param>, ) -> super::Param> { let (id, tensor, mapper) = param.consume(); super::Param::from_mapped_value(id, tensor, mapper) } } fn resolve(min: E, max: E, num_elements: usize) -> (E, E) { let range = max.elem::() - min.elem::(); let factor = range / num_elements as f64; let bias = min.elem::(); (factor.elem(), bias.elem()) } fn random_vector(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec { let mut rng = rand::rngs::StdRng::seed_from_u64(seed); let dist = rand::distr::Uniform::new(min, max).unwrap(); (0..num_elements) .map(|_| rng.sample(dist)) .map(|e| e.elem::()) .collect() } ================================================ FILE: crates/burn-core/src/record/base.rs ================================================ pub use burn_derive::Record; use burn_tensor::backend::Backend; use super::PrecisionSettings; use serde::{Serialize, de::DeserializeOwned}; /// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings). pub trait Record: Send { /// Type of the item that can be serialized and deserialized. type Item: Serialize + DeserializeOwned + Clone; /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). fn into_item(self) -> Self::Item; /// Convert the given item into a record. fn from_item(item: Self::Item, device: &B::Device) -> Self; } ================================================ FILE: crates/burn-core/src/record/file.rs ================================================ use super::{PrecisionSettings, Recorder, RecorderError, bin_config}; use burn_tensor::backend::Backend; use core::marker::PhantomData; use flate2::{Compression, read::GzDecoder, write::GzEncoder}; use serde::{Serialize, de::DeserializeOwned}; use std::io::{BufReader, BufWriter}; use std::{fs::File, path::PathBuf}; /// Recorder trait specialized to save and load data to and from files. pub trait FileRecorder: Recorder { /// File extension of the format used by the recorder. fn file_extension() -> &'static str; } /// Default [file recorder](FileRecorder). pub type DefaultFileRecorder = NamedMpkFileRecorder; /// File recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinFileRecorder { _settings: PhantomData, } /// File recorder using the [bincode format](bincode) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct BinGzFileRecorder { _settings: PhantomData, } /// File recorder using the [json format](serde_json) compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct JsonGzFileRecorder { _settings: PhantomData, } /// File recorder using [pretty json format](serde_json) for easy readability. #[derive(new, Debug, Default, Clone)] pub struct PrettyJsonFileRecorder { _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format compressed with gzip. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkGzFileRecorder { _settings: PhantomData, } /// File recorder using the [named msgpack](rmp_serde) format. #[derive(new, Debug, Default, Clone)] pub struct NamedMpkFileRecorder { _settings: PhantomData, } impl FileRecorder for BinGzFileRecorder { fn file_extension() -> &'static str { "bin.gz" } } impl FileRecorder for BinFileRecorder { fn file_extension() -> &'static str { "bin" } } impl FileRecorder for JsonGzFileRecorder { fn file_extension() -> &'static str { "json.gz" } } impl FileRecorder for PrettyJsonFileRecorder { fn file_extension() -> &'static str { "json" } } impl FileRecorder for NamedMpkGzFileRecorder { fn file_extension() -> &'static str { "mpk.gz" } } impl FileRecorder for NamedMpkFileRecorder { fn file_extension() -> &'static str { "mpk" } } macro_rules! str2reader { ( $file:expr ) => {{ $file.set_extension(>::file_extension()); let path = $file.as_path(); File::open(path) .map_err(|err| match err.kind() { std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), _ => RecorderError::Unknown(err.to_string()), }) .map(|file| BufReader::new(file)) }}; } macro_rules! str2writer { ( $file:expr ) => {{ $file.set_extension(>::file_extension()); let path = $file.as_path(); log::debug!("Writing to file: {:?}", path); // Add parent directories if they don't exist if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).ok(); } if path.exists() { log::warn!("File exists, replacing"); std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?; } File::create(path) .map_err(|err| match err.kind() { std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()), _ => RecorderError::Unknown(err.to_string()), }) .map(|file| BufWriter::new(file)) }}; } impl Recorder for BinGzFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let config = bin_config(); let writer = str2writer!(file)?; let mut writer = GzEncoder::new(writer, Compression::default()); bincode::serde::encode_into_std_write(&item, &mut writer, config) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let reader = str2reader!(file)?; let mut reader = GzDecoder::new(reader); let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } impl Recorder for BinFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let config = bin_config(); let mut writer = str2writer!(file)?; bincode::serde::encode_into_std_write(&item, &mut writer, config) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let mut reader = str2reader!(file)?; let state = bincode::serde::decode_from_std_read(&mut reader, bin_config()) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } impl Recorder for JsonGzFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let writer = str2writer!(file)?; let writer = GzEncoder::new(writer, Compression::default()); serde_json::to_writer(writer, &item) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let reader = str2reader!(file)?; let reader = GzDecoder::new(reader); let state = serde_json::from_reader(reader) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } impl Recorder for PrettyJsonFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let writer = str2writer!(file)?; serde_json::to_writer_pretty(writer, &item) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let reader = str2reader!(file)?; let state = serde_json::from_reader(reader) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } impl Recorder for NamedMpkGzFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let writer = str2writer!(file)?; let mut writer = GzEncoder::new(writer, Compression::default()); rmp_serde::encode::write_named(&mut writer, &item) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let reader = str2reader!(file)?; let reader = GzDecoder::new(reader); let state = rmp_serde::decode::from_read(reader) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } impl Recorder for NamedMpkFileRecorder { type Settings = S; type RecordArgs = PathBuf; type RecordOutput = (); type LoadArgs = PathBuf; fn save_item( &self, item: I, mut file: Self::RecordArgs, ) -> Result<(), RecorderError> { let mut writer = str2writer!(file)?; rmp_serde::encode::write_named(&mut writer, &item) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(()) } fn load_item( &self, file: &mut Self::LoadArgs, ) -> Result { let reader = str2reader!(file)?; let state = rmp_serde::decode::from_read(reader) .map_err(|err| RecorderError::Unknown(err.to_string()))?; Ok(state) } } #[allow(deprecated)] #[cfg(test)] mod tests { use super::*; use crate as burn; use crate::config::Config; use crate::module::Ignored; use crate::test_utils::SimpleLinear; use crate::{ TestBackend, module::Module, record::{BinBytesRecorder, FullPrecisionSettings}, }; use burn_tensor::Tensor; use burn_tensor::backend::Backend; #[inline(always)] fn file_path(file: &str) -> PathBuf { std::env::temp_dir().as_path().join(file) } #[test] fn test_can_save_and_load_jsongz_format() { test_can_save_and_load(JsonGzFileRecorder::::default()) } #[test] fn test_can_save_and_load_bin_format() { test_can_save_and_load(BinFileRecorder::::default()) } #[test] fn test_can_save_and_load_bingz_format() { test_can_save_and_load(BinGzFileRecorder::::default()) } #[test] fn test_can_save_and_load_pretty_json_format() { test_can_save_and_load(PrettyJsonFileRecorder::::default()) } #[test] fn test_can_save_and_load_mpkgz_format() { test_can_save_and_load(NamedMpkGzFileRecorder::::default()) } #[test] fn test_can_save_and_load_mpk_format() { test_can_save_and_load(NamedMpkFileRecorder::::default()) } fn test_can_save_and_load(recorder: Recorder) where Recorder: FileRecorder, { let filename = "burn_test_file_recorder"; let device = Default::default(); let mut model_before = create_model(&device); // NOTE: Non-module fields currently act like `#[module(skip)]`, meaning their state // is not persistent. These fields hold `EmptyRecord`s. // So `model_bytes_after == model_bytes_before` because the changes do not persist in the record. model_before.tensor = Tensor::full([4], 2., &device); model_before.arr = [3, 3]; model_before.int = 1; model_before.ignore = Ignored(PaddingConfig2d::Valid); recorder .record(model_before.clone().into_record(), file_path(filename)) .unwrap(); let model_after = create_model(&device).load_record(recorder.load(file_path(filename), &device).unwrap()); // State is not persisted for empty record fields assert_eq!(model_after.arr, [2, 2]); assert_eq!(model_after.int, 0); assert_eq!(model_after.ignore.0, PaddingConfig2d::Same); let byte_recorder = BinBytesRecorder::::default(); let model_bytes_before = byte_recorder .record(model_before.into_record(), ()) .unwrap(); let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap(); assert_eq!(model_bytes_after, model_bytes_before); } #[derive(Config, Debug, PartialEq, Eq)] pub enum PaddingConfig2d { Same, Valid, Explicit(usize, usize), } // Dummy model with different record types #[derive(Module, Debug)] pub struct Model { linear1: SimpleLinear, phantom: PhantomData, tensor: Tensor, arr: [usize; 2], int: usize, ignore: Ignored, } pub fn create_model(device: &::Device) -> Model { let linear1 = SimpleLinear::new(32, 32, device); Model { linear1, phantom: PhantomData, tensor: Tensor::zeros([2], device), arr: [2, 2], int: 0, ignore: Ignored(PaddingConfig2d::Same), } } } ================================================ FILE: crates/burn-core/src/record/memory.rs ================================================ use super::{PrecisionSettings, Recorder, RecorderError, bin_config}; use alloc::vec::Vec; use burn_tensor::backend::Backend; use serde::{Serialize, de::DeserializeOwned}; /// Recorder trait specialized to save and load data to and from bytes. /// /// # Notes /// /// This is especially useful in no_std environment where weights are stored directly in /// compiled binaries. pub trait BytesRecorder< B: Backend, L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default, >: Recorder, LoadArgs = L> { } /// In memory recorder using the [bincode format](bincode). #[derive(new, Debug, Default, Clone)] pub struct BinBytesRecorder< S: PrecisionSettings, L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec, > { _settings: core::marker::PhantomData, _loadargs: core::marker::PhantomData, } impl< S: PrecisionSettings, B: Backend, L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default, > BytesRecorder for BinBytesRecorder { } impl< S: PrecisionSettings, B: Backend, L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default, > Recorder for BinBytesRecorder { type Settings = S; type RecordArgs = (); type RecordOutput = Vec; type LoadArgs = L; fn save_item( &self, item: I, _args: Self::RecordArgs, ) -> Result { Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap()) } fn load_item( &self, args: &mut Self::LoadArgs, ) -> Result { let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat, _>( args.as_ref(), bin_config(), ) .unwrap() .0; Ok(state.0) } } #[cfg(feature = "std")] /// In memory recorder using the [Named MessagePack](rmp_serde). #[derive(new, Debug, Default, Clone)] pub struct NamedMpkBytesRecorder { _settings: core::marker::PhantomData, } #[cfg(feature = "std")] impl BytesRecorder> for NamedMpkBytesRecorder {} #[cfg(feature = "std")] impl Recorder for NamedMpkBytesRecorder { type Settings = S; type RecordArgs = (); type RecordOutput = Vec; type LoadArgs = Vec; fn save_item( &self, item: I, _args: Self::RecordArgs, ) -> Result { rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string())) } fn load_item( &self, args: &mut Self::LoadArgs, ) -> Result { rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string())) } } #[cfg(test)] mod tests { use super::*; use crate::test_utils::SimpleLinear; use crate::{ TestBackend, module::Module, record::FullPrecisionSettings, tensor::backend::Backend, }; #[test] fn test_can_save_and_load_bin_format() { test_can_save_and_load(BinBytesRecorder::::default()) } #[cfg(feature = "std")] #[test] fn test_can_save_and_load_named_mpk_format() { test_can_save_and_load(NamedMpkBytesRecorder::::default()) } fn test_can_save_and_load(recorder: Recorder) where Recorder: BytesRecorder>, { let device = Default::default(); let model1 = create_model::(&device); let model2 = create_model::(&device); let bytes1 = recorder.record(model1.into_record(), ()).unwrap(); let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap(); let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap()); let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap(); assert_ne!(bytes1, bytes2); assert_eq!(bytes1, bytes2_after); } pub fn create_model(device: &B::Device) -> SimpleLinear { SimpleLinear::new(32, 32, device) } } ================================================ FILE: crates/burn-core/src/record/mod.rs ================================================ mod primitive; mod tensor; mod base; mod memory; mod recorder; mod settings; pub use base::*; pub use memory::*; pub use recorder::*; pub use settings::*; #[cfg(feature = "std")] mod file; #[cfg(feature = "std")] pub use file::*; pub use primitive::ParamSerde; #[cfg(feature = "record-item-custom-serde")] pub mod serde; ================================================ FILE: crates/burn-core/src/record/primitive.rs ================================================ use alloc::{string::String, vec, vec::Vec}; use core::{fmt, marker::PhantomData}; use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde}; use super::{PrecisionSettings, Record}; use crate::module::{Param, ParamId}; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; use hashbrown::HashMap; use serde::{ Deserialize, Serialize, de::{Error, SeqAccess, Visitor}, ser::SerializeTuple, }; impl Record for () where B: Backend, { type Item = (); fn into_item(self) -> Self::Item {} fn from_item(_item: Self::Item, _device: &B::Device) -> Self {} } impl Record for Vec where T: Record, B: Backend, { type Item = Vec>; fn into_item(self) -> Self::Item { self.into_iter().map(Record::into_item).collect() } fn from_item(item: Self::Item, device: &B::Device) -> Self { item.into_iter() .map(|i| Record::from_item(i, device)) .collect() } } impl Record for Option where T: Record, B: Backend, { type Item = Option>; fn into_item(self) -> Self::Item { self.map(Record::into_item) } fn from_item(item: Self::Item, device: &B::Device) -> Self { item.map(|i| Record::from_item(i, device)) } } impl Record for [T; N] where T: Record, B: Backend, { /// The record item is an array of the record item of the elements. /// The reason why we wrap the array in a struct is because serde does not support /// deserializing arrays of variable size, /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937). /// for backward compatibility reasons. Serde APIs were created before const generics. type Item = Array>; fn into_item(self) -> Self::Item { Array(self.map(Record::into_item)) } fn from_item(item: Self::Item, device: &B::Device) -> Self { item.0.map(|i| Record::from_item(i, device)) } } /// A macro for generating implementations for tuple records of different sizes. /// For example: `impl_record_tuple!([R0, R1][0, 1])`. /// Would generate an implementation for a tuple of size 2. /// For this macro to work properly, please adhere to the convention: /// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`. macro_rules! impl_record_tuple { // `$r` represents the generic records. // `$i` represents the indices of the records in the tuple. ([$($r:ident),*][$($i:tt),*]) => { impl Record for ($($r,)*) where B: Backend, $($r: Record),* { type Item = ($($r::Item,)*); fn into_item(self) -> Self::Item { ($(self.$i.into_item(),)*) } fn from_item(item: Self::Item, device: &B::Device) -> Self { ($(Record::from_item(item.$i, device),)*) } } }; } impl_record_tuple!([R0, R1][0, 1]); impl_record_tuple!([R0, R1, R2][0, 1, 2]); impl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]); impl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]); impl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]); impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]); impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]); impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]); impl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); impl Record for HashMap where T: Record, B: Backend, { type Item = HashMap>; fn into_item(self) -> Self::Item { let mut items = HashMap::with_capacity(self.len()); self.into_iter().for_each(|(id, record)| { items.insert(id.serialize(), record.into_item()); }); items } fn from_item(item: Self::Item, device: &B::Device) -> Self { let mut record = HashMap::with_capacity(item.len()); item.into_iter().for_each(|(id, item)| { record.insert(ParamId::deserialize(&id), T::from_item(item, device)); }); record } } /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { id: String, param: T, } impl Record for Param> where B: Backend, { type Item = ParamSerde>; fn into_item(self) -> Self::Item { let (id, tensor, mapper) = self.consume(); let tensor = mapper.on_save(tensor); ParamSerde::new(id.serialize(), tensor.into_item()) } fn from_item(item: Self::Item, device: &B::Device) -> Self { B::memory_persistent_allocations(device, item, |item| { Param::initialized( ParamId::deserialize(&item.id), Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new // Param from a tensor. ) }) } } impl Record for Param> where B: Backend, { type Item = ParamSerde>; fn into_item(self) -> Self::Item { let (id, tensor, mapper) = self.consume(); let tensor = mapper.on_save(tensor); ParamSerde::new(id.serialize(), tensor.into_item()) } fn from_item(item: Self::Item, device: &B::Device) -> Self { B::memory_persistent_allocations(device, item, |item| { Param::initialized( ParamId::deserialize(&item.id), Tensor::from_item(item.param, device), ) }) } } impl Record for Param> where B: Backend, { type Item = ParamSerde; fn into_item(self) -> Self::Item { let (id, tensor, mapper) = self.consume(); let tensor = mapper.on_save(tensor); ParamSerde::new(id.serialize(), tensor.into_item::()) } fn from_item(item: Self::Item, device: &B::Device) -> Self { B::memory_persistent_allocations(device, item, |item| { Param::initialized( ParamId::deserialize(&item.id), Tensor::from_item::(item.param, device), ) }) } } // Type that can be serialized as is without any conversion. macro_rules! primitive { ($type:ty) => { impl Record for $type { type Item = $type; fn into_item(self) -> Self::Item { self } fn from_item(item: Self::Item, _device: &B::Device) -> Self { item } } }; } // General Types primitive!(alloc::string::String); primitive!(bool); // Float Types primitive!(f64); primitive!(f32); primitive!(half::bf16); primitive!(half::f16); // Unsigned Integer Types primitive!(usize); primitive!(u64); primitive!(u32); primitive!(u16); primitive!(u8); // Signed Integer Types primitive!(isize); primitive!(i64); primitive!(i32); primitive!(i16); primitive!(i8); /// A wrapper around an array of size N, so that it can be serialized and deserialized /// using serde. /// /// The reason why we wrap the array in a struct is because serde does not support /// deserializing arrays of variable size, /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937) /// for backward compatibility reasons. Serde APIs were created before const generics. #[derive(Clone)] pub struct Array([T; N]); impl Serialize for Array { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { let mut seq = serializer.serialize_tuple(self.0.len())?; for element in &self.0 { seq.serialize_element(element)?; } seq.end() } } impl<'de, T, const N: usize> Deserialize<'de> for Array where T: Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { struct ArrayVisitor { marker: PhantomData, } impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor where T: Deserialize<'de>, { type Value = Array; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a fixed size array") } fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de>, { let mut items = vec![]; for i in 0..N { let item = seq .next_element()? .ok_or_else(|| Error::invalid_length(i, &self))?; items.push(item); } let array: [T; N] = items .into_iter() .collect::>() .try_into() .map_err(|_| "An array of size {N}") .unwrap(); Ok(Array(array)) } } deserializer.deserialize_tuple( N, ArrayVisitor { marker: PhantomData, }, ) } } ================================================ FILE: crates/burn-core/src/record/recorder.rs ================================================ use core::any::type_name; use core::marker::PhantomData; use alloc::format; use alloc::string::{String, ToString}; use burn_tensor::backend::Backend; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record}; #[cfg(feature = "std")] use super::{ BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings, PrettyJsonFileRecorder, }; /// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned). pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone { /// Type of the settings used by the recorder. type Settings: PrecisionSettings; /// Arguments used to record objects. type RecordArgs: Clone; /// Record output type. type RecordOutput; /// Arguments used to load recorded objects. type LoadArgs; /// Records an item. /// /// # Arguments /// /// * `record` - The item to record. /// * `args` - Arguments used to record the item. /// /// # Returns /// /// The output of the recording. fn record( &self, record: R, args: Self::RecordArgs, ) -> Result where R: Record, { let item = record.into_item::(); let item = BurnRecord::new::(item); self.save_item(item, args) } /// Load an item from the given arguments. fn load(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result where R: Record, { let item: BurnRecord, B> = self.load_item(&mut args).map_err(|err| { if let Ok(record) = self.load_item::(&mut args) { let mut message = "Unable to load record.".to_string(); let metadata = recorder_metadata::(); if metadata.float != record.metadata.float { message += format!( "\nMetadata has a different float type: Actual {:?}, Expected {:?}", record.metadata.float, metadata.float ) .as_str(); } if metadata.int != record.metadata.int { message += format!( "\nMetadata has a different int type: Actual {:?}, Expected {:?}", record.metadata.int, metadata.int ) .as_str(); } if metadata.format != record.metadata.format { message += format!( "\nMetadata has a different format: Actual {:?}, Expected {:?}", record.metadata.format, metadata.format ) .as_str(); } if metadata.version != record.metadata.version { message += format!( "\nMetadata has a different Burn version: Actual {:?}, Expected {:?}", record.metadata.version, metadata.version ) .as_str(); } message += format!("\nError: {err:?}").as_str(); return RecorderError::Unknown(message); } err })?; Ok(R::from_item(item.item, device)) } /// Saves an item. /// /// This method is used by [record](Recorder::record) to save the item. /// /// # Arguments /// /// * `item` - Item to save. /// * `args` - Arguments to use to save the item. /// /// # Returns /// /// The output of the save operation. fn save_item( &self, item: I, args: Self::RecordArgs, ) -> Result; /// Loads an item. /// /// This method is used by [load](Recorder::load) to load the item. /// /// # Arguments /// /// * `args` - Arguments to use to load the item. /// /// # Returns /// /// The loaded item. fn load_item(&self, args: &mut Self::LoadArgs) -> Result where I: DeserializeOwned; } fn recorder_metadata() -> BurnMetadata where R: Recorder, B: Backend, { BurnMetadata::new( type_name::<::FloatElem>().to_string(), type_name::<::IntElem>().to_string(), type_name::().to_string(), env!("CARGO_PKG_VERSION").to_string(), format!("{:?}", R::Settings::default()), ) } /// Error that can occur when using a [Recorder](Recorder). #[derive(Debug)] pub enum RecorderError { /// File not found. FileNotFound(String), /// Failed to read file. DeserializeError(String), /// Other error. Unknown(String), } impl core::fmt::Display for RecorderError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(format!("{self:?}").as_str()) } } impl core::error::Error for RecorderError {} pub(crate) fn bin_config() -> bincode::config::Configuration { bincode::config::standard() } /// Metadata of a record. #[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct BurnMetadata { /// Float type used to record the item. pub float: String, /// Int type used to record the item. pub int: String, /// Format used to record the item. pub format: String, /// Burn record version used to record the item. pub version: String, /// Settings used to record the item. pub settings: String, } /// Record that can be saved by a [Recorder](Recorder). #[derive(Serialize, Deserialize, Debug)] pub struct BurnRecord { /// Metadata of the record. pub metadata: BurnMetadata, /// Item to record. pub item: I, _b: PhantomData, } impl BurnRecord { /// Creates a new record. /// /// # Arguments /// /// * `item` - Item to record. /// /// # Returns /// /// The new record. pub fn new>(item: I) -> Self { let metadata = recorder_metadata::(); Self { metadata, item, _b: PhantomData, } } } /// Record that can be saved by a [Recorder](Recorder) without the item. #[derive(new, Debug, Serialize, Deserialize)] pub struct BurnRecordNoItem { /// Metadata of the record. pub metadata: BurnMetadata, } /// Default recorder. /// /// It uses the [named msgpack](rmp_serde) format for serialization with full precision. #[cfg(feature = "std")] pub type DefaultRecorder = DefaultFileRecorder; /// Recorder optimized for compactness. /// /// It uses the [named msgpack](rmp_serde) format for serialization with half precision. /// If you are looking for the recorder that offers the smallest file size, have a look at /// [sensitive compact recorder](SensitiveCompactRecorder). #[cfg(feature = "std")] pub type CompactRecorder = DefaultFileRecorder; /// Recorder optimized for compactness making it a good choice for model deployment. /// /// It uses the [bincode](bincode) format for serialization and half precision. /// This format is not resilient to type changes since no metadata is encoded. /// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder) /// for long term data storage. #[cfg(feature = "std")] pub type SensitiveCompactRecorder = BinGzFileRecorder; /// Training recorder compatible with no-std inference. #[cfg(feature = "std")] pub type NoStdTrainingRecorder = BinFileRecorder; /// Inference recorder compatible with no-std. pub type NoStdInferenceRecorder = BinBytesRecorder; /// Debug recorder. /// /// It uses the [pretty json](serde_json) format for serialization with full precision making it /// human readable. #[cfg(feature = "std")] pub type DebugRecordSettings = PrettyJsonFileRecorder; #[cfg(all(test, feature = "std"))] mod tests { static FILE_PATH: &str = "/tmp/burn_test_record"; use crate::TestBackend; use super::*; use burn_tensor::{Device, ElementConversion}; #[test] #[should_panic] fn err_when_invalid_item() { #[derive(new, Serialize, Deserialize, Clone)] struct Item { value: S::FloatElem, } impl Record for Item where D: PrecisionSettings, B: Backend, { type Item = Item; fn into_item(self) -> Self::Item { Item { value: self.value.elem(), } } fn from_item(item: Self::Item, _device: &B::Device) -> Self { Item { value: item.value.elem(), } } } let item = Item::::new(16.elem()); let device: Device = Default::default(); // Serialize in f32. let recorder = DefaultFileRecorder::::new(); Recorder::::record(&recorder, item, FILE_PATH.into()).unwrap(); // Can't deserialize f32 into f16. let recorder = DefaultFileRecorder::::new(); Recorder::::load::>( &recorder, FILE_PATH.into(), &device, ) .unwrap(); } } ================================================ FILE: crates/burn-core/src/record/serde/adapter.rs ================================================ use super::data::NestedValue; /// A trait that defines the adapter for a Burn module. /// /// This is used to adapt an incoming module to a Burn module. pub trait BurnModuleAdapter: Sized { /// Adapts a module. fn adapt(name: &str, data: NestedValue) -> NestedValue { match name { "BatchNorm" => Self::adapt_batch_norm(data), "Conv1d" => Self::adapt_conv1d(data), "Conv2d" => Self::adapt_conv2d(data), "Conv3d" => Self::adapt_conv3d(data), "ConvTranspose1d" => Self::adapt_conv_transpose_1d(data), "ConvTranspose2d" => Self::adapt_conv_transpose_2d(data), "ConvTranspose3d" => Self::adapt_conv_transpose_3d(data), "Embedding" => Self::adapt_embedding(data), "GroupNorm" => Self::adapt_group_norm(data), "LayerNorm" => Self::adapt_layer_norm(data), "Linear" => Self::adapt_linear(data), _ => data, } } /// Adapts a linear module. fn adapt_linear(data: NestedValue) -> NestedValue { data } /// Adapts a Convolution 1D module. fn adapt_conv1d(data: NestedValue) -> NestedValue { data } /// Adapts a Convolution 2D module. fn adapt_conv2d(data: NestedValue) -> NestedValue { data } /// Adapts a Convolution 3D module. fn adapt_conv3d(data: NestedValue) -> NestedValue { data } /// Adapts convolution transpose 1D module. fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue { data } /// Adapts convolution transpose 2D module. fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue { data } /// Adapts convolution transpose 2D module. fn adapt_conv_transpose_3d(data: NestedValue) -> NestedValue { data } /// Adapts embedding module. fn adapt_embedding(data: NestedValue) -> NestedValue { data } /// Adapts group normalization module. fn adapt_group_norm(data: NestedValue) -> NestedValue { data } /// Adapts layer normalization module. fn adapt_layer_norm(data: NestedValue) -> NestedValue { data } /// Adapts batch normalization module. fn adapt_batch_norm(data: NestedValue) -> NestedValue { data } } /// Default adapter that takes no action. pub struct DefaultAdapter; impl BurnModuleAdapter for DefaultAdapter {} ================================================ FILE: crates/burn-core/src/record/serde/data.rs ================================================ use std::collections::HashMap; use super::adapter::BurnModuleAdapter; use super::de::Deserializer; use super::error::Error; use super::ser::Serializer; use crate::record::{PrecisionSettings, Record}; use crate::tensor::backend::Backend; use alloc::fmt; use burn_tensor::Bytes; use num_traits::cast::ToPrimitive; use regex::Regex; use serde::Deserialize; /// The main data structure used for deserialization. /// /// It can hold tree-like structures of nested maps and vectors. #[derive(Clone)] pub enum NestedValue { /// The default value, which actually does not hold any value and it is used to indicate that /// the value should be populated with the default value. It contains an optional string with /// the originator field name. Default(Option), /// A boolean value. Bool(bool), /// A string value. String(String), /// Floating point 32-bit value. F32(f32), /// Floating point 64-bit value. F64(f64), /// Signed 16-bit integer value. I16(i16), /// Signed 32-bit integer value. I32(i32), /// Signed 64-bit integer value. I64(i64), /// Unsigned 8-bit integer value. U8(u8), /// Unsigned 16-bit integer value used for bf16 and f16 serialization U16(u16), /// Unsigned 64-bit integer value. U64(u64), /// A map of nested values (typically used for structs) Map(HashMap), /// A vector of nested values (typically used for vector of structs or numbers) Vec(Vec), /// A vector of 8-bit unsigned integer values. U8s(Vec), /// A vector of 16-bit unsigned integer values. U16s(Vec), /// A vector of 32-bit floating point values. F32s(Vec), /// An opaque vector of bytes, with alignment. Bytes(Bytes), } impl NestedValue { /// Get the nested value as a map. pub fn as_map(self) -> Option> { match self { NestedValue::Map(map) => Some(map), _ => None, } } /// Get the nested value as a boolean. pub fn as_bool(self) -> Option { match self { NestedValue::Bool(bool) => Some(bool), _ => None, } } /// Get the nested value as a string. pub fn as_string(self) -> Option { match self { NestedValue::String(string) => Some(string), _ => None, } } /// Get the nested value as a f32. pub fn as_f32(self) -> Option { match self { NestedValue::F32(f32) => Some(f32), NestedValue::F64(f) => f.to_f32(), _ => None, } } /// Get the nested value as a f64. pub fn as_f64(self) -> Option { match self { NestedValue::F64(f64) => Some(f64), NestedValue::F32(f) => f.to_f64(), _ => None, } } /// Get the nested value as an i16. pub fn as_i16(self) -> Option { match self { NestedValue::I16(i16) => Some(i16), NestedValue::I32(i) => i.to_i16(), NestedValue::I64(i) => i.to_i16(), NestedValue::U16(u) => u.to_i16(), NestedValue::U64(u) => u.to_i16(), _ => None, } } /// Get the nested value as an i32. pub fn as_i32(self) -> Option { match self { NestedValue::I32(i32) => Some(i32), NestedValue::I16(i) => i.to_i32(), NestedValue::I64(i) => i.to_i32(), NestedValue::U16(u) => u.to_i32(), NestedValue::U64(u) => u.to_i32(), _ => None, } } /// Get the nested value as an i64. pub fn as_i64(self) -> Option { match self { NestedValue::I64(i64) => Some(i64), NestedValue::I16(i) => i.to_i64(), NestedValue::I32(i) => i.to_i64(), NestedValue::U16(u) => u.to_i64(), NestedValue::U64(u) => u.to_i64(), _ => None, } } /// Get the nested value as a u8. pub fn as_u8(self) -> Option { match self { NestedValue::U8(u8) => Some(u8), NestedValue::I16(i) => i.to_u8(), NestedValue::I32(i) => i.to_u8(), NestedValue::I64(i) => i.to_u8(), NestedValue::U16(u) => u.to_u8(), NestedValue::U64(u) => u.to_u8(), _ => None, } } /// Get the nested value as a u16. pub fn as_u16(self) -> Option { match self { NestedValue::U16(u16) => Some(u16), NestedValue::I16(i) => i.to_u16(), NestedValue::I32(i) => i.to_u16(), NestedValue::I64(i) => i.to_u16(), NestedValue::U64(u) => u.to_u16(), _ => None, } } /// Get the nested value as a u64. pub fn as_u64(self) -> Option { match self { NestedValue::U64(u64) => Some(u64), NestedValue::I16(i) => i.to_u64(), NestedValue::I32(i) => i.to_u64(), NestedValue::I64(i) => i.to_u64(), NestedValue::U16(u) => u.to_u64(), _ => None, } } /// Get the nested value as a vector of bytes. pub fn as_bytes(self) -> Option { match self { NestedValue::Bytes(u) => Some(u), NestedValue::U8s(u) => Some(Bytes::from_elems(u)), _ => None, } } /// Deserialize a nested value into a record type. pub fn try_into_record(self, device: &B::Device) -> Result where B: Backend, T: Record, PS: PrecisionSettings, A: BurnModuleAdapter, { let deserializer = Deserializer::::new(self, false); let item = T::Item::deserialize(deserializer)?; // Convert the deserialized item into a Record instance Ok(T::from_item::(item, device)) } } /// Remap the tensor locations according to the key remapping. /// /// # Arguments /// /// * `tensors` - A map of tensors. /// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. /// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) /// for more information. /// /// # Returns /// /// A map of tensors with the remapped keys and /// a vector of tuples containing the remapped and original. pub fn remap( mut tensors: HashMap, key_remap: Vec<(Regex, String)>, ) -> (HashMap, Vec<(String, String)>) { if key_remap.is_empty() { let remapped_names = tensors .keys() .cloned() .map(|s| (s.clone(), s)) // Name is the same as the remapped name .collect(); return (tensors, remapped_names); } let mut remapped = HashMap::new(); let mut remapped_names = Vec::new(); for (name, tensor) in tensors.drain() { let mut new_name = name.clone(); for (pattern, replacement) in &key_remap { if pattern.is_match(&new_name) { new_name = pattern .replace_all(&new_name, replacement.as_str()) .to_string(); } } remapped_names.push((new_name.clone(), name)); remapped.insert(new_name, tensor); } (remapped, remapped_names) } /// Helper function to insert a value into a nested map/vector of tensors. fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) { if keys.is_empty() { *current = value; return; } match current { NestedValue::Map(map) => { if !map.contains_key(keys[0]) { let next = if keys[1..] .first() .and_then(|k| k.parse::().ok()) .is_some() { NestedValue::Vec(Vec::new()) } else { NestedValue::Map(HashMap::new()) }; map.insert(keys[0].to_string(), next); } insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value); } NestedValue::Vec(vec) => { let index = keys[0].parse::().unwrap(); if index >= vec.len() { vec.resize_with(index + 1, || NestedValue::Map(HashMap::new())); } insert_nested_value(&mut vec[index], &keys[1..], value); } _ => panic!("Invalid structure encountered"), } } /// A trait for encapsulating the serialization logic. pub trait Serializable { /// Serializes the object into a `NestedValue` using the provided `Serializer`. /// This method is generic over the precision settings `PS`. /// /// # Parameters /// - `serializer`: The `Serializer` to use for serializing the object. /// /// # Returns /// - `Result`: The result of serialization. /// Returns a `NestedValue` on success, /// or an `Error` on failure. /// /// # Type Parameters /// - `PS`: The precision settings to use during serialization. /// This is a generic parameter and can be any type /// that implements the `PrecisionSettings` trait. fn serialize(&self, serializer: Serializer) -> Result where PS: PrecisionSettings; } /// Convert a vector of tensors to a nested value. pub fn unflatten(input: HashMap) -> Result where PS: PrecisionSettings, T: Serializable, { let mut result = NestedValue::Map(HashMap::new()); for (key, value) in input { let parts: Vec<&str> = key.split('.').collect(); let st = value.serialize::(Serializer::new())?; insert_nested_value(&mut result, &parts, st); } cleanup_empty_maps(&mut result); Ok(result) } /// Removes empty maps from the nested value. /// /// We need to clean up empty maps from the nested value /// in some cases when there is non-contiguous indices in keys. fn cleanup_empty_maps(current: &mut NestedValue) { match current { NestedValue::Map(map) => { map.values_mut().for_each(cleanup_empty_maps); } NestedValue::Vec(vec) => { vec.iter_mut().for_each(cleanup_empty_maps); vec.retain(|v| !matches!(v, NestedValue::Map(m) if m.is_empty())); } _ => {} } } fn write_vec_truncated( vec: &[T], f: &mut core::fmt::Formatter, ) -> fmt::Result { write!(f, "Vec([")?; for (i, v) in vec.iter().take(3).enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{v:?}")?; } write!(f, ", ...] len={})", vec.len()) } impl fmt::Debug for NestedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { // Truncate values for vector NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f), // Handle other variants as usual NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(), NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(), NestedValue::String(s) => f.debug_tuple("String").field(s).finish(), NestedValue::F32(val) => f.debug_tuple("F32").field(val).finish(), NestedValue::F64(val) => f.debug_tuple("F64").field(val).finish(), NestedValue::I16(val) => f.debug_tuple("I16").field(val).finish(), NestedValue::I32(val) => f.debug_tuple("I32").field(val).finish(), NestedValue::I64(val) => f.debug_tuple("I64").field(val).finish(), NestedValue::U8(val) => f.debug_tuple("U8").field(val).finish(), NestedValue::U16(val) => f.debug_tuple("U16").field(val).finish(), NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(), NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(), NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(), } } } ================================================ FILE: crates/burn-core/src/record/serde/de.rs ================================================ use core::ptr; use std::collections::HashMap; use super::data::NestedValue; use super::{adapter::BurnModuleAdapter, error::Error}; use serde::de::{EnumAccess, VariantAccess}; use serde::{ de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}, forward_to_deserialize_any, }; const RECORD_ITEM_SUFFIX: &str = "RecordItem"; /// A deserializer for the nested value data structure. pub struct Deserializer { // This string starts with the input data and characters are truncated off // the beginning as data is parsed. value: Option, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } impl Deserializer { /// Creates a new deserializer with the given nested value. /// /// # Arguments /// /// * `value` - A nested value. /// * `default_for_missing_fields` - A boolean indicating whether to add missing fields with default value. pub fn new(value: NestedValue, default_for_missing_fields: bool) -> Self { Self { value: Some(value), default_for_missing_fields, phantom: std::marker::PhantomData, } } } impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { type Error = Error; fn deserialize_any(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_any is not implemented") } fn deserialize_struct( self, name: &'static str, fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { let value = match self.value { Some(value) => { // Adapt modules if let Some(name) = name.strip_suffix(RECORD_ITEM_SUFFIX) { A::adapt(name, value) } else { value } } None => { return Err(de::Error::custom(format!( "Expected some value but got {:?}", self.value ))); } }; match value { NestedValue::Map(map) => { // Add missing fields into the map with default value if needed. let map = if self.default_for_missing_fields { let mut map = map; for field in fields.iter().map(|s| s.to_string()) { map.entry(field.clone()) .or_insert(NestedValue::Default(Some(field))); } map } else { map }; visitor.visit_map(HashMapAccess::::new( map, self.default_for_missing_fields, )) } _ => Err(de::Error::custom(format!( "Expected struct but got {value:?}" ))), } } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_string(self.value.unwrap().as_string().unwrap().to_string()) } fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_unit() } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { match self.value { Some(NestedValue::Map(map)) => visitor.visit_map(HashMapAccess::::new( map, self.default_for_missing_fields, )), _ => Err(de::Error::custom(format!( "Expected map value but got {:?}", self.value ))), } } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_bool(self.value.unwrap().as_bool().unwrap()) } fn deserialize_i8(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_i8 is not implemented") } fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i16(self.value.unwrap().as_i16().unwrap().to_owned()) } fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i32(self.value.unwrap().as_i32().unwrap().to_owned()) } fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i64(self.value.unwrap().as_i64().unwrap().to_owned()) } fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u8(self.value.unwrap().as_u8().unwrap().to_owned()) } fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u16(self.value.unwrap().as_u16().unwrap().to_owned()) } fn deserialize_u32(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_u32 is not implemented") } fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u64(self.value.unwrap().as_u64().unwrap().to_owned()) } fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_f32(self.value.unwrap().as_f32().unwrap().to_owned()) } fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_f64(self.value.unwrap().as_f64().unwrap().to_owned()) } fn deserialize_char(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_char is not implemented") } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_str(self.value.unwrap().as_string().unwrap().as_ref()) } fn deserialize_bytes(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_bytes is not implemented") } fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de>, { let bytes = self.value.unwrap().as_bytes().unwrap(); match bytes.try_into_vec::() { Ok(bytes) => visitor.visit_byte_buf(bytes), Err(bytes) => visitor.visit_bytes(&bytes), } } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de>, { if let Some(value) = self.value { visitor.visit_some(Deserializer::::new( value, self.default_for_missing_fields, )) } else { visitor.visit_none() } } fn deserialize_unit(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_unit is not implemented") } fn deserialize_unit_struct( self, _name: &'static str, _visitor: V, ) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_unit_struct is not implemented") } fn deserialize_newtype_struct( self, _name: &'static str, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_newtype_struct(Deserializer::::new( self.value.unwrap(), self.default_for_missing_fields, )) } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { if let Some(value) = self.value { match value { NestedValue::Vec(_) => visitor.visit_seq(VecSeqAccess::::new( value, self.default_for_missing_fields, )), NestedValue::U8s(_) => visitor.visit_seq(VecSeqAccess::::new( value, self.default_for_missing_fields, )), NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::::new( value, self.default_for_missing_fields, )), NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::::new( value, self.default_for_missing_fields, )), _ => Err(de::Error::custom(format!("Expected Vec but got {value:?}"))), } } else { Err(de::Error::custom("Expected Vec but got None")) } } fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_tuple is not implemented") } fn deserialize_tuple_struct( self, _name: &'static str, _len: usize, _visitor: V, ) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_tuple_struct is not implemented") } /// Deserializes an enum by attempting to match its variants against the provided data. /// /// This function attempts to deserialize an enum by iterating over its possible variants /// and trying to deserialize the data into each until one succeeds. We need to do this /// because we don't have a way to know which variant to deserialize from the data. /// /// This is similar to Serde's /// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged), /// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force /// using `deserialize_any`, which is not what we want because we want to use methods, such /// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just /// for enums because it will affect other serialization and deserialization, such /// as JSON and Bincode. /// /// # Safety /// The function uses an unsafe block to clone the `visitor`. This is necessary because /// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it /// as we are going to use it multiple times. The Visitor is a code generated unit struct /// with no states or mutations, so it is safe to clone it in this case. We mainly care /// about the `visit_enum` method, which is the only method that will be called on the /// cloned visitor. fn deserialize_enum( self, _name: &'static str, variants: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { fn clone_unsafely(thing: &T) -> T { unsafe { // Allocate memory for the clone. let mut clone = std::mem::MaybeUninit::::uninit(); // Get a mutable pointer to the allocated memory. let clone_ptr = clone.as_mut_ptr(); // Copy the memory ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1); // Assume the cloned data is initialized and convert it to an owned instance of T. clone.assume_init() } } // Try each variant in order for &variant in variants { // clone visitor to avoid moving it let cloned_visitor = clone_unsafely(&visitor); let result = cloned_visitor.visit_enum(ProbeEnumAccess::::new( self.value.clone().unwrap(), variant.to_owned(), self.default_for_missing_fields, )); if result.is_ok() { return result; } } Err(de::Error::custom("No variant match")) } fn deserialize_identifier(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("deserialize_identifier is not implemented") } } /// A sequence access for a vector in the nested value data structure. struct VecSeqAccess { iter: Box>, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } // Concrete implementation for `Vec` impl VecSeqAccess { fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { match vec { NestedValue::Vec(v) => VecSeqAccess { iter: Box::new(v.into_iter()), default_for_missing_fields, phantom: std::marker::PhantomData, }, _ => panic!("Invalid vec sequence"), } } } // Concrete implementation for `Vec` impl VecSeqAccess { fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { match vec { NestedValue::U8s(v) => VecSeqAccess { iter: Box::new(v.into_iter()), default_for_missing_fields, phantom: std::marker::PhantomData, }, _ => panic!("Invalid vec sequence"), } } } // Concrete implementation for `Vec` impl VecSeqAccess { fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { match vec { NestedValue::U16s(v) => VecSeqAccess { iter: Box::new(v.into_iter()), default_for_missing_fields, phantom: std::marker::PhantomData, }, _ => panic!("Invalid vec sequence"), } } } // Concrete implementation for `Vec` impl VecSeqAccess { fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { match vec { NestedValue::F32s(v) => VecSeqAccess { iter: Box::new(v.into_iter()), default_for_missing_fields, phantom: std::marker::PhantomData, }, _ => panic!("Invalid vec sequence"), } } } // Concrete implementation for `Vec` impl<'de, A> SeqAccess<'de> for VecSeqAccess where NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { let item = match self.iter.next() { Some(v) => v, None => return Ok(None), }; seed.deserialize( NestedValueWrapper::::new(item, self.default_for_missing_fields).into_deserializer(), ) .map(Some) } } // Concrete implementation for `Vec` impl<'de, A> SeqAccess<'de> for VecSeqAccess where NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { let item = match self.iter.next() { Some(v) => v, None => return Ok(None), }; seed.deserialize( NestedValueWrapper::::new(NestedValue::U8(item), self.default_for_missing_fields) .into_deserializer(), ) .map(Some) } } // Concrete implementation for `Vec` impl<'de, A> SeqAccess<'de> for VecSeqAccess where NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { let item = match self.iter.next() { Some(v) => v, None => return Ok(None), }; seed.deserialize( NestedValueWrapper::::new(NestedValue::U16(item), self.default_for_missing_fields) .into_deserializer(), ) .map(Some) } } // Concrete implementation for `Vec` impl<'de, A> SeqAccess<'de> for VecSeqAccess where NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { let item = match self.iter.next() { Some(v) => v, None => return Ok(None), }; seed.deserialize( NestedValueWrapper::::new(NestedValue::F32(item), self.default_for_missing_fields) .into_deserializer(), ) .map(Some) } } /// A map access for a map in the nested value data structure. struct HashMapAccess { iter: std::collections::hash_map::IntoIter, next_value: Option, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } impl HashMapAccess { fn new(map: HashMap, default_for_missing_fields: bool) -> Self { HashMapAccess { iter: map.into_iter(), next_value: None, default_for_missing_fields, phantom: std::marker::PhantomData, } } } impl<'de, A> MapAccess<'de> for HashMapAccess where String: IntoDeserializer<'de, Error>, NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, { type Error = Error; fn next_key_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { match self.iter.next() { Some((k, v)) => { // Keep the value for the next call to next_value_seed. self.next_value = Some(v); // Deserialize the key. seed.deserialize(k.into_deserializer()).map(Some) } None => Ok(None), } } fn next_value_seed(&mut self, seed: T) -> Result where T: DeserializeSeed<'de>, { match self.next_value.take() { Some(NestedValue::Default(originator)) => { seed.deserialize(DefaultDeserializer::new(originator)) } Some(v) => seed.deserialize( NestedValueWrapper::new(v, self.default_for_missing_fields).into_deserializer(), ), None => seed.deserialize(DefaultDeserializer::new(None)), } } } struct ProbeEnumAccess { value: NestedValue, current_variant: String, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } impl ProbeEnumAccess { fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self { ProbeEnumAccess { value, current_variant, default_for_missing_fields, phantom: std::marker::PhantomData, } } } impl<'de, A> EnumAccess<'de> for ProbeEnumAccess where A: BurnModuleAdapter, { type Error = Error; type Variant = Self; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where V: DeserializeSeed<'de>, { seed.deserialize(self.current_variant.clone().into_deserializer()) .map(|v| (v, self)) } } impl<'de, A> VariantAccess<'de> for ProbeEnumAccess where A: BurnModuleAdapter, { type Error = Error; fn newtype_variant_seed(self, seed: T) -> Result where T: DeserializeSeed<'de>, { let value = seed.deserialize( NestedValueWrapper::::new(self.value, self.default_for_missing_fields) .into_deserializer(), )?; Ok(value) } fn unit_variant(self) -> Result<(), Self::Error> { // Support tensor `DType` deserialization match self.value { NestedValue::Map(value) if value.contains_key("DType") => { match value.get("DType") { Some(NestedValue::String(variant)) => { if *variant == self.current_variant { Ok(()) } else { Err(Error::Other("Wrong variant".to_string())) // wrong match } } _ => panic!("expected DType variant as string"), } } _ => unimplemented!( "unit variant is not implemented because it is not used in the burn module" ), } } fn tuple_variant(self, _len: usize, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!("tuple variant is not implemented because it is not used in the burn module") } fn struct_variant( self, _fields: &'static [&'static str], _visitor: V, ) -> Result where V: Visitor<'de>, { unimplemented!( "struct variant is not implemented because it is not used in the burn module" ) } } /// A wrapper for the nested value data structure with a burn module adapter. struct NestedValueWrapper { value: NestedValue, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } impl NestedValueWrapper { fn new(value: NestedValue, default_for_missing_fields: bool) -> Self { Self { value, default_for_missing_fields, phantom: std::marker::PhantomData, } } } impl IntoDeserializer<'_, Error> for NestedValueWrapper { type Deserializer = Deserializer; fn into_deserializer(self) -> Self::Deserializer { Deserializer::::new(self.value, self.default_for_missing_fields) } } /// A default deserializer that always returns the default value. struct DefaultDeserializer { /// The originator field name (the top-level missing field name) originator_field_name: Option, } impl DefaultDeserializer { fn new(originator_field_name: Option) -> Self { Self { originator_field_name, } } } impl<'de> serde::Deserializer<'de> for DefaultDeserializer { type Error = Error; fn deserialize_any(self, _visitor: V) -> Result where V: Visitor<'de>, { unimplemented!() } fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i32(Default::default()) } fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_f32(Default::default()) } fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i16(Default::default()) } fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i64(Default::default()) } fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u16(Default::default()) } fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u64(Default::default()) } fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_f64(Default::default()) } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_bool(Default::default()) } fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_char(Default::default()) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_str(Default::default()) } fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_i8(Default::default()) } fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u8(Default::default()) } fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_u32(Default::default()) } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_none() } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(DefaultSeqAccess::new(None)) } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_string(Default::default()) } fn deserialize_struct( self, name: &'static str, _fields: &'static [&'static str], _visitor: V, ) -> Result where V: Visitor<'de>, { // Return an error if the originator field name is not set Err(Error::Other(format!( "Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct", self.originator_field_name.unwrap_or("UNKNOWN".to_string()), name, ))) } fn deserialize_tuple_struct( self, _name: &'static str, len: usize, visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_seq(DefaultSeqAccess::new(Some(len))) } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_seq(DefaultSeqAccess::new(Some(len))) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_map(DefaultMapAccess::new()) } forward_to_deserialize_any! { u128 bytes byte_buf unit unit_struct newtype_struct enum identifier ignored_any } } /// A default sequence access that always returns None (empty sequence). pub struct DefaultSeqAccess { size: Option, } impl Default for DefaultSeqAccess { fn default() -> Self { Self::new(None) } } impl DefaultSeqAccess { /// Creates a new default sequence access with the given size hint. pub fn new(size: Option) -> Self { DefaultSeqAccess { size } } } impl<'de> SeqAccess<'de> for DefaultSeqAccess { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { match self.size { Some(0) => Ok(None), Some(ref mut size) => { *size -= 1; seed.deserialize(DefaultDeserializer::new(None)).map(Some) } None => Ok(None), } } fn size_hint(&self) -> Option { self.size } } /// A default map access that always returns None (empty map). pub struct DefaultMapAccess; impl Default for DefaultMapAccess { fn default() -> Self { Self::new() } } impl DefaultMapAccess { /// Creates a new default map access. pub fn new() -> Self { DefaultMapAccess } } impl<'de> MapAccess<'de> for DefaultMapAccess { type Error = Error; fn next_key_seed(&mut self, _seed: T) -> Result, Self::Error> where T: DeserializeSeed<'de>, { // Since this is a default implementation, we'll just return None. Ok(None) } fn next_value_seed(&mut self, _seed: T) -> Result where T: DeserializeSeed<'de>, { unimplemented!("This should never be called since next_key_seed always returns None") } fn size_hint(&self) -> Option { // Since this is a default implementation, we'll just return None. None } } ================================================ FILE: crates/burn-core/src/record/serde/error.rs ================================================ use crate::record::RecorderError; /// The error type for Record serde. #[derive(thiserror::Error, Debug)] pub enum Error { /// Failed to deserialize. #[error("failed to deserialize: {0}")] Deserialize(#[from] serde::de::value::Error), /// Failed to serialize. #[error("failed to serialize")] Serialize(String), /// Encountered an invalid state. #[error("invalid state")] InvalidState, /// Other error. #[error("other error: {0}")] Other(String), } impl serde::de::Error for Error { fn custom(msg: T) -> Self { Error::Deserialize(serde::de::value::Error::custom(msg.to_string())) } } impl serde::ser::Error for Error { fn custom(msg: T) -> Self { Error::Serialize(msg.to_string()) } } // Implement From trait for Error to RecorderError impl From for RecorderError { fn from(error: Error) -> Self { RecorderError::DeserializeError(error.to_string()) } } ================================================ FILE: crates/burn-core/src/record/serde/mod.rs ================================================ //! Module contains the serde implementation for the record module //! useful for custom importing model weights, such as PyTorch's pt file format. /// The adapter trait that is used to convert the nested value to the module type. pub mod adapter; /// The main data structure used for deserialization. pub mod data; /// The deserializer that is used to convert the nested value to the record. pub mod ser; /// The deserializer that is used to convert the nested value to the record. pub mod de; /// Error types. pub mod error; ================================================ FILE: crates/burn-core/src/record/serde/ser.rs ================================================ use std::collections::HashMap; use super::{ data::NestedValue, error::{self, Error}, }; use serde::{ Serialize, ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait}, }; /// Simple struct serializer that converts a struct into NestedValues. /// /// NOTE: This is used to serialize Param structs into NestedValues and not so much for /// the actual serialization of modules (although it could be used for that as well if all /// primitive types are implemented). #[derive(Clone)] pub struct Serializer { /// The state of the serialization process state: Option, } impl Serializer { /// Creates a new serializer. pub fn new() -> Self { Serializer { state: None } } } impl Default for Serializer { fn default() -> Self { Self::new() } } impl SerializerTrait for Serializer { type Ok = NestedValue; type Error = Error; type SerializeSeq = Self; type SerializeTuple = ser::Impossible; type SerializeTupleStruct = ser::Impossible; type SerializeTupleVariant = ser::Impossible; type SerializeMap = ser::Impossible; type SerializeStruct = Self; type SerializeStructVariant = ser::Impossible; fn serialize_struct( self, _name: &'static str, _len: usize, ) -> Result { Ok(self) } fn serialize_newtype_struct( self, _name: &'static str, value: &T, ) -> Result where T: Serialize + ?Sized, { value.serialize(self) } fn serialize_seq(self, _len: Option) -> Result { Ok(self) } fn serialize_i32(self, v: i32) -> Result { Ok(NestedValue::I32(v)) } fn serialize_str(self, v: &str) -> Result { Ok(NestedValue::String(v.to_string())) } fn serialize_i16(self, v: i16) -> Result { Ok(NestedValue::I16(v)) } fn serialize_i64(self, v: i64) -> Result { Ok(NestedValue::I64(v)) } fn serialize_u16(self, v: u16) -> Result { Ok(NestedValue::U16(v)) } fn serialize_u64(self, v: u64) -> Result { Ok(NestedValue::U64(v)) } fn serialize_f32(self, v: f32) -> Result { Ok(NestedValue::F32(v)) } fn serialize_f64(self, v: f64) -> Result { Ok(NestedValue::F64(v)) } // The following methods are not implemented because they are not needed for the // serialization of Param structs. fn serialize_char(self, _v: char) -> Result { unimplemented!() } fn serialize_bytes(self, v: &[u8]) -> Result { Ok(NestedValue::U8s(v.to_vec())) } fn serialize_none(self) -> Result { Ok(NestedValue::Default(None)) } fn serialize_u32(self, _v: u32) -> Result { unimplemented!() } fn serialize_bool(self, _v: bool) -> Result { unimplemented!() } fn serialize_i8(self, _v: i8) -> Result { unimplemented!() } fn serialize_u8(self, v: u8) -> Result { Ok(NestedValue::U8(v)) } fn serialize_some(self, value: &T) -> Result where T: Serialize + ?Sized, { value.serialize(self) } fn serialize_unit(self) -> Result { unimplemented!() } fn serialize_unit_struct(self, _name: &'static str) -> Result { unimplemented!() } fn serialize_unit_variant( self, _name: &'static str, _variant_index: u32, _variant: &'static str, ) -> Result { Ok(NestedValue::Map(HashMap::from([( _name.to_string(), NestedValue::String(_variant.to_string()), )]))) } fn serialize_newtype_variant( self, _name: &'static str, _variant_index: u32, _variant: &'static str, _value: &T, ) -> Result where T: Serialize + ?Sized, { unimplemented!() } fn serialize_tuple(self, _len: usize) -> Result { unimplemented!() } fn serialize_tuple_struct( self, _name: &'static str, _len: usize, ) -> Result { unimplemented!() } fn serialize_tuple_variant( self, _name: &'static str, _variant_index: u32, _variant: &'static str, _len: usize, ) -> Result { unimplemented!() } fn serialize_map(self, _len: Option) -> Result { unimplemented!() } fn serialize_struct_variant( self, _name: &'static str, _variant_index: u32, _variant: &'static str, _len: usize, ) -> Result { unimplemented!() } } // Implementing the SerializeStruct trait for Serializer impl SerializeStruct for Serializer { type Ok = NestedValue; type Error = Error; fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> where T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; match self.state { Some(NestedValue::Map(ref mut map)) => { map.insert(key.to_string(), serialized_value); // Inserting into the state } Some(_) => { panic!("Invalid state encountered"); } None => { let mut map = HashMap::new(); map.insert(key.to_string(), serialized_value); // Inserting into the state self.state = Some(NestedValue::Map(map)); } } Ok(()) } fn end(self) -> Result { if self.state.is_none() { // If the state is empty, return an empty map Ok(NestedValue::Map(HashMap::new())) } else { self.state.ok_or(error::Error::InvalidState) } } } impl SerializeSeq for Serializer { type Ok = NestedValue; type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> where T: Serialize + ?Sized, { let serialized_value = value.serialize(Serializer::new())?; match self.state { Some(NestedValue::Vec(ref mut vec)) => { vec.push(serialized_value); // Inserting into the state } Some(NestedValue::U8s(ref mut vec)) => { if let NestedValue::U8(val) = serialized_value { vec.push(val); } else { panic!("Invalid value type encountered"); } } Some(NestedValue::U16s(ref mut vec)) => { if let NestedValue::U16(val) = serialized_value { vec.push(val); } else { panic!("Invalid value type encountered"); } } Some(NestedValue::F32s(ref mut vec)) => { if let NestedValue::F32(val) = serialized_value { vec.push(val); } else { panic!("Invalid value type encountered"); } } Some(_) => { panic!("Invalid state encountered"); } None => { let val = match serialized_value { NestedValue::U8(val) => NestedValue::U8s(vec![val]), NestedValue::U16(val) => NestedValue::U16s(vec![val]), NestedValue::F32(val) => NestedValue::F32s(vec![val]), _ => NestedValue::Vec(vec![serialized_value]), }; self.state = Some(val); } } Ok(()) } fn end(self) -> Result { if self.state.is_none() { // If the state is empty, return an empty vector Ok(NestedValue::Vec(Vec::new())) } else { self.state.ok_or(error::Error::InvalidState) } } } #[cfg(test)] mod tests { use crate::{ TestBackend, module::{Param, ParamId}, record::{FullPrecisionSettings, Record}, tensor::Tensor, }; use serde::Deserialize; use super::*; #[derive(Serialize, Deserialize, Debug, Clone)] struct MyStruct1 { a: MyStruct3, b: MyStruct2, } #[derive(Serialize, Deserialize, Debug, Clone)] struct MyStruct2 { a: i32, b: Option, c: String, d: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] struct MyStruct3 { x: String, y: String, } #[test] fn test_serialize() { let my_struct = MyStruct1 { a: MyStruct3 { x: "Hello".to_owned(), y: "World".to_owned(), }, b: MyStruct2 { a: 1, b: None, c: "Hello".to_owned(), d: Some("World".to_owned()), }, }; let serialized = my_struct .serialize(Serializer::new()) .expect("Should serialize item successfully"); let serialized_str = format!("{serialized:?}"); // Compare the lengths of expected and actual serialized strings because // the order of the fields is not guaranteed for HashMaps. assert_eq!(serialized_str.len(), 135); } #[test] fn test_param_serde() { let device = Default::default(); let tensor: Tensor = Tensor::ones([2, 2], &device); let param = Param::initialized(ParamId::new(), tensor); let param_item = param.into_item::(); let serialized = param_item .serialize(Serializer::new()) .expect("Should serialize item successfully"); let bytes = serialized.as_map().expect("is a map")["param"] .clone() .as_map() .expect("param is a map")["bytes"] .clone() .as_bytes() .expect("has bytes vec"); assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); } } ================================================ FILE: crates/burn-core/src/record/settings.rs ================================================ use burn_tensor::Element; use serde::{Serialize, de::DeserializeOwned}; /// Settings allowing to control the precision when (de)serializing items. pub trait PrecisionSettings: Send + Sync + core::fmt::Debug + core::default::Default + Clone { /// Float element type. type FloatElem: Element + Serialize + DeserializeOwned; /// Integer element type. type IntElem: Element + Serialize + DeserializeOwned; } /// Default precision settings. #[derive(Debug, Default, Clone)] pub struct FullPrecisionSettings; /// Precision settings optimized for compactness. #[derive(Debug, Default, Clone)] pub struct HalfPrecisionSettings; /// Precision settings optimized for precision. #[derive(Debug, Default, Clone)] pub struct DoublePrecisionSettings; impl PrecisionSettings for FullPrecisionSettings { type FloatElem = f32; type IntElem = i32; } impl PrecisionSettings for DoublePrecisionSettings { type FloatElem = f64; type IntElem = i64; } impl PrecisionSettings for HalfPrecisionSettings { type FloatElem = half::f16; type IntElem = i16; } ================================================ FILE: crates/burn-core/src/record/tensor.rs ================================================ use core::marker::PhantomData; use super::{PrecisionSettings, Record}; use burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend}; use serde::{Deserialize, Serialize}; use alloc::format; /// Deserialize the value into [`TensorData`]. fn deserialize_data<'de, E, De>(deserializer: De) -> Result where E: Element + Deserialize<'de>, De: serde::Deserializer<'de>, { let data = TensorData::deserialize(deserializer).map_err(|e| { serde::de::Error::custom(format!( "{e:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n" )) })?; let data = if let DType::QFloat(_) = data.dtype { data // do not convert quantized tensors } else { data.convert::() }; Ok(data) } /// This struct implements serde to lazily serialize and deserialize a float tensor /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct FloatTensorSerde { data: TensorData, _e: PhantomData, } /// This struct implements serde to lazily serialize and deserialize an int tensor /// using the given [record settings](RecordSettings). #[derive(new, Clone, Debug)] pub struct IntTensorSerde { data: TensorData, _e: PhantomData, } /// This struct implements serde to lazily serialize and deserialize an bool tensor. #[derive(new, Clone, Debug)] pub struct BoolTensorSerde { data: TensorData, } // --- SERDE IMPLEMENTATIONS --- // impl Serialize for FloatTensorSerde { fn serialize(&self, serializer: Se) -> Result where Se: serde::Serializer, { self.data.serialize(serializer) } } impl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde { fn deserialize(deserializer: De) -> Result where De: serde::Deserializer<'de>, { let data = deserialize_data::(deserializer)?; Ok(Self::new(data)) } } impl Serialize for IntTensorSerde { fn serialize(&self, serializer: Se) -> Result where Se: serde::Serializer, { self.data.serialize(serializer) } } impl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde { fn deserialize(deserializer: De) -> Result where De: serde::Deserializer<'de>, { let data = deserialize_data::(deserializer)?; Ok(Self::new(data)) } } impl Serialize for BoolTensorSerde { fn serialize(&self, serializer: Se) -> Result where Se: serde::Serializer, { self.data.serialize(serializer) } } impl<'de> Deserialize<'de> for BoolTensorSerde { fn deserialize(deserializer: De) -> Result where De: serde::Deserializer<'de>, { let data = deserialize_data::(deserializer)?; Ok(Self::new(data)) } } // --- RECORD IMPLEMENTATIONS --- // impl Record for Tensor { type Item = FloatTensorSerde; fn into_item(self) -> Self::Item { let data = self.into_data(); let data = if let DType::QFloat(_) = data.dtype { data // do not convert quantized tensors } else { data.convert::() }; FloatTensorSerde::new(data) } fn from_item(item: Self::Item, device: &B::Device) -> Self { let data = if let DType::QFloat(_) = item.data.dtype { item.data // do not convert quantized tensors } else { item.data.convert::() }; Tensor::from_data(data, device) } } impl Record for Tensor { type Item = IntTensorSerde; fn into_item(self) -> Self::Item { IntTensorSerde::new(self.into_data().convert::()) } fn from_item(item: Self::Item, device: &B::Device) -> Self { Tensor::from_data(item.data.convert::(), device) } } impl Record for Tensor { type Item = BoolTensorSerde; fn into_item(self) -> Self::Item { BoolTensorSerde::new(self.into_data()) } fn from_item(item: Self::Item, device: &B::Device) -> Self { Tensor::from_data(item.data, device) } } ================================================ FILE: crates/burn-core/src/tensor.rs ================================================ pub use burn_tensor::*; ================================================ FILE: crates/burn-core/src/vision.rs ================================================ pub use burn_vision::*; ================================================ FILE: crates/burn-core/tests/test_derive_config.rs ================================================ use burn::config::{Config, config_to_json}; use burn_core as burn; #[derive(Config, Debug, PartialEq, Eq)] pub struct TestEmptyStructConfig {} #[derive(Config, Debug, PartialEq)] pub struct TestStructConfig { int: i32, #[config(default = 2)] int_default: i32, float: f32, #[config(default = 2.0)] float_default: f32, string: String, other_config: TestEmptyStructConfig, } #[derive(Config, Debug, PartialEq)] pub enum TestEnumConfig { None, Single(f32), Multiple(f32, String), Named { first: f32, second: String }, } #[cfg(feature = "std")] #[inline(always)] fn file_path(file_name: &str) -> std::path::PathBuf { std::env::temp_dir().join(file_name) } #[cfg(feature = "std")] #[test] fn struct_config_should_impl_serde() { let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); let file_path = file_path("test_struct_config.json"); config.save(&file_path).unwrap(); let config_loaded = TestStructConfig::load(&file_path).unwrap(); assert_eq!(config, config_loaded); } #[test] fn struct_config_should_impl_clone() { let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); assert_eq!(config, config.clone()); } #[test] fn struct_config_should_impl_display() { let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[cfg(feature = "std")] #[test] fn enum_config_no_value_should_impl_serde() { let config = TestEnumConfig::None; let file_path = file_path("test_enum_no_value_config.json"); config.save(&file_path).unwrap(); let config_loaded = TestEnumConfig::load(&file_path).unwrap(); assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_one_value_should_impl_serde() { let config = TestEnumConfig::Single(42.0); let file_path = file_path("test_enum_one_value_config.json"); config.save(&file_path).unwrap(); let config_loaded = TestEnumConfig::load(&file_path).unwrap(); assert_eq!(config, config_loaded); } #[cfg(feature = "std")] #[test] fn enum_config_multiple_values_should_impl_serde() { let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); let file_path = file_path("test_enum_multiple_values_config.json"); config.save(&file_path).unwrap(); let config_loaded = TestEnumConfig::load(&file_path).unwrap(); assert_eq!(config, config_loaded); } #[test] fn enum_config_should_impl_clone() { let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); assert_eq!(config, config.clone()); } #[test] fn enum_config_should_impl_display() { let config = TestEnumConfig::Multiple(42.0, "Allow".to_string()); assert_eq!(burn::config::config_to_json(&config), config.to_string()); } #[test] fn struct_config_can_load_binary() { let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new()); let binary = config_to_json(&config).as_bytes().to_vec(); let config_loaded = TestStructConfig::load_binary(&binary).unwrap(); assert_eq!(config, config_loaded); } ================================================ FILE: crates/burn-core/tests/test_derive_module.rs ================================================ use std::marker::PhantomData; use burn::module::Initializer; use burn::module::{Module, Param}; use burn::tensor::backend::Backend; use burn::tensor::{Int, Tensor}; use burn_core as burn; pub type TestBackend = burn_ndarray::NdArray; #[cfg(feature = "std")] pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[derive(Module, Debug)] pub struct ModuleBasic { weight_basic: Param>, } #[derive(Module, Debug)] #[allow(unused)] struct ModuleTensorConstInt { weight_basic: Tensor, } impl ModuleBasic { fn new(device: &B::Device) -> Self { Self { weight_basic: Initializer::Normal { std: 1.0, mean: 0.0, } .init([20, 20], device), } } } #[derive(Module, Debug)] struct ModuleWithConstGeneric { modules: [ModuleBasic; N], } #[derive(Module, Debug)] struct ModuleWithGenericModule { module: M, _backend: PhantomData, } #[derive(Module, Debug)] #[allow(clippy::large_enum_variant)] enum ModuleEnum { Basic(ModuleBasic), Composed(ModuleComposed), } #[derive(Module, Debug)] #[allow(unused)] enum ModuleEnumNested { AnotherEnum(ModuleEnum), } #[derive(Module, Debug)] enum ModuleEnumWithGenericModule> { Basic(ModuleBasic), Generic(ModuleWithGenericModule), } #[derive(Module, Debug)] pub struct ModuleComposed { weight: Param>, basic: ModuleBasic, tuple: (ModuleBasic, ModuleBasic), } impl ModuleComposed { fn new(device: &B::Device) -> Self { let weight = Initializer::Normal { std: 1.0, mean: 0.0, } .init([20, 20], device); Self { weight, basic: ModuleBasic::new(device), tuple: (ModuleBasic::new(device), ModuleBasic::new(device)), } } } #[derive(Debug, Clone)] pub enum PaddingConfig { Default, Other, } #[derive(Module, Debug)] pub struct ModuleWithAttributes, N> { /// A normal parameter. weight: Param>, /// A nested module. nested: ModuleEnumWithGenericModule, /// By default, primitives were not persistent (same as `#[module(skip)]`). other_prob: f64, /// By default, tensors were not persistent and not visited/mapped (same as `#[module(skip)]`). tensor: Tensor, /// A field that is recomputed at runtime. #[module(skip)] cached_mask: Option>, /// A field that contains some debug state. debug_state: String, /// Hint required: this generic is NOT a module. #[module(skip)] config: N, } impl ModuleWithAttributes, PaddingConfig> { fn new(device: &B::Device) -> Self { let basic = ModuleBasic::new(device); let weight = basic.weight_basic.clone(); Self { weight, nested: ModuleEnumWithGenericModule::Basic(basic), other_prob: 1., tensor: Tensor::ones([2], device), cached_mask: Some(Tensor::ones([2, 2], device)), debug_state: "Hello World".into(), config: PaddingConfig::Default, } } } #[allow(dead_code)] mod compiletime_clone_impl_check { use burn_core::{ module::{Module, ModuleDisplay}, prelude::Backend, record::{PrecisionSettings, Record}, }; use super::*; type RecordItem = <>::Record as Record>::Item; fn implements_clone() {} fn basic_implements_clone() { implements_clone::, B, S>>(); implements_clone::, B, S>>(); } fn generic_implements_clone() where B: Backend, S: PrecisionSettings, M: Module + ModuleDisplay, RecordItem: Clone, { implements_clone::, B, S>>(); implements_clone::, B, S>>(); } } mod state { use burn_core::module::EmptyRecord; use super::*; #[test] fn should_load_from_record_basic() { let device = ::Device::default(); let module_1 = ModuleBasic::::new(&device); let mut module_2 = ModuleBasic::::new(&device); let state_1 = module_1.clone().into_record(); assert_ne!( module_1.weight_basic.to_data(), module_2.weight_basic.to_data() ); module_2 = module_2.load_record(state_1); assert_eq!( module_1.weight_basic.to_data(), module_2.weight_basic.to_data() ); } #[test] fn should_load_from_record_compose() { let device = ::Device::default(); let module_1 = ModuleComposed::::new(&device); let mut module_2 = ModuleComposed::::new(&device); assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); assert_ne!( module_1.basic.weight_basic.to_data(), module_2.basic.weight_basic.to_data() ); let state_1 = module_1.clone().into_record(); module_2 = module_2.load_record(state_1); assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); assert_eq!( module_1.basic.weight_basic.to_data(), module_2.basic.weight_basic.to_data() ); } #[test] fn should_load_from_record_enum() { let device = ::Device::default(); let module_1 = ModuleEnum::Basic(ModuleBasic::::new(&device)); let mut module_2 = ModuleEnum::Basic(ModuleBasic::::new(&device)); let state_1 = module_1.clone().into_record(); let ModuleEnum::Basic(module_1_basic) = module_1 else { panic!("Invalid module type") }; let ModuleEnum::Basic(module_2_basic) = module_2.clone() else { panic!("Invalid module type") }; assert_ne!( module_1_basic.weight_basic.to_data(), module_2_basic.weight_basic.to_data() ); module_2 = module_2.load_record(state_1); let ModuleEnum::Basic(module_2_basic) = module_2 else { panic!("Invalid module type") }; assert_eq!( module_1_basic.weight_basic.to_data(), module_2_basic.weight_basic.to_data() ); } #[test] fn should_load_from_record_based_on_attributes() { let device = ::Device::default(); let mut module_1 = ModuleWithAttributes::::new(&device); let mut module_2 = ModuleWithAttributes::new(&device); assert_ne!(module_1.weight.to_data(), module_2.weight.to_data(),); let ModuleEnumWithGenericModule::Basic(ref m1_basic) = module_1.nested else { panic!("Invalid module type") }; let ModuleEnumWithGenericModule::Basic(ref m2_basic) = module_2.nested else { panic!("Invalid module type") }; assert_ne!( m1_basic.weight_basic.to_data(), m2_basic.weight_basic.to_data(), ); assert_eq!(module_1.tensor.to_data(), module_2.tensor.to_data()); assert_eq!( module_1.cached_mask.as_ref().unwrap().to_data(), module_2.cached_mask.as_ref().unwrap().to_data() ); assert_eq!(module_1.other_prob, module_2.other_prob); assert_eq!(module_1.debug_state, module_2.debug_state); // Alter state of skipped fields to validate persistence module_1.cached_mask = Some(module_1.cached_mask.unwrap() * 2); module_1.tensor = module_1.tensor * 2; module_1.other_prob = 0.; module_1.debug_state = "Hello World!".into(); module_1.config = PaddingConfig::Other; let state_1 = module_1.clone().into_record(); assert_eq!(state_1.cached_mask, EmptyRecord); assert_eq!(state_1.other_prob, EmptyRecord); assert_eq!(state_1.debug_state, EmptyRecord); assert_eq!(state_1.config, EmptyRecord); module_2 = module_2.load_record(state_1); let ModuleEnumWithGenericModule::Basic(m2_basic) = module_2.nested else { panic!("Invalid module type") }; // Modules & params assert_eq!(module_1.weight.to_data(), module_2.weight.to_data(),); assert_eq!( m1_basic.weight_basic.to_data(), m2_basic.weight_basic.to_data(), ); // `#[module(skip)]` field and other skip-by-default assert_ne!(module_1.other_prob, module_2.other_prob); assert_ne!(module_1.debug_state, module_2.debug_state); assert!(matches!(module_1.config, PaddingConfig::Other)); assert!(matches!(module_2.config, PaddingConfig::Default)); assert_ne!(module_1.tensor.to_data(), module_2.tensor.to_data()); assert_ne!( module_1.cached_mask.as_ref().unwrap().to_data(), module_2.cached_mask.as_ref().unwrap().to_data() ); } #[test] fn should_load_from_record_const_generic() { let device = ::Device::default(); let module_1 = ModuleWithConstGeneric { modules: [ ModuleBasic::::new(&device), ModuleBasic::::new(&device), ], }; let mut module_2 = ModuleWithConstGeneric { modules: [ ModuleBasic::::new(&device), ModuleBasic::::new(&device), ], }; let state_1 = module_1.clone().into_record(); assert_ne!( module_1.modules[0].weight_basic.to_data(), module_2.modules[0].weight_basic.to_data(), ); assert_ne!( module_1.modules[1].weight_basic.to_data(), module_2.modules[1].weight_basic.to_data(), ); module_2 = module_2.load_record(state_1); assert_eq!( module_1.modules[0].weight_basic.to_data(), module_2.modules[0].weight_basic.to_data(), ); assert_eq!( module_1.modules[1].weight_basic.to_data(), module_2.modules[1].weight_basic.to_data(), ); } #[test] #[should_panic(expected = "Can't parse record from a different variant")] fn should_panic_load_from_incorrect_enum_variant() { let device = ::Device::default(); let module_1 = ModuleEnum::Basic(ModuleBasic::::new(&device)); let module_2 = ModuleEnum::Composed(ModuleComposed::::new(&device)); let state_1 = module_1.clone().into_record(); module_2.load_record(state_1); } } mod num_params { use super::*; #[test] fn should_calculate_num_params_basic() { let device = ::Device::default(); let module = ModuleBasic::::new(&device); assert_eq!(20 * 20, module.num_params()); } #[test] fn should_output_state_composed() { let device = ::Device::default(); let module = ModuleComposed::::new(&device); assert_eq!(4 * 20 * 20, module.num_params()); } #[test] fn should_calculate_num_params_enum() { let device = ::Device::default(); let module = ModuleEnum::Basic(ModuleBasic::::new(&device)); assert_eq!(20 * 20, module.num_params()); let module = ModuleEnum::Composed(ModuleComposed::::new(&device)); assert_eq!(4 * 20 * 20, module.num_params()); } #[test] fn should_calculate_num_params_based_on_attributes() { let device = ::Device::default(); let module = ModuleWithAttributes::::new(&device); assert_eq!(20 * 20 * 2, module.num_params()); } } #[cfg(feature = "std")] mod require_grad { use burn_tensor::backend::AutodiffBackend; use super::*; #[test] fn should_have_grad_by_default() { let device = ::Device::default(); let module = ModuleBasic::::new(&device); let mut grads = calculate_grads(&module); let grad_x = module.weight_basic.grad_remove(&mut grads); assert!(grad_x.is_some()); } #[test] fn should_have_no_grad_after_no_grad() { let device = ::Device::default(); let module = ModuleBasic::::new(&device).no_grad(); let mut grads = calculate_grads(&module); let grad_x = module.weight_basic.grad_remove(&mut grads); assert!(grad_x.is_none()); } #[test] fn should_have_grad_when_from_record() { let device = ::Device::default(); let module = ModuleBasic::::new(&device); let record = ModuleBasicRecord { weight_basic: module.weight_basic.clone(), // Even when param is no_grad, }; let module = module.load_record(record); let mut grads = calculate_grads(&module); let grad_x = module.weight_basic.grad_remove(&mut grads); assert!(grad_x.is_some()); } fn calculate_grads( module: &ModuleBasic, ) -> ::Gradients { let device = module.weight_basic.device(); let x = Tensor::ones([20, 20], &device).require_grad(); let y = module.weight_basic.val().matmul(x); y.backward() } } ================================================ FILE: crates/burn-core/tests/test_derive_record.rs ================================================ use burn_core as burn; use burn_core::record::Record; use burn_tensor::Tensor; use burn_tensor::backend::Backend; // It compiles #[derive(Record)] pub struct TestWithBackendRecord { tensor: Tensor, } // It compiles #[derive(Record)] pub struct TestWithoutBackendRecord { _tensor: usize, } ================================================ FILE: crates/burn-core/tests/test_record_resilience.rs ================================================ #[cfg(feature = "std")] mod tests { use burn::{ module::{Module, Param}, record::{ BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, PrettyJsonFileRecorder, RecorderError, }, }; use burn_core as burn; use burn_ndarray::NdArrayDevice; use burn_tensor::{Tensor, backend::Backend}; use std::path::PathBuf; type TestBackend = burn_ndarray::NdArray; /// Simple linear module. #[derive(Module, Debug)] pub struct Linear { pub weight: Param>, pub bias: Option>>, } impl Linear { pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self { let weight = Tensor::random( [out_features, in_features], burn_tensor::Distribution::Default, device, ); let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device); Self { weight: Param::from_tensor(weight), bias: Some(Param::from_tensor(bias)), } } } #[derive(Module, Debug)] pub struct Model { single_const: f32, linear1: Linear, array_const: [usize; 2], linear2: Linear, array_lin: [Linear; 2], } #[derive(Module, Debug)] pub struct ModelNewOptionalField { single_const: f32, linear1: Linear, array_const: [usize; 2], linear2: Linear, array_lin: [Linear; 2], new_field: Option, } #[derive(Module, Debug)] pub struct ModelNewConstantField { single_const: f32, linear1: Linear, array_const: [usize; 2], linear2: Linear, array_lin: [Linear; 2], new_field: usize, } #[derive(Module, Debug)] #[allow(unused)] pub struct ModelNewFieldOrders { array_const: [usize; 2], linear2: Linear, single_const: f32, array_lin: [Linear; 2], linear1: Linear, } #[test] fn deserialize_with_new_optional_field_works_with_default_file_recorder() { deserialize_with_new_optional_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { deserialize_with_removed_optional_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_default_file_recorder() { deserialize_with_new_constant_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { deserialize_with_removed_constant_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_field_order_works_with_default_file_recorder() { deserialize_with_new_field_order( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_optional_field_works_with_pretty_json() { deserialize_with_new_optional_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_pretty_json() { deserialize_with_removed_optional_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_pretty_json() { deserialize_with_new_constant_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_pretty_json() { deserialize_with_removed_constant_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_field_order_works_with_pretty_json() { deserialize_with_new_field_order( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_optional_field_works_with_bin_file_recorder() { deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { deserialize_with_removed_optional_field( "bin", BinFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { deserialize_with_removed_constant_field( "bin", BinFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_field_order_works_with_bin_file_recorder() { deserialize_with_new_field_order("bin", BinFileRecorder::::new()) .unwrap(); } #[inline(always)] fn file_path(filename: String) -> PathBuf { std::env::temp_dir().join(filename) } #[test] fn test_tensor_serde() { let tensor: burn_tensor::Tensor = burn_tensor::Tensor::ones([1], &NdArrayDevice::default()); let encoded = serde_json::to_string(&tensor).unwrap(); let decoded: burn_tensor::Tensor = serde_json::from_str(&encoded).unwrap(); assert_eq!(tensor.into_data(), decoded.into_data()); } fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}")); let model = Model { single_const: 32.0, linear1: Linear::::new(20, 20, &device), array_const: [2, 2], linear2: Linear::::new(20, 20, &device), array_lin: [ Linear::::new(20, 20, &device), Linear::::new(20, 20, &device), ], }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_removed_optional_field( name: &str, recorder: R, ) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_removed_optional_field-{name}")); let model = ModelNewOptionalField { single_const: 32.0, linear1: Linear::::new(20, 20, &device), array_const: [2, 2], linear2: Linear::::new(20, 20, &device), array_lin: [ Linear::::new(20, 20, &device), Linear::::new(20, 20, &device), ], new_field: None, }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}")); let model = Model { single_const: 32.0, array_const: [2, 2], linear1: Linear::::new(20, 20, &device), linear2: Linear::::new(20, 20, &device), array_lin: [ Linear::::new(20, 20, &device), Linear::::new(20, 20, &device), ], }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_removed_constant_field( name: &str, recorder: R, ) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_removed_constant_field-{name}")); let model = ModelNewConstantField { single_const: 32.0, array_const: [2, 2], linear1: Linear::::new(20, 20, &device), linear2: Linear::::new(20, 20, &device), array_lin: [ Linear::::new(20, 20, &device), Linear::::new(20, 20, &device), ], new_field: 0, }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}")); let model = Model { array_const: [2, 2], single_const: 32.0, linear1: Linear::::new(20, 20, &device), linear2: Linear::::new(20, 20, &device), array_lin: [ Linear::::new(20, 20, &device), Linear::::new(20, 20, &device), ], }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } } ================================================ FILE: crates/burn-cpu/Cargo.toml ================================================ [package] authors = ["marcantoinem "] categories = ["science"] description = "MLIR based CPU backend for the Burn framework" documentation = "https://docs.rs/burn-cpu" edition.workspace = true keywords = ["deep-learning", "machine-learning", "cpu"] license.workspace = true name = "burn-cpu" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cpu" version.workspace = true [lints] workspace = true [features] default = ["std", "fusion", "autotune", "burn-cubecl/default", "cubecl/default"] doc = ["burn-cubecl/doc"] fusion = ["burn-fusion", "burn-cubecl/fusion"] std = ["burn-cubecl/std", "cubecl/std"] tracing = [ "burn-backend/tracing", "burn-cubecl/tracing", "burn-fusion?/tracing", "cubecl/tracing", ] autotune = ["burn-cubecl/autotune"] autotune-checks = ["burn-cubecl/autotune-checks"] [dependencies] burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", features = [ "cubecl-cpu", ] } cubecl = { workspace = true, features = ["cpu"] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-cpu/README.md ================================================ # Burn CPU Backend [Burn](https://github.com/tracel-ai/burn) CubeCL CPU backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-cuda.svg)](https://crates.io/crates/burn-cuda) This crate provides a MLIR based CPU backend for [Burn](https://github.com/tracel-ai/burn) using the [cubecl](https://github.com/tracel-ai/cubecl.git) crates. ## Usage Example Example coming soon ================================================ FILE: crates/burn-cpu/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] extern crate alloc; use burn_cubecl::CubeBackend; pub use cubecl::cpu::CpuDevice; use cubecl::cpu::CpuRuntime; #[cfg(not(feature = "fusion"))] pub type Cpu = CubeBackend; #[cfg(feature = "fusion")] pub type Cpu = burn_fusion::Fusion>; #[cfg(test)] mod tests { use super::*; use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive}; use burn_cubecl::tensor::CubeTensor; #[test] fn should_support_dtypes() { type B = Cpu; let device = Default::default(); assert!(B::supports_dtype(&device, DType::F64)); assert!(B::supports_dtype(&device, DType::F32)); assert!(B::supports_dtype(&device, DType::F16)); assert!(B::supports_dtype(&device, DType::BF16)); assert!(B::supports_dtype(&device, DType::I64)); assert!(B::supports_dtype(&device, DType::I32)); assert!(B::supports_dtype(&device, DType::I16)); assert!(B::supports_dtype(&device, DType::I8)); assert!(B::supports_dtype(&device, DType::U64)); assert!(B::supports_dtype(&device, DType::U32)); assert!(B::supports_dtype(&device, DType::U16)); assert!(B::supports_dtype(&device, DType::U8)); assert!(B::supports_dtype( &device, DType::QFloat(CubeTensor::::default_scheme()) )); // Currently not registered in supported types assert!(!B::supports_dtype(&device, DType::Flex32)); assert!(!B::supports_dtype(&device, DType::Bool(BoolStore::Native))); } } ================================================ FILE: crates/burn-cubecl/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Generic backend that can be compiled just-in-time to any shader language target" documentation = "https://docs.rs/burn-cubecl" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu"] license.workspace = true name = "burn-cubecl" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl" version.workspace = true [lints] workspace = true [features] default = [ "autotune", "std", "fusion", "cubecl/default", "burn-fusion?/default", "burn-cubecl-fusion?/default", ] std = [ "cubecl/std", "burn-backend/std", "burn-fusion?/std", "burn-cubecl-fusion?/std", ] doc = ["default"] memory-checks = ["burn-fusion?/memory-checks"] tracing = [ "dep:tracing", "cubecl/tracing", "burn-std/tracing", "burn-backend/tracing", "burn-fusion?/tracing", "burn-cubecl-fusion?/tracing", ] autotune = ["burn-cubecl-fusion?/autotune"] autotune-checks = [ "autotune", "cubecl/autotune-checks", "burn-cubecl-fusion?/autotune-checks", ] fusion = ["burn-fusion", "burn-cubecl-fusion"] fusion-experimental = ["fusion"] template = [] [dependencies] burn-cubecl-fusion = { path = "../burn-cubecl-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false, features = [ "cubecl", ] } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false, features = [ "cubecl", ] } cubecl = { workspace = true, features = ["stdlib"] } cubek = { workspace = true, features = [ "attention", "matmul", "convolution", "reduce", "random", "quantization", ] } tracing = { workspace = true, features = ["attributes"], optional = true } derive-new = { workspace = true } log = { workspace = true } # Async futures-lite = { workspace = true, features = ["std"] } # Template serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-cubecl/README.md ================================================ # Burn CubeCL Backend Generic backend that can be compiled just-in-time (JIT) to any shader language target. ================================================ FILE: crates/burn-cubecl/src/backend.rs ================================================ use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor}; use burn_backend::{Backend, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData}; use burn_std::DType; use cubecl::{ features::{MmaConfig, TypeUsage}, server::ComputeServer, }; use std::marker::PhantomData; #[cfg(not(feature = "fusion"))] use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; #[cfg(not(feature = "fusion"))] use burn_ir::{BackendIr, TensorHandle}; /// Generic tensor backend that can be compiled just-in-time to any shader runtime #[derive(new)] pub struct CubeBackend { _runtime: PhantomData, _float_elem: PhantomData, _int_elem: PhantomData, _bool_elem: PhantomData, } impl Backend for CubeBackend where R: CubeRuntime, R::Server: ComputeServer, R::Device: DeviceOps, F: FloatElement, I: IntElement, BT: BoolElement, { type Device = R::Device; type FloatElem = F; type IntElem = I; type BoolElem = BT; type FloatTensorPrimitive = CubeTensor; type IntTensorPrimitive = CubeTensor; type BoolTensorPrimitive = CubeTensor; type QuantizedTensorPrimitive = CubeTensor; fn name(device: &Self::Device) -> String { let client = R::client(device); format!("cubecl<{}>", R::name(&client)) } fn seed(_device: &Self::Device, seed: u64) { cubek::random::seed(seed); } fn ad_enabled(_device: &Self::Device) -> bool { false } fn sync(device: &Self::Device) -> Result<(), ExecutionError> { let client = R::client(device); futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext { reason: format!("{err}"), }) } fn memory_persistent_allocations< Output: Send, Input: Send, Func: Fn(Input) -> Output + Send, >( device: &Self::Device, input: Input, func: Func, ) -> Output { let client = R::client(device); client.memory_persistent_allocation(input, func).unwrap() } fn memory_cleanup(device: &Self::Device) { let client = R::client(device); client.memory_cleanup(); } fn staging<'a, Iter>(data: Iter, device: &Self::Device) where Iter: Iterator, { let client = R::client(device); client.staging(data.map(|td| &mut td.bytes), false); } fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { let client = R::client(device); let type_usage = client.properties().type_usage(dtype.into()); // Same as `TypeUsage::all_scalar()`, but we make the usage explicit here type_usage.is_superset( TypeUsage::Buffer | TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct, ) } fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet { let client = R::client(device); let props = client.properties(); let storage = dtype.into(); let usage = props.type_usage(storage); let mut out = DTypeUsageSet::new(); if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) { out |= DTypeUsage::Storage; } if usage.contains(TypeUsage::Arithmetic) { out |= DTypeUsage::Arithmetic; } let has_mma = |cfg: &MmaConfig| { cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage }; if props.features.cmma.iter().any(has_mma) || props.features.mma.iter().any(has_mma) { out |= DTypeUsage::Accelerated; } out } } impl core::fmt::Debug for CubeBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("CubeCLBackend") } } impl Clone for CubeBackend { fn clone(&self) -> Self { Self::new() } } impl Default for CubeBackend { fn default() -> Self { Self::new() } } impl CubeRuntime for R where R::Device: DeviceOps, { type CubeDevice = R::Device; type CubeServer = R::Server; } #[cfg(not(feature = "fusion"))] impl BackendIr for CubeBackend { type Handle = CubeTensor; fn float_tensor(handle: TensorHandle) -> FloatTensor { handle.handle } fn int_tensor(handle: TensorHandle) -> IntTensor { handle.handle } fn bool_tensor(handle: TensorHandle) -> BoolTensor { handle.handle } fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { tensor } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { tensor } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { tensor } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { tensor } } ================================================ FILE: crates/burn-cubecl/src/element.rs ================================================ use burn_backend::{Element, bf16, f16}; use burn_std::DType; use cubecl::{ CubeElement as CubeElem, flex32, prelude::{Float, Int, Numeric}, }; use cubek::{ matmul::definition::{MatmulPrecision, MatrixPrecision}, reduce::ReducePrecision, }; /// The base element trait for the jit backend. pub trait CubeElement: Element + CubeElem + PartialEq + Numeric {} /// Element that can be used for matrix multiplication. Includes ints and floats. pub trait MatmulElement: CubeElement + MatmulPrecision> { } /// The float element type for the jit backend. pub trait FloatElement: MatmulElement + Float {} /// The int element type for the jit backend. pub trait IntElement: MatmulElement + Int + ReducePrecision { } /// The element type for booleans for the jit backend. pub trait BoolElement: CubeElement + Int { /// The true value for the boolean element. fn true_val() -> Self { Self::from_int(1) } /// The false value for the boolean element. fn false_val() -> Self { Self::from_int(0) } /// New bool element from Rust bool. fn new_bool(val: bool) -> Self { match val { true => Self::true_val(), false => Self::false_val(), } } } impl CubeElement for u64 {} impl CubeElement for u32 {} impl CubeElement for u16 {} impl CubeElement for u8 {} impl CubeElement for i64 {} impl CubeElement for i32 {} impl CubeElement for i16 {} impl CubeElement for i8 {} impl CubeElement for f64 {} impl CubeElement for f32 {} impl CubeElement for flex32 {} impl CubeElement for f16 {} impl CubeElement for bf16 {} impl FloatElement for f64 {} impl FloatElement for f32 {} impl FloatElement for flex32 {} impl FloatElement for bf16 {} impl FloatElement for f16 {} impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} impl IntElement for u64 {} impl IntElement for u32 {} impl IntElement for u16 {} impl IntElement for u8 {} impl BoolElement for u8 {} impl BoolElement for u32 {} impl MatmulElement for f64 {} impl MatmulElement for f32 {} impl MatmulElement for flex32 {} impl MatmulElement for bf16 {} impl MatmulElement for f16 {} impl MatmulElement for i64 {} impl MatmulElement for i32 {} impl MatmulElement for i16 {} impl MatmulElement for i8 {} impl MatmulElement for u64 {} impl MatmulElement for u32 {} impl MatmulElement for u16 {} impl MatmulElement for u8 {} // TODO: remove once backends no longer rely on generics for default elem types /// Returns the bool element dtype. pub(crate) fn bool_dtype() -> DType { match BT::dtype() { DType::U32 => DType::Bool(burn_backend::BoolStore::U32), DType::U8 => DType::Bool(burn_backend::BoolStore::U8), other => unimplemented!("Invalid bool dtye {other:?}"), } } ================================================ FILE: crates/burn-cubecl/src/fusion.rs ================================================ use crate::BoolElement; use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor}; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; use burn_backend::{DType, Shape}; use burn_cubecl_fusion::optim::reduce::ReduceSettings; use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser; use burn_cubecl_fusion::{ CubeFusionHandle, FallbackOperation, optim::{ CubeOptimization, CubeOptimizationState, elemwise::{ElementWiseFuser, ElemwiseOptimization}, matmul::{MatmulFuser, MatmulOptimization}, reduce::{ReduceFuser, ReduceOptimization}, reduce_broadcasted::ReduceBroadcastedOptimization, }, }; use burn_fusion::{ FusionBackend, FusionRuntime, stream::{Operation, OrderedExecution}, }; use burn_ir::{BackendIr, TensorHandle}; use burn_std::Metadata; use core::marker::PhantomData; use std::sync::Arc; impl burn_fusion::Optimization> for CubeOptimization where R: CubeRuntime, { fn execute( &mut self, context: &mut burn_fusion::stream::Context< '_, as FusionRuntime>::FusionHandle, >, execution: &OrderedExecution>, ) { match self { Self::ElementWise(op) => op.execute(context), Self::Matmul(op) => op.execute(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), Self::Reduce(op) => op.execute(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), Self::ReduceBroadcasted(op) => op.execute(context, |index| { let operation = execution.operation_within_optimization(index); Box::new(FallbackOperationWrapper::new(operation)) }), } } fn to_state(&self) -> CubeOptimizationState { self.to_opt_state() } fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self { match state { CubeOptimizationState::ElementWise(state) => { Self::ElementWise(ElemwiseOptimization::from_state(device, state)) } CubeOptimizationState::Matmul(state) => { Self::Matmul(MatmulOptimization::from_state(device, state)) } CubeOptimizationState::Reduce(state) => { Self::Reduce(ReduceOptimization::from_state(device, state)) } CubeOptimizationState::ReduceBroadcasted(state) => { Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state)) } } } } struct FallbackOperationWrapper { operation: O, } impl FallbackOperationWrapper { fn new(op: O) -> Self { Self { operation: op } } } impl FallbackOperation for FallbackOperationWrapper>>> { fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle>) { self.operation.as_ref().execute(context.handles); } } impl BackendIr for CubeBackend { type Handle = CubeFusionHandle; fn float_tensor(handle: TensorHandle) -> FloatTensor { into_tensor(handle.handle, handle.shape) } fn int_tensor(handle: TensorHandle) -> IntTensor { into_tensor(handle.handle, handle.shape) } fn bool_tensor(handle: TensorHandle) -> BoolTensor { into_tensor(handle.handle, handle.shape) } fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { into_tensor(handle.handle, handle.shape) } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { tensor.into() } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { tensor.into() } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { tensor.into() } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { tensor.into() } } impl FusionRuntime for FusionCubeRuntime { type OptimizationState = CubeOptimizationState; type Optimization = CubeOptimization; type FusionHandle = CubeFusionHandle; type FusionDevice = R::CubeDevice; fn fusers(device: R::Device) -> Vec>> { vec![ Box::new(ElementWiseFuser::new(device.clone())), Box::new(MatmulFuser::new(device.clone())), Box::new(ReduceFuser::new(device.clone(), ReduceSettings::Always)), Box::new(ReduceBroadcastedFuser::new(device.clone())), ] } } /// Fusion runtime for JIT runtimes. #[derive(Debug)] pub struct FusionCubeRuntime { _b: PhantomData, } impl FusionBackend for CubeBackend { type FusionRuntime = FusionCubeRuntime; type FullPrecisionBackend = CubeBackend; fn cast_float(tensor: FloatTensor, dtype: DType) -> Self::Handle { kernel::cast(tensor, dtype).into() } } fn into_tensor(handle: CubeFusionHandle, shape: Shape) -> CubeTensor { CubeTensor { client: handle.client.clone(), handle: handle.handle.clone(), device: handle.device.clone(), meta: Box::new(Metadata::new(shape, handle.strides.clone())), dtype: handle.dtype, qparams: handle.qparams.clone(), } } impl From> for CubeFusionHandle { fn from(value: CubeTensor) -> Self { Self { client: value.client.clone(), handle: value.handle.clone(), device: value.device.clone(), strides: value.meta.strides.clone(), dtype: value.dtype, qparams: value.qparams.clone(), } } } ================================================ FILE: crates/burn-cubecl/src/kernel/attention/base.rs ================================================ use crate::{ CubeBackend, CubeRuntime, kernel::attention::attention_autotune, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{ DType, Shape, ops::{AttentionModuleOptions, attention::attention_fallback}, }; use cubek::attention::launch; use cubek::attention::{ definition::{ AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError, }, routines::blackbox_accelerated::BlackboxAcceleratedStrategy, }; #[derive(Debug)] /// Strategy used to select which attention implementation to run. pub enum AttentionStrategy { /// Flash Attention using accelerated inner matmuls. FlashBlackboxAccelerated(BlackboxAcceleratedStrategy), /// Flash Attention using unit inner matmuls. FlashUnit, /// Fallback implementation using multiple separate kernels. Fallback, /// Automatically benchmark and select the best strategy at runtime. #[cfg(feature = "autotune")] Autotune, } impl Default for AttentionStrategy { fn default() -> Self { // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] return AttentionStrategy::Autotune; // if autotune is disabled, default to fallback to make sure it runs #[cfg(not(feature = "autotune"))] AttentionStrategy::Fallback } } #[allow(clippy::too_many_arguments)] /// Launch an attention kernel with given strategy pub fn attention( query: CubeTensor, key: CubeTensor, value: CubeTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, strategy: AttentionStrategy, out: Option>, ) -> Result, AttentionSetupError> { let mut out = out.unwrap_or_else(|| init_attention_output(&query, &value)); match strategy { AttentionStrategy::FlashBlackboxAccelerated(strategy) => flash_attention( query, key, value, mask, attn_bias, options, out, launch::Strategy::BlackboxAccelerated( cubek::attention::launch::BlueprintStrategy::Inferred(strategy), ), ), AttentionStrategy::FlashUnit => flash_attention( query, key, value, mask, attn_bias, options, out, launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())), ), AttentionStrategy::Fallback => { out = attention_fallback::>( query, key, value, mask, attn_bias, options, ); Ok(out) } #[cfg(feature = "autotune")] AttentionStrategy::Autotune => { attention_autotune(query, key, value, mask, attn_bias, options, out) } } } #[allow(clippy::too_many_arguments)] /// Launch a flash attention kernel pub fn flash_attention( query: CubeTensor, key: CubeTensor, value: CubeTensor, mask: Option>, _attn_bias: Option>, options: AttentionModuleOptions, out: CubeTensor, strategy: launch::Strategy, ) -> Result, AttentionSetupError> { let client = query.client.clone(); let dtypes = AttentionGlobalTypes { query: query.dtype.into(), key: key.dtype.into(), value: value.dtype.into(), mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(), out: out.dtype.into(), }; cubek::attention::launch::launch_ref::( strategy, &client, query.binding(), key.binding(), value.binding(), mask.map(|mask| mask.binding()), out.clone().binding(), &dtypes, AttentionOptions { causal: options.is_causal, accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar( cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32), )), }, )?; Ok(out) } pub(crate) fn init_attention_output( query: &CubeTensor, value: &CubeTensor, ) -> CubeTensor { let num_batches = query.meta.shape[0]; let num_heads = query.meta.shape[1]; let seq_q = query.meta.shape[2]; let val_dim = value.meta.shape[3]; let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]); empty_device_dtype::( query.client.clone(), query.device.clone(), out_shape, query.dtype, ) } ================================================ FILE: crates/burn-cubecl/src/kernel/attention/mod.rs ================================================ mod base; mod tune; pub use base::*; pub use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/attention/tune.rs ================================================ use crate::{ CubeRuntime, CubeTuneId, kernel::attention::{AttentionStrategy, attention}, tensor::CubeTensor, }; use burn_backend::ops::AttentionModuleOptions; use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}; use cubek::attention::{ definition::AttentionSetupError, launch::AttentionAutotuneKey, routines::blackbox_accelerated::BlackboxAcceleratedStrategy, }; /// Executes autotune on attention operations pub fn attention_autotune( query: CubeTensor, key: CubeTensor, value: CubeTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, out: CubeTensor, ) -> Result, AttentionSetupError> { let client = query.client.clone(); static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 3; const PRIORITY_MIN: i8 = 0; let flash_attention = TuneGroup::::new("flash_attention", |_key| PRIORITY_MAX); let fallback = TuneGroup::::new("fallback", |key| { if key.seq_q > 4096 { PRIORITY_MIN } else { PRIORITY_MAX } }); let mut set = TunableSet::new(create_key::, input_gen::); // First entry should always work, since it is considered the fallback. set = set.with( Tunable::new( "fallback", |query, key, value, mask, attn_bias, out, options| { attention::( query, key, value, mask, attn_bias, options, AttentionStrategy::Fallback, Some(out), ) .map_err(|err| std::format!("{err:?}")) }, ) .group(&fallback, |_key| PRIORITY_MAX), ); let seq_q = 1; let seq_kv = 1; for num_planes in [2, 4, 8] { let name = format!("blackbox_accelerated_{num_planes}_planes_p_{seq_q}-{seq_kv}"); set = set.with( Tunable::new( &name, move |query, key, value, mask, attn_bias, out, options| { attention::( query, key, value, mask, attn_bias, options, AttentionStrategy::FlashBlackboxAccelerated( BlackboxAcceleratedStrategy { num_planes, seq_q, seq_kv, }, ), Some(out), ) .map_err(|err| std::format!("{err:?}")) }, ) .group(&flash_attention, |_key| PRIORITY_MAX), ); } set = set.with( Tunable::new( "unit", |query, key, value, mask, attn_bias, out, options| { attention::( query, key, value, mask, attn_bias, options, AttentionStrategy::FlashUnit, Some(out), ) .map_err(|err| std::format!("{err:?}")) }, ) .group(&flash_attention, |_key| PRIORITY_MIN), ); set }); TUNER.execute( &CubeTuneId::new(&client, &query.device), &client, tunables, (query, key, value, mask, attn_bias, out.clone(), options), ); Ok(out) } fn create_key( query: &CubeTensor, key: &CubeTensor, value: &CubeTensor, mask: &Option>, _attn_bias: &Option>, out: &CubeTensor, _options: &AttentionModuleOptions, ) -> AttentionAutotuneKey { let total_batches = query.meta.shape[0] * query.meta.shape[1]; let seq_q = query.meta.shape[2]; let head_dim = query.meta.shape[3]; let seq_kv = value.meta.shape[2]; let val_dim = value.meta.shape[3]; AttentionAutotuneKey::generate( query.dtype.into(), key.dtype.into(), value.dtype.into(), out.dtype.into(), total_batches, seq_q, head_dim, seq_kv, val_dim, mask.is_some(), ) } #[allow(clippy::type_complexity)] #[allow(clippy::too_many_arguments)] fn input_gen( _key: &AttentionAutotuneKey, query: &CubeTensor, key: &CubeTensor, value: &CubeTensor, mask: &Option>, attn_bias: &Option>, out: &CubeTensor, options: &AttentionModuleOptions, ) -> ( CubeTensor, CubeTensor, CubeTensor, Option>, Option>, CubeTensor, AttentionModuleOptions, ) { ( query.clone(), key.clone(), value.clone(), mask.clone(), attn_bias.clone(), out.copy(), *options, ) } ================================================ FILE: crates/burn-cubecl/src/kernel/binary.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::{TensorMetadata, bf16, f16}; use cubecl::{ calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView, }; pub(crate) trait BinaryOpFamily: Send + Sync + 'static { type BinaryOp: BinaryOp; } #[cube] pub(crate) trait BinaryOp: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Vector, rhs: Vector) -> Vector; } pub(crate) struct AddOp; pub(crate) struct SubOp; pub(crate) struct MulOp; pub(crate) struct DivOp; pub(crate) struct RemainderOp; pub(crate) struct AndOp; pub(crate) struct OrOp; pub(crate) struct PowOp; impl BinaryOpFamily for AddOp { type BinaryOp = Self; } impl BinaryOpFamily for SubOp { type BinaryOp = Self; } impl BinaryOpFamily for MulOp { type BinaryOp = Self; } impl BinaryOpFamily for DivOp { type BinaryOp = Self; } impl BinaryOpFamily for RemainderOp { type BinaryOp = Self; } impl BinaryOpFamily for PowOp { type BinaryOp = Self; } impl BinaryOpFamily for AndOp { type BinaryOp = Self; } impl BinaryOpFamily for OrOp { type BinaryOp = Self; } #[cube] impl BinaryOp for AddOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs + rhs } } #[cube] impl BinaryOp for SubOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs - rhs } } #[cube] impl BinaryOp for MulOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs * rhs } } #[cube] impl BinaryOp for DivOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs / rhs } } #[cube] impl BinaryOp for RemainderOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { Vector::rem(lhs, rhs) } } #[cube] impl BinaryOp for PowOp { #[allow(unused)] fn execute(lhs: Vector, rhs: Vector) -> Vector { intrinsic!(|scope| { let elem = T::as_type(scope).elem_type(); if let cubecl::ir::ElemType::Float(kind) = elem { match kind { cubecl::ir::FloatKind::F16 => { let lhs = as Cast>::__expand_cast_from(scope, lhs); let rhs = as Cast>::__expand_cast_from(scope, rhs); let out = Vector::__expand_powf(scope, lhs, rhs); return as Cast>::__expand_cast_from(scope, out); } cubecl::ir::FloatKind::BF16 => { let lhs = as Cast>::__expand_cast_from(scope, lhs); let rhs = as Cast>::__expand_cast_from(scope, rhs); let out = Vector::__expand_powf(scope, lhs, rhs); return as Cast>::__expand_cast_from(scope, out); } cubecl::ir::FloatKind::F64 => { let lhs = as Cast>::__expand_cast_from(scope, lhs); let rhs = as Cast>::__expand_cast_from(scope, rhs); let out = Vector::__expand_powf(scope, lhs, rhs); return as Cast>::__expand_cast_from(scope, out); } _ => {} } }; let lhs = as Cast>::__expand_cast_from(scope, lhs); let rhs = as Cast>::__expand_cast_from(scope, rhs); let out = Vector::__expand_powf(scope, lhs, rhs); return as Cast>::__expand_cast_from(scope, out); }) } } #[cube] impl BinaryOp for AndOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { Vector::cast_from(Vector::::cast_from(lhs).and(Vector::::cast_from(rhs))) } } #[cube] impl BinaryOp for OrOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { Vector::cast_from(Vector::::cast_from(lhs).or(Vector::::cast_from(rhs))) } } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_scalar_binop( input: &LinearView>, scalar: InputScalar, output: &mut LinearView, ReadWrite>, #[define(C)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Vector::new(scalar.get::())); } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_binop( lhs: &LinearView>, rhs: &LinearView>, out: &mut LinearView, ReadWrite>, #[define(C)] _dtype: StorageType, ) { if !out.is_in_bounds(ABSOLUTE_POS) { terminate!(); } out[ABSOLUTE_POS] = O::BinaryOp::::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]); } pub(crate) fn launch_binop( lhs: CubeTensor, rhs: CubeTensor, ) -> CubeTensor { let vector_size_lhs = max_vector_size(&lhs); let vector_size_rhs = max_vector_size(&rhs); let vector_size = Ord::min(vector_size_lhs, vector_size_rhs); let shape_out = broadcast_shape(&[&lhs, &rhs]); let dtype = lhs.dtype; let client = lhs.client.clone(); let num_elems = shape_out.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&lhs.client, working_units); let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim); unsafe { if lhs.can_mut_broadcast(&rhs) { kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.clone().into_linear_view(), rhs.into_linear_view_like(&lhs), lhs.as_linear_view_alias(0), dtype.into(), ); lhs } else if rhs.can_mut_broadcast(&lhs) { kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.into_linear_view_like(&rhs), rhs.clone().into_linear_view(), rhs.as_linear_view_alias(1), dtype.into(), ); rhs } else { let output = empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype); kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs, output), vector_size, lhs.into_linear_view_like(&output), rhs.into_linear_view_like(&output), output.clone().into_linear_view(), dtype.into(), ); output } } } pub(crate) fn launch_scalar_binop( tensor: CubeTensor, scalar: InputScalar, ) -> CubeTensor { // Vectorization is only enabled when the last dimension is contiguous. let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let dtype = tensor.dtype; let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); unsafe { if tensor.can_mut() && tensor.is_nonoverlapping() { kernel_scalar_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), scalar, tensor.as_linear_view_alias(0), dtype.into(), ); tensor } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), dtype, ); kernel_scalar_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), scalar, output.clone().into_linear_view(), dtype.into(), ); output } } } ================================================ FILE: crates/burn-cubecl/src/kernel/binary_float.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait BinaryOpFloatFamily: Send + Sync + 'static { type BinaryOp: BinaryOpFloat; } #[cube] pub(crate) trait BinaryOpFloat: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Vector, rhs: Vector) -> Vector; } pub(crate) struct ArcTan2Op; impl BinaryOpFloatFamily for ArcTan2Op { type BinaryOp = Self; } #[cube] impl BinaryOpFloat for ArcTan2Op { fn execute(lhs: Vector, rhs: Vector) -> Vector { Vector::atan2(lhs, rhs) } } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_binop( lhs: &LinearView>, rhs: &LinearView>, out: &mut LinearView, ReadWrite>, #[define(C)] _dtype: StorageType, ) { if !out.is_in_bounds(ABSOLUTE_POS) { terminate!(); } out[ABSOLUTE_POS] = O::BinaryOp::::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]); } pub(crate) fn launch_binop_float( lhs: CubeTensor, rhs: CubeTensor, ) -> CubeTensor { let vector_size_lhs = max_vector_size(&lhs); let vector_size_rhs = max_vector_size(&rhs); let vector_size = Ord::min(vector_size_lhs, vector_size_rhs); let shape_out = broadcast_shape(&[&lhs, &rhs]); let dtype = lhs.dtype; let client = lhs.client.clone(); let num_elems = shape_out.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&lhs.client, working_units); let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim); unsafe { if lhs.can_mut_broadcast(&rhs) { kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.clone().into_linear_view(), rhs.clone().into_linear_view_like(&lhs), lhs.as_linear_view_alias(0), dtype.into(), ); lhs } else if rhs.can_mut_broadcast(&lhs) { kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.into_linear_view_like(&rhs), rhs.clone().into_linear_view(), rhs.as_linear_view_alias(1), dtype.into(), ); rhs } else { let output = empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype); kernel_binop::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs, output), vector_size, lhs.into_linear_view_like(&output), rhs.into_linear_view_like(&output), output.clone().into_linear_view(), dtype.into(), ); output } } } /// Calculate the four-quadrant inverse tangent of `lhs / rhs`. pub fn atan2(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop_float::(lhs, rhs) } ================================================ FILE: crates/burn-cubecl/src/kernel/binary_int.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { type BinaryOp: BinaryOpInt; } #[cube] pub(crate) trait BinaryOpInt: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Vector, rhs: Vector) -> Vector; } pub(crate) struct BitwiseAndOp; pub(crate) struct BitwiseOrOp; pub(crate) struct BitwiseXorOp; pub(crate) struct BitwiseShrOp; pub(crate) struct BitwiseShlOp; impl BinaryOpIntFamily for BitwiseAndOp { type BinaryOp = Self; } impl BinaryOpIntFamily for BitwiseOrOp { type BinaryOp = Self; } impl BinaryOpIntFamily for BitwiseXorOp { type BinaryOp = Self; } impl BinaryOpIntFamily for BitwiseShrOp { type BinaryOp = Self; } impl BinaryOpIntFamily for BitwiseShlOp { type BinaryOp = Self; } #[cube] impl BinaryOpInt for BitwiseAndOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs & rhs } } #[cube] impl BinaryOpInt for BitwiseOrOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs | rhs } } #[cube] impl BinaryOpInt for BitwiseXorOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs ^ rhs } } #[cube] impl BinaryOpInt for BitwiseShrOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs >> rhs } } #[cube] impl BinaryOpInt for BitwiseShlOp { fn execute(lhs: Vector, rhs: Vector) -> Vector { lhs << rhs } } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_scalar_binop_int( input: &LinearView>, scalar: InputScalar, output: &mut LinearView, ReadWrite>, #[define(C)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Vector::new(scalar.get::())); } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_binop_int( lhs: &LinearView>, rhs: &LinearView>, out: &mut LinearView, ReadWrite>, #[define(C)] _dtype: StorageType, ) { if !out.is_in_bounds(ABSOLUTE_POS) { terminate!(); } out[ABSOLUTE_POS] = O::BinaryOp::::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]); } pub(crate) fn launch_binop_int( lhs: CubeTensor, rhs: CubeTensor, ) -> CubeTensor { let vector_size_lhs = max_vector_size(&lhs); let vector_size_rhs = max_vector_size(&rhs); let vector_size = Ord::min(vector_size_lhs, vector_size_rhs); let shape_out = broadcast_shape(&[&lhs, &rhs]); let client = lhs.client.clone(); let num_elems = shape_out.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&lhs.client, working_units); let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim); let dtype = lhs.dtype; unsafe { if lhs.can_mut_broadcast(&rhs) { kernel_binop_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.clone().into_linear_view(), rhs.into_linear_view_like(&lhs), lhs.as_linear_view_alias(0), dtype.into(), ); lhs } else if rhs.can_mut_broadcast(&lhs) { kernel_binop_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.into_linear_view_like(&rhs), rhs.clone().into_linear_view(), rhs.as_linear_view_alias(1), dtype.into(), ); rhs } else { let output = empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, lhs.dtype); kernel_binop_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs, output), vector_size, lhs.into_linear_view_like(&output), rhs.into_linear_view_like(&output), output.clone().into_linear_view(), dtype.into(), ); output } } } pub(crate) fn launch_scalar_binop_int( tensor: CubeTensor, scalar: InputScalar, ) -> CubeTensor { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.shape.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); unsafe { if tensor.can_mut() && tensor.is_nonoverlapping() { kernel_scalar_binop_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), scalar, tensor.as_linear_view_alias(0), tensor.dtype.into(), ); tensor } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), tensor.dtype, ); kernel_scalar_binop_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), scalar, output.clone().into_linear_view(), output.dtype.into(), ); output } } } ================================================ FILE: crates/burn-cubecl/src/kernel/cast/base.rs ================================================ use crate::{ CubeRuntime, kernel::utils::address_type, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::{DType, TensorMetadata}; use cubecl::std::tensor::layout::linear::LinearView; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch, address_type = "dynamic")] pub(crate) fn cast_element( input: &LinearView>, output: &mut LinearView, ReadWrite>, #[define(I, O)] _dtypes: [StorageType; 2], ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]); } /// Cast a tensor to the given element type. /// /// Note: When input element is semantically a boolean, prefer bool_cast function. pub fn cast(input: CubeTensor, dtype: DType) -> CubeTensor { let dtype_output = match dtype { DType::Flex32 => DType::F32, _ => dtype, }; let dtype_input = match input.dtype { DType::Flex32 => DType::F32, _ => input.dtype, }; if dtype_input == dtype_output { return input; } let client = input.client.clone(); let vector_size = max_vector_size(&input); let num_elems: usize = input.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&client, working_units); let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim); let output = empty_device_dtype( client.clone(), input.device.clone(), input.shape(), dtype, // We take the same dtype as passed as input (Flex32 not F32) ); cast_element::launch( &client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_linear_view(), output.clone().into_linear_view(), [dtype_input.into(), dtype_output.into()], ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/cast/bool_cast.rs ================================================ use crate::{ CubeElement, CubeRuntime, kernel::utils::address_type, ops::{max_vector_size, numeric::empty_device}, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::{ CubeDim, calculate_cube_count_elemwise, num_traits::One, prelude::*, std::tensor::layout::linear::LinearView, }; #[cube(launch_unchecked, address_type = "dynamic")] fn bool_cast_kernel( input: &LinearView>, output: &mut LinearView, ReadWrite>, #[define(B)] _input_ty: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS] & Vector::one()); } /// Cast a bool tensor to the given element type. /// /// This alternative to cast is necessary because bool are represented as u32 or u8 /// where any non-zero value means true. Depending how it was created /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. pub fn bool_cast(tensor: CubeTensor) -> CubeTensor { let output = empty_device::(tensor.client.clone(), tensor.device.clone(), tensor.shape()); let vector_size = max_vector_size(&tensor); let num_elems = tensor.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { bool_cast_kernel::launch_unchecked::( &output.client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), output.clone().into_linear_view(), dtype.into(), ) }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/cast/mod.rs ================================================ mod base; mod bool_cast; pub use base::*; pub use bool_cast::*; ================================================ FILE: crates/burn-cubecl/src/kernel/clamp.rs ================================================ use cubecl::prelude::*; use crate::{ CubeRuntime, kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric}, tensor::CubeTensor, }; #[derive(CubeLaunch, CubeType)] struct Options { min_value: InputScalar, max_value: InputScalar, } pub(crate) fn clamp( input: CubeTensor, min_value: InputScalar, max_value: InputScalar, ) -> CubeTensor { struct ClampOp; #[cube] impl NumericUnaryOp for ClampOp { type Options = Options; fn execute(input: Vector, options: &Self::Options) -> Vector { cubecl::prelude::clamp( input, Vector::new(options.min_value.get::()), Vector::new(options.max_value.get::()), ) } } impl NumericUnaryOpFamily for ClampOp { type Options = Options; type Unary = Self; } launch_unary_numeric::(input, |_| OptionsLaunch::new(min_value, max_value)) } ================================================ FILE: crates/burn-cubecl/src/kernel/comparison.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::{DType, TensorMetadata}; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; #[cube] pub(crate) trait ComparisonOpFamily: 'static + Send + Sync { type Operation: ComparisonOp; } #[cube] pub(crate) trait ComparisonOp: 'static + Send + Sync { /// Execute a comparison operation. fn execute(lhs: Vector, rhs: Vector) -> bool; } struct EqualOp; struct GreaterEqualOp; struct LowerEqualOp; struct GreaterOp; struct LowerOp; impl ComparisonOpFamily for EqualOp { type Operation = Self; } #[cube] impl ComparisonOp for EqualOp { fn execute(lhs: Vector, rhs: Vector) -> bool { lhs == rhs } } impl ComparisonOpFamily for GreaterEqualOp { type Operation = Self; } #[cube] impl ComparisonOp for GreaterEqualOp { fn execute(lhs: Vector, rhs: Vector) -> bool { lhs >= rhs } } impl ComparisonOpFamily for LowerEqualOp { type Operation = Self; } #[cube] impl ComparisonOp for LowerEqualOp { fn execute(lhs: Vector, rhs: Vector) -> bool { lhs <= rhs } } impl ComparisonOpFamily for GreaterOp { type Operation = Self; } #[cube] impl ComparisonOp for GreaterOp { fn execute(lhs: Vector, rhs: Vector) -> bool { lhs > rhs } } impl ComparisonOpFamily for LowerOp { type Operation = Self; } #[cube] impl ComparisonOp for LowerOp { fn execute(lhs: Vector, rhs: Vector) -> bool { lhs < rhs } } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_scalar_cmp( input: &LinearView>, scalar: InputScalar, output: &mut LinearView, ReadWrite>, #[define(T, Bool)] _dtypes: [StorageType; 2], ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = Vector::cast_from(O::Operation::::execute( input[ABSOLUTE_POS], Vector::new(scalar.get::()), )); } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_cmp( lhs: &LinearView>, rhs: &LinearView>, out: &mut LinearView, ReadWrite>, #[define(T, Bool)] _dtype: [StorageType; 2], ) { if !out.is_in_bounds(ABSOLUTE_POS) { terminate!(); } out[ABSOLUTE_POS] = Vector::cast_from(O::Operation::::execute( lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS], )); } pub(crate) fn launch_cmp( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { let vector_size_lhs = max_vector_size(&lhs); let vector_size_rhs = max_vector_size(&rhs); let vector_size = Ord::min(vector_size_lhs, vector_size_rhs); let shape_out = broadcast_shape(&[&lhs, &rhs]); let client = lhs.client.clone(); let num_elems = shape_out.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&lhs.client, working_units); let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim); let dtypes = [lhs.dtype.into(), dtype_bool.into()]; let same_tensor_type = dtypes[0] == dtypes[1]; if same_tensor_type && lhs.can_mut_broadcast(&rhs) { unsafe { kernel_cmp::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.clone().into_linear_view(), rhs.into_linear_view_like(&lhs), lhs.as_linear_view_alias(0), dtypes, ); } CubeTensor::new( lhs.client.clone(), lhs.handle.clone(), *lhs.meta.clone(), lhs.device.clone(), dtype_bool, ) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { unsafe { kernel_cmp::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs), vector_size, lhs.into_linear_view_like(&rhs), rhs.clone().into_linear_view(), rhs.as_linear_view_alias(1), dtypes, ); }; CubeTensor::new( rhs.client.clone(), rhs.handle.clone(), *rhs.meta.clone(), rhs.device.clone(), dtype_bool, ) } else { let output = empty_device_dtype( lhs.client.clone(), lhs.device.clone(), shape_out, dtype_bool, ); unsafe { kernel_cmp::launch_unchecked::( &client, cube_count, cube_dim, address_type!(lhs, rhs, output), vector_size, lhs.into_linear_view_like(&output), rhs.into_linear_view_like(&output), output.clone().into_linear_view(), dtypes, ); }; output } } pub(crate) fn launch_scalar_cmp( tensor: CubeTensor, scalar: InputScalar, dtype_bool: DType, ) -> CubeTensor { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtypes = [tensor.dtype.into(), dtype_bool.into()]; let same_tensor_type = dtypes[0] == dtypes[1]; if same_tensor_type && tensor.can_mut() && tensor.is_nonoverlapping() { unsafe { kernel_scalar_cmp::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), scalar, tensor.as_linear_view_alias(0), dtypes, ); } CubeTensor::new( tensor.client.clone(), tensor.handle.clone(), *tensor.meta.clone(), tensor.device.clone(), dtype_bool, ) } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), dtype_bool, ); unsafe { kernel_scalar_cmp::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), scalar, output.clone().into_linear_view(), dtypes, ); } output } } pub fn equal( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { launch_cmp::(lhs, rhs, dtype_bool) } pub fn greater( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { launch_cmp::(lhs, rhs, dtype_bool) } pub fn greater_equal( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { launch_cmp::(lhs, rhs, dtype_bool) } pub fn lower( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { launch_cmp::(lhs, rhs, dtype_bool) } pub fn lower_equal( lhs: CubeTensor, rhs: CubeTensor, dtype_bool: DType, ) -> CubeTensor { launch_cmp::(lhs, rhs, dtype_bool) } pub fn equal_elem( lhs: CubeTensor, rhs: InputScalar, dtype_bool: DType, ) -> CubeTensor { launch_scalar_cmp::(lhs, rhs, dtype_bool) } pub fn greater_elem( lhs: CubeTensor, rhs: InputScalar, dtype_bool: DType, ) -> CubeTensor { launch_scalar_cmp::(lhs, rhs, dtype_bool) } pub fn lower_elem( lhs: CubeTensor, rhs: InputScalar, dtype_bool: DType, ) -> CubeTensor { launch_scalar_cmp::(lhs, rhs, dtype_bool) } pub fn greater_equal_elem( lhs: CubeTensor, rhs: InputScalar, dtype_bool: DType, ) -> CubeTensor { launch_scalar_cmp::(lhs, rhs, dtype_bool) } pub fn lower_equal_elem( lhs: CubeTensor, rhs: InputScalar, dtype_bool: DType, ) -> CubeTensor { launch_scalar_cmp::(lhs, rhs, dtype_bool) } // Unary comparison / predicate / relational ops #[cube] pub(crate) trait PredicateOp: 'static + Send + Sync { /// Execute a predicate operation. fn execute(input: Vector) -> Vector; } pub(crate) trait PredicateOpFamily: 'static + Send + Sync { type Operation: PredicateOp; } struct IsNanOp; struct IsInfOp; impl PredicateOpFamily for IsNanOp { type Operation = Self; } #[cube] impl PredicateOp for IsNanOp { fn execute(input: Vector) -> Vector { Vector::is_nan(input) } } impl PredicateOpFamily for IsInfOp { type Operation = Self; } #[cube] impl PredicateOp for IsInfOp { fn execute(input: Vector) -> Vector { Vector::is_inf(input) } } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn kernel_predicate( input: &LinearView>, output: &mut LinearView, ReadWrite>, #[define(F, Bool)] _dtypes: [StorageType; 2], ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = Vector::cast_from(O::Operation::::execute(input[ABSOLUTE_POS])); } pub(crate) fn launch_predicate( tensor: CubeTensor, dtype_bool: DType, ) -> CubeTensor { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let dtypes = [tensor.dtype.into(), dtype_bool.into()]; let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), dtype_bool, ); unsafe { kernel_predicate::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view_like(&output), output.clone().into_linear_view(), dtypes, ); } output } pub fn is_nan(tensor: CubeTensor, dtype_bool: DType) -> CubeTensor { launch_predicate::(tensor, dtype_bool) } pub fn is_inf(tensor: CubeTensor, dtype_bool: DType) -> CubeTensor { launch_predicate::(tensor, dtype_bool) } ================================================ FILE: crates/burn-cubecl/src/kernel/contiguous.rs ================================================ use burn_backend::{DType, QTensorPrimitive, TensorMetadata}; use cubecl::quant::scheme::{QuantStore, QuantValue}; use cubecl::server::MemoryLayoutStrategy; use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor}; /// Make a jit tensor contiguous. pub fn into_contiguous(tensor: CubeTensor) -> CubeTensor { if tensor.is_contiguous() { return tensor; } if tensor.qparams.is_some() { return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Contiguous); } let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype); let output = cubecl::std::tensor::into_contiguous(&client, tensor.binding(), dtype.into()); CubeTensor::new( client.clone(), output.handle, *output.metadata, device, dtype, ) } /// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous /// if runtime can read it as is. This is equivalent in practice. #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensor)) )] pub fn into_contiguous_aligned(tensor: CubeTensor) -> CubeTensor { if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) { return tensor; } if tensor.qparams.is_some() { return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Optimized); } let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype); let output = cubecl::std::tensor::into_contiguous_pitched(&client, tensor.binding(), dtype.into()); CubeTensor::new( client.clone(), output.handle, *output.metadata, device, dtype, ) } #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensor)) )] fn into_contiguous_quantized( tensor: CubeTensor, strategy: MemoryLayoutStrategy, ) -> CubeTensor { let scheme = tensor.scheme(); let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, strategy); let (values, scales) = tensor.quantized_handles().unwrap(); let (out_values, out_scales) = output.quantized_handles().unwrap(); let (client, dtype_scales, dtype_value) = (scales.client.clone(), scales.dtype, values.dtype); match scheme.store { QuantStore::PackedU32(packed_dim) => { cubecl::std::tensor::into_contiguous_packed_ref( &client, values.binding(), out_values.binding(), packed_dim, tensor.meta.shape(), scheme.num_quants(), DType::U32.into(), ); } // e2m1 is special because it has a native packed representation, `e2m1x2`. // It's internally stored as `u8` with a packing factor of 2. QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => { cubecl::std::tensor::into_contiguous_packed_ref( &client, values.binding(), out_values.binding(), packed_dim, tensor.meta.shape(), scheme.num_quants(), DType::U8.into(), ); } _ => { cubecl::std::tensor::copy_into( &client, values.binding(), out_values.binding(), dtype_value.into(), ); } } cubecl::std::tensor::copy_into( &client, scales.binding(), out_scales.binding(), dtype_scales.into(), ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs ================================================ use burn_backend::{ TensorMetadata, ops::{ConvOptions, ConvTransposeOptions}, }; use burn_std::Shape; use cubek::convolution::components::ConvSetupError; use crate::{ CubeRuntime, kernel::conv::{conv_transpose2d, conv_transpose3d}, ops::{permute_nchw_to_nhwc, permute_nhwc_to_nchw, reshape}, tensor::CubeTensor, }; pub(crate) fn conv_data_backward_fallback( out_grad: CubeTensor, weights: CubeTensor, in_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { let dim_c = out_grad.rank(); let kernel_size = &weights.meta.shape()[1..dim_c]; let in_shape = &in_shape[1..dim_c]; let out_shape = &out_grad.meta.shape()[1..dim_c]; let mut padding_out = [0; N_DIM]; for i in 0..N_DIM { padding_out[i] = calculate_padding_out( kernel_size[i], options.stride[i], options.padding[i], options.dilation[i], in_shape[i], out_shape[i], ); } // We don't yet have NHWC kernels for conv_transpose so need to do this. // Should eventually use NHWC kernels instead let out_grad = permute_nhwc_to_nchw(out_grad); let weights = permute_nhwc_to_nchw(weights); let in_grad = match N_DIM { 1 => conv_transpose1d_from_conv_transpose2d( out_grad, weights, ConvTransposeOptions::new( [options.stride[0]], [options.padding[0]], [padding_out[0]], [options.dilation[0]], options.groups, ), ), 2 => conv_transpose2d( out_grad, weights, None, ConvTransposeOptions::new( [options.stride[0], options.stride[1]], [options.padding[0], options.padding[1]], [padding_out[0], padding_out[1]], [options.dilation[0], options.dilation[1]], options.groups, ), Default::default(), ), 3 => Ok(conv_transpose3d( out_grad, weights, None, ConvTransposeOptions::new( [options.stride[0], options.stride[1], options.stride[2]], [options.padding[0], options.padding[1], options.padding[2]], [padding_out[0], padding_out[1], padding_out[2]], [ options.dilation[0], options.dilation[1], options.dilation[2], ], options.groups, ), ) .unwrap()), _ => unimplemented!("Invalid dimensionality"), }?; Ok(permute_nchw_to_nhwc(in_grad)) } fn calculate_padding_out( kernel_size: usize, stride: usize, padding: usize, dilation: usize, size_in: usize, size_out: usize, ) -> usize { if stride <= 1 { return 0; } let out = 1 + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil() as usize; i64::max(0, out as i64 - size_out as i64) as usize } fn conv_transpose1d_from_conv_transpose2d( x: CubeTensor, weight: CubeTensor, options: ConvTransposeOptions<1>, ) -> Result, ConvSetupError> { let [channels_in, channels_out, kernel_size] = weight.shape().dims(); let [batch_size, _channels_in, length_in] = x.shape().dims(); let weight = reshape( weight, Shape::new([channels_in, channels_out, kernel_size, 1]), ); let x = reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); let tensor = conv_transpose2d( x, weight, None, ConvTransposeOptions::new( [options.stride[0], 1], [options.padding[0], 0], [options.padding_out[0], 0], [options.dilation[0], 1], options.groups, ), Default::default(), )?; let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); Ok(reshape( tensor, Shape::from([batch_size, channels_out, height_out]), )) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs ================================================ use burn_backend::ops::ConvOptions; use burn_std::Shape; use cubek::{ convolution::{ AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_data, components::ConvSetupError, }, matmul::{ definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputBinding, }, }; use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; pub fn dgrad_gemm_simple_sync( out_grad: CubeTensor, weights: CubeTensor, input_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic, AcceleratedTileKind::Mma => ReadingStrategy::Strided, }; launch_backwards_data::( &Strategy::Simple { read_strategy, tile_kind, }, out_grad, weights, input_shape, options, ) } pub fn dgrad_gemm_simple_async( out_grad: CubeTensor, weights: CubeTensor, input_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic, AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided, }; launch_backwards_data::( &Strategy::Simple { read_strategy, tile_kind, }, out_grad, weights, input_shape, options, ) } pub fn dgrad_gemm_simple_tma( out_grad: CubeTensor, weights: CubeTensor, input_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { launch_backwards_data::( &Strategy::Simple { read_strategy: ReadingStrategy::Tma, tile_kind, }, out_grad, weights, input_shape, options, ) } /// Perform a convolution backwards data pass using the implicit GEMM (im2col) algorithm, using /// cubecl tiling matmul components. /// /// * `input` - The input feature map /// * `out_grad` - The output gradients /// * `weight_shape` - The shape of the weights/weight gradients /// * `options` - The options to use for the convolution pub fn launch_backwards_data( strategy: &Strategy, out_grad: CubeTensor, weights: CubeTensor, input_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { if options.groups != 1 || options.stride.iter().any(|&s| s != 1) { return Err(ConvSetupError::Groups(options.groups)); } let out_dtype = out_grad.dtype; let in_grad = empty_device_dtype( out_grad.client.clone(), out_grad.device.clone(), input_shape, out_dtype, ); let client = out_grad.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { lhs: out_grad.dtype.into(), rhs: weights.dtype.into(), out: out_dtype.into(), }); let out_grad_dtype = out_grad.dtype; let weights_dtype = weights.dtype; let out_grad = MatmulInputBinding::new(out_grad.binding(), out_grad_dtype.into()); let weights = MatmulInputBinding::new(weights.binding(), weights_dtype.into()); backward_data::launch_ref::( strategy, &client, out_grad, weights, in_grad.clone().binding(), ConvolutionArgs { stride: options.stride, padding: options.padding, dilation: options.dilation, }, dtypes, )?; Ok(in_grad) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/mod.rs ================================================ pub mod launch; pub use launch::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_data/mod.rs ================================================ pub mod fallback; pub mod implicit_gemm; #[cfg(feature = "autotune")] pub mod tune; #[cfg(feature = "autotune")] pub(crate) use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs ================================================ use burn_backend::ops::ConvOptions; use burn_std::Shape; use cubecl::{ ir::StorageType, tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner}, }; use cubek::convolution::AcceleratedTileKind; use crate::{ CubeAutotuneKey, CubeRuntime, CubeTuneId, kernel::conv::{ ConvAutotuneKey, backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*}, }, tensor::CubeTensor, }; /// Executes autotune on conv2d operations pub fn dgrad_autotune( out_grad: CubeTensor, weights: CubeTensor, input_shape: Shape, options: ConvOptions, ) -> CubeTensor { let client = out_grad.client.clone(); static TUNER: LocalTuner = local_tuner!(); // Note: TMA isn't currently implemented properly, and will always error. // It's kept here so it gets automatically enabled as soon as cubek updates. // No CMMA for TMA because swizzling will be mandatory for good performance on dgrad. let tunables = TUNER.init(|| { TunableSet::new(create_key::, create_wgrad_input::) .with(Tunable::new( "wgrad_fallback", conv_data_backward_fallback::, )) .with(Tunable::new( "simple_sync_cmma", |input, grad, shape, options| { dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_sync_mma", |input, grad, shape, options| { dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_async_cmma", |input, grad, shape, options| { dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_async_mma", |input, grad, shape, options| { dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_tma_mma", |input, grad, shape, options| { dgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) }); TUNER.execute( &CubeTuneId::new(&out_grad.client, &out_grad.device), &client, tunables, (out_grad, weights, input_shape, options), ) } pub fn create_wgrad_input( _key: &CubeAutotuneKey, out_grad: &CubeTensor, weights: &CubeTensor, input_shape: &Shape, options: &ConvOptions, ) -> (CubeTensor, CubeTensor, Shape, ConvOptions) { ( out_grad.clone(), weights.clone(), input_shape.clone(), options.clone(), ) } fn create_key( out_grad: &CubeTensor, weights: &CubeTensor, input_shape: &Shape, options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = out_grad.dtype; let rank = out_grad.meta.num_dims(); let dim_c = rank - 1; let batch_size = out_grad.meta.shape()[0]; let in_channels = input_shape[dim_c]; let out_channels = out_grad.meta.shape()[dim_c]; let kernel_size = weights.meta.shape()[1..dim_c].to_vec(); let in_shape = input_shape[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) .collect(); let ConvOptions { stride, padding, dilation, groups, } = options.clone(); let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 { stride_align(out_grad.meta.strides(), out_grad.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align); let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 { stride_align(weights.meta.strides(), weights.dtype.into()) } else { 0 }; let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align); CubeAutotuneKey::Conv(ConvAutotuneKey::new( kernel_size, stride.to_vec(), padding.to_vec(), dilation.to_vec(), groups, in_channels, out_channels, in_shape, batch_size, false, dtype, lhs_shape_align, lhs_stride_align, rhs_shape_align, rhs_stride_align, )) } /// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's /// repeat number, so it's the largest align that can have performance impacts. const MAX_STRIDE_FACTOR: u32 = 10; /// Defines the non-contiguous stride alignment in terms of powers of two fn stride_align(strides: &[usize], elem: StorageType) -> u8 { let max = MAX_STRIDE_FACTOR; let dim_c = strides.len() - 1; let factor = strides[..dim_c] .iter() .map(|it| (*it * elem.size_bits()) / 8) .map(|it| it.trailing_zeros()) .min() .unwrap_or(max); factor.min(max) as u8 } /// Defines the potential vectorization. fn pow2_factor(axis: usize) -> u8 { axis.trailing_zeros().min(4) as u8 } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs ================================================ use burn_backend::{TensorMetadata, ops::ConvOptions}; use burn_std::{Shape, Slice}; use cubek::convolution::components::ConvSetupError; use crate::{ CubeRuntime, kernel::{conv::base::conv_forward_nhwc, slice, slice_assign}, ops::{numeric::empty_device_dtype, swap_dims}, tensor::CubeTensor, }; /// Calculate the convolution backward pass with regard to the weight gradients. pub fn conv_weight_backward_fallback( input: CubeTensor, output_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { match options.groups == 1 { true => conv_weight_grad_no_groups::(input, output_grad, weight_shape, options), false => conv_weight_grad_groups::(input, output_grad, weight_shape, options), } } fn conv_weight_grad_no_groups( input: CubeTensor, output_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { let dim_c = input.rank() - 1; let input_swapped = swap_dims(input, 0, dim_c); let out_grad_swapped = swap_dims(output_grad, 0, dim_c); let weight_grad_swapped = conv_forward_nhwc( input_swapped, out_grad_swapped, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), Default::default(), )?; let mut weight_grad = swap_dims(weight_grad_swapped, 0, dim_c); if weight_grad.shape() != weight_shape { let ranges = weight_shape.iter().map(|&s| 0..s).collect::>(); weight_grad = slice(weight_grad, &ranges); } Ok(weight_grad) } #[allow(clippy::single_range_in_vec_init, reason = "False positive")] fn conv_weight_grad_groups( input: CubeTensor, output_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { let mut weight_grad = empty_device_dtype( input.client.clone(), input.device.clone(), weight_shape.clone(), input.dtype, ); let dim_c = input.rank() - 1; let channels_out = weight_shape[0]; let increment_co = channels_out / options.groups; let input_swapped = swap_dims(input, 0, dim_c); let output_grad_swapped = swap_dims(output_grad, 0, dim_c); let kernel_size = &weight_shape[1..dim_c]; let kernel_size_slice = kernel_size.iter().map(|&s| 0..s).collect::>(); let increment_ci = weight_grad.meta.shape()[dim_c]; for g in 0..options.groups { let start_idx_ci = g * increment_ci; let end_idx_ci = (g + 1) * increment_ci; let start_idx_co = g * increment_co; let end_idx_co = (g + 1) * increment_co; let input = slice(input_swapped.clone(), &[start_idx_ci..end_idx_ci]); let grad = slice(output_grad_swapped.clone(), &[start_idx_co..end_idx_co]); let weight_grad_tmp = conv_forward_nhwc( input, grad, None, ConvOptions::new(options.dilation, options.padding, options.stride, 1), Default::default(), )?; let mut weight_grad_tmp = swap_dims(weight_grad_tmp, 0, dim_c); let kernel_size_tmp = &weight_grad_tmp.meta.shape()[1..dim_c]; if kernel_size != kernel_size_tmp { let mut slices = vec![0..increment_co]; slices.extend(kernel_size_slice.clone()); slices.push(0..increment_ci); weight_grad_tmp = slice(weight_grad_tmp, &slices); } let mut slices = vec![start_idx_co..end_idx_co]; slices.extend(kernel_size_slice.clone()); slices.push(0..increment_ci); let slices = slices.into_iter().map(Slice::from).collect::>(); weight_grad = slice_assign(weight_grad, &slices, weight_grad_tmp); } Ok(weight_grad) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs ================================================ use burn_backend::ops::ConvOptions; use burn_std::Shape; use cubek::{ convolution::{ AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_weight, components::ConvSetupError, }, matmul::{ definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputBinding, }, }; use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; pub(crate) fn wgrad_gemm_simple_sync( input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic, AcceleratedTileKind::Mma => ReadingStrategy::Strided, }; launch_backwards_weight::( &Strategy::Simple { read_strategy, tile_kind, }, input, out_grad, weight_shape, options, ) } pub(crate) fn wgrad_gemm_simple_async( input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic, AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided, }; launch_backwards_weight::( &Strategy::Simple { read_strategy, tile_kind, }, input, out_grad, weight_shape, options, ) } pub(crate) fn wgrad_gemm_simple_tma( input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { launch_backwards_weight::( &Strategy::Simple { read_strategy: ReadingStrategy::Tma, tile_kind, }, input, out_grad, weight_shape, options, ) } /// Perform a convolution backwards weight pass using the implicit GEMM (im2col) algorithm, using /// cubecl tiling matmul components. /// /// * `input` - The input feature map /// * `out_grad` - The output gradients /// * `weight_shape` - The shape of the weights/weight gradients /// * `options` - The options to use for the convolution pub fn launch_backwards_weight( strategy: &Strategy, input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, ) -> Result, ConvSetupError> { if options.groups != 1 { return Err(ConvSetupError::Groups(options.groups)); } let out_dtype = out_grad.dtype; let weight_grad = empty_device_dtype( input.client.clone(), input.device.clone(), weight_shape, out_dtype, ); let client = input.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { lhs: input.dtype.into(), rhs: out_grad.dtype.into(), out: out_dtype.into(), }); let input_dtype = input.dtype; let out_grad_dtype = out_grad.dtype; let input = MatmulInputBinding::new(input.binding(), input_dtype.into()); let out_grad = MatmulInputBinding::new(out_grad.binding(), out_grad_dtype.into()); backward_weight::launch_ref::( strategy, &client, input, out_grad, weight_grad.clone().binding(), ConvolutionArgs { stride: options.stride, padding: options.padding, dilation: options.dilation, }, dtypes, )?; Ok(weight_grad) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/mod.rs ================================================ pub mod launch; pub use launch::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_weight/mod.rs ================================================ pub mod fallback; pub mod implicit_gemm; #[cfg(feature = "autotune")] pub mod tune; #[cfg(feature = "autotune")] pub(crate) use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs ================================================ use burn_backend::ops::ConvOptions; use burn_std::Shape; use cubecl::{ ir::StorageType, tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner}, }; use cubek::convolution::AcceleratedTileKind; use crate::{ CubeAutotuneKey, CubeRuntime, CubeTuneId, kernel::conv::{ ConvAutotuneKey, backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*}, }, tensor::CubeTensor, }; /// Executes autotune on the weight gradients pass for convolution pub fn wgrad_autotune( input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, ) -> CubeTensor { let client = input.client.clone(); static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { TunableSet::new(create_key::, create_wgrad_input::) .with(Tunable::new( "wgrad_fallback", conv_weight_backward_fallback::, )) .with(Tunable::new( "simple_sync_cmma", |input, grad, shape, options| { wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_sync_mma", |input, grad, shape, options| { wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_async_cmma", |input, grad, shape, options| { wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_async_mma", |input, grad, shape, options| { wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_tma_cmma", |input, grad, shape, options| { wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_tma_mma", |input, grad, shape, options| { wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma) }, )) }); TUNER.execute( &CubeTuneId::new(&input.client, &input.device), &client, tunables, (input, out_grad, weight_shape, options), ) } pub fn create_wgrad_input( _key: &CubeAutotuneKey, input: &CubeTensor, out_grad: &CubeTensor, weight_shape: &Shape, options: &ConvOptions, ) -> (CubeTensor, CubeTensor, Shape, ConvOptions) { ( input.clone(), out_grad.clone(), weight_shape.clone(), options.clone(), ) } fn create_key( input: &CubeTensor, out_grad: &CubeTensor, weight_shape: &Shape, options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = input.dtype; let rank = input.meta.num_dims(); let dim_c = rank - 1; let batch_size = input.meta.shape()[0]; let in_channels = input.meta.shape()[dim_c]; let out_channels = weight_shape[0]; let kernel_size = weight_shape[1..dim_c].to_vec(); let in_shape = input.meta.shape()[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) .collect(); let ConvOptions { stride, padding, dilation, groups, } = options.clone(); let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 { stride_align(out_grad.meta.strides(), out_grad.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align); let rhs_stride_align = if input.meta.strides()[dim_c] == 1 { stride_align(input.meta.strides(), input.dtype.into()) } else { 0 }; let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align); CubeAutotuneKey::Conv(ConvAutotuneKey::new( kernel_size, stride.to_vec(), padding.to_vec(), dilation.to_vec(), groups, in_channels, out_channels, in_shape, batch_size, false, dtype, lhs_shape_align, lhs_stride_align, rhs_shape_align, rhs_stride_align, )) } /// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's /// repeat number, so it's the largest align that can have performance impacts. const MAX_STRIDE_FACTOR: u32 = 10; /// Defines the non-contiguous stride alignment in terms of powers of two fn stride_align(strides: &[usize], elem: StorageType) -> u8 { let max = MAX_STRIDE_FACTOR; let dim_c = strides.len() - 1; let factor = strides[..dim_c] .iter() .map(|it| (*it * elem.size_bits()) / 8) .map(|it| it.trailing_zeros()) .min() .unwrap_or(max); factor.min(max) as u8 } /// Defines the potential vectorization. fn pow2_factor(axis: usize) -> u8 { axis.trailing_zeros().min(4) as u8 } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/base.rs ================================================ use burn_backend::ops::ConvOptions; use burn_std::Shape; use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError}; #[cfg(feature = "autotune")] use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune}; use crate::{ CubeRuntime, kernel::conv::{ backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*}, backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*}, forward::implicit_gemm::conv_gemm_simple_sync, }, ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw}, tensor::CubeTensor, }; use super::conv_direct; #[cfg(feature = "autotune")] use super::forward::conv_autotune; /// The strategy to be used when launching a convolution kernel. pub enum ConvStrategy { /// A simple direct convolution. Direct, #[cfg(feature = "autotune")] /// Using autotune to choose the best kernel based on runtime information. Autotune, /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and /// has constraints on tensor shape. ImplicitGemm, } impl Default for ConvStrategy { fn default() -> Self { // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] return ConvStrategy::Autotune; // if autotune is disabled, default to the more memory-conservative algorithm #[cfg(not(feature = "autotune"))] ConvStrategy::Direct } } /// Performs an N-dimensional convolution with the given strategy /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. pub fn conv_forward( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, strategy: ConvStrategy, ) -> Result, ConvSetupError> { let input = permute_nchw_to_nhwc(input); let weight = permute_nchw_to_nhwc(weight); let out = conv_forward_nhwc(input, weight, bias, options, strategy)?; Ok(permute_nhwc_to_nchw(out)) } /// Performs an N-dimensional convolution with the given strategy on NHWC inputs/outputs /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. pub fn conv_forward_nhwc( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, strategy: ConvStrategy, ) -> Result, ConvSetupError> { match strategy { ConvStrategy::Direct => conv_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] ConvStrategy::Autotune => Ok(conv_autotune::(input, weight, bias, options)), ConvStrategy::ImplicitGemm => { if options.groups != 1 { conv_direct::(input, weight, bias, options) } else { conv_gemm_simple_sync::( input, weight, bias, options, AcceleratedTileKind::Cmma, ) } } } } /// Performs an N-dimensional convolution backwards pass with regard to weight, with the given strategy /// /// * `input` - The input feature map /// * `out_grad` - The output gradients /// * `weight_shape` - The shape of the weights/weight gradients /// * `options` - The options used for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. pub fn conv_weight_backward( input: CubeTensor, out_grad: CubeTensor, weight_shape: Shape, options: ConvOptions, strategy: ConvStrategy, ) -> Result, ConvSetupError> { let input = permute_nchw_to_nhwc(input); let out_grad = permute_nchw_to_nhwc(out_grad); let weight_shape = permute_nchw_to_nhwc_shape(weight_shape); let weight_grad = match strategy { ConvStrategy::Direct => { conv_weight_backward_fallback::(input, out_grad, weight_shape, options) } #[cfg(feature = "autotune")] ConvStrategy::Autotune => Ok(wgrad_autotune::( input, out_grad, weight_shape, options, )), ConvStrategy::ImplicitGemm => { if options.groups != 1 { conv_weight_backward_fallback::(input, out_grad, weight_shape, options) } else { wgrad_gemm_simple_sync::( input, out_grad, weight_shape, options, AcceleratedTileKind::Cmma, ) } } }?; Ok(permute_nhwc_to_nchw(weight_grad)) } /// Performs an N-dimensional convolution backwards data pass with the given strategy /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `in_shape` - The shape of the input to the layer /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. pub fn conv_data_backward( out_grad: CubeTensor, weights: CubeTensor, in_shape: Shape, options: ConvOptions, strategy: ConvStrategy, ) -> Result, ConvSetupError> { let out_grad = permute_nchw_to_nhwc(out_grad); let weights = permute_nchw_to_nhwc(weights); let in_shape = permute_nchw_to_nhwc_shape(in_shape); let weight_grad = match strategy { ConvStrategy::Direct => { conv_data_backward_fallback::(out_grad, weights, in_shape, options)? } #[cfg(feature = "autotune")] ConvStrategy::Autotune => dgrad_autotune::(out_grad, weights, in_shape, options), ConvStrategy::ImplicitGemm => { if options.groups != 1 || options.stride.iter().any(|&s| s != 1) { conv_data_backward_fallback::(out_grad, weights, in_shape, options)? } else { dgrad_gemm_simple_sync::( out_grad, weights, in_shape, options, AcceleratedTileKind::Cmma, )? } } }; Ok(permute_nhwc_to_nchw(weight_grad)) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose2d/base.rs ================================================ use crate::{CubeRuntime, tensor::CubeTensor}; use burn_backend::ops::ConvTransposeOptions; use cubek::convolution::components::ConvSetupError; #[cfg(feature = "autotune")] use super::conv_transpose2d_autotune; use super::{conv_transpose2d_col2im, conv_transpose2d_direct}; /// The strategy to be used when launching a conv_transpose kernel. pub enum ConvTranspose2dStrategy { /// A simple direct convolution. Direct, #[cfg(feature = "autotune")] /// Using autotune to choose the best kernel based on runtime information. Autotune, /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage. Gemm, } impl Default for ConvTranspose2dStrategy { fn default() -> Self { // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] return ConvTranspose2dStrategy::Autotune; // if autotune is disabled, default to the more memory-conservative algorithm #[cfg(not(feature = "autotune"))] ConvTranspose2dStrategy::Direct } } /// Performs a 2D convolution with the given strategy /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. pub fn conv_transpose2d( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, ) -> Result, ConvSetupError> { match strategy { ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options), #[cfg(feature = "autotune")] ConvTranspose2dStrategy::Autotune => { Ok(conv_transpose2d_autotune(input, weight, bias, options)) } ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options), } } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs ================================================ use crate::{ CubeRuntime, kernel::{ conv::batches_per_run, into_contiguous_aligned, matmul::{MatmulStrategy, matmul}, slice, utils::{address_type, decompose_linear, shape_divmod}, }, ops::{numeric::empty_device_dtype, reshape, swap_dims}, tensor::CubeTensor, }; use burn_backend::{ Shape, ops::{ConvTransposeOptions, conv::calculate_conv_transpose_output_size}, }; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; use cubek::convolution::components::ConvSetupError; /// Perform a 2D convolution transposition using the GEMM (col2im) algorithm. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution pub fn conv_transpose2d_col2im( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> Result, ConvSetupError> { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.meta.shape().dims(); let [batch_size, _, input_h, input_w] = input.meta.shape().dims(); let groups = options.groups; let input_ch_per_group = input_channels / groups; let ConvTransposeOptions { padding: [padding_h, padding_w], padding_out: [padding_out_h, padding_out_w], dilation: [dilation_h, dilation_w], stride: [stride_h, stride_w], .. } = options.clone(); let im_h = calculate_conv_transpose_output_size( kernel_h, stride_h, padding_h, padding_out_h, dilation_h, input_h, ); let im_w = calculate_conv_transpose_output_size( kernel_w, stride_w, padding_w, padding_out_w, dilation_w, input_w, ); let im_channels = im_ch_per_group * groups; let batches_per_run = batches_per_run( batch_size, input_h * input_w, input.client.properties().hardware.plane_size_max as usize, )?; let col_shape_0 = im_ch_per_group * kernel_h * kernel_w; let weight = reshape( weight.clone(), Shape::new([groups, input_ch_per_group, col_shape_0]), ); let weight = into_contiguous_aligned(swap_dims(weight, 1, 2)); if batches_per_run != batch_size { let runs = batch_size / batches_per_run; let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]); let image = empty_device_dtype( input.client.clone(), input.device.clone(), im_shape, input.dtype, ); let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]); let input = reshape(input, input_shape); let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]); for run in 0..runs { let input = index(input.clone(), run); let input = reshape(input, input_shape_run.clone()); let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image_slice = index(image.clone(), run); let image_slice = reshape(image_slice, im_shape); execute( input, weight.clone(), bias.clone(), image_slice, options.clone(), kernel_h, kernel_w, )?; } Ok(reshape( image, Shape::new([batch_size, im_channels, im_h, im_w]), )) } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device_dtype( input.client.clone(), input.device.clone(), im_shape, input.dtype, ); execute( input, weight, bias, image.clone(), options, kernel_h, kernel_w, )?; Ok(image) } } pub(crate) fn index(tensor: CubeTensor, i: usize) -> CubeTensor { #[allow(clippy::single_range_in_vec_init)] let mut indices = vec![i..i + 1]; for dim in tensor.meta.shape()[1..].iter() { indices.push(0..*dim); } let mut tensor = slice(tensor, &indices); tensor.meta.remove(0); tensor } #[allow(clippy::too_many_arguments)] fn execute( input: CubeTensor, weight: CubeTensor, bias: Option>, image: CubeTensor, options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, ) -> Result<(), ConvSetupError> { let [batch_size, _, input_h, input_w] = input.meta.shape().dims(); let [groups, col_shape_0, input_ch_per_group] = weight.meta.shape().dims(); let col_shape_1 = batch_size * input_h * input_w; let input = swap_dims(input, 0, 1); let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); let dtype = input.dtype; let columns = matmul(weight, input, None, MatmulStrategy::default(), dtype)?; let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im( columns, bias, image, kernel_h, kernel_w, input_h, input_w, options, )?; Ok(()) } #[allow(clippy::too_many_arguments)] fn col2im( columns: CubeTensor, bias: Option>, out: CubeTensor, kernel_h: usize, kernel_w: usize, out_h: usize, out_w: usize, options: ConvTransposeOptions<2>, ) -> Result<(), LaunchError> { let dtype = columns.dtype; let columns = into_contiguous_aligned(columns); let bias = bias.map(into_contiguous_aligned); let num_elems = out.meta.num_elements(); let cube_dim = CubeDim::new(&columns.client, num_elems); let cube_count = calculate_cube_count_elemwise(&columns.client, num_elems, cube_dim); let shape = shape_divmod(&out); unsafe { col2im_kernel::launch_unchecked( &columns.client.clone(), cube_count, cube_dim, address_type!(columns, bias, out), columns.into_tensor_arg(), bias.map(|bias| bias.into_tensor_arg()).into(), out.into_linear_view(), shape, Col2ImArgsLaunch::new( out_h, out_w, kernel_h, kernel_w, options.padding[0], options.padding[1], options.dilation[0], options.dilation[1], options.stride[0], options.stride[1], ), dtype.into(), ) }; Ok(()) } #[derive(CubeLaunch, CubeType)] struct Col2ImArgs { out_h: usize, out_w: usize, kernel_h: usize, kernel_w: usize, pad_h: usize, pad_w: usize, dilation_h: usize, dilation_w: usize, stride_h: usize, stride_w: usize, } #[cube(launch_unchecked, address_type = "dynamic")] fn col2im_kernel( columns: &Tensor, bias: &ComptimeOption>, image: &mut LinearView, image_shape: Sequence>, args: &Col2ImArgs, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= image.shape() { terminate!(); } let (_, pos) = decompose_linear(ABSOLUTE_POS, &image_shape); let [batch, ch_im, im_y, im_x] = *pos else { unreachable!() }; let im_x = im_x + args.pad_w; let im_y = im_y + args.pad_h; let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1; let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1; let mut val = E::zero(); let x_col_start = if im_x >= kernel_extent_w { (im_x - kernel_extent_w) / args.stride_w + 1 } else { 0usize.runtime() }; let x_col_end = clamp_max(im_x / args.stride_w + 1, args.out_w); let y_col_start = if im_y >= kernel_extent_h { (im_y - kernel_extent_h) / args.stride_h + 1 } else { 0usize.runtime() }; let y_col_end = clamp_max(im_y / args.stride_h + 1, args.out_h); for col_y in y_col_start..y_col_end { let kernel_y = im_y - col_y * args.stride_h; for col_x in x_col_start..x_col_end { let kernel_x = im_x - col_x * args.stride_w; if kernel_y.is_multiple_of(args.dilation_h) && kernel_x.is_multiple_of(args.dilation_w) { let kernel_y = kernel_y / args.dilation_h; let kernel_x = kernel_x / args.dilation_w; let col_k = ch_im * args.kernel_h * args.kernel_w + kernel_y * args.kernel_w + kernel_x; let col_n = batch * args.out_h * args.out_w + col_y * args.out_w + col_x; let col_pos = col_k * columns.stride(0) + col_n * columns.stride(1); val += columns[col_pos]; } } } #[comptime] match bias { ComptimeOption::Some(bias) => image[ABSOLUTE_POS] = val + bias[ch_im], ComptimeOption::None => image[ABSOLUTE_POS] = val, } } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose2d/mod.rs ================================================ mod base; mod col2im; mod transpose_direct; #[cfg(feature = "autotune")] mod tune; pub use base::*; pub use col2im::*; pub use transpose_direct::*; #[cfg(feature = "autotune")] pub use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, decompose_linear, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{Shape, ops::ConvTransposeOptions}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; use cubek::convolution::components::ConvSetupError; #[derive(CubeLaunch, CubeType)] struct ConvArgs { conv_stride_0: usize, conv_stride_1: usize, dilation_0: usize, dilation_1: usize, padding_0: usize, padding_1: usize, groups: usize, } #[cube(launch, address_type = "dynamic")] fn conv_transpose2d_direct_kernel( input: &Tensor, weight: &Tensor, bias: &ComptimeOption>, output: &mut LinearView, out_shape: Sequence>, args: ConvArgs, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.shape() { terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; let out_c_per_group = weight.shape(1); let kernel_h = weight.shape(2); let kernel_w = weight.shape(3); let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape); let [batch, oc_out, out_y, out_x] = *pos else { unreachable!() }; let k = oc_out / out_c_per_group; let group = k % args.groups; let out_c = oc_out - out_c_per_group * group; let in_c_start = group * in_c_per_group; let in_c_end = in_c_start + in_c_per_group; let stride_0_i = args.conv_stride_0 as i32; let stride_1_i = args.conv_stride_1 as i32; let kms_h = (kernel_h * args.dilation_0) as i32 - stride_0_i; let kms_w = (kernel_w * args.dilation_1) as i32 - stride_1_i; let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i; let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i; let y_end = clamp(kms_h + y_start + 1, 0, input.shape(2) as i32) as usize; let x_end = clamp(kms_w + x_start + 1, 0, input.shape(3) as i32) as usize; let y_start = clamp_min(y_start, 0) as usize; let x_start = clamp_min(x_start, 0) as usize; let idx_input_batch = batch * input.stride(0); let idx_weight_oc = out_c * weight.stride(1); let bias: ComptimeOption = bias.map(|bias| bias[oc_out]); let mut sum = bias.unwrap_or_default(); let numerator_h_base = out_y + args.padding_0; let numerator_w_base = out_x + args.padding_1; for in_c in in_c_start..in_c_end { let idx_input_ic = in_c * input.stride(1); let idx_weight_ic = in_c * weight.stride(0); for in_y in y_start..y_end { let numerator_tmp = in_y * args.conv_stride_0; let numerator_h = numerator_h_base - numerator_tmp; if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_0) { let kernel_y = numerator_h / args.dilation_0; let idx_input_y = in_y * input.stride(2); let idx_weight_ky = kernel_y * weight.stride(2); for in_x in x_start..x_end { let numerator_tmp = in_x * args.conv_stride_1; let numerator_w = numerator_w_base - numerator_tmp; if numerator_w_base >= numerator_tmp && numerator_w.is_multiple_of(args.dilation_1) { let kernel_x = numerator_w / args.dilation_1; let idx_input_x = in_x * input.stride(3); let idx_weight_kx = kernel_x * weight.stride(3); let index_input = idx_input_batch + idx_input_ic + idx_input_y + idx_input_x; let index_weight = idx_weight_ic + idx_weight_oc + idx_weight_ky + idx_weight_kx; let value = input[index_input]; let weight = weight[index_weight]; sum += value * weight; } } } } } output[ABSOLUTE_POS] = sum; } /// Perform a 2D convolution transposition using the direct algorithm. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// pub fn conv_transpose2d_direct( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> Result, ConvSetupError> { let [batch_size, _, in_height, in_width] = input.meta.shape().dims(); let [_, out_channels, kernel_0, kernel_1] = weight.meta.shape().dims(); let out_0 = (in_height - 1) * options.stride[0] + options.dilation[0] * (kernel_0 - 1) + options.padding_out[0] - 2 * options.padding[0] + 1; let out_1 = (in_width - 1) * options.stride[1] + options.dilation[1] * (kernel_1 - 1) + options.padding_out[1] - 2 * options.padding[1] + 1; let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]); let output = empty_device_dtype( input.client.clone(), input.device.clone(), shape_out.clone(), input.dtype, ); let num_elems = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, num_elems); let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim); let dtype = input.dtype; conv_transpose2d_direct_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, weight, bias, output), input.into_tensor_arg(), weight.into_tensor_arg(), bias.map(|bias| bias.into_tensor_arg()).into(), output.clone().into_linear_view(), shape_divmod(&output), ConvArgsLaunch::new( options.stride[0], options.stride[1], options.dilation[0], options.dilation[1], options.padding[0], options.padding[1], options.groups, ), dtype.into(), ); Ok(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs ================================================ use burn_backend::ops::ConvTransposeOptions; use cubecl::tune::{LocalTuner, Tunable, TunableSet, local_tuner}; use crate::{ CubeAutotuneKey, CubeRuntime, CubeTuneId, kernel::conv::{ConvTranspose2dAutotuneKey, conv_transpose2d_col2im, conv_transpose2d_direct}, tensor::CubeTensor, }; /// Executes autotune on conv2d operations pub fn conv_transpose2d_autotune( input: CubeTensor, weights: CubeTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> CubeTensor { let client = input.client.clone(); static TUNER: LocalTuner = local_tuner!(); let tune_set = TUNER.init(|| { TunableSet::new(create_key::, create_transpose2d_input::) .with(Tunable::new( "conv_transpose2d_direct", conv_transpose2d_direct::, )) .with(Tunable::new( "conv_transpose2d_col2im", conv_transpose2d_col2im::, )) }); TUNER.execute( &CubeTuneId::new(&input.client, &input.device), &client, tune_set, (input, weights, bias, options), ) } pub fn create_transpose2d_input( _key: &CubeAutotuneKey, input: &CubeTensor, weights: &CubeTensor, bias: &Option>, options: &ConvTransposeOptions<2>, ) -> ( CubeTensor, CubeTensor, Option>, ConvTransposeOptions<2>, ) { ( input.clone(), weights.clone(), bias.clone(), options.clone(), ) } fn create_key( input: &CubeTensor, weights: &CubeTensor, bias: &Option>, options: &ConvTransposeOptions<2>, ) -> CubeAutotuneKey { let [batch_size, in_channels, height, width] = input.meta.shape().dims(); let [out_channels, _, kernel_h, kernel_w] = weights.meta.shape().dims(); let ConvTransposeOptions { stride, padding, dilation, groups, padding_out, } = options.clone(); CubeAutotuneKey::ConvTranspose(ConvTranspose2dAutotuneKey::new( [kernel_h, kernel_w], stride, padding, padding_out, dilation, groups, in_channels, out_channels, height, width, batch_size, bias.is_some(), input.dtype, )) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs ================================================ use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; use crate::{ CubeRuntime, kernel::utils::{address_type, decompose_linear, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{Shape, ops::ConvTransposeOptions}; #[derive(CubeLaunch, CubeType)] struct ConvArgs { conv_stride_0: usize, conv_stride_1: usize, conv_stride_2: usize, dilation_0: usize, dilation_1: usize, dilation_2: usize, padding_0: usize, padding_1: usize, padding_2: usize, groups: usize, } #[cube(launch, address_type = "dynamic")] fn conv_transpose3d_kernel( input: &Tensor, weight: &Tensor, bias: &ComptimeOption>, output: &mut LinearView, out_shape: Sequence>, args: ConvArgs, #[define(E)] _dtype: StorageType, ) { let in_channels = weight.shape(0); let out_c_per_group = weight.shape(1); let kernel_size_0 = weight.shape(2); let kernel_size_1 = weight.shape(3); let kernel_size_2 = weight.shape(4); let stride_0_i = args.conv_stride_0 as i32; let stride_1_i = args.conv_stride_1 as i32; let stride_2_i = args.conv_stride_2 as i32; let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape); let [batch, out_c_out, out_z, out_y, out_x] = *pos else { unreachable!() }; let groups = args.groups; let in_c_per_group = in_channels / groups; let k = out_c_out / out_c_per_group; let group = k % groups; let out_channel = out_c_out - out_c_per_group * group; let in_c_start = group * in_c_per_group; let in_c_end = in_c_start + in_c_per_group; let kernel_d = (kernel_size_0 * args.dilation_0 - args.conv_stride_0) as i32; let kernel_h = (kernel_size_1 * args.dilation_1 - args.conv_stride_1) as i32; let kernel_w = (kernel_size_2 * args.dilation_2 - args.conv_stride_2) as i32; let z_start = ((out_z + args.padding_0) as i32 - kernel_d) / stride_0_i; let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i; let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i; let z_end = clamp(kernel_d + z_start + 1, 0, input.shape(2) as i32) as usize; let y_end = clamp(kernel_h + y_start + 1, 0, input.shape(3) as i32) as usize; let x_end = clamp(kernel_w + x_start + 1, 0, input.shape(4) as i32) as usize; let z_start = clamp_min(z_start, 0) as usize; let y_start = clamp_min(y_start, 0) as usize; let x_start = clamp_min(x_start, 0) as usize; let index_input_batch = batch * input.stride(0); let index_weight_out_c = out_channel * weight.stride(1); let bias: ComptimeOption = bias.map(|bias| bias[out_c_out]); let mut sum = bias.unwrap_or_default(); let numerator_d_base = out_z + args.padding_0; let numerator_h_base = out_y + args.padding_1; let numerator_w_base = out_x + args.padding_2; for in_c in in_c_start..in_c_end { let index_input_in_c = in_c * input.stride(1); let index_weight_in_c = in_c * weight.stride(0); for in_z in z_start..z_end { let numerator_tmp = in_z * args.conv_stride_0; let numerator_d = numerator_d_base - numerator_tmp; if numerator_d_base >= numerator_tmp && numerator_d.is_multiple_of(args.dilation_0) { let kernel_z = numerator_d / args.dilation_0; let index_input_z = in_z * input.stride(2); let index_weight_kz = kernel_z * weight.stride(2); for in_y in y_start..y_end { let numerator_tmp = in_y * args.conv_stride_1; let numerator_h = numerator_h_base - numerator_tmp; if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_1) { let kernel_y = numerator_h / args.dilation_1; let index_input_y = in_y * input.stride(3); let index_weight_ky = kernel_y * weight.stride(3); for in_x in x_start..x_end { let numerator_tmp = in_x * args.conv_stride_2; let numerator_w = numerator_w_base - numerator_tmp; if numerator_w_base >= numerator_tmp && numerator_w.is_multiple_of(args.dilation_2) { let kernel_x = numerator_w / args.dilation_2; let index_input_x = in_x * input.stride(4); let index_weight_kx = kernel_x * weight.stride(4); let index_input = index_input_batch + index_input_in_c + index_input_z + index_input_y + index_input_x; let index_weight = index_weight_in_c + index_weight_out_c + index_weight_kz + index_weight_ky + index_weight_kx; let value = input[index_input]; let weight = weight[index_weight]; sum += value * weight; } } } } } } } output[ABSOLUTE_POS] = sum; } pub(crate) fn conv_transpose3d( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> Result, LaunchError> { let [batch_size, _, in_depth, in_height, in_width] = input.meta.shape().dims(); let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.meta.shape().dims(); let out_0 = (in_depth - 1) * options.stride[0] + options.dilation[0] * (kernel_0 - 1) + options.padding_out[0] - 2 * options.padding[0] + 1; let out_1 = (in_height - 1) * options.stride[1] + options.dilation[1] * (kernel_1 - 1) + options.padding_out[1] - 2 * options.padding[1] + 1; let out_2 = (in_width - 1) * options.stride[2] + options.dilation[2] * (kernel_2 - 1) + options.padding_out[2] - 2 * options.padding[2] + 1; let shape_out = Shape::new([ batch_size, out_channels * options.groups, out_0, out_1, out_2, ]); let output = empty_device_dtype( input.client.clone(), input.device.clone(), shape_out.clone(), input.dtype, ); let num_elems = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, num_elems); let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim); let dtype = input.dtype; conv_transpose3d_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, weight, bias, output), input.into_tensor_arg(), weight.into_tensor_arg(), bias.map(|bias| bias.into_tensor_arg()).into(), output.clone().into_linear_view(), shape_divmod(&output), ConvArgsLaunch::new( options.stride[0], options.stride[1], options.stride[2], options.dilation[0], options.dilation[1], options.dilation[2], options.padding[0], options.padding[1], options.padding[2], options.groups, ), dtype.into(), ); Ok(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs ================================================ use cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod}; use cubek::convolution::components::ConvSetupError; use burn_backend::{ Shape, ops::{DeformConvOptions, conv::calculate_conv_output_size}, }; use crate::{ CubeRuntime, kernel::{ AddOp, into_contiguous_aligned, launch_binop, matmul::{MatmulStrategy, matmul}, utils::address_type, }, ops::{numeric::zeros_client, reshape, swap_dims}, tensor::CubeTensor, }; #[derive(CubeLaunch, CubeType)] struct DeformConv2dArgs { conv_stride_h: usize, conv_stride_w: usize, dilation_h: usize, dilation_w: usize, padding_h: InputScalar, padding_w: InputScalar, offset_groups: usize, kernel_height: usize, kernel_width: usize, out_h: usize, out_w: usize, } #[cube(launch, address_type = "dynamic")] fn deform_im2col_kernel( input: &Tensor, offset: &Tensor, mask: &ComptimeOption>, columns: &mut Tensor, pos_shape: Sequence>, args: &DeformConv2dArgs, #[comptime] kernel_h_unroll: Option, #[comptime] kernel_w_unroll: Option, #[define(F)] _dtype: StorageType, ) { // position shape: [in_channels, batch_size, out_h, out_w] // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height); let unroll_h = kernel_h_unroll.is_some(); let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width); let unroll_w = kernel_w_unroll.is_some(); let out_h = args.out_h; let out_w = args.out_w; let in_channels = input.shape(1); let height = input.shape(2); let width = input.shape(3); let col_stride_0 = columns.stride(0); let (rem, out_x) = pos_shape[3].div_mod(ABSOLUTE_POS); let (rem, out_y) = pos_shape[2].div_mod(rem); let (in_channel, batch) = pos_shape[1].div_mod(rem); if in_channel >= in_channels { terminate!() } let out_k_base = in_channel * kernel_height * kernel_width; let out_n = batch * out_h * out_w + out_y * out_w + out_x; let channels_per_offset_group = in_channels / args.offset_groups; let group_index = in_channel / channels_per_offset_group; let mut col_base_idx = out_k_base * columns.stride(0) + out_n * columns.stride(1); let input_base_idx = batch * input.stride(0) + in_channel * input.stride(1); let offset_base_idx = batch * offset.stride(0) + group_index * kernel_height * kernel_width * 2 * offset.stride(1); let mask_base_idx = mask.as_ref().map(|mask| { batch * mask.stride(0) + group_index * kernel_height * kernel_width * mask.stride(1) }); #[unroll(unroll_h)] for kernel_y in 0..kernel_height { #[unroll(unroll_w)] for kernel_x in 0..kernel_width { let mask_index = kernel_y * kernel_width + kernel_x; let offset_index = mask_index * 2; let offset_y = offset[offset_base_idx + offset_index * offset.stride(1) + out_y * offset.stride(2) + out_x * offset.stride(3)]; let offset_x = offset[offset_base_idx + (offset_index + 1) * offset.stride(1) + out_y * offset.stride(2) + out_x * offset.stride(3)]; let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h) - args.padding_h.get::() + offset_y; let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w) - args.padding_w.get::() + offset_x; let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx); #[comptime] let value = match mask.zip::(mask_base_idx) { ComptimeOption::Some((mask, base_idx)) => { let mask_value = mask[base_idx + mask_index * mask.stride(1) + out_y * mask.stride(2) + out_x * mask.stride(3)]; mask_value * interpolated } ComptimeOption::None => interpolated, }; columns[col_base_idx] = value; col_base_idx += col_stride_0; } } } #[cube] pub(crate) fn bilinear_interpolate( input: &Tensor, height: usize, width: usize, y: F, x: F, offset: usize, ) -> F { // To simplify code let y = f32::cast_from(y); let x = f32::cast_from(x); let stride_y = input.stride(2); let stride_x = input.stride(3); let mut result = F::new(0.0); if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { let y_low = y.floor(); let x_low = x.floor(); let y_high = (y_low + 1.) as usize; let x_high = (x_low + 1.) as usize; let zero = F::new(0.0); let v1: F = if y_low >= 0. && x_low >= 0. { input[offset + y_low as usize * stride_y + x_low as usize * stride_x] } else { zero }; let v2: F = if y_low >= 0. && x_high < width { input[offset + y_low as usize * stride_y + x_high * stride_x] } else { zero }; let v3: F = if y_high < height && x_low >= 0. { input[offset + y_high * stride_y + x_low as usize * stride_x] } else { zero }; let v4: F = if y_high < height && x_high < width { input[offset + y_high * stride_y + x_high * stride_x] } else { zero }; let l_y = y - y_low; let l_x = x - x_low; let h_y = 1.0 - l_y; let h_x = 1.0 - l_x; let w1 = F::cast_from(h_y * h_x); let w2 = F::cast_from(h_y * l_x); let w3 = F::cast_from(l_y * h_x); let w4 = F::cast_from(l_y * l_x); result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; } result } pub(crate) fn deform_im2col( input: CubeTensor, offset: CubeTensor, mask: Option>, options: DeformConvOptions<2>, out_dims: (usize, usize), kernel_dims: (usize, usize), ) -> Result, LaunchError> { let client = input.client.clone(); let device = input.device.clone(); let dtype = input.dtype; let [batch_size, in_channels, _, _] = input.meta.shape().dims(); let (out_height, out_width) = out_dims; let (kernel_height, kernel_width) = kernel_dims; let shape_out = Shape::new([ in_channels * kernel_height * kernel_width, batch_size * out_height * out_width, ]); let pos_shape = [in_channels, batch_size, out_height, out_width] .into_iter() .collect(); let output = zeros_client(client.clone(), device.clone(), shape_out.clone(), dtype); let num_kernels = in_channels * batch_size * out_height * out_width; let cube_dim = CubeDim::new(&input.client, num_kernels); let cube_count = calculate_cube_count_elemwise(&input.client, num_kernels, cube_dim); deform_im2col_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, offset, mask, output), input.into_tensor_arg(), offset.into_tensor_arg(), mask.map(|mask| mask.into_tensor_arg()).into(), output.clone().binding().into_tensor_arg(), pos_shape, DeformConv2dArgsLaunch::new( options.stride[0], options.stride[1], options.dilation[0], options.dilation[1], { let val = options.padding[0] as f32; InputScalar::new(val, dtype) }, { let val = options.padding[1] as f32; InputScalar::new(val, dtype) }, options.offset_groups, kernel_height, kernel_width, out_height, out_width, ), Some(kernel_height), Some(kernel_width), dtype.into(), ); Ok(output) } pub(crate) fn deform_conv2d( input: CubeTensor, offset: CubeTensor, weight: CubeTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> Result, ConvSetupError> { let input = into_contiguous_aligned(input); let offset = into_contiguous_aligned(offset); let weight = into_contiguous_aligned(weight); let mask = mask.map(|it| into_contiguous_aligned(it)); let bias = bias.map(|it| into_contiguous_aligned(it)); let [batch_size, _, in_height, in_width] = input.meta.shape().dims(); let [out_channels, _, kernel_h, kernel_w] = weight.meta.shape().dims(); let groups = options.weight_groups; let out_h = calculate_conv_output_size( kernel_h, options.stride[0], options.padding[0], options.dilation[0], in_height, ); let out_w = calculate_conv_output_size( kernel_w, options.stride[1], options.padding[1], options.dilation[1], in_width, ); let out_dims = (out_h, out_w); let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w))?; let [col_size_0, col_size_1] = columns.meta.shape().dims(); let col_size_0 = col_size_0 / groups; let out_c_per_group = out_channels / groups; let dtype = weight.dtype; let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let out = matmul(weight, columns, None, MatmulStrategy::default(), dtype)?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); Ok(launch_binop::(out, bias)) } else { Ok(out) } } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs ================================================ use super::{bilinear_interpolate, deform_im2col, index}; use crate::{ CubeRuntime, kernel::{ cast, into_contiguous_aligned, matmul::{MatmulStrategy, matmul}, reduce::reduce_dim, slice_assign, utils::{address_type, decompose_linear}, }, ops::{ numeric::{empty_device_dtype, zeros_client}, reshape, swap_dims, }, tensor::CubeTensor, }; use burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions}; use cubecl::{ CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube, features::TypeUsage, ir::FloatKind, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; use cubek::{ convolution::components::ConvSetupError, reduce::components::instructions::ReduceOperationConfig, }; use std::marker::PhantomData; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow( clippy::single_range_in_vec_init, clippy::type_complexity, clippy::too_many_arguments )] pub(crate) fn deform_conv2d_backward( input: CubeTensor, offset: CubeTensor, weight: CubeTensor, mask: Option>, bias: Option>, out_grad: CubeTensor, options: DeformConvOptions<2>, ) -> Result< ( CubeTensor, CubeTensor, CubeTensor, Option>, Option>, ), ConvSetupError, > { let [_, _, out_h, out_w] = out_grad.meta.shape().dims(); let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims(); let gradient_bias = bias.map(|bias| { let grad = reduce_dim( out_grad.clone(), None, 0, Default::default(), ReduceOperationConfig::Sum, ) .unwrap(); let grad = reduce_dim( grad, None, 2, Default::default(), ReduceOperationConfig::Sum, ) .unwrap(); let grad = reduce_dim( grad, None, 3, Default::default(), ReduceOperationConfig::Sum, ) .unwrap(); reshape(grad, bias.meta.shape.clone()) }); let input = into_contiguous_aligned(input); let offset = into_contiguous_aligned(offset); let weight = into_contiguous_aligned(weight); let mask = mask.map(|it| into_contiguous_aligned(it)); let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs( input.clone(), weight.clone(), offset.clone(), mask.clone(), out_grad.clone(), &options, (kernel_h, kernel_w), )?; let weight_grad = compute_weight_grad( input, offset, mask, out_grad, options, (kernel_h, kernel_w), (out_h, out_w), )?; Ok(( input_gradient, offset_gradient, weight_grad, mask_gradient, gradient_bias, )) } fn compute_weight_grad( input: CubeTensor, offset: CubeTensor, mask: Option>, out_grad: CubeTensor, options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), ) -> Result, ConvSetupError> { let [_, in_channels, _, _] = input.meta.shape().dims(); let [_, out_channels, _, _] = out_grad.meta.shape().dims(); let (kernel_h, kernel_w) = kernel_dims; let groups = options.weight_groups; let dtype = input.dtype; let in_c_per_group = in_channels / groups; let out_c_per_group = out_channels / groups; let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?; let [col_size_0, col_size_1] = columns.meta.shape().dims(); let col_size_0 = col_size_0 / groups; let out_grad = swap_dims(out_grad, 0, 1); let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); let grad_weight = matmul(out_grad, columns, None, MatmulStrategy::default(), dtype)?; Ok(reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), )) } type InputGradients = (CubeTensor, CubeTensor, Option>); fn backward_gradient_inputs( image: CubeTensor, weight: CubeTensor, offset: CubeTensor, mask: Option>, out_grad: CubeTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), ) -> Result, ConvSetupError> { let client = out_grad.client.clone(); let device = out_grad.device.clone(); let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims(); let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims(); let groups = options.weight_groups; let out_c_per_group = out_channels / groups; let col_shape_0 = in_c_per_group * kernel_h * kernel_w; let col_shape_1 = batch_size * out_h * out_w; let col_shape = Shape::new([groups, col_shape_0, col_shape_1]); let mut columns = empty_device_dtype(client, device, col_shape, weight.dtype); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); let out_grad = swap_dims(out_grad, 0, 1); let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]); let out_grad = reshape(out_grad, out_grad_shape); for group in 0..groups { let dtype = weight.dtype; let weight = swap_dims(index(weight.clone(), group), 0, 1); let out_grad = index(out_grad.clone(), group); let values = matmul(weight, out_grad, None, MatmulStrategy::default(), dtype)?; let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign( columns, &[ burn_backend::Slice::from(group..group + 1), burn_backend::Slice::from(0..col_shape_0), burn_backend::Slice::from(0..col_shape_1), ], values, ); } let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); let input_shape = image.shape(); let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( columns.clone(), image, offset.clone(), mask.clone(), options, kernel_dims, )?; let input_gradient = compute_input_grad(columns, offset, mask, options, kernel_dims, input_shape)?; Ok((input_gradient, offset_gradient, mask_gradient)) } fn compute_offset_and_mask_gradient( columns: CubeTensor, image: CubeTensor, offset: CubeTensor, mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), ) -> Result<(CubeTensor, Option>), ConvSetupError> { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_h, kernel_w) = kernel_dims; let [batches, _, out_h, out_w] = offset.meta.shape().dims(); let offset_groups = options.offset_groups; let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w]; let pos_shape = pos_shape.into_iter().collect(); let grad_offset = empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype); let grad_mask = mask .as_ref() .map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype)); let num_elements_offset = offset.meta.num_elements(); let cube_dim = CubeDim::new(&image.client, num_elements_offset); let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim); let dtype: StorageType = image.dtype.into(); unsafe { deform_col2img_coord_kernel::launch_unchecked( &grad_offset.client, cube_count, cube_dim, address_type!(image, offset, mask, grad_offset, grad_mask), image.into_tensor_arg(), offset.into_tensor_arg(), mask.map(|mask| mask.into_tensor_arg()).into(), columns.into_tensor_arg(), grad_offset.clone().into_linear_view(), grad_mask .clone() .map(|grad_mask| grad_mask.into_tensor_arg()) .into(), pos_shape, DeformConv2dCol2ImgCoordArgsLaunch::new( options.stride[0], options.stride[1], options.dilation[0], options.dilation[1], InputScalar::new(options.padding[0] as f32, dtype.elem_type()), InputScalar::new(options.padding[1] as f32, dtype.elem_type()), offset_groups, kernel_h, kernel_w, ), dtype, ) }; Ok((grad_offset, grad_mask)) } #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgCoordArgs { stride_h: usize, stride_w: usize, dilation_h: usize, dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, offset_groups: usize, kernel_height: usize, kernel_width: usize, } #[allow(clippy::collapsible_if)] #[cube(launch_unchecked, address_type = "dynamic")] fn deform_col2img_coord_kernel( image: &Tensor, offset: &Tensor, mask: &ComptimeOption>, columns: &Tensor, grad_offset: &mut LinearView, grad_mask: &mut ComptimeOption>, pos_shape: Sequence>, args: &DeformConv2dCol2ImgCoordArgs, #[define(F)] _dtype: StorageType, ) { // Position format: [batch, [offset_groups, kernel_h, kernel_w, 2], out_h, out_w] // Columns format: [[in_channel, kernel_h, kernel_w], [batch, out_h, out_w]] // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.shape() { terminate!(); } let out_h = offset.shape(2); let out_w = offset.shape(3); let in_channels = image.shape(1); let height = image.shape(2); let width = image.shape(3); let kernel_w = args.kernel_width; let kernel_h = args.kernel_height; let mut grad_offset_val = F::new(0.0); let mut grad_mask_val = F::new(0.0); let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape); let [batch, offset_group, kernel_y, kernel_x, dir, out_y, out_x] = *pos else { unreachable!() }; let channels_per_offset_group = in_channels / args.offset_groups; let col_n = batch * out_h * out_w + out_y * out_w + out_x; let col_base_idx = offset_group * channels_per_offset_group * kernel_h * kernel_w * columns.stride(0) + col_n * columns.stride(1); let mut image_base_idx = batch * image.stride(0) + offset_group * channels_per_offset_group * image.stride(1); let offset_pos_1 = offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2; let offset_base_idx = batch * offset.stride(0) + offset_pos_1 * offset.stride(1) + out_y * offset.stride(2) + out_x * offset.stride(3); let offset_y_idx = offset_base_idx; let offset_x_idx = offset_base_idx + offset.stride(1); let offset_y = offset[offset_y_idx]; let offset_x = offset[offset_x_idx]; let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x; #[comptime] let mask_value = match &mask { ComptimeOption::Some(mask) => { let mask_idx = batch * mask.stride(0) + mask_pos_1 * mask.stride(1) + out_y * mask.stride(2) + out_x * mask.stride(3); mask[mask_idx] } ComptimeOption::None => F::new(1.0), }; let is_y_direction = dir == 0; for col_c in 0..channels_per_offset_group { let col_pos = col_base_idx + col_c * kernel_h * kernel_w * columns.stride(0); let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h.get::() + offset_y; let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w.get::() + offset_x; let weight = get_coordinate_weight(image, image_base_idx, height, width, y, x, is_y_direction); let columns_value = columns[col_pos]; grad_offset_val += mask_value * weight * columns_value; if grad_mask.is_some() && is_y_direction { grad_mask_val += columns_value * bilinear_interpolate(image, height, width, y, x, image_base_idx); } image_base_idx += image.stride(1); } grad_offset[ABSOLUTE_POS] = grad_offset_val; #[comptime] if let ComptimeOption::Some(grad_mask) = grad_mask { if is_y_direction { let idx = batch * grad_mask.stride(0) + mask_pos_1 * grad_mask.stride(1) + out_y * grad_mask.stride(2) + out_x * grad_mask.stride(3); grad_mask[idx] = grad_mask_val } } } #[cube] fn get_coordinate_weight( input: &Tensor, offset: usize, height: usize, width: usize, y: F, x: F, is_y_direction: bool, ) -> F { let stride_y = input.stride(2); let stride_x = input.stride(3); let y = f32::cast_from(y); let x = f32::cast_from(x); let y_low = f32::floor(y); let x_low = f32::floor(x); let y_high = y_low + 1.; let x_high = x_low + 1.; let valid_y_low = y_low >= 0. && y_low < height as f32; let valid_y_high = y_high >= 0. && y_high < height as f32; let valid_x_low = x_low >= 0. && x_low < width as f32; let valid_x_high = x_high >= 0. && x_high < width as f32; let bottom_left = if valid_y_low && valid_x_low { input[offset + y_low as usize * stride_y + x_low as usize * stride_x] } else { F::new(0.0) }; let bottom_right = if valid_y_low && valid_x_high { input[offset + y_low as usize * stride_y + x_high as usize * stride_x] } else { F::new(0.0) }; let top_left = if valid_y_high && valid_x_low { input[offset + y_high as usize * stride_y + x_low as usize * stride_x] } else { F::new(0.0) }; let top_right = if valid_y_high && valid_x_high { input[offset + y_high as usize * stride_y + x_high as usize * stride_x] } else { F::new(0.0) }; if is_y_direction { let delta_x = F::cast_from(x - x_low); delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left) } else { let delta_y = F::cast_from(y - y_low); delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left) } } fn compute_input_grad( columns: CubeTensor, offset: CubeTensor, mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), input_shape: Shape, ) -> Result, LaunchError> { let client = offset.client.clone(); let device = offset.device.clone(); let supports_fadd = client .properties() .type_usage(StorageType::Atomic(FloatKind::F32.into())) .contains(TypeUsage::AtomicAdd); let supports_same_type = client .properties() .type_usage(StorageType::Atomic(columns.dtype.into())) .contains(TypeUsage::AtomicAdd); let [batches, in_channels, height, width] = input_shape.dims(); let [_, _, out_h, out_w] = offset.meta.shape().dims(); let (kernel_h, kernel_w) = kernel_dims; let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w]; let pos_shape = pos_shape.into_iter().collect(); let shape = Shape::new([batches, in_channels, height, width]); let grad_in = match supports_fadd && supports_same_type { // Use type as is to save a cast true => zeros_client(client.clone(), device.clone(), shape, columns.dtype), // Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported false => zeros_client(client.clone(), device.clone(), shape, DType::F32), }; let grad_arg = grad_in.clone().into_tensor_arg(); let num_elements = columns.meta.num_elements(); let cube_dim = CubeDim::new(&offset.client, num_elements); let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim); let launch = match supports_fadd { true => deform_col2img_kernel::launch_unchecked::, false => deform_col2img_kernel::launch_unchecked::, }; let dtype = offset.dtype; let dtypes: [StorageType; 2] = match supports_same_type { true => [dtype.into(), dtype.into()], false => [dtype.into(), DType::F32.into()], }; unsafe { launch( &grad_in.client, cube_count, cube_dim, address_type!(offset, mask, columns, grad_in), offset.into_tensor_arg(), mask.map(|mask| mask.into_tensor_arg()).into(), columns.into_linear_view(), grad_arg, pos_shape, DeformConv2dCol2ImgArgsLaunch::new( options.stride[0], options.stride[1], options.dilation[0], options.dilation[1], InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()), InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()), options.offset_groups, kernel_h, kernel_w, ), dtypes, ) }; Ok(if !supports_same_type || !supports_fadd { cast(grad_in, dtype) } else { grad_in }) } #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgArgs { stride_h: usize, stride_w: usize, dilation_h: usize, dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, offset_groups: usize, kernel_height: usize, kernel_width: usize, } #[cube(launch_unchecked, address_type = "dynamic")] fn deform_col2img_kernel( offset: &Tensor, mask: &ComptimeOption>, columns: &LinearView, grad_input: &mut Tensor>>, pos_shape: Sequence>, args: &DeformConv2dCol2ImgArgs, #[define(F, FP)] _dtype: [StorageType; 2], ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] if ABSOLUTE_POS >= columns.shape() { terminate!(); } let n_in_channels = grad_input.shape(1); let height = grad_input.shape(2); let width = grad_input.shape(3); let kernel_h = args.kernel_height; let kernel_w = args.kernel_width; let n_offset_groups = args.offset_groups; let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape); let [in_channel, kernel_y, kernel_x, batch, out_y, out_x] = *pos else { unreachable!() }; let channels_per_offset_group = n_in_channels / n_offset_groups; let offset_group = in_channel / channels_per_offset_group; let offset_pos_1 = offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2; let offset_base_idx = batch * offset.stride(0) + offset_pos_1 * offset.stride(1) + out_y * offset.stride(2) + out_x * offset.stride(3); let offset_y_idx = offset_base_idx; let offset_x_idx = offset_base_idx + offset.stride(1); let offset_y = offset[offset_y_idx]; let offset_x = offset[offset_x_idx]; #[comptime] let mask_value = match mask { ComptimeOption::Some(mask) => { let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x; mask[batch * mask.stride(0) + mask_pos_1 * mask.stride(1) + out_y * mask.stride(2) + out_x * mask.stride(3)] } ComptimeOption::None => F::new(1.0), }; let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h.get::() + offset_y; let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w.get::() + offset_x; for dy in -1..=1i32 { #[unroll] for dx in -1..=1i32 { let yp = y.floor() + F::cast_from(dy); let xp = x.floor() + F::cast_from(dx); if yp >= F::new(0.0) && yp < F::cast_from(height) && xp >= F::new(0.0) && xp < F::cast_from(width) && F::abs(y - yp) < F::new(1.0) && F::abs(x - xp) < F::new(1.0) { let gradient_pos = batch * grad_input.stride(0) + in_channel * grad_input.stride(1) + usize::cast_from(yp) * grad_input.stride(2) + usize::cast_from(xp) * grad_input.stride(3); let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp)); let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS]; FAdd::Op::::float_atomic_add::(&mut grad_input[gradient_pos], value); } } } } type ProxyType = <::Op as FloatAtomicAdd>::ProxyType; #[cube] trait FloatAtomicAddFamily: Send + Sync + 'static { type Op: FloatAtomicAdd; } #[cube] trait FloatAtomicAdd: Send + Sync + 'static { type ProxyType: Numeric; fn float_atomic_add(ptr: &mut Atomic, value: F); } #[derive(CubeType)] struct IntrinsicFloatAtomicAdd { #[cube(comptime)] _ty: PhantomData, } #[derive(CubeType)] struct CASFloatAtomicAdd; struct IntrinsicFloatAtomicAddFamily; impl FloatAtomicAddFamily for IntrinsicFloatAtomicAddFamily { type Op = IntrinsicFloatAtomicAdd; } impl FloatAtomicAddFamily for CASFloatAtomicAdd { type Op = Self; } #[cube] impl FloatAtomicAdd for IntrinsicFloatAtomicAdd { type ProxyType = FAdd; fn float_atomic_add(ptr: &mut Atomic, value: F) { let value = FAdd::cast_from(value); ptr.fetch_add(value); } } #[cube] impl FloatAtomicAdd for CASFloatAtomicAdd { type ProxyType = u32; fn float_atomic_add(ptr: &mut Atomic, value: F) { let value = f32::cast_from(value); if value != 0.0 { let mut v = ptr.load(); loop { let prev = v; let v_float = f32::from_bits(v); let new = (v_float + value).to_bits(); v = ptr.compare_exchange_weak(v, new); if prev == v { break; } } } } } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/direct.rs ================================================ use crate::{ CubeRuntime, kernel::{into_contiguous_aligned, utils::address_type}, ops::max_vector_size, tensor::CubeTensor, }; use crate::{kernel::utils::decompose_linear, ops::numeric::empty_device_dtype}; use burn_backend::{ TensorMetadata, ops::{ConvOptions, conv::calculate_conv_output_sizes}, }; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView, tensor_vector_size_parallel, }; use cubecl::{num_traits::Zero, std::FastDivmod}; use cubek::convolution::components::ConvSetupError; #[derive(CubeLaunch, CubeType, Clone)] pub(crate) struct ConvParam { pub stride: u32, pub dilation: u32, pub padding: i32, } #[derive(CubeLaunch, CubeType)] struct Conv2dArgs { conv_params: Sequence, channels_per_group: u32, } #[cube(launch_unchecked, address_type = "dynamic")] #[allow(clippy::redundant_closure)] fn direct_conv2d_kernel( input: &Tensor>, weight: &Tensor>, bias: ComptimeOption>>, output: &mut LinearView, ReadWrite>, args: Conv2dArgs, shape_out: Sequence>, shape_out_c: FastDivmod, #[comptime] has_padding: bool, #[define(E)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let n_spatial = comptime![shape_out.len()]; let vector_size_out = output.vector_size(); let pos = ABSOLUTE_POS * vector_size_out; let in_c_per_group = weight.shape(weight.rank() - 1) as u32; let (rem, out_c) = shape_out_c.div_mod(pos as u32); let (b, spatial_pos) = decompose_linear(rem, &shape_out); let g = out_c / args.channels_per_group; let ic_start = in_c_per_group * g; let bias: ComptimeOption> = bias.map(|bias| bias[out_c as usize / vector_size_out]); let mut sum = bias.unwrap_or_else(|| Vector::zero()); let in_offs = b as usize * input.stride(0) + ic_start as usize; let stride_oc = weight.stride(0); let mut in_shape = Sequence::new(); let mut in_strides = Sequence::new(); let mut kernel_shape = Sequence::new(); let mut kernel_strides = Sequence::new(); #[unroll] for i in 0..n_spatial { in_shape.push(input.shape(i + 1) as u32); in_strides.push(input.stride(i + 1)); kernel_shape.push(weight.shape(i + 1) as u32); kernel_strides.push(weight.stride(i + 1)); } let weight_offs = out_c as usize * stride_oc; let loop_params = LoopParams { out_pos: spatial_pos, in_shape, in_strides, kernel_shape, kernel_strides, conv_params: args.conv_params, in_c_per_group, stride_oc, }; kernel_loop( input, weight, &mut sum, in_offs, true, weight_offs, &loop_params, 0usize, has_padding, ); output[ABSOLUTE_POS] = sum; } #[derive(CubeType, Clone)] struct LoopParams { out_pos: Sequence, in_shape: Sequence, in_strides: Sequence, kernel_shape: Sequence, kernel_strides: Sequence, conv_params: Sequence, in_c_per_group: u32, stride_oc: usize, } #[cube] fn kernel_loop( input: &Tensor>, weight: &Tensor>, sum: &mut Vector, in_offs: usize, in_bounds: bool, weight_offs: usize, params: &LoopParams, #[comptime] kernel_dim: usize, #[comptime] has_padding: bool, ) { if comptime![kernel_dim < params.kernel_shape.len()] { let out_idx = *params.out_pos.index(kernel_dim); let conv = params.conv_params.index(kernel_dim); let shape = *params.in_shape.index(kernel_dim); let stride = *params.in_strides.index(kernel_dim); let k_stride = *params.kernel_strides.index(kernel_dim); for pos in 0..*params.kernel_shape.index(kernel_dim) { let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding; let in_offs = in_offs + in_pos as usize * stride; let weight_offs = weight_offs + pos as usize * k_stride; let mut in_bounds = in_bounds; if has_padding { in_bounds &= in_pos >= 0 && (in_pos as u32) < shape; } kernel_loop( input, weight, sum, in_offs, in_bounds, weight_offs, params, comptime![kernel_dim + 1], has_padding, ); } } else { kernel_loop_inner( input, weight, sum, in_offs, in_bounds, weight_offs, params.in_c_per_group, params.stride_oc, ); } } #[cube] fn kernel_loop_inner( input: &Tensor>, weight: &Tensor>, sum: &mut Vector, in_offs: usize, in_bounds: bool, weight_offs: usize, in_c_per_group: u32, stride_oc: usize, ) { let vector_size_in = input.vector_size(); let vector_size_out = sum.size(); if in_bounds { for in_c in range_stepped(0, in_c_per_group, vector_size_in as u32) { let in_pos = in_offs + in_c as usize; let mut weight_pos = weight_offs + in_c as usize; let val = input[in_pos / vector_size_in]; #[unroll] for v in 0..vector_size_out { let weight = weight[weight_pos / vector_size_in]; let val = val * weight; #[unroll] for i in 0..vector_size_in { sum[v] += val[i]; } weight_pos += stride_oc; } } } } /// Perform a 2D convolution using the direct convolution algorithm. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// pub fn conv_direct( mut input: CubeTensor, mut weight: CubeTensor, bias: Option>, options: ConvOptions, ) -> Result, ConvSetupError> { let out_dtype = input.dtype; let rank = input.meta.shape().num_dims(); let dim_c = rank - 1; // We only care about the channels here, everything else can be permuted if input.meta.strides()[dim_c] != 1 { input = into_contiguous_aligned(input); } if weight.meta.strides()[dim_c] != 1 { weight = into_contiguous_aligned(weight); } let batch_size = input.meta.shape()[0]; let in_shape = &input.meta.shape()[1..dim_c]; let out_channels = weight.meta.shape()[0]; let kernel_shape = &weight.meta.shape()[1..dim_c]; let channels_per_group = out_channels / options.groups; let out_size = calculate_conv_output_sizes( kernel_shape, &options.stride, &options.padding, &options.dilation, in_shape, ); let mut shape_out = vec![batch_size]; shape_out.extend(out_size.iter().copied()); shape_out.push(out_channels); let output = empty_device_dtype( input.client.clone(), input.device.clone(), shape_out.into(), out_dtype, ); // Need custom vector size calculation here to account for the groups division. Need to vectorize // over `channels_per_group` instead. let mut grouped_out_shape = output.shape(); grouped_out_shape[dim_c] = channels_per_group; let vector_size_out = tensor_vector_size_parallel( input.client.io_optimized_vector_sizes(input.dtype.size()), &grouped_out_shape, output.meta.strides(), dim_c, ); // Use channels_per_group instead of in_channels to avoid issues here let vector_size_in = max_vector_size(&weight); let shape_out = output.meta.shape()[1..dim_c] .iter() .map(|s| *s as u32) .collect(); let shape_out_c = out_channels as u32; let mut conv_params = SequenceArg::new(); for i in 0..kernel_shape.len() { conv_params.push(ConvParamLaunch::new( options.stride[i] as u32, options.dilation[i] as u32, options.padding[i] as i32, )); } let working_units = output.meta.num_elements() / vector_size_out; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); unsafe { direct_conv2d_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(input, weight, bias, output), vector_size_in, vector_size_out, input.into_tensor_arg(), weight.into_tensor_arg(), bias.map(|b| b.into_tensor_arg()).into(), output.clone().into_linear_view(), Conv2dArgsLaunch::new(conv_params, channels_per_group as u32), shape_out, shape_out_c, options.padding.iter().any(|it| *it != 0), out_dtype.into(), ) }; Ok(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs ================================================ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes}; use cubek::{ convolution::{ AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, components::ConvSetupError, forward, }, matmul::{ definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputBinding, }, }; /// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul /// components. Uses [`CmmaLargeMAlgorithm`] for the stage size /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution pub fn conv_gemm_simple_sync( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic, AcceleratedTileKind::Mma => ReadingStrategy::Strided, }; launch_convolution_forward::( &Strategy::Simple { read_strategy, tile_kind, }, input, weight, bias, options, ) } pub fn conv_gemm_simple_async( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { let read_strategy = match tile_kind { AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic, AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided, }; launch_convolution_forward::( &Strategy::Simple { read_strategy, tile_kind, }, input, weight, bias, options, ) } /// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul /// components. Uses [`CmmaLargeMAlgorithm`] for the stage size /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution pub fn conv_gemm_simple_tma( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, tile_kind: AcceleratedTileKind, ) -> Result, ConvSetupError> { launch_convolution_forward::( &Strategy::Simple { read_strategy: ReadingStrategy::Tma, tile_kind, }, input, weight, bias, options, ) } /// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul /// components, using the specified algorithm. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution pub fn launch_convolution_forward( strategy: &Strategy, input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, ) -> Result, ConvSetupError> { if options.groups != 1 { return Err(ConvSetupError::Groups(options.groups)); } let out_dtype = input.dtype; let rank = input.meta.shape().num_dims(); let batch_size = input.meta.shape()[0]; let dim_c = rank - 1; let shape = &input.meta.shape()[1..dim_c]; let out_channels = weight.meta.shape()[0]; let weight_shape = &weight.meta.shape()[1..dim_c]; let mut out_shape = calculate_conv_output_sizes( weight_shape, &options.stride, &options.padding, &options.dilation, shape, ); out_shape.insert(0, batch_size); out_shape.push(out_channels); let out = empty_device_dtype( input.client.clone(), input.device.clone(), out_shape.into(), out_dtype, ); let bias = bias.map(|bias| { let dtype = bias.dtype; MatmulInputBinding::Normal(bias.binding(), dtype.into()) }); let client = input.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { lhs: input.dtype.into(), rhs: weight.dtype.into(), out: out_dtype.into(), }); let input_dtype = input.dtype; let weight_dtype = weight.dtype; let input = MatmulInputBinding::new(input.binding(), input_dtype.into()); let weight = MatmulInputBinding::new(weight.binding(), weight_dtype.into()); forward::launch_ref::( strategy, &client, input, weight, bias, out.clone().binding(), ConvolutionArgs { stride: options.stride, padding: options.padding, dilation: options.dilation, }, dtypes, )?; Ok(out) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/mod.rs ================================================ pub mod launch; pub use launch::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/forward/mod.rs ================================================ pub mod implicit_gemm; #[cfg(feature = "autotune")] pub mod tune; #[cfg(feature = "autotune")] pub(crate) use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/forward/tune.rs ================================================ use burn_backend::ops::ConvOptions; use cubecl::{ ir::StorageType, tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner}, }; use cubek::convolution::AcceleratedTileKind; use crate::{ CubeAutotuneKey, CubeRuntime, CubeTuneId, kernel::conv::{ConvAutotuneKey, conv_direct, conv_im2col_1x1, forward::implicit_gemm::*}, tensor::CubeTensor, }; /// Executes autotune on convolution operations pub fn conv_autotune( input: CubeTensor, weight: CubeTensor, bias: Option>, options: ConvOptions, ) -> CubeTensor { let client = input.client.clone(); static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { TunableSet::new(create_key::, create_conv_input::) .with(Tunable::new("conv_direct", conv_direct::)) .with(Tunable::new("conv_im2col_1x1", conv_im2col_1x1::)) .with(Tunable::new( "simple_sync_cmma", |input, weight, bias, options| { conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_sync_mma", |input, weight, bias, options| { conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_async_cmma", |input, weight, bias, options| { conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_async_mma", |input, weight, bias, options| { conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Mma) }, )) .with(Tunable::new( "simple_tma_cmma", |input, weight, bias, options| { conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Cmma) }, )) .with(Tunable::new( "simple_tma_mma", |input, weight, bias, options| { conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Mma) }, )) }); TUNER.execute( &CubeTuneId::new(&input.client, &input.device), &client, tunables, (input, weight, bias, options), ) } pub fn create_conv_input( _key: &CubeAutotuneKey, input: &CubeTensor, weights: &CubeTensor, bias: &Option>, options: &ConvOptions, ) -> ( CubeTensor, CubeTensor, Option>, ConvOptions, ) { ( input.clone(), weights.clone(), bias.clone(), options.clone(), ) } fn create_key( input: &CubeTensor, weights: &CubeTensor, bias: &Option>, options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = input.dtype; let rank = input.meta.shape().num_dims(); let dim_c = rank - 1; let batch_size = input.meta.shape()[0]; let in_channels = input.meta.shape()[dim_c]; let out_channels = weights.meta.shape()[0]; let kernel_size = weights.meta.shape()[1..dim_c].to_vec(); let in_shape = input.meta.shape()[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) .collect(); let ConvOptions { stride, padding, dilation, groups, } = options.clone(); let lhs_stride_align = if input.meta.strides()[dim_c] == 1 { stride_align(input.meta.strides(), input.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(in_channels).min(lhs_stride_align); let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 { stride_align(weights.meta.strides(), weights.dtype.into()) } else { 0 }; let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align); CubeAutotuneKey::Conv(ConvAutotuneKey::new( kernel_size, stride.to_vec(), padding.to_vec(), dilation.to_vec(), groups, in_channels, out_channels, in_shape, batch_size, bias.is_some(), dtype, lhs_shape_align, lhs_stride_align, rhs_shape_align, rhs_stride_align, )) } /// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's /// repeat number, so it's the largest align that can have performance impacts. const MAX_STRIDE_FACTOR: u32 = 10; /// Defines the non-contiguous stride alignment in terms of powers of two fn stride_align(strides: &[usize], elem: StorageType) -> u8 { let max = MAX_STRIDE_FACTOR; let dim_c = strides.len() - 1; let factor = strides[..dim_c] .iter() .map(|it| (*it * elem.size_bits()) / 8) .map(|it| it.trailing_zeros()) .min() .unwrap_or(max); factor.min(max) as u8 } /// Defines the potential vectorization. fn pow2_factor(axis: usize) -> u8 { axis.trailing_zeros().min(4) as u8 } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/im2col.rs ================================================ use burn_backend::{ DType, ops::{ConvOptions, conv::calculate_conv_output_sizes}, }; use burn_std::{Metadata, Shape}; use core::iter; use cubecl::{ prelude::*, std::tensor::{TensorHandle, into_contiguous_pitched}, }; use cubek::convolution::components::ConvSetupError; use crate::{ CubeRuntime, kernel::{ AddOp, into_contiguous_aligned, launch_binop, matmul::{MatmulStrategy, matmul}, utils::split_dim, }, ops::{reshape, swap_dims}, tensor::CubeTensor, }; #[cfg(not(test))] pub(crate) fn batches_per_run( batch_size: usize, out_shape: usize, plane_size: usize, ) -> Result { use cubek::matmul::definition::MatmulAvailabilityError; let cube_count_per_batch = out_shape.div_ceil(plane_size); let max_cube_count = u16::MAX as usize; let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size); if max_simultaneous == 0 { return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( cube_count_per_batch as u32, 1, 1, )) .into()); } Ok((0..=max_simultaneous) .rev() .find(|per_run| batch_size.is_multiple_of(*per_run)) .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] pub(crate) fn batches_per_run( batch_size: usize, out_shape: usize, plane_size: usize, ) -> Result { Ok(1) } pub fn conv_im2col_1x1( input: CubeTensor, mut weight: CubeTensor, bias: Option>, options: ConvOptions, ) -> Result, ConvSetupError> { if options.groups != 1 { return Err(ConvSetupError::Groups(options.groups)); } let rank = input.meta.num_dims(); let dim_c = rank - 1; let batch_size = input.meta.shape()[0]; let in_channels = input.meta.shape()[dim_c]; let in_shape = &input.meta.shape()[1..dim_c]; let out_channels = weight.meta.shape()[0]; let kernel_shape = &weight.meta.shape()[1..dim_c]; if kernel_shape.iter().any(|s| *s != 1) { return Err(ConvSetupError::Unknown); } let out_shape = calculate_conv_output_sizes( kernel_shape, &options.stride, &options.padding, &options.dilation, in_shape, ); let mut split_m = vec![batch_size]; split_m.extend(out_shape.iter().copied()); if kernel_shape.iter().any(|it| *it != 1) || in_shape != out_shape { return Err(ConvSetupError::Unknown); } let input = reshape_input(input); // [(NHW), C] : [M, K] let dtype = input.dtype; // Efficient permutation that takes the stride required for TMA into account let weight = if weight.meta.strides()[dim_c] != 1 { // Remove kernel dims so padded dim is channels *weight.meta = Metadata::new( [out_channels, in_channels], // [N, K] [weight.meta.strides()[0], weight.meta.strides()[dim_c]], ); // Pitched contiguous to skip running another kernel for TMA into_contiguous_aligned(weight) } else { // Already compatible, skip initial reshape *weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]); weight }; // Permute to N-major, while keeping memory layout K-major. K-major for both sides is the most // efficient for matmul, and allows skipping a contiguous kernel let weight = swap_dims(weight, 0, 1); // [K, N] let out = matmul(input, weight, None, MatmulStrategy::default(), dtype)?; // [M, N] // Skip reshape to avoid potential `into_contiguous`. We're only splitting dims so it's safe. let mut out = split_dim(out, 0, &split_m); // [N, H, W, C] if let Some(bias) = bias { let mut bias_shape = iter::repeat_n(1, rank - 1).collect::>(); bias_shape.push(out_channels); let bias = reshape(bias, bias_shape.into()); out = launch_binop::(out, bias); } Ok(out) } /// Reshapes NHWC input to [(N, H, W), C] fn reshape_input(input: CubeTensor) -> CubeTensor { let rank = input.meta.num_dims(); let dim_c = rank - 1; let dtype = input.dtype; let batch_size = input.meta.shape()[0]; let in_c: usize = input.meta.shape()[dim_c]; let in_shape: Shape = input.meta.shape()[1..dim_c].into(); let mut input = if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) { let (client, device) = (input.client.clone(), input.device.clone()); let contiguous = into_contiguous_pitched(&client, input.binding(), dtype.into()); from_handle(client, device, contiguous, dtype) } else { input }; *input.meta = Metadata::new( [batch_size * in_shape.num_elements(), in_c], // [M, K] [input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]], ); input } fn is_spatial_contiguous(shape: &[usize], strides: &[usize]) -> bool { let rank = shape.len(); let dim_c = rank - 1; // Channel must be contiguous for the [(N, H, W), C] reshape to be valid if strides[dim_c] != 1 { return false; } for i in (1..dim_c).rev() { if strides[i + 1] * shape[i + 1] != strides[i] { return false; } } true } fn from_handle( client: ComputeClient, device: R::Device, handle: TensorHandle, dtype: DType, ) -> CubeTensor { CubeTensor::new( client.clone(), handle.handle, *handle.metadata, device.clone(), dtype, ) } ================================================ FILE: crates/burn-cubecl/src/kernel/conv/mod.rs ================================================ mod backward_data; mod backward_weight; mod base; mod conv_transpose2d; mod conv_transpose3d; mod deform_conv2d; mod deform_conv_transpose2d; mod direct; mod forward; mod im2col; mod tune_key; pub(crate) use backward_data::*; pub(crate) use conv_transpose2d::*; pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv_transpose2d::*; pub(crate) use deform_conv2d::*; pub(crate) use direct::*; pub(crate) use im2col::*; pub use base::*; pub use conv_transpose2d::{ConvTranspose2dStrategy, conv_transpose2d}; pub(crate) use tune_key::*; ================================================ FILE: crates/burn-cubecl/src/kernel/conv/tune_key.rs ================================================ use burn_backend::DType; use cubecl::AutotuneKey; use serde::{Deserialize, Serialize}; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] /// Autotune key representative of matmul versions pub struct ConvAutotuneKey { pub kernel_size: Vec, pub stride: Vec, pub padding: Vec, pub dilation: Vec, pub groups: usize, #[autotune(anchor)] pub in_channels: usize, #[autotune(anchor)] pub out_channels: usize, pub shape: Vec, #[autotune(anchor)] pub batch_size: usize, pub has_bias: bool, pub dtype: DType, pub lhs_shape_align: u8, pub lhs_stride_align: u8, pub rhs_shape_align: u8, pub rhs_stride_align: u8, } #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] /// Autotune key representative of matmul versions pub struct ConvTranspose2dAutotuneKey { pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub padding_out: [usize; 2], pub dilation: [usize; 2], pub groups: usize, #[autotune(anchor)] pub in_channels: usize, #[autotune(anchor)] pub out_channels: usize, #[autotune(anchor)] pub height: usize, #[autotune(anchor)] pub width: usize, #[autotune(anchor)] pub batch_size: usize, pub has_bias: bool, pub dtype: DType, } ================================================ FILE: crates/burn-cubecl/src/kernel/cross.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use cubecl::std::tensor::layout::linear::LinearView; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked, address_type = "dynamic")] fn cross_kernel( lhs: &LinearView, rhs: &LinearView, output: &mut LinearView, #[define(E)] _dtype: StorageType, ) { // Each thread processes one 3-element vector let vector_idx = ABSOLUTE_POS; let base_pos = vector_idx * 3; if !output.is_in_bounds(base_pos) { terminate!(); } // Extract vectors let a0 = lhs[base_pos]; let a1 = lhs[base_pos + 1]; let a2 = lhs[base_pos + 2]; let b0 = rhs[base_pos]; let b1 = rhs[base_pos + 1]; let b2 = rhs[base_pos + 2]; // Compute cross product: a × b let x = a1 * b2 - a2 * b1; let y = a2 * b0 - a0 * b2; let z = a0 * b1 - a1 * b0; // Store result output[base_pos] = x; output[base_pos + 1] = y; output[base_pos + 2] = z; } pub(crate) fn cross( lhs: CubeTensor, rhs: CubeTensor, dim: usize, ) -> CubeTensor { let ndims = lhs.meta.num_dims(); // Validate that the cross dimension has size 3 if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 { panic!( "Cross product requires dimension {} to have size 3, but got {} and {}", dim, lhs.meta.shape()[dim], rhs.meta.shape()[dim] ); } // For now, only support cross on the last dimension if dim != ndims - 1 { unimplemented!( "Cross product on non-last dimension not yet implemented for CubeCL backend" ); } let output_shape = broadcast_shape(&[&lhs, &rhs]); let output = empty_device_dtype( lhs.client.clone(), lhs.device.clone(), output_shape.clone(), lhs.dtype, ); // Number of vectors to process let num_vectors = output_shape.num_elements() / 3; let cube_dim = CubeDim::new(&lhs.client, num_vectors); let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim); let dtype = lhs.dtype; unsafe { cross_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(lhs, rhs, output), lhs.into_linear_view_like(&output), rhs.into_linear_view_like(&output), output.clone().into_linear_view(), dtype.into(), ); }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/grid_sample/base.rs ================================================ use cubecl::prelude::*; use crate::{CubeRuntime, tensor::CubeTensor}; use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; use super::bilinear::grid_sample_bilinear_launch; /// Grid sample operation supporting bilinear interpolation pub fn grid_sample( input: CubeTensor, grid: CubeTensor, options: GridSampleOptions, ) -> CubeTensor { match options.mode { InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options), _ => panic!( "Unsupported grid_sample interpolation mode: {:?}", options.mode ), } } /// Compile-time padding mode for kernel specialization #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum PaddingMode { /// Fill with zeros for out-of-bounds coordinates. Zeros, /// Clamp coordinates to the border (use nearest edge value). Border, /// Reflect coordinates at the boundary. Reflection, } impl From for PaddingMode { fn from(mode: GridSamplePaddingMode) -> Self { match mode { GridSamplePaddingMode::Zeros => PaddingMode::Zeros, GridSamplePaddingMode::Border => PaddingMode::Border, GridSamplePaddingMode::Reflection => PaddingMode::Reflection, } } } /// Fetch value based on padding mode (dispatch to appropriate handler) #[cube] pub(crate) fn fetch_value( input: &Tensor, base: usize, stride_h: usize, stride_w: usize, y: i32, x: i32, h: i32, w: i32, #[comptime] padding_mode: PaddingMode, ) -> F { match padding_mode { PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w), PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w), PaddingMode::Reflection => { fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w) } } } /// Fetch value with zeros padding (return 0 for out-of-bounds). #[cube] pub(crate) fn fetch_with_zeros( input: &Tensor, base: usize, stride_h: usize, stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { let in_bounds = x >= 0 && x < w && y >= 0 && y < h; let x_clamped = clamp(x, 0, w - 1) as usize; let y_clamped = clamp(y, 0, h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; select(in_bounds, input[idx], F::new(0.0)) } /// Fetch value with border padding (clamp to edge). #[cube] pub(crate) fn fetch_with_border( input: &Tensor, base: usize, stride_h: usize, stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { let x_clamped = clamp(x, 0, w - 1) as usize; let y_clamped = clamp(y, 0, h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; input[idx] } /// Fetch value with reflection padding. /// Assumes float reflection was applied to center, so indices are at most 2 steps out of bounds. #[cube] pub(crate) fn fetch_with_reflection( input: &Tensor, base: usize, stride_h: usize, stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { let x_reflected = reflect_coord_bounded(x, w); let y_reflected = reflect_coord_bounded(y, h); let idx = base + y_reflected * stride_h + x_reflected * stride_w; input[idx] } /// Reflect an integer index that may be out of bounds. /// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear). #[cube] fn reflect_coord_bounded(idx: i32, size: i32) -> usize { let max_idx = size - 1; let neg_reflected = -idx - 1; let pos_reflected = 2 * max_idx + 1 - idx; let result = select( idx < 0, neg_reflected, select(idx > max_idx, pos_reflected, idx), ); clamp(result, 0, max_idx) as usize } /// Reflect a float coordinate into the valid sampling range. #[cube] pub(crate) fn reflect_coord(coord: F, size: u32, #[comptime] align_corners: bool) -> F { let size_f = F::cast_from(size); if align_corners { reflect_float_impl::(coord, F::new(0.0), size_f - F::new(1.0)) } else { reflect_float_impl::(coord, F::new(-0.5), size_f - F::new(0.5)) } } /// Reflect a float coordinate into [min_val, max_val] using a triangle wave pattern. #[cube] fn reflect_float_impl(coord: F, min_val: F, max_val: F) -> F { let span = max_val - min_val; let is_valid = span > F::new(0.0); let safe_span = select(is_valid, span, F::new(1.0)); // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val let period = safe_span * F::new(2.0); let x = (coord - min_val).abs(); let x_mod = x - (x / period).floor() * period; let reflected = safe_span - (x_mod - safe_span).abs() + min_val; select(is_valid, reflected, min_val) } ================================================ FILE: crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs ================================================ use cubecl::std::FastDivmod; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ CubeRuntime, kernel::utils::address_type, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{Shape, ops::GridSampleOptions}; use super::base::{PaddingMode, fetch_value, reflect_coord}; /// Grid sample with bilinear interpolation. /// /// Each thread processes all channels for one spatial output position: /// 1. Reading (x, y) coordinates from the grid tensor (once per spatial position) /// 2. Converting normalized [-1, 1] coords to pixel coordinates (once) /// 3. For each channel: fetch 4 corner values, interpolate, and write output #[cube(launch, address_type = "dynamic")] fn grid_sample_bilinear_kernel( input: &Tensor, // [N, C, H_in, W_in] grid: &Tensor, // [N, H_out, W_out, 2] output: &mut Tensor, // [N, C, H_out, W_out] shape_spatial: Sequence>, // [N, H_out, W_out] for thread decomposition #[comptime] align_corners: bool, #[comptime] pad_mode: PaddingMode, #[define(F)] _dtype: StorageType, ) { // Thread index maps to spatial position (n, h_out, w_out) only let spatial_idx = ABSOLUTE_POS; let num_spatial = output.shape(0) * output.shape(2) * output.shape(3); if spatial_idx >= num_spatial { terminate!(); } // Decompose spatial index into (n, h_out, w_out) let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx); let (n, h_out) = shape_spatial[1].div_mod(rem); let channels = input.shape(1) as u32; let h_in = input.shape(2) as u32; let w_in = input.shape(3) as u32; // Read grid coordinates once per spatial position let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2); let gx = grid[grid_offset]; // x coordinate in [-1, 1] let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1] // Convert normalized coordinates to pixel coordinates let (px, py) = if align_corners { let px = (gx + F::new(1.0)) * F::cast_from((w_in - 1) as f32) / F::new(2.0); let py = (gy + F::new(1.0)) * F::cast_from((h_in - 1) as f32) / F::new(2.0); (px, py) } else { let px = (gx + F::new(1.0)) * F::cast_from(w_in as f32) / F::new(2.0) - F::new(0.5); let py = (gy + F::new(1.0)) * F::cast_from(h_in as f32) / F::new(2.0) - F::new(0.5); (px, py) }; // For reflection padding, reflect the coordinate into the valid sampling range. // This ensures integer indices are at most 1 step out of bounds. let (px, py) = if comptime!(pad_mode == PaddingMode::Reflection) { let px = reflect_coord::(px, w_in, align_corners); let py = reflect_coord::(py, h_in, align_corners); (px, py) } else { (px, py) }; // Compute floor and ceil indices let x0_f = px.floor(); let y0_f = py.floor(); let x1_f = x0_f + F::new(1.0); let y1_f = y0_f + F::new(1.0); // Compute interpolation weights let wx = px - x0_f; let wy = py - y0_f; let wx_ = F::new(1.0) - wx; let wy_ = F::new(1.0) - wy; // Convert to integers for indexing let x0 = i32::cast_from(x0_f); let y0 = i32::cast_from(y0_f); let x1 = i32::cast_from(x1_f); let y1 = i32::cast_from(y1_f); let w_in = w_in as i32; let h_in = h_in as i32; // Pre-compute strides let stride_n = input.stride(0); let stride_c = input.stride(1); let stride_h = input.stride(2); let stride_w = input.stride(3); let out_stride_n = output.stride(0); let out_stride_c = output.stride(1); let out_stride_h = output.stride(2); let out_stride_w = output.stride(3); // Base offsets for this spatial position let in_base_n = n * stride_n; let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w; // Loop over all channels - grid coords and weights are reused for c in 0..channels { let in_base = in_base_n + c as usize * stride_c; let v00 = fetch_value( input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode, ); let v01 = fetch_value( input, in_base, stride_h, stride_w, y1, x0, h_in, w_in, pad_mode, ); let v10 = fetch_value( input, in_base, stride_h, stride_w, y0, x1, h_in, w_in, pad_mode, ); let v11 = fetch_value( input, in_base, stride_h, stride_w, y1, x1, h_in, w_in, pad_mode, ); // Bilinear interpolation let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11; let out_idx = out_base_spatial + c as usize * out_stride_c; output[out_idx] = result; } } /// Launch the grid sample bilinear kernel pub(crate) fn grid_sample_bilinear_launch( input: CubeTensor, grid: CubeTensor, options: GridSampleOptions, ) -> CubeTensor { let [batch_size, channels, _h_in, _w_in] = input.meta.shape().dims(); let [_n, h_out, w_out, two] = grid.meta.shape().dims(); assert_eq!(two, 2, "Grid last dimension must be 2"); // Create output tensor [N, C, H_out, W_out] let output_shape = Shape::new([batch_size, channels, h_out, w_out]); let output = empty_device_dtype( input.client.clone(), input.device.clone(), output_shape, input.dtype, ); // Spatial threading: one thread per (n, h_out, w_out) let spatial_shape = Shape::new([batch_size, h_out, w_out]); let num_spatial = spatial_shape.num_elements(); let mut shape_spatial = SequenceArg::new(); for dim in spatial_shape.iter() { shape_spatial.push(*dim); } let cube_dim = CubeDim::new(&input.client, num_spatial); let cube_count = calculate_cube_count_elemwise(&input.client, num_spatial, cube_dim); let padding_mode: PaddingMode = options.padding_mode.into(); let dtype = input.dtype; grid_sample_bilinear_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, grid, output), input.into_tensor_arg(), grid.into_tensor_arg(), output.clone().into_tensor_arg(), shape_spatial, options.align_corners, padding_mode, dtype.into(), ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/grid_sample/mod.rs ================================================ mod base; mod bilinear; pub use base::*; ================================================ FILE: crates/burn-cubecl/src/kernel/index/flip.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{DType, TensorMetadata}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; #[cube(launch_unchecked, address_type = "dynamic")] fn flip_kernel( input: &Tensor, output: &mut LinearView, in_shape: Sequence>, indices: Sequence, #[define(E, Bool)] _dtypes: [StorageType; 2], ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let rank = in_shape.len().comptime(); let mut offset = ABSOLUTE_POS; let mut offset_input = 0; #[unroll] for i in 0..rank { let dim = rank - i - 1; let shape = input.shape(dim); let (rem, offset_local) = in_shape[dim].div_mod(offset); offset = rem; let flip = indices.index(dim).get::() == Bool::from_int(1); let offset_local = select(flip, shape - offset_local - 1, offset_local); offset_input += offset_local * input.stride(dim); } output[ABSOLUTE_POS] = input[offset_input]; } pub(crate) fn flip( tensor: CubeTensor, indices: &[usize], dtype_bool: DType, ) -> CubeTensor { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), tensor.dtype, ); flip_on_output(tensor, output, indices, dtype_bool) } pub(crate) fn flip_on_output( tensor: CubeTensor, output: CubeTensor, indices: &[usize], dtype_bool: DType, ) -> CubeTensor { let dtype_input = tensor.dtype; let ndims = tensor.meta.num_dims(); let mut indices_sequence = SequenceArg::::new(); for i in 0..ndims { indices_sequence.push({ let val = indices.contains(&i) as u8; InputScalar::new(val, dtype_bool) }); } let num_elements = output.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, num_elements); let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim); let shape = shape_divmod(&tensor); unsafe { flip_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(tensor, output), tensor.into_tensor_arg(), output.clone().into_linear_view(), shape, indices_sequence, [dtype_input.into(), dtype_bool.into()], ) } output } ================================================ FILE: crates/burn-cubecl/src/kernel/index/gather.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_strides, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor}; use cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod}; use cubecl::{CubeDim, std::tensor::layout::linear::LinearView}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked, address_type = "dynamic")] fn gather_kernel( input: &Tensor, indices: &LinearView, output: &mut LinearView, in_strides: Sequence, // zeroed out for broadcast dims and `dim` out_shape: Sequence>, dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if !indices.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let mut offset = index_offset_contiguous_fastdivmod( ABSOLUTE_POS, &out_shape, &in_strides, input.vector_size(), ); offset += usize::cast_from(indices[ABSOLUTE_POS]) * input.stride(dim); output[ABSOLUTE_POS] = input[offset]; } pub(crate) fn gather( dim: usize, tensor: CubeTensor, indices: CubeTensor, ) -> CubeTensor { let shape_output = indices.shape(); let total_elem = shape_output.num_elements(); let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), shape_output, tensor.dtype, ); let cube_dim = CubeDim::new(&tensor.client, total_elem); let cube_count = calculate_cube_count_elemwise(&tensor.client, total_elem, cube_dim); let mut in_strides = broadcast_strides(&output, &tensor); in_strides.values[dim] = 0; // Zero `dim` to exclude it from the indexing let (dtype, indices_dtype) = (tensor.dtype, indices.dtype); unsafe { gather_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(tensor, indices, output), tensor.into_tensor_arg(), indices.into_linear_view(), output.clone().into_linear_view(), in_strides, shape_divmod(&output), dim, [dtype.into(), indices_dtype.into()], ) } output } ================================================ FILE: crates/burn-cubecl/src/kernel/index/mod.rs ================================================ mod flip; mod gather; mod repeat_dim; mod scatter; mod select; mod select_assign; mod slice; mod slice_assign; pub(crate) use flip::*; pub(crate) use repeat_dim::*; pub(crate) use select::*; pub(crate) use select_assign::*; pub use slice::*; pub(crate) use slice_assign::*; pub(crate) use gather::*; pub(crate) use scatter::*; ================================================ FILE: crates/burn-cubecl/src/kernel/index/repeat_dim.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod}; #[cube(launch_unchecked, address_type = "dynamic")] fn repeat_dim_kernel( input: &Tensor, output: &mut Tensor, out_shape: Sequence>, in_shape: FastDivmod, #[comptime] dim: usize, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let rank = out_shape.len().comptime(); let mut pos = ABSOLUTE_POS; let mut offset_input = 0; let mut offset_output = 0; #[unroll] for i in 0..rank { let i = rank - i - 1; let (rem, mut local_pos) = out_shape[i].div_mod(pos); pos = rem; offset_output += local_pos * output.stride(i); if i == dim { local_pos = in_shape.modulo(local_pos); } offset_input += local_pos * input.stride(i); } output[offset_output] = input[offset_input]; } pub(crate) fn repeat_dim( mut input: CubeTensor, dim: usize, times: usize, ) -> CubeTensor { if input.meta.shape()[dim] == 1 { input.meta.strides[dim] = 0; input.meta.shape = input.meta.shape.clone().repeat(dim, times).unwrap(); return input; } let shape = input.meta.shape.clone().repeat(dim, times).unwrap(); // Create output handle let output = empty_device_dtype( input.client.clone(), input.device.clone(), shape, input.dtype, ); let working_units = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); let shape_arg = input.meta.shape()[dim]; unsafe { repeat_dim_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(input, output), input.into_tensor_arg(), output.clone().into_tensor_arg(), shape_divmod(&output), shape_arg, dim, output.dtype.into(), ) }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/index/scatter.rs ================================================ use crate::{ CubeRuntime, kernel::{ AddOp, BinaryOp, BinaryOpFamily, OrOp, utils::{address_type, shape_divmod}, }, tensor::CubeTensor, }; use cubecl::{CubeDim, calculate_cube_count_elemwise}; use cubecl::{prelude::*, std::FastDivmod}; #[cube(launch_unchecked, address_type = "dynamic")] fn scatter_kernel( input: &mut Tensor, indices: &Tensor, value: &Tensor, in_shape: Sequence>, #[comptime] dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { let rank = in_shape.len().comptime(); let stride_input = input.stride(dim); let stride_value = value.stride(dim); let stride_indices = indices.stride(dim); let shape_value = value.shape(dim); let mut offset = ABSOLUTE_POS; let mut offset_input = 0; let mut offset_indices = 0; let mut offset_value = 0; let mut num_elems = 1; #[unroll] for i in 0..rank { let i = rank - i - 1; if i != dim { let shape_input_loop = input.shape(i); let (rem, local_pos) = in_shape[i].div_mod(offset); offset = rem; offset_input += local_pos * input.stride(i); offset_indices += local_pos * indices.stride(i); offset_value += local_pos * value.stride(i); num_elems *= shape_input_loop; } } let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { terminate!(); } for i in 0..shape_value { let value_idx = (stride_value * i) + offset_value; let index_idx = (stride_indices * i) + offset_indices; let value = value[value_idx]; let index = usize::cast_from(indices[index_idx]); let input_idx = (stride_input * index) + offset_input; let value = Op::BinaryOp::>::execute( Vector::cast_from(input[input_idx]), Vector::cast_from(value), ); input[input_idx] = value[0]; } } pub(crate) fn scatter( dim: usize, tensor: CubeTensor, indices: CubeTensor, value: CubeTensor, is_bool: bool, ) -> CubeTensor { let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() { true => tensor, false => tensor.copy(), }; let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim]; let working_units = num_elems; let cube_dim = CubeDim::new(&indices.client, working_units); let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim); let launch = match is_bool { true => scatter_kernel::launch_unchecked::, false => scatter_kernel::launch_unchecked::, }; let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype); unsafe { launch( &tensor.client.clone(), cube_count, cube_dim, address_type!(tensor, indices, value), tensor.clone().into_tensor_arg(), indices.into_tensor_arg(), value.into_tensor_arg(), shape_divmod(&tensor), dim, [tensor_dtype.into(), indices_dtype.into()], ) } tensor } ================================================ FILE: crates/burn-cubecl/src/kernel/index/select.rs ================================================ use crate::{CubeRuntime, kernel::utils::address_type, tensor::CubeTensor}; use crate::{kernel::utils::shape_divmod, ops::numeric::empty_device_dtype}; use burn_backend::TensorMetadata; use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView}; use cubecl::{prelude::*, std::FastDivmod}; #[cube(launch_unchecked, address_type = "dynamic")] fn select_kernel( input: &Tensor, indices: &LinearView, output: &mut LinearView, out_shape: Sequence>, dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= output.shape() { terminate!(); } let rank = out_shape.len().comptime(); let mut offset = ABSOLUTE_POS; let mut offset_input = 0; #[unroll] for i in 0..rank { let i = rank - i - 1; let (rem, offset_local) = out_shape[i].div_mod(offset); offset = rem; let offset_local = cubecl::prelude::select( i == dim, usize::cast_from(indices[offset_local]), offset_local, ); offset_input += offset_local * input.stride(i); } output[ABSOLUTE_POS] = input[offset_input]; } pub(crate) fn select( tensor: CubeTensor, dim: usize, indices: CubeTensor, ) -> CubeTensor { let mut shape_output = tensor.shape(); shape_output[dim] = indices.meta.shape()[0]; let total_elem = shape_output.num_elements(); let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), shape_output, tensor.dtype, ); let working_units = total_elem; let cube_dim = CubeDim::new(&indices.client, working_units); let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim); let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype); unsafe { select_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(tensor, indices, output), tensor.into_tensor_arg(), indices.into_linear_view(), output.clone().into_linear_view(), shape_divmod(&output), dim, [tensor_dtype.into(), indices_dtype.into()], ) }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/index/select_assign.rs ================================================ use crate::kernel::{ AddOp, BinaryOp, BinaryOpFamily, OrOp, utils::{address_type, shape_divmod}, }; use crate::{CubeRuntime, tensor::CubeTensor}; use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView}; use cubecl::{prelude::*, std::FastDivmod}; #[cube(launch_unchecked, address_type = "dynamic")] fn select_assign_kernel( tensor: &mut Tensor, indices: &LinearView, value: &Tensor, value_shape: Sequence>, num_elems: usize, #[comptime] dim: usize, #[define(F, I)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= num_elems { terminate!(); } let rank = value_shape.len().comptime(); let mut offset = ABSOLUTE_POS; let mut offset_tensor = 0; let mut offset_value = 0; // Calculate offsets and num_elems #[unroll] for i in 0..rank { let i = rank - i - 1; if i != dim { let (rem, local_pos) = value_shape[i].div_mod(offset); offset = rem; offset_tensor += local_pos * tensor.stride(i); offset_value += local_pos * value.stride(i); } } let strides_tensor_dim = tensor.stride(dim); let strides_value_dim = value.stride(dim); // Main operation for i in 0..value.shape(dim) { let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value; let value = Op::BinaryOp::>::execute( Vector::cast_from(tensor[index_tensor]), Vector::cast_from(value[index_value]), ); tensor[index_tensor] = F::cast_from(value); } } pub(crate) fn select_assign( tensor: CubeTensor, dim: usize, indices: CubeTensor, value: CubeTensor, is_bool: bool, ) -> CubeTensor { let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() { true => tensor, false => tensor.copy(), }; let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim]; let working_units = num_elems; let cube_dim = CubeDim::new(&indices.client, working_units); let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim); let launch = match is_bool { true => select_assign_kernel::launch_unchecked::, false => select_assign_kernel::launch_unchecked::, }; let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype); let shape = shape_divmod(&value); unsafe { launch( &tensor.client, cube_count, cube_dim, address_type!(tensor, indices, value), tensor.clone().into_tensor_arg(), indices.into_linear_view(), value.into_tensor_arg(), shape, num_elems, dim, [tensor_dtype.into(), indices_dtype.into()], ) }; tensor } ================================================ FILE: crates/burn-cubecl/src/kernel/index/slice.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, shape_divmod}, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::{Slice, TensorMetadata}; use burn_std::{Metadata, SliceOps}; use cubecl::{ calculate_cube_count_elemwise, intrinsic, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; use std::ops::Range; /// Slice a jit tensor with a set of ranges pub fn slice(tensor: CubeTensor, indices: &[Range]) -> CubeTensor { let mut dims = tensor.shape(); let mut offset_start = 0u64; let mut offset_end = 0u64; for i in 0..indices.len() { offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64; offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64; dims[i] = indices[i].end - indices[i].start; } let offset_start = offset_start * tensor.dtype.size() as u64; let offset_end = offset_end * tensor.dtype.size() as u64; let memory_offset_alignment = tensor.client.properties().memory.alignment; if offset_start.is_multiple_of(memory_offset_alignment) && offset_end.is_multiple_of(memory_offset_alignment) { CubeTensor::new( tensor.client.clone(), tensor .handle .clone() .offset_start(offset_start) .offset_end(offset_end), Metadata::new(dims, tensor.meta.strides.clone()), tensor.device.clone(), tensor.dtype, ) } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), dims, tensor.dtype, ); slice_on_output(tensor, output, indices) } } #[cube(launch_unchecked, address_type = "dynamic")] fn slice_kernel( input: &Tensor, output: &mut LinearView, out_shape: Sequence>, indices: Sequence, #[define(E)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let rank = comptime![out_shape.len()]; let mut offset_output = ABSOLUTE_POS; let mut offset_input = 0; #[unroll] for i in 0..rank { // Iterate in reverse to use divmod let dim = rank - i - 1; let range_start = indices[dim]; let (rem, offset_local) = out_shape[dim].div_mod(offset_output); offset_output = rem; let offset_local = offset_local + range_start; offset_input += offset_local * input.stride(dim); } output[ABSOLUTE_POS] = input[offset_input]; } pub(crate) fn slice_on_output( tensor: CubeTensor, output: CubeTensor, indices: &[Range], ) -> CubeTensor { let ndims = tensor.meta.num_dims(); let mut indices_sequence = SequenceArg::::new(); for i in 0..ndims { let start = indices.get(i).map(|index| index.start).unwrap_or(0); indices_sequence.push(start); } let working_units = output.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { slice_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(tensor, output), tensor.into_tensor_arg(), output.clone().into_linear_view(), shape_divmod(&output), indices_sequence, dtype.into(), ) }; output } /// Kernel for slicing with steps #[cube(launch_unchecked, address_type = "dynamic")] fn slice_with_steps_kernel( input: &Tensor, output: &mut LinearView, out_shape: Sequence>, starts: Sequence, ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let rank = comptime![out_shape.len()]; let mut output_offset = ABSOLUTE_POS; let mut input_offset = 0; // Calculate the input offset based on output position and slice info #[unroll] for i in 0..rank { // Iterate in reverse to use divmod let dim = rank - i - 1; let start = starts[dim]; let end = ends[dim]; let step = steps[dim]; let (rem, output_idx) = out_shape[dim].div_mod(output_offset); output_offset = rem; let input_idx = if step > 0 { // Forward stepping start + output_idx * (step as usize) } else { // Backward stepping - start from end-1 let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - output_idx * abs_step }; input_offset += input_idx * input.stride(dim); } output[ABSOLUTE_POS] = input[input_offset]; } /// Slice a tensor with steps pub fn slice_with_steps(tensor: CubeTensor, slices: &[Slice]) -> CubeTensor { // Check if all steps are 1 - if so, use the optimized regular slice let all_steps_one = slices.iter().all(|info| info.step == 1); if all_steps_one { // Convert Slice to Range for step=1 let simple_ranges: Vec> = slices .iter() .enumerate() .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); return slice(tensor, &simple_ranges); } // Calculate output shape let shape_output = tensor.shape().slice(slices).unwrap(); // Create output tensor let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), shape_output.clone(), tensor.dtype, ); // Prepare three separate sequences for kernel let mut starts = SequenceArg::::new(); let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.meta.shape()[dim]); starts.push(range.start); ends.push(range.end); steps.push(slice.step as i32); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.meta.num_dims() { starts.push(0); ends.push(tensor.meta.shape[dim]); steps.push(1); } // Launch kernel let working_units = shape_output.num_elements(); let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { slice_with_steps_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(tensor, output), tensor.into_tensor_arg(), output.clone().into_linear_view(), shape_divmod(&output), starts, ends, steps, dtype.into(), ); } output } /// This is annoying and we need to find a way to do this automatically at some point #[allow(unused)] #[cube] fn unwrap(value: u32) -> comptime_type!(u32) { intrinsic!(|_| value.constant().unwrap().as_u32()) } ================================================ FILE: crates/burn-cubecl/src/kernel/index/slice_assign.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, shape_divmod}, tensor::CubeTensor, }; use cubecl::{ calculate_cube_count_elemwise, intrinsic, prelude::*, std::{FastDivmod, tensor::layout::linear::LinearView}, }; #[cube(launch_unchecked, address_type = "dynamic")] fn slice_assign_kernel( input: &mut Tensor>, value: &LinearView>, slice_shape: Sequence>, slice_offsets: Sequence, #[define(E)] _dtype: StorageType, ) { if !value.is_in_bounds(ABSOLUTE_POS) { terminate!() } let rank = comptime!(slice_shape.len()); let line_size = input.vector_size(); let mut offset_remainder = ABSOLUTE_POS * line_size; let mut offset_input = 0; #[allow(clippy::explicit_counter_loop)] #[unroll] for i in 0..rank { let dim = rank - i - 1; let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder); let range_start = slice_offsets[dim]; let offset_local_input = offset_local + range_start; offset_input += offset_local_input * input.stride(dim); offset_remainder = rem; } // Value tensor is accessed linearly since it's a LinearView input[offset_input / line_size] = value[ABSOLUTE_POS]; } /// Kernel for slice assign with steps #[cube(launch_unchecked, address_type = "dynamic")] fn slice_assign_with_steps_kernel( input: &mut Tensor, value: &LinearView, value_shape: Sequence>, starts: Sequence, ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { if !value.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let rank = comptime![value_shape.len()]; let mut value_offset = ABSOLUTE_POS; let mut input_offset = 0; // Calculate the input offset based on value position and slice info #[unroll] for i in 0..rank { // Iterate in reverse to use divmod let dim = rank - i - 1; let start = starts[dim]; let end = ends[dim]; let step = steps[dim]; let (rem, value_idx) = value_shape[dim].div_mod(value_offset); value_offset = rem; let input_idx = if step > 0 { // Forward stepping start + value_idx * (step as usize) } else if step < 0 { // Backward stepping - start from end-1 // For negative steps, we iterate backwards through the selected indices let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - value_idx * abs_step } else { // step == 0, shouldn't happen value_idx }; input_offset += input_idx * input.stride(dim); } input[input_offset] = value[ABSOLUTE_POS]; } pub(crate) fn slice_assign( tensor: CubeTensor, indices: &[burn_backend::Slice], value: CubeTensor, ) -> CubeTensor { // Check if any slice has non-unit step let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0); if has_non_unit_step { // Use slice_assign_with_steps return slice_assign_with_steps(tensor, indices, value); } let client = tensor.client.clone(); let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() { true => tensor, false => tensor.copy(), }; let ndims = tensor.meta.num_dims(); let vector_size = if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1 { let last = indices .get(ndims - 1) .cloned() .unwrap_or(burn_backend::Slice { start: 0, end: Some(tensor.meta.shape()[ndims - 1] as isize), step: 1, }); let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize); let shape = (end - last.start) as usize; let offset = last.start as usize; client .io_optimized_vector_sizes(tensor.dtype.size()) .filter(|&it| { shape.is_multiple_of(it) && strides_compatible(tensor.meta.strides(), it) && strides_compatible(value.meta.strides(), it) && offset.is_multiple_of(it) }) .max() .unwrap_or(1) } else { 1 }; let mut shape = SequenceArg::>::new(); let mut offsets = SequenceArg::::new(); for i in 0..ndims { let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice { start: 0, end: Some(tensor.meta.shape()[i] as isize), step: 1, }); let start = slice.start as usize; let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize); let length = (end - slice.start) as usize; shape.push(length); offsets.push(start); } let working_units = value.meta.num_elements() / vector_size; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); unsafe { slice_assign_kernel::launch_unchecked( &tensor.client, cube_count, cube_dim, address_type!(tensor, value), vector_size, tensor.clone().into_tensor_arg(), value.into_linear_view(), shape, offsets, tensor.dtype.into(), ) }; tensor } /// Slice assign with steps support /// /// This function handles slice assignment with arbitrary step values, including negative steps. /// It follows NumPy/PyTorch semantics where values[i] is assigned to selected_indices[i]. /// /// For example, with s![0..6;-1] which selects indices [5,4,3,2,1,0]: /// - values[0] goes to index 5 /// - values[1] goes to index 4 /// - etc. pub(crate) fn slice_assign_with_steps( tensor: CubeTensor, slices: &[burn_backend::Slice], value: CubeTensor, ) -> CubeTensor { let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() { true => tensor, false => tensor.copy(), }; // Prepare sequences for kernel let mut starts = SequenceArg::::new(); let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.meta.shape()[dim]); starts.push(range.start); ends.push(range.end); steps.push(slice.step as i32); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.meta.num_dims() { starts.push(0); ends.push(tensor.meta.shape[dim]); steps.push(1); } // Launch kernel let working_units = value.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let shape = shape_divmod(&value); unsafe { slice_assign_with_steps_kernel::launch_unchecked( &tensor.client, cube_count, cube_dim, address_type!(tensor, value), tensor.clone().into_tensor_arg(), value.into_linear_view(), shape, starts, ends, steps, tensor.dtype.into(), ); } tensor } fn strides_compatible(strides: &[usize], vec: usize) -> bool { strides .iter() .all(|stride| *stride % vec == 0 || *stride == 1) } /// Helper function for unwrap #[allow(unused)] #[cube] fn unwrap(value: u32) -> comptime_type!(u32) { intrinsic!(|_| value.constant().unwrap().as_u32()) } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/base.rs ================================================ use crate::{ CubeRuntime, kernel::into_contiguous, ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw}, tensor::CubeTensor, }; use burn_backend::{ Shape, TensorMetadata, ops::{InterpolateMode, InterpolateOptions}, }; use super::{ bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch, lanczos3::interpolate_lanczos3_launch, nearest::interpolate_nearest_launch, nearest_backward::interpolate_nearest_backward_launch, }; /// Interpolate operation /// /// Supports nearest, bilinear, bicubic and lanczos3 modes pub fn interpolate( input: CubeTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> CubeTensor { let [batch_size, channels, _, _] = input.meta.shape().dims(); let [out_height, out_width] = output_size; let input = into_contiguous(permute_nchw_to_nhwc(input)); let shape_out = Shape::new([batch_size, out_height, out_width, channels]); let output = empty_device_dtype( input.client.clone(), input.device.clone(), shape_out, input.dtype, ); let align_corners = options.align_corners; let output = match options.mode { InterpolateMode::Nearest => interpolate_nearest_launch(input, output), InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners), InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners), InterpolateMode::Lanczos3 => interpolate_lanczos3_launch(input, output, align_corners), }; permute_nhwc_to_nchw(output) } /// Backward interpolate operation /// /// Note: only nearest mode is supported pub fn interpolate_backward( input: CubeTensor, out_grad: CubeTensor, _output_size: [usize; 2], options: InterpolateOptions, ) -> CubeTensor { let input = permute_nchw_to_nhwc(input); let out_grad = permute_nchw_to_nhwc(out_grad); let output_shape = input.shape(); let output = empty_device_dtype( input.client.clone(), input.device.clone(), output_shape, input.dtype, ); let output = match options.mode { InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output), InterpolateMode::Bilinear => { panic!("bilinear interpolation backward is not supported by JIT backend") } InterpolateMode::Bicubic => { panic!("bicubic interpolation backward is not supported by JIT backend") } InterpolateMode::Lanczos3 => { panic!("lanczos3 interpolation backward is not supported by JIT backend") } }; permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/bicubic.rs ================================================ use cubecl::std::{ FastDivmod, tensor::layout::{linear::LinearLayout, *}, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ CubeRuntime, kernel::utils::{address_type, linear_layout, shape_divmod}, ops::max_vector_size, tensor::CubeTensor, }; #[cube(launch, address_type = "dynamic")] fn interpolate_bicubic_kernel( input: &Tensor>, output: &mut Tensor>, shape_out: Sequence>, out_layout: LinearLayout, #[comptime] align_corners: bool, #[define(F)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let vector_size = input.vector_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size); let (rem, x) = shape_out[2].div_mod(rem); let (b, y) = shape_out[1].div_mod(rem); let input_height = input.shape(1) - 1; let input_height_f = input_height as f32; let frac = if align_corners { let output_height = clamp_min(output.shape(1) - 1, 1) as f32; (y * input_height) as f32 / output_height } else { let in_size = (input_height + 1) as f32; let out_size = output.shape(1) as f32; (y as f32 + 0.5) * (in_size / out_size) - 0.5 }; let y_in_f = frac.floor(); let yw = Vector::new(F::cast_from(frac - y_in_f)); // Clamp indices in float space to handle negative coordinates from half_pixel let y0 = clamp(y_in_f - 1.0, 0.0, input_height_f) as usize; let y1 = clamp(y_in_f, 0.0, input_height_f) as usize; let y2 = clamp(y_in_f + 1.0, 0.0, input_height_f) as usize; let y3 = clamp(y_in_f + 2.0, 0.0, input_height_f) as usize; let input_width = input.shape(2) - 1; let input_width_f = input_width as f32; let frac = if align_corners { let output_width = clamp_min(output.shape(2) - 1, 1) as f32; (x * input_width) as f32 / output_width } else { let in_size = (input_width + 1) as f32; let out_size = output.shape(2) as f32; (x as f32 + 0.5) * (in_size / out_size) - 0.5 }; let x_in_f = frac.floor(); let xw = Vector::new(F::cast_from(frac - x_in_f)); // Clamp indices in float space to handle negative coordinates from half_pixel let x0 = clamp(x_in_f - 1.0, 0.0, input_width_f) as usize; let x1 = clamp(x_in_f, 0.0, input_width_f) as usize; let x2 = clamp(x_in_f + 1.0, 0.0, input_width_f) as usize; let x3 = clamp(x_in_f + 2.0, 0.0, input_width_f) as usize; let index_base = b * input.stride(0) + c * input.stride(3); let in_stride_y = input.stride(1); let in_stride_x = input.stride(2); let y0_stride = y0 * in_stride_y; let y1_stride = y1 * in_stride_y; let y2_stride = y2 * in_stride_y; let y3_stride = y3 * in_stride_y; let x0_stride = x0 * in_stride_x; let x1_stride = x1 * in_stride_x; let x2_stride = x2 * in_stride_x; let x3_stride = x3 * in_stride_x; let inp_0 = input[(index_base + y0_stride + x0_stride) / vector_size]; let inp_1 = input[(index_base + y0_stride + x1_stride) / vector_size]; let inp_2 = input[(index_base + y0_stride + x2_stride) / vector_size]; let inp_3 = input[(index_base + y0_stride + x3_stride) / vector_size]; let coefficients0 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw); let inp_0 = input[(index_base + y1_stride + x0_stride) / vector_size]; let inp_1 = input[(index_base + y1_stride + x1_stride) / vector_size]; let inp_2 = input[(index_base + y1_stride + x2_stride) / vector_size]; let inp_3 = input[(index_base + y1_stride + x3_stride) / vector_size]; let coefficients1 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw); let inp_0 = input[(index_base + y2_stride + x0_stride) / vector_size]; let inp_1 = input[(index_base + y2_stride + x1_stride) / vector_size]; let inp_2 = input[(index_base + y2_stride + x2_stride) / vector_size]; let inp_3 = input[(index_base + y2_stride + x3_stride) / vector_size]; let coefficients2 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw); let inp_0 = input[(index_base + y3_stride + x0_stride) / vector_size]; let inp_1 = input[(index_base + y3_stride + x1_stride) / vector_size]; let inp_2 = input[(index_base + y3_stride + x2_stride) / vector_size]; let inp_3 = input[(index_base + y3_stride + x3_stride) / vector_size]; let coefficients3 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw); let val = cubic_interp_1d( coefficients0, coefficients1, coefficients2, coefficients3, yw, ); output[out_idx] = val; } #[cube] fn cubic_interp_1d( x0: Vector, x1: Vector, x2: Vector, x3: Vector, t: Vector, ) -> Vector { let a = float(-0.75); let coeffs0 = cubic_convolution_2(t + float(1.0), a); let coeffs1 = cubic_convolution_1(t, a); let coeffs2 = cubic_convolution_1(float(1.0) - t, a); let coeffs3 = cubic_convolution_2(float(2.0) - t, a); x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3 } #[cube] fn cubic_convolution_1(x: Vector, a: Vector) -> Vector { let conv = (a + float(2.0)) * x; let tmp = a + float(3.0); (conv - tmp) * x * x + float(1.0) } #[cube] fn cubic_convolution_2(x: Vector, a: Vector) -> Vector { let conv = a * x; let conv = (conv - float(5.0) * a) * x; let tmp = float(8.0) * a; let conv = (conv + tmp) * x; conv - float(4.0) * a } #[cube] fn float(#[comptime] v: f32) -> Vector { Vector::new(F::new(v)) } pub(crate) fn interpolate_bicubic_launch( input: CubeTensor, output: CubeTensor, align_corners: bool, ) -> CubeTensor { let vector_size = max_vector_size(&input); let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, vector_size); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); interpolate_bicubic_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_tensor_arg(), output.clone().into_tensor_arg(), out_shape, out_layout, align_corners, output.dtype.into(), ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/bilinear.rs ================================================ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{ num_traits::Zero, std::{ FastDivmod, tensor::layout::{linear::LinearLayout, *}, }, }; use crate::{ CubeRuntime, kernel::utils::{address_type, linear_layout, shape_divmod}, ops::max_vector_size, tensor::CubeTensor, }; #[cube(launch, address_type = "dynamic")] fn interpolate_bilinear_kernel( input: &Tensor>, output: &mut Tensor>, shape_out: Sequence>, out_layout: LinearLayout, #[comptime] align_corners: bool, #[define(F)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let vector_size = input.vector_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size); let (rem, x) = shape_out[2].div_mod(rem); let (b, y) = shape_out[1].div_mod(rem); let frac = if align_corners { let numerator = (input.shape(1) - 1) as f32; let denominator = clamp_min(output.shape(1) - 1, 1) as f32; y as f32 * (numerator / denominator) } else { let in_size = input.shape(1) as f32; let out_size = output.shape(1) as f32; clamp( (y as f32 + 0.5) * (in_size / out_size) - 0.5, 0.0, in_size - 1.0, ) }; let v0 = frac.floor(); let v1 = frac.ceil(); let yw = F::cast_from(frac - v0); let yw_ = Vector::new(F::one() - yw); let yw = Vector::new(yw); let y0_ok = v0 >= 0.0; let y0 = v0 as usize; let y1 = v1 as usize; let frac = if align_corners { let numerator = (input.shape(2) - 1) as f32; let denominator = clamp_min(output.shape(2) - 1, 1) as f32; x as f32 * (numerator / denominator) } else { let in_size = input.shape(2) as f32; let out_size = output.shape(2) as f32; clamp( (x as f32 + 0.5) * (in_size / out_size) - 0.5, 0.0, in_size - 1.0, ) }; let v0 = frac.floor(); let v1 = frac.ceil(); let xw = F::cast_from(frac - v0); let xw_ = Vector::new(F::one() - xw); let xw = Vector::new(xw); let x0_ok = v0 >= 0.0; let x0 = v0 as usize; let x1 = v1 as usize; let index_base = b * input.stride(0) + c * input.stride(3); let in_stride_y = input.stride(1); let in_stride_x = input.stride(2); let y0_stride = y0 * in_stride_y; let y1_stride = y1 * in_stride_y; let x0_stride = x0 * in_stride_x; let x1_stride = x1 * in_stride_x; let height = input.shape(1); let width = input.shape(2); let y1_ok = y1 < height; let x1_ok = x1 < width; let zero = Vector::zero(); let p_a = select( x0_ok && y0_ok, input[(index_base + y0_stride + x0_stride) / vector_size] * xw_ * yw_, zero, ); let p_b = select( x1_ok && y0_ok, input[(index_base + y0_stride + x1_stride) / vector_size] * xw * yw_, zero, ); let p_c = select( x0_ok && y1_ok, input[(index_base + y1_stride + x0_stride) / vector_size] * xw_ * yw, zero, ); let p_d = select( x1_ok && y1_ok, input[(index_base + y1_stride + x1_stride) / vector_size] * xw * yw, zero, ); output[out_idx] = p_a + p_b + p_c + p_d; } pub(crate) fn interpolate_bilinear_launch( input: CubeTensor, output: CubeTensor, align_corners: bool, ) -> CubeTensor { let vector_size = max_vector_size(&input); let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, vector_size); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); interpolate_bilinear_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_tensor_arg(), output.clone().into_tensor_arg(), out_shape, out_layout, align_corners, output.dtype.into(), ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/lanczos3.rs ================================================ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{ num_traits::Zero, std::{ FastDivmod, tensor::layout::{linear::LinearLayout, *}, }, }; use crate::{ CubeRuntime, kernel::utils::{address_type, linear_layout, shape_divmod}, ops::max_vector_size, tensor::CubeTensor, }; #[cube(launch, address_type = "dynamic")] fn interpolate_lanczos3_kernel( input: &Tensor>, output: &mut Tensor>, shape_out: Sequence>, out_layout: LinearLayout, #[comptime] align_corners: bool, #[define(F)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let vector_size = input.vector_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size); let (rem, x) = shape_out[2].div_mod(rem); let (b, y) = shape_out[1].div_mod(rem); let input_height = input.shape(1) - 1; let input_height_f = input_height as f32; let y_frac = if align_corners { let output_height = clamp_min(output.shape(1) - 1, 1) as f32; (y * input_height) as f32 / output_height } else { let in_size = (input_height + 1) as f32; let out_size = output.shape(1) as f32; (y as f32 + 0.5) * (in_size / out_size) - 0.5 }; let y0 = f32::floor(y_frac); let input_width = input.shape(2) - 1; let input_width_f = input_width as f32; let x_frac = if align_corners { let output_width = clamp_min(output.shape(2) - 1, 1) as f32; (x * input_width) as f32 / output_width } else { let in_size = (input_width + 1) as f32; let out_size = output.shape(2) as f32; (x as f32 + 0.5) * (in_size / out_size) - 0.5 }; let x0 = f32::floor(x_frac); let index_base = b * input.stride(0) + c * input.stride(3); let in_stride_y = input.stride(1); let in_stride_x = input.stride(2); let mut result = Vector::zero(); let mut weight_sum = 0.0f32; // 6-tap separable Lanczos3 filter: ky in -2..=3, kx in -2..=3 // Skip out-of-bounds positions instead of clamping (matches TF/JAX/PIL) #[unroll] for ky in -2..4i32 { let y_pos = y0 + ky as f32; if y_pos >= 0.0 && y_pos <= input_height_f { let y_idx = y_pos as usize; let wy = lanczos3_weight(y_frac - y_pos); #[unroll] for kx in -2..4i32 { let x_pos = x0 + kx as f32; if x_pos >= 0.0 && x_pos <= input_width_f { let x_idx = x_pos as usize; let wx = lanczos3_weight(x_frac - x_pos); let wt = wy * wx; let idx = index_base + y_idx * in_stride_y + x_idx * in_stride_x; let pixel = input[idx / vector_size]; let w = Vector::new(F::cast_from(wt)); result += pixel * w; weight_sum += wt; } } } } if weight_sum != 0.0 { let inv_w = Vector::new(F::cast_from(1.0 / weight_sum)); result *= inv_w; } output[out_idx] = result; } #[cube] fn lanczos3_weight(x: f32) -> f32 { let abs_x = f32::abs(x); let mut result = 0.0f32; if abs_x < 1e-7 { result = 1.0; } else if abs_x < 3.0 { let pi = core::f32::consts::PI; let pi_x = pi * x; let pi_x_over_3 = pi_x / 3.0; result = (f32::sin(pi_x) * f32::sin(pi_x_over_3)) / (pi_x * pi_x_over_3); } result } pub(crate) fn interpolate_lanczos3_launch( input: CubeTensor, output: CubeTensor, align_corners: bool, ) -> CubeTensor { let vector_size = max_vector_size(&input); let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, vector_size); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); interpolate_lanczos3_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_tensor_arg(), output.clone().into_tensor_arg(), out_shape, out_layout, align_corners, output.dtype.into(), ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/mod.rs ================================================ mod base; mod bicubic; mod bilinear; mod lanczos3; mod nearest; mod nearest_backward; pub use base::*; ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/nearest.rs ================================================ use cubecl::std::{ FastDivmod, tensor::layout::{linear::LinearLayout, *}, }; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ CubeRuntime, kernel::utils::{address_type, linear_layout, shape_divmod}, ops::max_vector_size, tensor::CubeTensor, }; #[cube(launch_unchecked, address_type = "dynamic")] fn interpolate_nearest_kernel( input: &Tensor>, output: &mut Tensor>, shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let vector_size = input.vector_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); let out_pos = ABSOLUTE_POS * vector_size; let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32); let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32); let (rem, c) = shape_out[3].div_mod(out_pos); let (rem, x) = shape_out[2].div_mod(rem); let (b, y) = shape_out[1].div_mod(rem); let y = y as f32 * (h_in / h_out); let x = x as f32 * (w_in / w_out); let in_idx = b * input.stride(0) + y as usize * input.stride(1) + x as usize * input.stride(2) + c * input.stride(3); output[out_idx] = input[in_idx / vector_size]; } pub(crate) fn interpolate_nearest_launch( input: CubeTensor, output: CubeTensor, ) -> CubeTensor { let client = input.client.clone(); let vector_size = max_vector_size(&input); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); let shape_out = shape_divmod(&output); let out_layout = linear_layout(&output, vector_size); unsafe { interpolate_nearest_kernel::launch_unchecked( &client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_tensor_arg(), output.clone().into_tensor_arg(), shape_out, out_layout, output.dtype.into(), ) }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs ================================================ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{ num_traits::Zero, std::{ FastDivmod, tensor::layout::{linear::LinearLayout, *}, }, }; use crate::{ CubeRuntime, kernel::utils::{address_type, linear_layout, shape_divmod}, ops::max_vector_size, tensor::CubeTensor, }; #[cube(launch_unchecked, address_type = "dynamic")] fn interpolate_nearest_backward_kernel( grad: &Tensor>, output: &mut Tensor>, shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { terminate!(); } let vector_size = grad.vector_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); let out_h = output.shape(1); let out_w = output.shape(2); let grad_h = grad.shape(1); let grad_w = grad.shape(2); let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size); let (rem, out_x) = shape_out[2].div_mod(rem); let (b, out_y) = shape_out[1].div_mod(rem); let grad_y_start = start_index::(out_y, grad_h, out_h); let grad_y_end = end_index::(out_y, grad_h, out_h); let grad_x_start = start_index::(out_x, grad_w, out_w); let grad_x_end = end_index::(out_x, grad_w, out_w); let index_grad_base = b * grad.stride(0) + c * grad.stride(3); let mut sum = Vector::zero(); for grad_y in grad_y_start..grad_y_end { for grad_x in grad_x_start..grad_x_end { let index_grad = index_grad_base + grad_y * grad.stride(1) + grad_x * grad.stride(2); sum += grad[index_grad]; } } output[out_idx] = sum; } #[cube] fn start_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from(input_index * output_size); let div = (numerator / F::cast_from(input_size)).ceil(); usize::cast_from(div) } #[cube] fn end_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from((input_index + 1) * output_size); let div = (numerator / F::cast_from(input_size)).ceil(); let index = usize::cast_from(div); clamp_max(index, output_size) } pub(crate) fn interpolate_nearest_backward_launch( out_grad: CubeTensor, output: CubeTensor, ) -> CubeTensor { let vector_size = max_vector_size(&out_grad); let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, vector_size); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&out_grad.client, working_units); let cube_count = calculate_cube_count_elemwise(&out_grad.client, working_units, cube_dim); unsafe { interpolate_nearest_backward_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(out_grad, output), vector_size, out_grad.into_tensor_arg(), output.clone().into_tensor_arg(), out_shape, out_layout, output.dtype.into(), ) }; output } ================================================ FILE: crates/burn-cubecl/src/kernel/mask/base.rs ================================================ use burn_backend::DType; use cubecl::prelude::InputScalar; use super::{MaskFillStrategy, mask_where::MaskWhereStrategy}; use crate::{CubeRuntime, tensor::CubeTensor}; /// Execute the mask fill kernel. pub(crate) fn mask_fill_auto( tensor: CubeTensor, mask: CubeTensor, value: InputScalar, dtype_bool: DType, ) -> CubeTensor { let strategy = if tensor.can_mut() && tensor.is_nonoverlapping() { MaskFillStrategy::Inplace } else { MaskFillStrategy::Readonly }; super::mask_fill(tensor, mask, value, strategy, dtype_bool) } /// Execute the mask where kernel. pub(crate) fn mask_where_auto( tensor: CubeTensor, mask: CubeTensor, value: CubeTensor, dtype_bool: DType, ) -> CubeTensor { let strategy = if tensor.can_mut_broadcast(&value) { MaskWhereStrategy::InplaceLhs } else if value.can_mut_broadcast(&tensor) { MaskWhereStrategy::InplaceRhs } else { MaskWhereStrategy::Readonly }; super::mask_where(tensor, mask, value, strategy, dtype_bool) } ================================================ FILE: crates/burn-cubecl/src/kernel/mask/mask_fill.rs ================================================ use burn_backend::{DType, TensorMetadata}; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; use crate::{ CubeRuntime, kernel::utils::address_type, ops::{max_vector_size_many, numeric::empty_device_dtype}, tensor::CubeTensor, }; #[cube(launch_unchecked, address_type = "dynamic")] fn mask_fill_kernel( input: &LinearView>, mask: &LinearView>, output: &mut LinearView, ReadWrite>, value: InputScalar, #[define(T, B)] _dtypes: [StorageType; 2], ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let mask = Vector::cast_from(mask[ABSOLUTE_POS]); let input = input[ABSOLUTE_POS]; let value = Vector::new(value.get::()); output[ABSOLUTE_POS] = select_many(mask, value, input); } #[derive(Clone, Copy, Debug)] /// Define how to run the mask fill kernel. /// /// # Notes /// /// All assertions should be done before choosing the strategy. pub enum MaskFillStrategy { /// Don't mutate any input. Readonly, /// Reuse the input tensor inplace. Inplace, } /// Execute the mask fill kernel with the given strategy. pub fn mask_fill( input: CubeTensor, mask: CubeTensor, value: InputScalar, strategy: MaskFillStrategy, dtype_bool: DType, ) -> CubeTensor { let ndims = input.meta.num_dims(); let output = match strategy { MaskFillStrategy::Readonly => empty_device_dtype( input.client.clone(), input.device.clone(), input.shape(), input.dtype, ), MaskFillStrategy::Inplace => input.clone(), }; let vector_size = max_vector_size_many(&[&input, &mask], ndims - 1); let working_units = input.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); let out_arg = match strategy { MaskFillStrategy::Readonly => output.clone().into_linear_view(), MaskFillStrategy::Inplace => output.as_linear_view_alias(0), }; let at = address_type!(input, mask, output); let mask = mask.into_linear_view_like(&input); unsafe { mask_fill_kernel::launch_unchecked( &output.client, cube_count, cube_dim, at, vector_size, input.into_linear_view(), mask, out_arg, value, [output.dtype.into(), dtype_bool.into()], ); } output } ================================================ FILE: crates/burn-cubecl/src/kernel/mask/mask_where.rs ================================================ use burn_backend::DType; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; use crate::{ CubeRuntime, kernel::utils::{address_type, broadcast_shape}, ops::{max_vector_size_many, numeric::empty_device_dtype}, tensor::CubeTensor, }; #[cube(launch, address_type = "dynamic")] fn mask_where_kernel( input: &LinearView>, value: &LinearView>, mask: &LinearView>, output: &mut LinearView, ReadWrite>, #[define(T, B)] _dtypes: [StorageType; 2], ) { let pos = ABSOLUTE_POS; if !output.is_in_bounds(pos) { terminate!(); } output[pos] = select_many(Vector::cast_from(mask[pos]), value[pos], input[pos]); } #[derive(Clone, Copy, Debug)] /// Define how to run the mask where kernel. /// /// # Notes /// /// All assertions should be done before choosing the strategy. pub enum MaskWhereStrategy { /// Don't mutate any input. Readonly, /// Reuse the lhs tensor inplace. InplaceLhs, /// Reuse the rhs tensor inplace. InplaceRhs, } /// Execute the mask where kernel with the given strategy. pub fn mask_where( input: CubeTensor, mask: CubeTensor, value: CubeTensor, strategy: MaskWhereStrategy, dtype_bool: DType, ) -> CubeTensor { let vector_size = max_vector_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1); let working_units = input.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); let out_shape = broadcast_shape(&[&input, &mask, &value]); let output = match strategy { MaskWhereStrategy::Readonly => empty_device_dtype( input.client.clone(), input.device.clone(), out_shape, input.dtype, ), MaskWhereStrategy::InplaceLhs => input.clone(), MaskWhereStrategy::InplaceRhs => value.clone(), }; let out = match strategy { MaskWhereStrategy::Readonly => output.clone().into_linear_view(), MaskWhereStrategy::InplaceLhs => output.as_linear_view_alias(0), MaskWhereStrategy::InplaceRhs => output.as_linear_view_alias(1), }; mask_where_kernel::launch( &output.client, cube_count, cube_dim, address_type!(input, value, mask, output), vector_size, input.into_linear_view_like(&output), value.into_linear_view_like(&output), mask.into_linear_view_like(&output), out, [output.dtype.into(), dtype_bool.into()], ); output } ================================================ FILE: crates/burn-cubecl/src/kernel/mask/mod.rs ================================================ mod base; mod mask_fill; mod mask_where; pub(crate) use base::*; pub use mask_fill::*; pub use mask_where::*; ================================================ FILE: crates/burn-cubecl/src/kernel/matmul/base.rs ================================================ use super::init_matmul_output; use crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor}; use burn_backend::{DType, QTensorPrimitive}; use burn_std::QuantLevel; use cubek::matmul::{ definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError}, launch::{MatmulInputBinding, Strategy}, }; #[cfg(feature = "autotune")] use super::matmul_autotune; /// The strategy to be used when launching a matmul kernel. pub enum MatmulStrategy { #[cfg(feature = "autotune")] /// Using autotune to choose the best kernel based on runtime information. Autotune, /// Cube implementation of matmul. Cube, } impl Default for MatmulStrategy { fn default() -> Self { // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] return MatmulStrategy::Autotune; #[cfg(not(feature = "autotune"))] MatmulStrategy::Cube } } /// Launch a matmul kernel using the given strategy. pub fn matmul( lhs: CubeTensor, rhs: CubeTensor, out: Option>, strategy: MatmulStrategy, out_dtype: DType, ) -> Result, MatmulSetupError> { match strategy { MatmulStrategy::Cube => { let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype)); launch_matmul(&Default::default(), lhs, rhs, out.clone())?; Ok(out) } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)), } } pub(crate) fn launch_matmul_naive( strategy: &Strategy, mut lhs: CubeTensor, mut rhs: CubeTensor, out: CubeTensor, ) -> Result<(), MatmulSetupError> { // Naive has very specific layout requirements for block scaled tensors, so we need to manually // dequantize if it fails to launch normally. This is because naive is assumed to always work. if lhs.qparams.is_some() || rhs.qparams.is_some() { match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) { Err(_) => { if lhs.qparams.is_some() { lhs = dequantize(lhs, out.dtype); } if rhs.qparams.is_some() { rhs = dequantize(rhs, out.dtype); } launch_matmul(strategy, lhs, rhs, out) } Ok(_) => Ok(()), } } else { launch_matmul(strategy, lhs, rhs, out) } } pub(crate) fn launch_matmul( strategy: &Strategy, lhs: CubeTensor, mut rhs: CubeTensor, out: CubeTensor, ) -> Result<(), MatmulSetupError> { let client = &out.client; let lhs_quant_handles = lhs.quantized_handles(); let out_dtype: DType = out.dtype; let (lhs_dtype, lhs_handle) = match lhs_quant_handles { None => { let lhs_dtype = lhs.dtype; ( lhs_dtype, MatmulInputBinding::new(lhs.binding(), lhs_dtype.into()), ) } Some((data, scale)) => { let scheme = *lhs.scheme(); let data_dtype = data.dtype; let scale_dtype = scale.dtype; ( out_dtype, MatmulInputBinding::quantized( data.binding(), scale.binding(), lhs.meta.shape().clone(), scheme, data_dtype.into(), scale_dtype.into(), ), ) } }; let rhs_quant_handles = rhs.quantized_handles(); let (rhs_dtype, rhs_handle) = match rhs_quant_handles { None => ( lhs_dtype, MatmulInputBinding::new(rhs.binding(), lhs_dtype.into()), ), Some((data, scale)) => { // Extremely hacky fix to ensure naive can run in every case if matches!(strategy, Strategy::Naive) && matches!(rhs.scheme().level, QuantLevel::Block(_)) { rhs = dequantize(rhs.clone(), lhs_dtype); let rhs_dtype = rhs.dtype; ( lhs_dtype, MatmulInputBinding::new(rhs.binding(), rhs_dtype.into()), ) } else { let scheme = *rhs.scheme(); let data_dtype = data.dtype; let scale_dtype = scale.dtype; ( out_dtype, MatmulInputBinding::quantized( data.binding(), scale.binding(), rhs.meta.shape().clone(), scheme, data_dtype.into(), scale_dtype.into(), ), ) } } }; let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems { lhs: lhs_dtype.into(), rhs: rhs_dtype.into(), out: out_dtype.into(), }); cubek::matmul::launch::launch_ref( strategy, client, lhs_handle, rhs_handle, out.clone().binding(), &mut dtypes, )?; Ok(()) } ================================================ FILE: crates/burn-cubecl/src/kernel/matmul/mod.rs ================================================ mod base; mod tune; /// Contains utilities for matmul operation pub mod utils; pub use base::*; #[cfg(feature = "autotune")] pub use tune::*; pub use utils::*; ================================================ FILE: crates/burn-cubecl/src/kernel/matmul/tune/base.rs ================================================ use crate::{ CubeRuntime, CubeTuneId, kernel::matmul::{launch_matmul, launch_matmul_naive, utils::init_matmul_output}, tensor::CubeTensor, }; use burn_backend::DType; use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}; use cubek::matmul::{ definition::MatmulKind, launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering}, routines::{ BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs, double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs, simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs, }, }; fn matmul_input_gen( _key: &MatmulAutotuneKey, lhs: &CubeTensor, rhs: &CubeTensor, out: &CubeTensor, ) -> (CubeTensor, CubeTensor, CubeTensor) { (lhs.clone(), rhs.clone(), out.copy()) } /// Executes autotune on matmul operations pub fn matmul_autotune( lhs: CubeTensor, rhs: CubeTensor, out: Option>, out_dtype: DType, ) -> CubeTensor { let output = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype)); let client = lhs.client.clone(); static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 3; const PRIORITY_HIGH: i8 = 2; const PRIORITY_MEDIUM: i8 = 1; const PRIORITY_MIN: i8 = 0; const PRIORITY_NEVER: i8 = -1; let cmma = TuneGroup::::new("cmma", |key| { if matches!( key.analysis.kind, MatmulKind::General // Those variants are just because the unit alternatives aren't very good yet. | MatmulKind::VecMat | MatmulKind::MatVec ) { PRIORITY_HIGH } else { PRIORITY_MEDIUM } }); let mma = TuneGroup::::new("mma", |key| { if matches!( key.analysis.kind, // General is usually bad, but I think shapes like 16x8196 would be classed as // general and are very good with MMA // Should highly degenerated matrices that aren't VecMat have their own class? MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec ) { PRIORITY_HIGH } else { PRIORITY_MEDIUM } }); let unit = TuneGroup::::new("unit", |key| { if !matches!(key.analysis.kind, MatmulKind::General) || matches!(key.analysis.scale_global, MatmulGlobalScale::Small) { PRIORITY_HIGH } else { PRIORITY_MIN } }); let tma = TuneGroup::::new("tma", |key| { // For large matmul, we set the max priority to TMA kernels, higher than any other // matmuls, since they are the best kernels no matter what. // // But only when all axis are large. let max_axis = usize::max(key.definition.m, key.definition.n); let max_axis = usize::max(key.definition.k, max_axis); let min_axis = usize::min(key.definition.m, key.definition.n); let min_axis = usize::min(key.definition.k, min_axis); let skewed_factor = max_axis / min_axis; let priority_max = if matches!(key.analysis.kind, MatmulKind::General) && matches!(key.analysis.scale_global, MatmulGlobalScale::Large) && skewed_factor < 4 { PRIORITY_MAX } else { PRIORITY_HIGH }; if key.definition.lhs_stride_factor >= 4 && key.definition.rhs_stride_factor >= 4 { priority_max } else { PRIORITY_NEVER } }); fn double_buffering_priority(key: &MatmulAutotuneKey, max: i8, min: i8) -> i8 { if should_tune_double_buffering(false, key) { max } else { min } } let mut set = TunableSet::new(create_key::, matmul_input_gen::); // First entry should always work, since it is considered the fallback. set = set.with( Tunable::new("matmul_naive", |lhs, rhs, out| { launch_matmul_naive::(&Strategy::Naive, lhs, rhs, out) .map_err(|err| std::format!("{err:?}")) }) .group(&unit, |key| { if matches!(key.analysis.scale_global, MatmulGlobalScale::Small) || matches!(key.analysis.kind, MatmulKind::InnerProduct) { PRIORITY_MAX } else { PRIORITY_MIN } }), ); // Unit VecMat for (strategy, double_buf) in [ ( Strategy::SimpleVecMat(BlueprintStrategy::Inferred(().into())), false, ), ( Strategy::DoubleVecMat(BlueprintStrategy::Inferred(().into())), true, ), ] { set = set.with( Tunable::new(strategy.to_string(), move |lhs, rhs, out| { launch_matmul::(&strategy, lhs, rhs, out) .map_err(|err| std::format!("{err:?}")) }) .group(&unit, move |key| match double_buf { false => PRIORITY_MAX, true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH), }), ); } // Unit matmuls for tile_size in [ TileSizeSelection::MaxTileSize, TileSizeSelection::MinTileSize, ] { for (strategy, double_buf) in [ ( Strategy::SimpleUnit(BlueprintStrategy::Inferred(SimpleUnitSelectionArgs { tile_size, })), false, ), ( Strategy::DoubleUnit(BlueprintStrategy::Inferred(DoubleUnitSelectionArgs { tile_size, })), true, ), ] { set = set.with( Tunable::new(strategy.to_string(), move |lhs, rhs, out| { launch_matmul::(&strategy, lhs, rhs, out) .map_err(|err| format!("{err:?}")) }) .group(&unit, move |key| match double_buf { false => PRIORITY_MAX, true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH), }), ) } } // Accelerated matmuls for (strategy, double_buf, group_extra, tile_group) in [ ( Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: false, })), false, None, &cmma, ), ( Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: false, })), false, None, &mma, ), ( Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: true, })), false, None, &cmma, ), ( Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: true, })), false, None, &mma, ), ( Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs { partition_k: Some(2), row_count: Some(4), rows_per_plane: Some(2), })), true, None, &cmma, ), ( Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs { partition_k: Some(2), row_count: Some(4), rows_per_plane: Some(2), })), true, None, &mma, ), ( Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs { partition_k: Some(2), row_count: Some(8), rows_per_plane: Some(2), })), true, None, &cmma, ), ( Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs { partition_k: Some(2), row_count: Some(8), rows_per_plane: Some(2), })), true, None, &mma, ), ( Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized: false, })), true, None, &cmma, ), ( Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized: false, })), true, None, &mma, ), ( Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized: true, })), true, None, &cmma, ), ( Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized: true, })), true, None, &mma, ), ( Strategy::SpecializedCyclicCmma(BlueprintStrategy::Inferred(().into())), true, None, &cmma, ), ( Strategy::SpecializedCyclicMma(BlueprintStrategy::Inferred(().into())), true, None, &mma, ), ( Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: false, })), false, Some(&tma), &cmma, ), ( Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: false, })), false, Some(&tma), &mma, ), ( Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: true, })), false, Some(&tma), &cmma, ), ( Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs { multi_rows: true, })), false, Some(&tma), &mma, ), ( Strategy::SpecializedTmaCmma(BlueprintStrategy::Inferred(().into())), true, Some(&tma), &cmma, ), ( Strategy::SpecializedTmaMma(BlueprintStrategy::Inferred(().into())), true, Some(&tma), &mma, ), ] { let priority_within_group = |key: &MatmulAutotuneKey, double_buf: bool| match double_buf { false => PRIORITY_MAX, true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH), }; let mut tunable = Tunable::new(strategy.to_string(), move |lhs, rhs, out| { launch_matmul::(&strategy, lhs, rhs, out).map_err(|err| format!("{err:?}")) }); // tile group tunable = tunable.group(tile_group, move |key| { priority_within_group(key, double_buf) }); // extra group if let Some(group) = group_extra { tunable = tunable.group(group, move |key| priority_within_group(key, double_buf)); } set = set.with(tunable); } set }); TUNER.execute( &CubeTuneId::new(&lhs.client, &lhs.device), &client, tunables, (lhs, rhs, output.clone()), ); output } fn create_key( lhs: &CubeTensor, rhs: &CubeTensor, out: &CubeTensor, ) -> MatmulAutotuneKey { MatmulAutotuneKey::generate( &lhs.client, lhs.meta.shape(), rhs.meta.shape(), lhs.meta.strides(), rhs.meta.strides(), lhs.dtype.into(), rhs.dtype.into(), out.dtype.into(), lhs.try_scheme(), rhs.try_scheme(), ) } ================================================ FILE: crates/burn-cubecl/src/kernel/matmul/tune/mod.rs ================================================ #[cfg(feature = "autotune")] mod base; #[cfg(feature = "autotune")] pub use base::matmul_autotune; ================================================ FILE: crates/burn-cubecl/src/kernel/matmul/utils.rs ================================================ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::{DType, calculate_matmul_output}; /// Creates an empty output tensor with matmul output shape pub fn init_matmul_output( lhs: &CubeTensor, rhs: &CubeTensor, dtype: DType, ) -> CubeTensor { empty_device_dtype( lhs.client.clone(), lhs.device.clone(), calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(), dtype, ) } ================================================ FILE: crates/burn-cubecl/src/kernel/mod.rs ================================================ mod binary; mod binary_float; mod binary_int; mod cast; mod clamp; mod comparison; mod contiguous; mod cross; mod index; mod mask; mod unary_float; mod unary_int; mod unary_numeric; pub(crate) use binary::*; pub(crate) use binary_float::*; pub(crate) use binary_int::*; pub use cast::*; pub use contiguous::*; pub(crate) use cross::*; pub use mask::*; pub(crate) use unary_float::*; pub(crate) use unary_int::*; pub(crate) use unary_numeric::*; pub use crate::cubecl::prelude::KernelMetadata; /// Attention kernels pub mod attention; /// Convolution kernels pub mod conv; /// Grid sampling kernels pub mod grid_sample; /// Interpolation kernels pub mod interpolate; /// Matmul kernels pub mod matmul; /// Pooling kernels pub mod pool; /// Pseudo-random number generator kernels pub mod prng; /// Quantization operations pub mod quantization; /// Reduction algorithms pub mod reduce; pub(crate) use clamp::*; pub(crate) use comparison::*; pub use index::*; pub(crate) mod utils; ================================================ FILE: crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs ================================================ use crate::{ CubeRuntime, kernel::{ into_contiguous_aligned, pool::pool2d::{Position, view4d}, utils::{address_type, decompose_linear, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::Shape; use cubecl::{ calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::{FastDivmod, tensor::View}, }; #[cube(launch, address_type = "dynamic")] fn adaptive_avg_pool2d_direct( input: &Tensor>, output: &mut View, Position, ReadWrite>, out_shape: Sequence>, working_units: usize, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= working_units { terminate!(); } let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape); let [b, oh, ow, c] = *pos else { unreachable!() }; let (_, out_h, out_w, _) = output.shape(); let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2)); let (in_h, in_w) = (input.shape(1), input.shape(2)); let ih_start = start_index(oh, out_h, in_h); let ih_end = end_index(oh, out_h, in_h); let iw_start = start_index(ow, out_w, in_w); let iw_end = end_index(ow, out_w, in_w); let mut sum = Vector::zero(); let index_input_base = b * input.stride(0) + c * input.stride(3); for ih in ih_start..ih_end { let index_input_2 = ih * in_stride_h; for iw in iw_start..iw_end { let index_input_3 = iw * in_stride_w; let index_input = index_input_base + index_input_2 + index_input_3; sum += input[index_input / input.vector_size()]; } } let num_ih = ih_end - ih_start; let num_iw = iw_end - iw_start; output[(b, oh, ow, c)] = sum / Vector::cast_from(num_ih * num_iw); } #[cube] fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); if input_size < index { input_size } else { index } } pub(crate) fn adaptive_avg_pool2d( input: CubeTensor, output_size: [usize; 2], ) -> CubeTensor { let [batch_size, channels, _, _] = input.meta.shape().dims(); let input = into_contiguous_aligned(permute_nchw_to_nhwc(input)); let vector_size = max_vector_size(&input); let output_shape = Shape::new([batch_size, output_size[0], output_size[1], channels]); let num_elems: usize = output_shape.num_elements(); let output = empty_device_dtype( input.client.clone(), input.device.clone(), output_shape, input.dtype, ); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); adaptive_avg_pool2d_direct::launch( &output.client, cube_count, cube_dim, address_type!(input, output), vector_size, input.into_tensor_arg(), view4d(output.clone(), vector_size), shape_divmod(&output), working_units, output.dtype.into(), ); permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs ================================================ use crate::{ CubeRuntime, kernel::{ into_contiguous_aligned, pool::pool2d::{Position, view4d}, utils::{address_type, decompose_linear, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::Shape; use cubecl::{ calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::{FastDivmod, tensor::View}, }; #[cube(launch, address_type = "dynamic")] fn adaptive_avg_pool2d_backward_direct( grad: &Tensor>, output: &mut View, Position, ReadWrite>, out_shape: Sequence>, working_units: usize, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= working_units { terminate!(); } let (_, out_h, out_w, _) = output.shape(); let (grad_stride_h, grad_stride_w) = (grad.stride(1), grad.stride(2)); let (grad_h, grad_w) = (grad.shape(1), grad.shape(2)); let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape); let [b, ih, iw, c] = *pos else { unreachable!() }; let oh_start = start_index(ih, out_h, grad_h); let oh_end = end_index(ih, out_h, grad_h); let ow_start = start_index(iw, out_w, grad_w); let ow_end = end_index(iw, out_w, grad_w); let mut grad_acc = Vector::zero(); let index_base = b * grad.stride(0) + (c * grad.stride(3)); for oh in oh_start..oh_end { let ih_start = start_index(oh, grad_h, out_h); let ih_end = end_index(oh, grad_h, out_h); if ih >= ih_start && ih < ih_end { for ow in ow_start..ow_end { let iw_start = start_index(ow, grad_w, out_w); let iw_end = end_index(ow, grad_w, out_w); if iw >= iw_start && iw < iw_end { let num_ih = ih_end - ih_start; let num_iw = iw_end - iw_start; let index = index_base + (oh * grad_stride_h) + (ow * grad_stride_w); grad_acc += grad[index / grad.vector_size()] / Vector::cast_from(num_iw * num_ih); } } } } output[(b, ih, iw, c)] = grad_acc; } #[cube] fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); if input_size < index { input_size } else { index } } pub(crate) fn adaptive_avg_pool2d_backward( x: CubeTensor, out_grad: CubeTensor, ) -> CubeTensor { let [batches, channels, height, width] = x.meta.shape().dims(); let out_grad = into_contiguous_aligned(permute_nchw_to_nhwc(out_grad)); let vector_size = max_vector_size(&out_grad); let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); let num_elems = output.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); adaptive_avg_pool2d_backward_direct::launch( &output.client, cube_count, cube_dim, address_type!(out_grad, output), vector_size, out_grad.into_tensor_arg(), view4d(output.clone(), vector_size), shape_divmod(&output), working_units, output.dtype.into(), ); permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs ================================================ use super::pool2d::{ Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct, }; use crate::{ CubeRuntime, kernel::{ into_contiguous_aligned, pool::pool2d::{Position, view4d}, utils::{address_type, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::{Shape, ops::conv::calculate_pool_output_size}; use cubecl::{CubeDim, calculate_cube_count_elemwise, num_traits::Zero}; use cubecl::{prelude::*, std::tensor::View}; struct AvgPoolStrategy; impl Pool2dDirectStrategyFamily for AvgPoolStrategy { type Indices = (); type Config = AvgPoolStrategyConfig; type Pool2d = Self; } #[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)] pub struct AvgPoolStrategyConfig { count_include_pad: bool, /// Total padded height (input_height + 2 * padding_0) padded_h: u32, /// Total padded width (input_width + 2 * padding_1) padded_w: u32, } #[cube] impl Pool2dDirectStrategy for AvgPoolStrategy { type Accumulator = (Vector, u32); type Config = AvgPoolStrategyConfig; type Indices = (); fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator { let sum = Vector::zero(); // Count will be set dynamically: either by accumulate (count_include_pad=false) // or by set_padded_count (count_include_pad=true) let count = 0u32; (sum, count) } fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, _index: usize, result: Vector, ) { let (sum, count) = accumulator; // Only count valid positions when count_include_pad=false if comptime![!config.count_include_pad] { *count += 1; } *sum += result; } fn count_position( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, ih: u32, iw: u32, ) { // When count_include_pad=true, count positions within padded bounds // (excludes ceil_mode extensions beyond the padded input) if comptime![config.count_include_pad] && ih < config.padded_h && iw < config.padded_w { let (_sum, count) = accumulator; *count += 1; } } fn store( #[comptime] _config: &Self::Config, position: Position, output: &mut View, Position, ReadWrite>, _output_indices: &mut (), accumulator: Self::Accumulator, ) { let (sum, count) = accumulator; output[position] = sum / Vector::cast_from(count); } } pub(crate) fn avg_pool2d( x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> CubeTensor { let [batch_size, channels, in_h, in_w] = x.meta.shape().dims(); let dilation = 1; let size_0 = calculate_pool_output_size( kernel_size[0], stride[0], padding[0], dilation, in_h, ceil_mode, ); let size_1 = calculate_pool_output_size( kernel_size[1], stride[1], padding[1], dilation, in_w, ceil_mode, ); // Padded dimensions (for count_include_pad with ceil_mode) let padded_0 = in_h + 2 * padding[0]; let padded_1 = in_w + 2 * padding[1]; let x = into_contiguous_aligned(permute_nchw_to_nhwc(x)); let vector_size = max_vector_size(&x); let shape_out = Shape::new([batch_size, size_0, size_1, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); pool2d_direct::launch::( &output.client, cube_count, cube_dim, address_type!(x, output), vector_size, x.into_tensor_arg(), view4d(output.clone(), vector_size), (), shape_divmod(&output), working_units, Pool2dDirectArgsLaunch::new( stride[0] as u32, stride[1] as u32, dilation as u32, dilation as u32, padding[0] as u32, padding[1] as u32, ), (kernel_size[0] as u32, kernel_size[1] as u32), AvgPoolStrategyConfig { count_include_pad, padded_h: padded_0 as u32, padded_w: padded_1 as u32, }, output.dtype.into(), ); permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs ================================================ use crate::{ CubeRuntime, kernel::{ pool::pool2d::{Position, view4d}, utils::{address_type, decompose_linear, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::Shape; use cubecl::{ calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::{FastDivmod, tensor::View}, }; #[derive(CubeLaunch, CubeType)] pub(crate) struct PoolBackwardArgs { pub stride_0: i32, pub stride_1: i32, pub dilation_0: i32, pub dilation_1: i32, pub padding_0: i32, pub padding_1: i32, } #[cube(launch_unchecked, address_type = "dynamic")] fn avg_pool2d_backward_kernel( grad: &Tensor>, output: &mut View, Position, ReadWrite>, out_shape: Sequence>, working_units: usize, args: &PoolBackwardArgs, #[comptime] kernel_size_0: i32, #[comptime] kernel_size_1: i32, #[comptime] count_include_pad: bool, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= working_units { terminate!(); } let vector_size = grad.vector_size(); let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape); let [batch, ih, iw, channel] = *pos else { unreachable!() }; let mut grad_acc = Vector::zero(); let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, grad.shape(1) as u32, grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, ); let padding_0 = args.padding_0 as u32; let padding_1 = args.padding_1 as u32; let stride_0 = args.stride_0 as u32; let stride_1 = args.stride_1 as u32; let kernel_size_0 = comptime![kernel_size_0 as u32]; let kernel_size_1 = comptime![kernel_size_1 as u32]; let index_base = batch * grad.stride(0) + channel * grad.stride(3); let border_bottom = output.shape().1 as u32 + padding_0; let border_right = output.shape().2 as u32 + padding_1; let begin_h = ih as u32 + padding_0; let begin_w = iw as u32 + padding_1; for oh in oh_start..oh_end { let ih_start = oh * stride_0; let ih_end = clamp_max(ih_start + kernel_size_0, border_bottom); let ih_start = clamp_min(ih_start, padding_0); if begin_h >= ih_start && (ih as u32) < ih_end { for ow in ow_start..ow_end { let index = index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let iw_start = ow * stride_1; let iw_end = clamp_max(iw_start + kernel_size_1, border_right); let iw_start = clamp_min(iw_start, padding_1); if begin_w >= iw_start && (iw as u32) < iw_end { if count_include_pad { grad_acc += grad[index / vector_size] / Vector::cast_from(kernel_size_0 * kernel_size_1); } else { let ih_diff = ih_end - ih_start; let iw_diff = iw_end - iw_start; let count = Vector::cast_from(ih_diff * iw_diff); grad_acc += grad[index / vector_size] / count; } } } } } output[(batch, ih, iw, channel)] = grad_acc; } #[cube] fn loop_ranges( ih: i32, iw: i32, grad_h: u32, grad_w: u32, args: &PoolBackwardArgs, #[comptime] kernel_size_0: i32, #[comptime] kernel_size_1: i32, ) -> (u32, u32, u32, u32) { let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0; let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1; let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32; let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32; let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1; let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1; (oh_start, oh_end, ow_start, ow_end) } pub(crate) fn avg_pool2d_backward( x: CubeTensor, grad: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, _ceil_mode: bool, ) -> CubeTensor { let [batches, channels, height, width] = x.meta.shape().dims(); let grad = permute_nchw_to_nhwc(grad); let vector_size = if x.meta.strides()[3] == grad.meta.strides()[3] { max_vector_size(&x) } else { 1 }; let dilation = 1; let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); unsafe { avg_pool2d_backward_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(grad, output), vector_size, grad.into_tensor_arg(), view4d(output.clone(), vector_size), shape_divmod(&output), working_units, PoolBackwardArgsLaunch::new( stride[0] as i32, stride[1] as i32, dilation, dilation, padding[0] as i32, padding[1] as i32, ), kernel_size[0] as i32, kernel_size[1] as i32, count_include_pad, output.dtype.into(), ) }; permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/max_pool2d.rs ================================================ use super::pool2d::{ Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct, }; use crate::{ CubeRuntime, kernel::{ into_contiguous_aligned, pool::pool2d::{Position, view4d}, utils::{address_type, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::{DType, Shape, ops::conv::calculate_pool_output_size}; use cubecl::{ CubeDim, calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::tensor::View, }; struct MaxPoolStrategy; struct MaxPoolWithIndicesStrategy; impl Pool2dDirectStrategyFamily for MaxPoolStrategy { type Indices = (); type Config = (); type Pool2d = Self; } impl Pool2dDirectStrategyFamily for MaxPoolWithIndicesStrategy { type Indices = View, Position, ReadWrite>; type Config = (); type Pool2d = Self; } #[cube] impl Pool2dDirectStrategy for MaxPoolStrategy { type Accumulator = Vector; type Config = (); type Indices = (); fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator { Vector::new(T::min_value()) } fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, _index: VectorSize, result: Vector, ) { *accumulator = max(*accumulator, result); } fn count_position( #[comptime] _config: &Self::Config, _accumulator: &mut Self::Accumulator, _ih: u32, _iw: u32, ) { } fn store( #[comptime] _config: &Self::Config, position: Position, output: &mut View, Position, ReadWrite>, _output_indices: &mut (), accumulator: Self::Accumulator, ) { output[position] = accumulator; } } #[cube] impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { type Accumulator = (Vector, Vector); type Config = (); type Indices = View, Position, ReadWrite>; fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator { let val = Vector::new(T::min_value()); let idx = Vector::zero(); (val, idx) } fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, index: usize, result: Vector, ) { let indices = Vector::cast_from(index); accumulator.1 = select_many(result.greater_than(accumulator.0), indices, accumulator.1); accumulator.0 = max(result, accumulator.0); } fn count_position( #[comptime] _config: &Self::Config, _accumulator: &mut Self::Accumulator, _ih: u32, _iw: u32, ) { } fn store( #[comptime] _config: &Self::Config, position: Position, output: &mut View, Position, ReadWrite>, output_indices: &mut View, Position, ReadWrite>, accumulator: Self::Accumulator, ) { output[position] = accumulator.0; output_indices[position] = accumulator.1; } } pub(crate) fn max_pool2d( x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> CubeTensor { let [batch_size, channels, height, width] = x.meta.shape().dims(); let size_0 = calculate_pool_output_size( kernel_size[0], stride[0], padding[0], dilation[0], height, ceil_mode, ); let size_1 = calculate_pool_output_size( kernel_size[1], stride[1], padding[1], dilation[1], width, ceil_mode, ); let x = into_contiguous_aligned(permute_nchw_to_nhwc(x)); let vector_size = max_vector_size(&x); let shape_out = Shape::new([batch_size, size_0, size_1, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); pool2d_direct::launch::( &output.client, cube_count, cube_dim, address_type!(x, output), vector_size, x.into_tensor_arg(), view4d(output.clone(), vector_size), (), shape_divmod(&output), working_units, Pool2dDirectArgsLaunch::new( stride[0] as u32, stride[1] as u32, dilation[0] as u32, dilation[1] as u32, padding[0] as u32, padding[1] as u32, ), (kernel_size[0] as u32, kernel_size[1] as u32), (), output.dtype.into(), ); permute_nhwc_to_nchw(output) } pub(crate) fn max_pool2d_with_indices( x: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, dtype_indices: DType, ) -> (CubeTensor, CubeTensor) { let [batch_size, channels, size_0, size_1] = x.meta.shape().dims(); let size_0 = calculate_pool_output_size( kernel_size[0], stride[0], padding[0], dilation[0], size_0, ceil_mode, ); let size_1 = calculate_pool_output_size( kernel_size[1], stride[1], padding[1], dilation[1], size_1, ceil_mode, ); let x = into_contiguous_aligned(permute_nchw_to_nhwc(x)); let vector_size = max_vector_size(&x); let shape_out = Shape::new([batch_size, size_0, size_1, channels]); let output = empty_device_dtype( x.client.clone(), x.device.clone(), shape_out.clone(), x.dtype, ); let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); pool2d_direct::launch::( &output.client, cube_count, cube_dim, address_type!(x, output, indices), vector_size, x.into_tensor_arg(), view4d(output.clone(), vector_size), view4d(indices.clone(), vector_size), shape_divmod(&output), working_units, Pool2dDirectArgsLaunch::new( stride[0] as u32, stride[1] as u32, dilation[0] as u32, dilation[1] as u32, padding[0] as u32, padding[1] as u32, ), (kernel_size[0] as u32, kernel_size[1] as u32), (), output.dtype.into(), ); let output = permute_nhwc_to_nchw(output); let indices = permute_nhwc_to_nchw(indices); (output, indices) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs ================================================ use crate::{ CubeRuntime, kernel::{ into_contiguous_aligned, utils::{address_type, decompose_linear, shape_divmod}, }, ops::{ max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw, }, tensor::CubeTensor, }; use burn_backend::Shape; use cubecl::{calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::FastDivmod}; use super::{PoolBackwardArgs, PoolBackwardArgsLaunch}; #[cube(launch_unchecked, address_type = "dynamic")] fn max_pool2d_with_indices_backward_kernel( grad: &Tensor>, indices: &Tensor>, output: &mut Tensor>, out_shape: Sequence>, working_units: usize, args: &PoolBackwardArgs, #[comptime] kernel_size_0: i32, #[comptime] kernel_size_1: i32, #[define(E, I)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= working_units { terminate!(); } let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape); let [batch, ih, iw, channel] = *pos else { unreachable!() }; let vector_size = grad.vector_size(); let index_current = ih * output.shape(2) + iw; let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, grad.shape(1) as u32, grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, ); let mut grad_acc = Vector::zero(); let grad_idx_base = batch * grad.stride(0) + channel * grad.stride(3); let ind_idx_base = batch * indices.stride(0) + channel * indices.stride(3); for oh in oh_start..oh_end { for ow in ow_start..ow_end { let grad_index = grad_idx_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let indices_index = ind_idx_base + oh as usize * indices.stride(1) + ow as usize * indices.stride(2); let index_max = Vector::::cast_from(indices[indices_index / vector_size]); grad_acc += select_many( index_max.equal(Vector::cast_from(index_current)), grad[grad_index / vector_size], Vector::zero(), ); } } let index_output = batch * output.stride(0) + ih * output.stride(1) + iw * output.stride(2) + channel * output.stride(3); output[index_output / output.vector_size()] = grad_acc; } #[cube] fn loop_ranges( ih: i32, iw: i32, grad_h: u32, grad_w: u32, args: &PoolBackwardArgs, #[comptime] kernel_size_0: i32, #[comptime] kernel_size_1: i32, ) -> (u32, u32, u32, u32) { let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0; let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1; let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32; let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32; let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1; let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1; (oh_start, oh_end, ow_start, ow_end) } #[allow(clippy::too_many_arguments)] pub(crate) fn max_pool2d_with_indices_backward( x: CubeTensor, grad: CubeTensor, indices: CubeTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], _ceil_mode: bool, ) -> CubeTensor { let [batches, channels, height, width] = x.meta.shape().dims(); let grad = into_contiguous_aligned(permute_nchw_to_nhwc(grad)); let indices = into_contiguous_aligned(permute_nchw_to_nhwc(indices)); let vector_size = if grad.meta.strides()[3] == indices.meta.strides()[3] { max_vector_size(&grad) } else { 1 }; let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); let working_units = output.meta.num_elements() / vector_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); let indices_dtype = indices.dtype; let x_dtype = x.dtype; unsafe { max_pool2d_with_indices_backward_kernel::launch_unchecked( &output.client, cube_count, cube_dim, address_type!(grad, indices, output), vector_size, grad.into_tensor_arg(), indices.into_tensor_arg(), output.clone().into_tensor_arg(), shape_divmod(&output), working_units, PoolBackwardArgsLaunch::new( stride[0] as i32, stride[1] as i32, dilation[0] as i32, dilation[1] as i32, padding[0] as i32, padding[1] as i32, ), kernel_size[0] as i32, kernel_size[1] as i32, [x_dtype.into(), indices_dtype.into()], ) }; permute_nhwc_to_nchw(output) } ================================================ FILE: crates/burn-cubecl/src/kernel/pool/mod.rs ================================================ mod adaptive_avg_pool2d; mod adaptive_avg_pool2d_backward; mod avg_pool2d; mod avg_pool2d_backward; mod max_pool2d; mod max_pool2d_backward; pub(super) mod pool2d; pub(crate) use adaptive_avg_pool2d::*; pub(crate) use adaptive_avg_pool2d_backward::*; pub(crate) use avg_pool2d::*; pub(crate) use avg_pool2d_backward::*; pub(crate) use max_pool2d::*; pub(crate) use max_pool2d_backward::*; ================================================ FILE: crates/burn-cubecl/src/kernel/pool/pool2d.rs ================================================ use core::hash::Hash; use cubecl::{ prelude::*, std::{ FastDivmod, tensor::{ View, launch::ViewArg, layout::fixed_dim::{FixedDimLayout, FixedDimLayoutLaunch}, }, }, }; use crate::{CubeRuntime, kernel::utils::decompose_linear, tensor::CubeTensor}; pub trait Pool2dDirectStrategyFamily: Send + Sync + 'static { type Indices: LaunchArg; type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq; type Pool2d: Pool2dDirectStrategy>; } pub(super) type Position = (usize, usize, usize, usize); #[cube] pub(crate) trait Pool2dDirectStrategy: Send + Sync + 'static { type Accumulator: CubeType; type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq; type Indices: LaunchArg; fn initialize(#[comptime] config: &Self::Config) -> Self::Accumulator; fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, index: usize, result: Vector, ); /// Count a position within the kernel window (for avg_pool count_include_pad). /// Called for each position in the kernel window with the current ih/iw coordinates. /// Only avg_pool uses this; max_pool implements as no-op. fn count_position( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, ih: u32, iw: u32, ); fn store( #[comptime] config: &Self::Config, position: Position, output: &mut View, Position, ReadWrite>, output_indices: &mut Self::Indices, accumulator: Self::Accumulator, ); } #[derive(CubeLaunch, CubeType)] pub struct Pool2dDirectArgs { pub strides_0: u32, pub strides_1: u32, pub dilation_0: u32, pub dilation_1: u32, pub padding_0: u32, pub padding_1: u32, } #[cube(launch, address_type = "dynamic")] pub fn pool2d_direct( input: &Tensor>, output: &mut View, Position, ReadWrite>, indices: &mut S::Indices, out_shape: Sequence>, working_units: usize, args: &Pool2dDirectArgs, #[comptime] kernel_size: (u32, u32), #[comptime] config: &S::Config, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= working_units { terminate!(); } let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape); let [b, oh, ow, c] = *pos else { unreachable!() }; let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2)); let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32); let mut accumulator = S::Pool2d::::initialize(config); let in_b_off = b * input.stride(0); let in_c_off = c * input.stride(3); let border_bottom = in_h + args.padding_0; let border_right = in_w + args.padding_1; for kh in 0..kernel_size.0 { let ih = oh as u32 * args.strides_0 + kh * args.dilation_0; let within_padding_h = ih >= args.padding_0 && ih < border_bottom; for kw in 0..kernel_size.1 { let iw = ow as u32 * args.strides_1 + kw * args.dilation_1; let within_padding_w = iw >= args.padding_1 && iw < border_right; // Let strategy handle position counting (only used by avg_pool) S::Pool2d::::count_position(config, &mut accumulator, ih, iw); // Only accumulate values from valid input positions if within_padding_h && within_padding_w { let ih_pad = ih - args.padding_0; let iw_pad = iw - args.padding_1; let in_h_off = ih_pad as usize * in_stride_h; let in_w_off = iw_pad as usize * in_stride_w; let index_input = in_b_off + in_c_off + in_h_off + in_w_off; S::Pool2d::::accumulate( config, &mut accumulator, ih_pad as usize * in_w as usize + iw_pad as usize, input[index_input / input.vector_size()], ); } } } S::Pool2d::::store(config, (b, oh, ow, c), output, indices, accumulator); } pub(super) fn view4d( tensor: CubeTensor, vector_size: VectorSize, ) -> ViewArg { let shape = tensor.meta.shape(); let shape = (shape[0], shape[1], shape[2], shape[3]); let binding = tensor.binding(); let layout = FixedDimLayoutLaunch::::from_shape_handle_unchecked( &binding, shape, vector_size, ); let buffer = binding.into_tensor_arg(); ViewArg::new_tensor::>(buffer, layout) } ================================================ FILE: crates/burn-cubecl/src/kernel/prng/bernoulli.rs ================================================ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::{DType, Shape}; /// Pseudo-random generator with bernoulli distribution pub fn random_bernoulli( shape: Shape, device: &R::Device, probability: f32, dtype: DType, ) -> CubeTensor { let client = R::client(device); let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype); cubek::random::random_bernoulli(&client, probability, output.clone().binding(), dtype.into()) .expect("Kernel to never fail"); output } ================================================ FILE: crates/burn-cubecl/src/kernel/prng/mod.rs ================================================ mod bernoulli; mod normal; mod uniform; pub use bernoulli::*; pub use normal::*; pub use uniform::*; ================================================ FILE: crates/burn-cubecl/src/kernel/prng/normal.rs ================================================ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::{DType, Shape}; /// Pseudo-random generator with uniform distribution pub fn random_normal( shape: Shape, device: &R::Device, mean: f32, std: f32, dtype: DType, ) -> CubeTensor { let client = R::client(device); let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype); cubek::random::random_normal(&client, mean, std, output.clone().binding(), dtype.into()) .expect("Kernel to never fail"); output } ================================================ FILE: crates/burn-cubecl/src/kernel/prng/uniform.rs ================================================ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::{DType, Shape, TensorMetadata}; /// Pseudo-random generator with uniform distribution pub fn random_uniform( shape: Shape, device: &R::Device, lower_bound: f32, upper_bound: f32, dtype: DType, ) -> CubeTensor { let client = R::client(device); let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype); cubek::random::random_uniform( &client, lower_bound, upper_bound, output.clone().binding(), dtype.into(), ) .expect("Kernel to never fail"); output } /// Pseudo-random generator for uniform distribution, based on /// another tensor. pub fn random_like_uniform( tensor: &CubeTensor, lower_bound: f32, upper_bound: f32, dtype: DType, ) -> CubeTensor { random_uniform( tensor.shape(), &tensor.device, lower_bound, upper_bound, dtype, ) } ================================================ FILE: crates/burn-cubecl/src/kernel/quantization/dequantize.rs ================================================ use crate::tensor::CubeTensor; use crate::{CubeRuntime, ops::numeric::empty_device_dtype}; use burn_backend::{DType, TensorMetadata}; /// Convert the tensor back to a higher precision data type. pub fn dequantize(tensor: CubeTensor, dtype: DType) -> CubeTensor where R: CubeRuntime, { let scheme = match tensor.dtype { DType::QFloat(scheme) => scheme, _ => return tensor, }; let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), dtype, ); let (values, params) = tensor.quantized_handles().unwrap(); cubek::quantization::dequantize::launch_ref( &output.client, values.binding(), output.clone().binding(), params.binding(), &scheme, dtype.into(), ) .expect("Kernel to never fail"); output } ================================================ FILE: crates/burn-cubecl/src/kernel/quantization/mod.rs ================================================ mod dequantize; mod quantize; pub use dequantize::*; pub use quantize::*; ================================================ FILE: crates/burn-cubecl/src/kernel/quantization/quantize.rs ================================================ use crate::CubeRuntime; use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor}; use burn_backend::{TensorMetadata, quantization::QuantScheme}; /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( tensor: CubeTensor, scheme: &QuantScheme, scale: CubeTensor, ) -> CubeTensor where R: CubeRuntime, { let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device); let (out_values, out_params) = output.clone().quantized_handles().unwrap(); let dtype = tensor.dtype; cubek::quantization::quantize::launch_ref( &output.client, tensor.binding(), out_values.binding(), scale.binding(), out_params.binding(), scheme, dtype.into(), ) .expect("Kernel to never fail"); output } ================================================ FILE: crates/burn-cubecl/src/kernel/reduce/base.rs ================================================ #[cfg(feature = "autotune")] use super::{autotune_reduce, autotune_sum}; use crate::{ CubeRuntime, ops::numeric::{empty_device_contiguous_dtype, zeros_client}, tensor::CubeTensor, }; use burn_backend::{DType, TensorMetadata}; use burn_std::Metadata; use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType}; use cubek::reduce::{ ReduceDtypes, ReduceError, ReduceStrategy, components::instructions::ReduceOperationConfig, launch::{RoutineStrategy, VectorizationStrategy}, routines::{BlueprintStrategy, unit::UnitStrategy}, shared_sum, }; use serde::{Deserialize, Serialize}; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] /// Autotune key representative of sum versions pub struct SumAutotuneKey { /// The type of the tensor dtype: burn_backend::DType, /// The anchored length of the tensor #[autotune(anchor)] length: usize, } /// Check if the client supports atomic add for the given element type. fn supports_atomic_add(client: &ComputeClient, dtype: DType) -> bool { client .properties() .type_usage(StorageType::Atomic(dtype.into())) .contains(TypeUsage::AtomicAdd) } /// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`. pub fn sum_fallback( tensor: CubeTensor, mut strategy: SumStrategy, ) -> Result, ReduceError> { // Early check before creating output and fallback if matches!(strategy, SumStrategy::OneShot(_)) && !supports_atomic_add(&tensor.client, tensor.dtype) { strategy = SumStrategy::Chained(Default::default()); } sum(tensor, strategy) } /// Specialize reduce function to compute the sum of all elements of the `input` tensor and return /// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`. /// /// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction. /// /// Return an error if the `client` doesn't support atomic add for the type `E`. pub fn sum( tensor: CubeTensor, strategy: SumStrategy, ) -> Result, ReduceError> { let client = tensor.client.clone(); let device = tensor.device.clone(); match strategy { SumStrategy::OneShot(cube_count) => { let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype); let dtype = tensor.dtype; shared_sum::( &client, tensor.binding(), output.clone().binding(), cube_count, dtype.into(), )?; Ok(output) } SumStrategy::Chained(strategy) => { reduce::(tensor, None, strategy, ReduceOperationConfig::Sum) } #[cfg(feature = "autotune")] SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), } } /// Select a strategy to perform a sum. pub enum SumStrategy { /// Run a single kernel with many cubes working in parallel to sum all elements. /// The provided value is the number of elements summed per unit (up-to-rounding ) OneShot(u32), /// Use multiple kernels Chained(KernelReduceStrategy), /// Use autotune to find the best cube count given the hardware and the input. #[cfg(feature = "autotune")] Autotune, } impl Default for SumStrategy { fn default() -> Self { #[cfg(feature = "autotune")] return Self::Autotune; #[cfg(not(feature = "autotune"))] return Self::OneShot(4); } } /// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. pub fn reduce( mut tensor: CubeTensor, output_dtype: Option, strategy: KernelReduceStrategy, config: ReduceOperationConfig, ) -> Result, cubek::reduce::ReduceError> { // In practice, it looks like starting by the axis with the smallest shape // and going in increasing order lead to the fastest calculation. let sorted_axis = argsort(tensor.meta.shape()); for axis in sorted_axis { tensor = reduce_dim::(tensor, output_dtype, axis, strategy.clone(), config)?; } // reshape to scalar tensor *tensor.meta = Metadata::new([1], [1]); Ok(tensor) } fn argsort(shape: &[usize]) -> Vec { let mut indices = (0..shape.len()).collect::>(); indices.sort_by_key(|&i| &shape[i]); indices } /// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. /// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. pub fn reduce_dim( input: CubeTensor, output_dtype: Option, dim: usize, strategy: KernelReduceStrategy, config: ReduceOperationConfig, ) -> Result, cubek::reduce::ReduceError> { debug_assert!( !matches!( config, ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin ) || output_dtype.is_some(), "The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`. " ); let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into)); let client = input.client.clone(); let output = init_reduce_output::(&input, dim, &dtypes).ok_or( cubek::reduce::ReduceError::InvalidAxis { axis: dim, rank: input.meta.num_dims(), }, )?; let result = match strategy { KernelReduceStrategy::Unspecified => cubek::reduce::reduce::( &client, input.binding(), output.clone().binding(), dim, ReduceStrategy { routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)), vectorization: VectorizationStrategy { parallel_output_vectorization: false, }, }, config, dtypes, ), KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::( &client, input.binding(), output.clone().binding(), dim, strategy, config, dtypes, ), #[cfg(feature = "autotune")] KernelReduceStrategy::Autotune => { autotune_reduce::(&client, input, output.clone(), dim, config, dtypes); Ok(()) } }; result.map(|_| output) } /// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input` /// or return `None` if `axis` is out-of-bound. pub fn init_reduce_output( input: &CubeTensor, dim: usize, dtypes: &ReduceDtypes, ) -> Option> { (dim < input.meta.num_dims()).then(|| { let mut shape_out = input.shape(); shape_out[dim] = 1; empty_device_contiguous_dtype( input.client.clone(), input.device.clone(), shape_out, dtypes.output.elem_type().into(), ) }) } /// Select a strategy to perform a reduction. #[derive(Clone, Debug)] pub enum KernelReduceStrategy { /// Use a best-effort strategy based on the hardware capacity. /// This differs from Autotune as it doesn't try and compare many strategies to select the best. Unspecified, /// Fix the exact strategy for the reduction. Specific(cubek::reduce::launch::ReduceStrategy), /// Use autotune to find the best strategy given the hardware and the inputs. #[cfg(feature = "autotune")] Autotune, } impl Default for KernelReduceStrategy { fn default() -> Self { #[cfg(feature = "autotune")] return Self::Autotune; #[cfg(not(feature = "autotune"))] return Self::Unspecified; } } ================================================ FILE: crates/burn-cubecl/src/kernel/reduce/mod.rs ================================================ mod base; #[cfg(feature = "autotune")] mod tune; pub use base::*; #[cfg(feature = "autotune")] pub use tune::*; ================================================ FILE: crates/burn-cubecl/src/kernel/reduce/tune.rs ================================================ #![allow(missing_docs)] use super::SumAutotuneKey; use crate::{CubeAutotuneKey, CubeRuntime, CubeTuneId, tensor::CubeTensor}; use cubecl::{ client::ComputeClient, tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}, }; use cubek::reduce::{ ReduceDtypes, ReduceStrategy, components::instructions::ReduceOperationConfig, launch::{RoutineStrategy, VectorizationStrategy, tune_key::ReduceAutotuneKey}, routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy}, }; /// Executes autotune on reduce operations. pub fn autotune_reduce( client: &ComputeClient, input: CubeTensor, output: CubeTensor, axis: usize, config: ReduceOperationConfig, dtypes: ReduceDtypes, ) { use reduce_ops::*; static TUNER: LocalTuner = local_tuner!("reduce-dim"); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 2; const PRIORITY_MIN: i8 = 1; const PRIORITY_SKIP: i8 = -1; let mut set = TunableSet::new(create_key::, reduce_input_gen::); let default_group = TuneGroup::::new("default_reduce", |_key| PRIORITY_MAX); let vectorized_parallel_group = TuneGroup::::new("vectorized_parallel_reduce", |key| { if key.axis_is_contiguous { PRIORITY_MAX } else { // We disable the tunable with the setting [vector_size.parallel_output_vectorization] // when the reduce isn't parallel, since it would duplicate tunables. PRIORITY_SKIP } }); enum ReduceProps { GreatWithLowReduceCount, GreatWithHighReduceCount, Balanced, } for (vectorization, vector_size_ident) in [ ( VectorizationStrategy { parallel_output_vectorization: true, }, "_vectorized_parallel_reduce", ), ( VectorizationStrategy { parallel_output_vectorization: false, }, "", ), ] { for (name, routine, props) in [ ( "unit", RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)), ReduceProps::GreatWithHighReduceCount, ), ( "plane", RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy { independent: true, })), ReduceProps::Balanced, ), ( "cube", RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy { use_planes: true, })), ReduceProps::GreatWithLowReduceCount, ), ] { let name = format!("{name}{vector_size_ident}"); let mut tunable = Tunable::new( name, move |(input, output, axis, config, dtypes): ( CubeTensor, CubeTensor, usize, ReduceOperationConfig, ReduceDtypes, )| { let strategy = ReduceStrategy { routine: routine.clone(), vectorization, }; cubek::reduce::reduce::( &output.client, input.binding(), output.clone().binding(), axis, strategy, config, dtypes, ) .map_err(|e| format!("{e}")) }, ); if vectorization.parallel_output_vectorization { tunable = tunable.group(&vectorized_parallel_group, |_| PRIORITY_MAX); } tunable = tunable.group(&default_group, move |key| match props { ReduceProps::GreatWithLowReduceCount => { if key.vector_count < 128 { PRIORITY_MAX } else { // When you have a high level of vector to reduce, it is normally // better to use another routine. PRIORITY_MIN } } ReduceProps::GreatWithHighReduceCount => { if key.vector_count > 64 { PRIORITY_MAX } else { // Bellow 64 it is normally better to use another routine PRIORITY_MIN } } ReduceProps::Balanced => PRIORITY_MAX, }); set = set.with(tunable); } } set }); TUNER.execute( &CubeTuneId::new(&input.client, &input.device), client, tunables, (input, output, axis, config, dtypes), ); } pub(crate) fn create_key( input: &CubeTensor, output: &CubeTensor, axis: &usize, _config: &ReduceOperationConfig, dtypes: &ReduceDtypes, ) -> ReduceAutotuneKey { let elem_input = input.dtype.into(); let elem_output = output.dtype.into(); let elem_acc = dtypes.accumulation.elem_type(); ReduceAutotuneKey::generate( elem_input, elem_output, elem_acc, input.meta.shape(), input.meta.strides()[*axis] == 1, *axis, ) } mod reduce_ops { #![allow(missing_docs)] use cubek::reduce::ReduceDtypes; use super::*; pub(crate) fn reduce_input_gen( _key: &ReduceAutotuneKey, input: &CubeTensor, output: &CubeTensor, dim: &usize, config: &ReduceOperationConfig, dtypes: &ReduceDtypes, ) -> ( CubeTensor, CubeTensor, usize, ReduceOperationConfig, ReduceDtypes, ) { (input.clone(), output.copy(), *dim, *config, *dtypes) } } /// Executes autotune on reduce operations. #[cfg(feature = "autotune")] pub fn autotune_sum( client: &ComputeClient, input: CubeTensor, ) -> CubeTensor { use sum_ops::*; static TUNER: LocalTuner = local_tuner!("autotune-sum"); let tunables = TUNER.init(|| { TunableSet::new(create_key_sum::, sum_input_gen::) .with(Tunable::new("sum_chained", sum_chained::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) .with(Tunable::new("sum_one_shot", sum_one_shot::)) }); TUNER.execute( &CubeTuneId::new(&input.client, &input.device), client, tunables, input, ) } pub(crate) fn create_key_sum(input: &CubeTensor) -> CubeAutotuneKey { CubeAutotuneKey::Sum(SumAutotuneKey::generate(input)) } impl SumAutotuneKey { #[allow(unused)] pub(crate) fn generate(input: &CubeTensor) -> Self { let dtype = input.dtype; let length = input.meta.num_elements(); Self::new(dtype, length) } } mod sum_ops { #![allow(missing_docs)] use crate::ops::numeric::zeros_client; use super::*; pub(crate) fn sum_input_gen( _key: &CubeAutotuneKey, input: &CubeTensor, ) -> CubeTensor { input.clone() } pub(crate) fn sum_one_shot( input: CubeTensor, ) -> Result, String> { let client = input.client.clone(); let device = input.device.clone(); let output = zeros_client(client.clone(), device, [1].into(), input.dtype); let dtype = input.dtype; cubek::reduce::shared_sum::( &output.client, input.binding(), output.clone().binding(), C, dtype.into(), ) .map_err(|e| e.to_string()) .map(|_| output) } #[cfg(feature = "autotune")] pub(crate) fn sum_chained( input: CubeTensor, ) -> Result, String> { crate::kernel::reduce::reduce::( input, None, crate::kernel::reduce::KernelReduceStrategy::Autotune, cubek::reduce::components::instructions::ReduceOperationConfig::Sum, ) .map_err(|e| e.to_string()) } } ================================================ FILE: crates/burn-cubecl/src/kernel/unary_float.rs ================================================ use crate::{ CubeRuntime, kernel::utils::address_type, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync { type Options: LaunchArg; type Unary: FloatUnaryOp; } #[cube] pub(crate) trait FloatUnaryOp: 'static + Send + Sync { type Options: LaunchArg; fn execute(input: Vector, options: &Self::Options) -> Vector; } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn unary_float( input: &LinearView>, output: &mut LinearView, ReadWrite>, options: &O::Options, #[define(F)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = O::Unary::::execute(input[ABSOLUTE_POS], options); } pub(crate) fn launch_unary_float(tensor: CubeTensor, args: Args) -> CubeTensor where // Magic fix for lifetime, the closure is supposed to capture everything required to create the // argument. for<'a> Args: FnOnce(&'a ()) -> RuntimeArg, R: CubeRuntime, O: FloatUnaryOpFamily, { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { if tensor.can_mut() && tensor.is_nonoverlapping() { unary_float::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), tensor.as_linear_view_alias(0), args(&()), dtype.into(), ); tensor } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), tensor.dtype, ); unary_float::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), output.clone().into_linear_view(), args(&()), dtype.into(), ); output } } } /// Use comptime enum to implement all unary operations that don't have any input argument in the /// kernel definition. pub(crate) mod unary_basic { use cubecl::num_traits::{One, Zero}; use super::*; pub(crate) fn launch(tensor: CubeTensor, args: Args) -> CubeTensor where R: CubeRuntime, for<'a> Args: FnOnce(&'a ()) -> BasicFloatUnaryKind, { launch_unary_float::(tensor, |input| { BasicFloatUnaryOptionsLaunch::new(args(input)) }) } #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub enum BasicFloatUnaryKind { Exp, Log, Log1p, Sqrt, Abs, Sign, ArcCos, ArcCosh, ArcSin, ArcSinh, ArcTan, ArcTanh, Cos, Cosh, Sin, Sinh, Tan, Tanh, Round, Floor, Ceil, Trunc, Erf, Recip, } #[derive(CubeLaunch, CubeType)] struct BasicFloatUnaryOptions { #[cube(comptime)] kind: BasicFloatUnaryKind, } struct BasicFloatUnary; #[cube] impl FloatUnaryOp for BasicFloatUnary { type Options = BasicFloatUnaryOptions; fn execute(input: Vector, options: &Self::Options) -> Vector { match comptime![options.kind] { BasicFloatUnaryKind::Exp => Vector::exp(input), BasicFloatUnaryKind::Log => Vector::ln(input), BasicFloatUnaryKind::Log1p => Vector::log1p(input), BasicFloatUnaryKind::Sqrt => Vector::sqrt(input), BasicFloatUnaryKind::Abs => Vector::abs(input), BasicFloatUnaryKind::Sign => { let zero = Vector::zero(); let one = Vector::one(); let minus_one = Vector::new(F::new(-1.0)); let is_positive = input.greater_than(zero); let is_negative = input.less_than(zero); let sign = select_many(is_negative, minus_one, zero); select_many(is_positive, one, sign) } BasicFloatUnaryKind::Cos => Vector::cos(input), BasicFloatUnaryKind::Sin => Vector::sin(input), BasicFloatUnaryKind::Tan => Vector::tan(input), BasicFloatUnaryKind::Cosh => Vector::cosh(input), BasicFloatUnaryKind::Sinh => Vector::sinh(input), BasicFloatUnaryKind::Tanh => Vector::tanh(input), BasicFloatUnaryKind::Round => Vector::round(input), BasicFloatUnaryKind::Floor => Vector::floor(input), BasicFloatUnaryKind::Ceil => Vector::ceil(input), BasicFloatUnaryKind::Trunc => Vector::trunc(input), BasicFloatUnaryKind::Erf => Vector::erf(input), BasicFloatUnaryKind::Recip => Vector::recip(input), BasicFloatUnaryKind::ArcCos => Vector::acos(input), BasicFloatUnaryKind::ArcCosh => Vector::acosh(input), BasicFloatUnaryKind::ArcSin => Vector::asin(input), BasicFloatUnaryKind::ArcSinh => Vector::asinh(input), BasicFloatUnaryKind::ArcTan => Vector::atan(input), BasicFloatUnaryKind::ArcTanh => Vector::atanh(input), } } } impl FloatUnaryOpFamily for BasicFloatUnary { type Options = BasicFloatUnaryOptions; type Unary = Self; } } ================================================ FILE: crates/burn-cubecl/src/kernel/unary_int.rs ================================================ use crate::{ CubeRuntime, kernel::utils::address_type, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { type Options: LaunchArg; type Unary: IntUnaryOp; } #[cube] pub(crate) trait IntUnaryOp: 'static + Send + Sync { type Options: LaunchArg; fn execute(input: Vector, options: &Self::Options) -> Vector; } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn unary_int( input: &LinearView>, output: &mut LinearView, ReadWrite>, options: &O::Options, #[define(I)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = O::Unary::::execute(input[ABSOLUTE_POS], options); } pub(crate) fn launch_unary_int(tensor: CubeTensor, args: Args) -> CubeTensor where for<'a> Args: FnOnce(&'a ()) -> RuntimeArg, R: CubeRuntime, O: IntUnaryOpFamily, { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { if tensor.can_mut() && tensor.is_nonoverlapping() { unary_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), tensor.as_linear_view_alias(0), args(&()), dtype.into(), ); tensor } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), tensor.dtype, ); unary_int::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), output.clone().into_linear_view(), args(&()), dtype.into(), ); output } } } pub(crate) mod unary_basic_int { use cubecl::num_traits::{One, Zero}; use super::*; pub(crate) fn launch(tensor: CubeTensor, args: Args) -> CubeTensor where R: CubeRuntime, for<'a> Args: FnOnce(&'a ()) -> BasicIntUnaryKind, { launch_unary_int::(tensor, |input| { BasicIntUnaryOptionsLaunch::new(args(input)) }) } #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub enum BasicIntUnaryKind { BitwiseNot, Sign, } #[derive(CubeLaunch, CubeType)] struct BasicIntUnaryOptions { #[cube(comptime)] kind: BasicIntUnaryKind, } struct BasicIntUnary; #[cube] impl IntUnaryOp for BasicIntUnary { type Options = BasicIntUnaryOptions; fn execute(input: Vector, options: &Self::Options) -> Vector { match comptime![options.kind] { BasicIntUnaryKind::BitwiseNot => !input, BasicIntUnaryKind::Sign => { let zero = Vector::zero(); let one = Vector::one(); let minus_one = Vector::new(I::new(-1)); let is_positive = input.greater_than(zero); let is_negative = input.less_than(zero); let sign = select_many(is_negative, minus_one, zero); select_many(is_positive, one, sign) } } } } impl IntUnaryOpFamily for BasicIntUnary { type Options = BasicIntUnaryOptions; type Unary = Self; } } ================================================ FILE: crates/burn-cubecl/src/kernel/unary_numeric.rs ================================================ use crate::{ CubeRuntime, kernel::utils::address_type, ops::{max_vector_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync { type Options: LaunchArg; type Unary: NumericUnaryOp; } #[cube] pub(crate) trait NumericUnaryOp: 'static + Send + Sync { type Options: LaunchArg; fn execute(input: Vector, options: &Self::Options) -> Vector; } #[cube(launch_unchecked, address_type = "dynamic")] pub(crate) fn unary_numeric( input: &LinearView>, output: &mut LinearView, ReadWrite>, options: &O::Options, #[define(T)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } output[ABSOLUTE_POS] = O::Unary::::execute(input[ABSOLUTE_POS], options); } pub(crate) fn launch_unary_numeric(tensor: CubeTensor, args: Args) -> CubeTensor where // Magic fix for lifetime, the closure is supposed to capture everything required to create the // argument. for<'a> Args: FnOnce(&'a ()) -> RuntimeArg, R: CubeRuntime, O: NumericUnaryOpFamily, { let vector_size = max_vector_size(&tensor); let client = tensor.client.clone(); let num_elems = tensor.meta.num_elements(); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); let dtype = tensor.dtype; unsafe { if tensor.can_mut() && tensor.is_nonoverlapping() { unary_numeric::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor), vector_size, tensor.clone().into_linear_view(), tensor.as_linear_view_alias(0), args(&()), dtype.into(), ); tensor } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), tensor.shape(), tensor.dtype, ); unary_numeric::launch_unchecked::( &client, cube_count, cube_dim, address_type!(tensor, output), vector_size, tensor.into_linear_view(), output.clone().into_linear_view(), args(&()), dtype.into(), ); output } } } ================================================ FILE: crates/burn-cubecl/src/kernel/utils.rs ================================================ use burn_backend::Shape; use cubecl::prelude::SequenceArg; use cubecl::{ ir::{UIntKind, VectorSize}, prelude::*, std::{ FastDivmod, FastDivmodInt, tensor::layout::linear::{LinearLayoutLaunch, LinearViewLayoutLaunch}, }, }; use crate::{CubeRuntime, tensor::CubeTensor}; pub fn shape_divmod(tensor: &CubeTensor) -> SequenceArg> { let mut arg = SequenceArg::new(); for dim in tensor.meta.shape().iter() { arg.push(*dim); } arg } pub fn linear_layout( tensor: &CubeTensor, vector_size: VectorSize, ) -> LinearLayoutLaunch { LinearLayoutLaunch::from_shape_strides( tensor.meta.shape().clone(), tensor.meta.strides().clone(), // Don't care about type size, only vector size Type::new(UIntKind::U32.into()).with_vector_size(vector_size), LinearViewLayoutLaunch::new(), ) } pub fn split_dim( mut tensor: CubeTensor, dim: usize, shape: &[usize], ) -> CubeTensor { let mut stride = tensor.meta.strides()[dim]; tensor.meta.remove(dim); for size in shape.iter().rev() { tensor.meta.insert(dim, *size, stride); stride *= size; } tensor } pub fn broadcast_shape(tensors: &[&CubeTensor]) -> Shape { let rank = tensors[0].meta.num_dims(); debug_assert!( tensors.iter().all(|it| it.meta.num_dims() == rank), "Broadcast tensors must have the same rank" ); let dims = (0..rank).map(|dim| { let max = tensors.iter().map(|it| it.meta.shape()[dim]).max(); let max = max.unwrap_or(1); debug_assert!( tensors .iter() .all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1), "Broadcast dims must be size 1" ); max }); Shape::from(dims) } pub fn broadcast_strides( reference: &CubeTensor, tensor: &CubeTensor, ) -> SequenceArg { if reference.meta.shape() != tensor.meta.shape() { tensor .meta .strides() .iter() .zip( tensor .meta .shape() .iter() .zip(reference.meta.shape().iter()), ) .map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 }) .collect() } else { tensor.meta.strides().iter().copied().collect() } } #[cube] pub(crate) fn decompose_linear( pos: I, shape: &Sequence>, ) -> (I, Sequence) { let rank = comptime![shape.len()]; let mut offs = pos; let mut out = Sequence::new(); #[unroll] for i in 0..rank { let dim = comptime![rank - i - 1]; let (rem, offs_local) = shape.index(dim).div_mod(offs); out.push(offs_local); offs = rem; } (offs, out.rev()) } pub(crate) trait RequiredAddrType { fn required_address_type(&self) -> AddressType; } impl RequiredAddrType for CubeTensor { fn required_address_type(&self) -> AddressType { self.required_address_type() } } impl RequiredAddrType for Option> { fn required_address_type(&self) -> AddressType { self.as_ref() .map(|it| it.required_address_type()) .unwrap_or_default() } } macro_rules! address_type { ($($tensor: tt),*) => { [$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*] .into_iter() .max() .unwrap_or_default() }; } pub(crate) use address_type; ================================================ FILE: crates/burn-cubecl/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! Burn JIT Backend #[macro_use] extern crate derive_new; extern crate alloc; /// Utilities for implementing JIT kernels pub mod ops; /// Kernel module pub mod kernel; /// Tensor module. pub mod tensor; /// Elements for JIT backend pub mod element; use cubecl::{CubeTask, Runtime}; pub use element::{BoolElement, CubeElement, FloatElement, IntElement}; mod backend; pub use backend::*; // Re-export cubecl. pub use cubecl; mod tune_key; pub use tune_key::CubeAutotuneKey; #[cfg(any(feature = "fusion", test))] /// Module for interacting with fusion pub mod fusion; #[cfg(feature = "template")] /// Module for compiling custom non-jit kernels pub mod template; /// Just-in-Time runtime extending the [cube runtime](Runtime). pub trait CubeRuntime: Runtime { /// The device that should also implement [burn_backend::backend::DeviceOps]. type CubeDevice: burn_backend::DeviceOps; /// The cube server with the [CubeAutotuneKey]. type CubeServer: cubecl::server::ComputeServer>>; } pub use cubecl::CubeTuneId; ================================================ FILE: crates/burn-cubecl/src/ops/activation.rs ================================================ use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement}; use burn_backend::ops::ActivationOps; impl ActivationOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { } ================================================ FILE: crates/burn-cubecl/src/ops/base.rs ================================================ use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor}; use burn_backend::{ DType, ExecutionError, QTensorPrimitive, Shape, TensorData, quantization::{QuantLevel, QuantStore, params_shape}, }; use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape}; use burn_std::{ Metadata, strides, tensor::{ReshapeAction, contiguous_strides, reshape_action}, }; use cubecl::{ir::VectorSize, server::CopyDescriptor}; use cubecl::{quant::scheme::BlockSize, tensor_vector_size_parallel}; pub(crate) fn from_data(data: TensorData, device: &R::Device) -> CubeTensor { let client = R::client(device); let alloc = client.create_tensor(data.bytes, data.shape.clone(), data.dtype.size()); let shape: Shape = (&data.shape).into(); CubeTensor::new( client, alloc.memory, Metadata::new(shape, alloc.strides), device.clone(), data.dtype, ) } pub(crate) async fn into_data( tensor: CubeTensor, ) -> Result { let tensor = kernel::into_contiguous_aligned(tensor); let elem_size = tensor.elem_size(); let shape = tensor.meta.shape().clone(); let strides = tensor.meta.strides().clone(); let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size); let bytes = tensor .client .read_one_tensor_async(binding) .await .map_err(|err| ExecutionError::WithContext { reason: format!("{err}"), })?; Ok(TensorData::from_bytes( bytes, tensor.meta.shape.clone(), tensor.dtype, )) } /// Read data from a `CubeTensor` synchronously #[allow(unused, reason = "useful for debugging kernels")] pub fn into_data_sync(tensor: CubeTensor) -> TensorData { burn_std::future::block_on(into_data(tensor)).unwrap() } #[cfg_attr( feature = "tracing", tracing::instrument(level = "trace", skip(tensor, device)) )] pub(crate) fn to_device( tensor: CubeTensor, device: &R::Device, ) -> CubeTensor { if &tensor.device == device { return tensor; } let tensor = kernel::into_contiguous_aligned(tensor); let client = R::client(device); tensor.to_client(client, device.clone()) } pub(crate) fn empty( shape: Shape, device: &R::Device, dtype: DType, ) -> CubeTensor { let client = R::client(device); let alloc = client.empty_tensor(shape.clone(), dtype.size()); CubeTensor::new( client, alloc.memory, Metadata::new(shape, alloc.strides), device.clone(), dtype, ) } pub(crate) fn swap_dims( mut tensor: CubeTensor, dim1: usize, dim2: usize, ) -> CubeTensor { tensor.meta.swap(dim1, dim2); if let DType::QFloat(scheme) = tensor.dtype && let QuantLevel::Block(block_size) = scheme.level { let rank = tensor.rank(); let qparams = tensor.qparams.as_mut().unwrap(); let mut block_size = block_size.to_dim_vec(rank); block_size.swap(dim1, dim2); // Truncate unit dims from the start let block_size = BlockSize::new_trim(block_size); if block_size.len() > BlockSize::MAX_DIMS { panic!("Swapped block size would exceed max dims"); } qparams.scales.metadata.swap(dim1, dim2); tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size))) } if let DType::QFloat(scheme) = &mut tensor.dtype && let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) = &mut scheme.store { let rank = tensor.meta.num_dims(); if *packed_dim == rank - dim1 - 1 { *packed_dim = rank - dim2 - 1; } else if *packed_dim == rank - dim2 - 1 { *packed_dim = rank - dim1 - 1; } } tensor } /// Permute a tensor's dimensions pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> CubeTensor { tensor.meta.permute(axes).unwrap(); if let DType::QFloat(scheme) = tensor.dtype && let QuantLevel::Block(block_size) = scheme.level { let rank = tensor.rank(); let qparams = tensor.qparams.as_mut().unwrap(); let mut block_size = block_size.to_dim_vec(rank); block_size = axes.iter().map(|i| block_size[*i]).collect(); // Truncate unit dims from the start let block_size = block_size .into_iter() .skip_while(|it| *it == 1) .collect::>(); if block_size.len() > BlockSize::MAX_DIMS { panic!("Swapped block size would exceed max dims"); } qparams.scales.metadata.permute(axes).unwrap(); tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size))) } if let DType::QFloat(scheme) = &mut tensor.dtype && let QuantStore::PackedU32(packed_dim) = &mut scheme.store { let rank = tensor.meta.num_dims(); let new_pos = axes .iter() .position(|axis| *axis == rank - *packed_dim - 1) .unwrap_or(0); *packed_dim = rank - new_pos - 1; } tensor } /// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent pub fn permute_nchw_to_nhwc(tensor: CubeTensor) -> CubeTensor { let rank = tensor.meta.num_dims(); let c_dim = 1; let mut dims = vec![0]; dims.extend(2..rank); dims.push(c_dim); permute(tensor, &dims) } /// Permute a shape's dimensions from NCHW to NHWC, or the N-dimensional equivalent pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape { let rank = shape.num_dims(); let c_dim = 1; let mut dims = vec![0]; dims.extend(2..rank); dims.push(c_dim); shape.permuted(&dims).expect("Shape permute should succeed") } /// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent pub fn permute_nhwc_to_nchw(tensor: CubeTensor) -> CubeTensor { let rank = tensor.meta.num_dims(); let c_dim = rank - 1; let mut dims = vec![0]; dims.push(c_dim); dims.extend(1..c_dim); permute(tensor, &dims) } /// Permute a shape's dimensions from NHWC to NCHW, or the N-dimensional equivalent pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape { let rank = shape.num_dims(); let c_dim = rank - 1; let mut dims = vec![0]; dims.push(c_dim); dims.extend(1..c_dim); shape.permuted(&dims).expect("Shape permute should succeed") } pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) -> CubeTensor { let ndims_in = tensor.meta.shape().num_dims(); let ndims_out = target_shape.num_dims(); // Initialize new strides with zeros let mut new_strides = strides![0usize; ndims_out]; // Calculate the difference in dimensions let dim_diff = ndims_out.saturating_sub(ndims_in); // Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones let mut tensor_dim_iter = tensor.meta.shape().iter().rev(); for i in (0..ndims_out).rev() { if i >= dim_diff { if let Some(&tensor_dim) = tensor_dim_iter.next() { if tensor_dim == target_shape[i] || tensor_dim == 1 { // Copy stride for non-broadcast dimensions or set to 0 for broadcast ones new_strides[i] = if tensor_dim == target_shape[i] { tensor.meta.strides()[i - dim_diff] } else { 0 }; } else { // Error handling: Dimension mismatch for broadcasting panic!( "Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape" ); } } else { // If the input tensor has fewer dimensions, treat missing dimensions as 1 // and set stride to 0 (broadcasting) new_strides[i] = 0; } } else { // For extra dimensions in the target shape, set stride to 0 (broadcasting) new_strides[i] = 0; } } // Extra check to ensure block scales must be properly handled once they're added if tensor.qparams.is_some() { match tensor.scheme().level { QuantLevel::Tensor => {} QuantLevel::Block(_) => todo!(), } } CubeTensor { client: tensor.client.clone(), device: tensor.device.clone(), meta: Box::new(Metadata::new(target_shape, new_strides)), handle: tensor.handle.clone(), dtype: tensor.dtype, qparams: tensor.qparams.clone(), } } /// Reshape a jit tensor to a new shape pub fn reshape(mut tensor: CubeTensor, shape: Shape) -> CubeTensor { let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape); match analysis { ReshapeAction::UpdateStrides { strides } => { *tensor.meta = Metadata::new(shape, strides); return tensor; } ReshapeAction::NoChange => return tensor, ReshapeAction::Recompute => (), } let out = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), shape, tensor.dtype, ); cubecl::std::tensor::copy_into( &out.client, tensor.binding(), out.clone().binding(), out.dtype.into(), ); out } /// Reshape a jit tensor to a new shape pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> CubeTensor { let scheme = *tensor.scheme(); let shape_values = { let rank = shape.num_dims(); let mut shape = shape.clone(); shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants()); shape }; let shape_scales = params_shape(&shape, scheme.level); let (values, scales) = tensor.quantized_handles().unwrap(); let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values); let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales); match (analysis_values, analysis_scales) { ( ReshapeAction::UpdateStrides { strides }, ReshapeAction::UpdateStrides { strides: scales_strides, }, ) => { let qparams = tensor.qparams.as_mut().unwrap(); *tensor.meta = Metadata::new(shape, strides); qparams.scales.metadata = Metadata::new(shape_scales, scales_strides); } (ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => { *tensor.meta = Metadata::new(shape, strides); } ( ReshapeAction::NoChange, ReshapeAction::UpdateStrides { strides: scales_strides, }, ) => { let qparams = tensor.qparams.as_mut().unwrap(); qparams.scales.metadata = Metadata::new(shape_scales, scales_strides); } (ReshapeAction::NoChange, ReshapeAction::NoChange) => {} _ => { tensor = kernel::into_contiguous(tensor); *tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values)); let qparams = tensor.qparams.as_mut().unwrap(); let strides = contiguous_strides(&shape_scales); qparams.scales.metadata = Metadata::new(shape_scales, strides); } } tensor } pub(crate) fn max_vector_size(tensor: &CubeTensor) -> VectorSize { tensor_vector_size_parallel( tensor.client.io_optimized_vector_sizes(tensor.dtype.size()), tensor.meta.shape(), tensor.meta.strides(), tensor.meta.num_dims() - 1, ) } pub(crate) fn max_vector_size_many( tensors: &[&CubeTensor], axis: usize, ) -> VectorSize { let vec = tensors .iter() .map(|tensor| { tensor_vector_size_parallel( tensor.client.io_optimized_vector_sizes(tensor.dtype.size()), tensor.meta.shape(), tensor.meta.strides(), axis, ) }) .min(); vec.unwrap_or(0) } /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// The new view will have the unfolded dimension replaced by two dimensions; /// one in the position of the original dimension, with size equal to the number of windows, /// and one appended to the right-most position, with size equal to `size`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with the shape ``[pre=..., windows, post=..., size]``. pub fn unfold( tensor: CubeTensor, dim: usize, size: usize, step: usize, ) -> CubeTensor { let shape = calculate_unfold_shape(tensor.shape(), dim, size, step); let d_stride = tensor.meta.strides()[dim]; let mut strides = tensor.meta.strides.clone(); strides[dim] = step * d_stride; strides.push(d_stride); CubeTensor { meta: Box::new(Metadata::new(shape, strides)), client: tensor.client.clone(), handle: tensor.handle.clone(), device: tensor.device.clone(), dtype: tensor.dtype, qparams: tensor.qparams.clone(), } } ================================================ FILE: crates/burn-cubecl/src/ops/bool_tensor.rs ================================================ use crate::{ CubeBackend, CubeRuntime, FloatElement, IntElement, element::{BoolElement, bool_dtype}, kernel::{self, AndOp, OrOp}, }; use burn_backend::{ ExecutionError, Slice, ops::BoolTensorOps, tensor::{BoolTensor, Device, FloatTensor, IntTensor}, }; use burn_backend::{Scalar, Shape, TensorData}; use burn_std::{BoolStore, DType}; use cubecl::prelude::InputScalar; use std::ops::Range; use super::{expand, numeric, permute, unfold}; impl BoolTensorOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { super::empty(shape, device, bool_dtype::()) } fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor { numeric::zeros(device.clone(), shape, bool_dtype::()) } fn bool_ones(shape: Shape, device: &Device) -> BoolTensor { numeric::ones(device.clone(), shape, bool_dtype::()) } async fn bool_into_data(tensor: BoolTensor) -> Result { super::into_data(tensor).await } fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let bool_dtype = bool_dtype::(); // TODO: remove once backends no longer rely on generics for default elem types let data = match (data.dtype, bool_dtype) { (DType::U8, DType::Bool(BoolStore::U8)) | (DType::U32, DType::Bool(BoolStore::U32)) => { // No-op, but change dtype to bool w/ storage type data.convert_dtype(bool_dtype) } (DType::U8, DType::U8) | (DType::U32, DType::U32) => data, other => unimplemented!("Unsupported dtype for `bool_from_data` {other:?}"), }; super::from_data(data, device) } fn bool_into_int(tensor: BoolTensor) -> IntTensor { kernel::bool_cast::(tensor) } fn bool_device(tensor: &BoolTensor) -> Device { tensor.device.clone() } fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { super::to_device(tensor, device) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { super::reshape(tensor, shape) } fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor { // Check if all steps are 1 let all_steps_one = slices.iter().all(|info| info.step == 1); if all_steps_one { // Use optimized slice for step=1 let simple_ranges: Vec> = slices .iter() .enumerate() .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) } else { // Use slice with steps kernel kernel::slice_with_steps(tensor, slices) } } fn bool_slice_assign( tensor: BoolTensor, ranges: &[Slice], value: BoolTensor, ) -> BoolTensor { kernel::slice_assign(tensor, ranges, value) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { kernel::equal(lhs, rhs, bool_dtype::()) } fn bool_not(tensor: BoolTensor) -> BoolTensor { kernel::equal_elem( tensor, InputScalar::new(BT::false_val(), bool_dtype::()), bool_dtype::(), ) } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { kernel::launch_binop::(lhs, rhs) } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { kernel::launch_binop::(lhs, rhs) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { kernel::bool_cast::(tensor) } fn bool_swap_dims(mut tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { tensor.meta.swap(dim1, dim2); tensor } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { kernel::repeat_dim(tensor, dim, times) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { permute(tensor, axes) } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { expand(tensor, shape) } fn bool_select( tensor: BoolTensor, dim: usize, indices: IntTensor, ) -> BoolTensor { kernel::select(tensor, dim, indices) } fn bool_select_or( tensor: BoolTensor, dim: usize, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { kernel::select_assign(tensor, dim, indices, value, true) } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { kernel::flip(tensor, axes, bool_dtype::()) } fn bool_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { unfold(tensor, dim, size, step) } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { kernel::mask_where_auto(tensor, mask, value, bool_dtype::()) } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { let dtype = tensor.dtype; kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype) } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { kernel::gather(dim, tensor, indices) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { kernel::scatter(dim, tensor, indices, value, true) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype) } } ================================================ FILE: crates/burn-cubecl/src/ops/int_tensor.rs ================================================ use self::unary_basic_int::BasicIntUnaryKind; use super::{expand, numeric, permute, unfold}; use crate::element::bool_dtype; use crate::kernel::{ BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int, }; use crate::{ CubeBackend, CubeRuntime, FloatElement, IntElement, kernel::{ self, matmul::{MatmulStrategy, matmul}, }, }; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, }; use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps}; use burn_backend::{Distribution, ElementConversion, Shape, TensorData}; use burn_backend::{ExecutionError, Scalar}; use cubecl::frontend::Numeric; use cubecl::prelude::*; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; impl IntTensorOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { fn int_empty(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let dtype = dtype.into(); super::empty(shape, device, dtype) } async fn int_into_data(tensor: IntTensor) -> Result { super::into_data(tensor).await } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { match data.dtype { DType::I64 | DType::I32 | DType::I16 | DType::I8 | DType::U64 | DType::U32 | DType::U16 | DType::U8 => super::from_data(data, device), _ => unimplemented!("Unsupported dtype for `int_from_data`"), } } fn int_device(tensor: &IntTensor) -> Device { tensor.device.clone() } fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { super::to_device(tensor, device) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { super::reshape(tensor, shape) } fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor { // Check if all steps are 1 let all_steps_one = slices.iter().all(|info| info.step == 1); if all_steps_one { // Use optimized slice for step=1 let simple_ranges: Vec> = slices .iter() .enumerate() .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) } else { // Use slice with steps kernel kernel::slice_with_steps(tensor, slices) } } fn int_slice_assign( tensor: IntTensor, ranges: &[Slice], value: IntTensor, ) -> IntTensor { kernel::slice_assign(tensor, ranges, value) } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let dtype = lhs.dtype; matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap() } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> IntTensor { kernel::mask_where_auto(tensor, mask, value, bool_dtype::()) } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> IntTensor { let dtype = tensor.dtype; kernel::mask_fill_auto( tensor, mask, InputScalar::new(value, dtype), bool_dtype::(), ) } fn int_gather( dim: usize, tensor: IntTensor, indices: IntTensor, ) -> IntTensor { kernel::gather(dim, tensor, indices) } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { kernel::scatter(dim, tensor, indices, value, false) } fn int_select( tensor: IntTensor, dim: usize, indices: IntTensor, ) -> IntTensor { kernel::select(tensor, dim, indices) } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { kernel::select_assign(tensor, dim, indices, value, false) } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { kernel::equal(lhs, rhs, bool_dtype::()) } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { kernel::greater(lhs, rhs, bool_dtype::()) } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { kernel::greater_equal(lhs, rhs, bool_dtype::()) } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { kernel::lower(lhs, rhs, bool_dtype::()) } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { kernel::lower_equal(lhs, rhs, bool_dtype::()) } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::add(lhs, rhs) } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::add_scalar(lhs, InputScalar::new(rhs, dtype)) } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::sub(lhs, rhs) } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype)) } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::mul(lhs, rhs) } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype)) } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::div(lhs, rhs) } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::div_scalar(lhs, InputScalar::new(rhs, dtype)) } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::remainder(lhs, rhs) } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype)) } fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let dtype = dtype.into(); numeric::zeros(device.clone(), shape, dtype) } fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let dtype = dtype.into(); numeric::ones(device.clone(), shape, dtype) } fn int_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: IntDType, ) -> IntTensor { let dtype: DType = dtype.into(); let client = R::client(device); numeric::full_device_dtype( client, shape, device.clone(), InputScalar::new(fill_value, dtype), dtype, ) } fn int_sum(tensor: IntTensor) -> IntTensor { reduce::sum_fallback(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Sum, ) .unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { reduce::reduce( tensor, None, Default::default(), ReduceOperationConfig::Prod, ) .unwrap() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Prod, ) .unwrap() } fn int_max(tensor: IntTensor) -> IntTensor { reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap() } fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Max, ) .unwrap() } fn int_max_abs(tensor: IntTensor) -> IntTensor { reduce::reduce( tensor, None, Default::default(), ReduceOperationConfig::MaxAbs, ) .unwrap() } fn int_max_abs_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::MaxAbs, ) .unwrap() } fn int_min(tensor: IntTensor) -> IntTensor { reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap() } fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Min, ) .unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Mean, ) .unwrap() } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { numeric::cumsum(tensor, dim) } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { numeric::cumprod(tensor, dim) } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { numeric::cummin(tensor, dim) } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { numeric::cummax(tensor, dim) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { let dtype = tensor.dtype; reduce::reduce_dim( tensor, Some(dtype), dim, Default::default(), ReduceOperationConfig::ArgMax, ) .unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { let dtype = tensor.dtype; reduce::reduce_dim( tensor, Some(dtype), dim, Default::default(), ReduceOperationConfig::ArgMin, ) .unwrap() } fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { let dtype = tensor.dtype; kernel::clamp( tensor, InputScalar::new(min, dtype), InputScalar::new(max, dtype), ) } fn int_abs(tensor: IntTensor) -> IntTensor { struct Abs; #[cube] impl NumericUnaryOp for Abs { type Options = (); fn execute(input: Vector, _options: &Self::Options) -> Vector { Vector::abs(input) } } impl NumericUnaryOpFamily for Abs { type Options = (); type Unary = Self; } launch_unary_numeric::(tensor, |_| ()) } fn int_sign(tensor: IntTensor) -> IntTensor { unary_basic_int::launch::(tensor, |_| BasicIntUnaryKind::Sign) } fn int_into_float(tensor: IntTensor) -> FloatTensor { kernel::cast(tensor, F::dtype()) } fn int_swap_dims(mut tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { tensor.meta.swap(dim1, dim2); tensor } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { kernel::repeat_dim(tensor, dim, times) } fn int_random( shape: Shape, distribution: Distribution, device: &Device, ) -> IntTensor { let dtype = IntElem::::dtype(); match distribution { Distribution::Default => random_uniform(shape, device, 0., 255., dtype), Distribution::Uniform(low, high) => { random_uniform(shape, device, low.elem(), high.elem(), dtype) } Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype), Distribution::Normal(mean, std) => { random_normal(shape, device, mean.elem(), std.elem(), dtype) } } } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { permute(tensor, axes) } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { expand(tensor, shape) } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { kernel::flip(tensor, axes, bool_dtype::()) } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::bitwise_and(lhs, rhs) } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype)) } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::bitwise_or(lhs, rhs) } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype)) } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { numeric::bitwise_xor(lhs, rhs) } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype)) } fn bitwise_not(tensor: IntTensor) -> IntTensor { unary_basic_int::launch::(tensor, |_| BasicIntUnaryKind::BitwiseNot) } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { launch_binop_int::(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; launch_scalar_binop_int::(lhs, InputScalar::new(rhs, dtype)) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { launch_binop_int::(lhs, rhs) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let dtype = lhs.dtype; launch_scalar_binop_int::(lhs, InputScalar::new(rhs, dtype)) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { kernel::cast(tensor, dtype.into()) } fn int_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { unfold(tensor, dim, size, step) } } ================================================ FILE: crates/burn-cubecl/src/ops/mod.rs ================================================ mod activation; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; pub(crate) mod base; pub use base::*; pub use qtensor::*; /// Numeric utility functions for jit backends pub mod numeric; ================================================ FILE: crates/burn-cubecl/src/ops/module.rs ================================================ use crate::{ CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement, kernel::{self, conv::ConvTranspose2dStrategy}, }; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor}; use burn_backend::{ TensorMetadata, ops::{ AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }, }; impl ModuleOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { kernel::conv::conv_forward::(x, weight, bias, options, Default::default()).unwrap() } fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { kernel::conv::conv_data_backward( output_grad, weight, x.shape(), options, Default::default(), ) .unwrap() } fn conv1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { kernel::conv::conv_weight_backward::( x, output_grad, weight.shape(), options, Default::default(), ) .unwrap() } fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { kernel::conv::conv_forward::(x, weight, bias, options, Default::default()).unwrap() } fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { kernel::conv::conv_data_backward( output_grad, weight, x.shape(), options, Default::default(), ) .unwrap() } fn conv2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { kernel::conv::conv_weight_backward::( x, output_grad, weight.shape(), options, Default::default(), ) .unwrap() } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap() } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward( x, offset, weight, mask, bias, output_grad, options, ) .unwrap(); DeformConv2dBackward::new(x, o, w, m, b) } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { kernel::conv::conv_forward::(x, weight, bias, options, Default::default()).unwrap() } fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { kernel::conv::conv_data_backward( output_grad, weight, x.shape(), options, Default::default(), ) .unwrap() } fn conv3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { kernel::conv::conv_weight_backward::( x, output_grad, weight.shape(), options, Default::default(), ) .unwrap() } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default()) .unwrap() } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { kernel::conv::conv_transpose3d(x, weight, bias, options).expect("Kernel to never fail") } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { kernel::pool::avg_pool2d( x, kernel_size, stride, padding, count_include_pad, ceil_mode, ) } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { kernel::pool::avg_pool2d_backward( x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode, ) } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { let (output, indices) = kernel::pool::max_pool2d_with_indices( x, kernel_size, stride, padding, dilation, ceil_mode, I::dtype(), ); MaxPool2dWithIndices::new(output, indices) } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward { MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward( x, output_grad, indices, kernel_size, stride, padding, dilation, ceil_mode, )) } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { kernel::pool::adaptive_avg_pool2d(x, output_size) } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { kernel::pool::adaptive_avg_pool2d_backward(x, grad) } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { kernel::interpolate::interpolate(x, output_size, options) } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { kernel::interpolate::interpolate_backward(x, grad, output_size, options) } fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> FloatTensor { // Fall back to naive attention for features the flash kernel doesn't support. if attn_bias.is_some() || options.softcap.is_some() || options.scale.is_some() { return burn_backend::ops::attention::attention_fallback::( query, key, value, mask, attn_bias, options, ); } kernel::attention::attention( query, key, value, mask, attn_bias, options, Default::default(), None, ) .expect("Kernel to never fail") } } ================================================ FILE: crates/burn-cubecl/src/ops/numeric.rs ================================================ use crate::{ CubeRuntime, kernel::utils::{address_type, shape_divmod}, }; use crate::{element::CubeElement, tensor::CubeTensor}; use crate::{ kernel::{ AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, }, ops::max_vector_size, }; use burn_backend::{DType, Shape, TensorMetadata}; use burn_std::Metadata; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{client::ComputeClient, server::MemoryLayout}; use cubecl::{ server::MemoryLayoutDescriptor, std::{FastDivmod, tensor::layout::linear::LinearView}, }; /// Creates a tensor filled with `value` pub fn full( shape: Shape, device: &R::Device, value: E, ) -> CubeTensor { let client = R::client(device); full_client::(client, shape, device.clone(), value) } /// Creates a tensor filled with `value` pub fn full_client( client: ComputeClient, shape: Shape, device: R::Device, value: E, ) -> CubeTensor { let dtype = E::dtype(); full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype) } /// Creates a tensor filled with `value` pub fn full_device_dtype( client: ComputeClient, shape: Shape, device: R::Device, value: InputScalar, dtype: DType, ) -> CubeTensor { let empty = empty_device_dtype(client, device, shape, dtype); #[cube(launch_unchecked, address_type = "dynamic")] pub fn full_kernel( tensor: &mut LinearView, ReadWrite>, value: InputScalar, #[define(C)] _dtype: StorageType, ) { if !tensor.is_in_bounds(ABSOLUTE_POS) { terminate!(); } tensor[ABSOLUTE_POS] = Vector::new(value.get::()); } let num_elems = empty.meta.num_elements(); let vector_size = max_vector_size(&empty); let working_units = num_elems / vector_size as usize; let cube_dim = CubeDim::new(&empty.client, working_units); let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim); unsafe { full_kernel::launch_unchecked( &empty.client, cube_count, cube_dim, address_type!(empty), vector_size, empty.clone().into_linear_view(), value, empty.dtype.into(), ); } empty } /// Creates a tensor filled with zeros pub fn zeros(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor { let client = R::client(&device); full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype) } /// Creates a tensor filled with ones pub fn ones(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor { let client = R::client(&device); full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype) } /// Creates a tensor filled with zeros pub fn zeros_client( client: ComputeClient, device: R::Device, shape: Shape, dtype: DType, ) -> CubeTensor { full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype) } /// Creates a tensor filled with ones pub fn ones_client( client: ComputeClient, device: R::Device, shape: Shape, dtype: DType, ) -> CubeTensor { full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype) } /// Create a tensor with uninitialized memory pub fn empty_device( client: ComputeClient, device: R::Device, shape: Shape, ) -> CubeTensor { let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), size_of::()); CubeTensor::new( client, memory, Metadata::new(shape, strides), device, E::dtype(), ) } /// Create a tensor with uninitialized memory pub fn empty_device_dtype( client: ComputeClient, device: R::Device, shape: Shape, dtype: DType, ) -> CubeTensor { let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), dtype.size()); CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype) } /// Create a contiguous tensor with uninitialized memory pub fn empty_device_contiguous_dtype( client: ComputeClient, device: R::Device, shape: Shape, dtype: DType, ) -> CubeTensor { let descriptor = MemoryLayoutDescriptor::contiguous(shape.clone(), dtype.size()); let MemoryLayout { memory, strides } = client.empty_tensors(vec![descriptor]).remove(0); CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype) } /// Add two tensors pub fn add(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Add a tensor and a scalar pub fn add_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Subtract two tensors pub fn sub(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Subtract a tensor and a scalar pub fn sub_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Multiply two tensors pub fn mul(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Multiply a tensor and a scalar pub fn mul_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Divide two tensors pub fn div(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Divide a tensor by a scalar pub fn div_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Calculate remainder of two tensors pub fn remainder(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Calculate the remainder of a tensor with a scalar pub fn remainder_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop::(lhs, rhs) } /// Calculate the power of two tensors pub fn pow(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop::(lhs, rhs) } /// Bitwise and two tensors pub fn bitwise_and(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise and with a scalar pub fn bitwise_and_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } /// Bitwise or two tensors pub fn bitwise_or(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise or with a scalar pub fn bitwise_or_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } /// Bitwise xor two tensors pub fn bitwise_xor(lhs: CubeTensor, rhs: CubeTensor) -> CubeTensor { launch_binop_int::(lhs, rhs) } /// Bitwise xor with a scalar pub fn bitwise_xor_scalar(lhs: CubeTensor, rhs: InputScalar) -> CubeTensor { launch_scalar_binop_int::(lhs, rhs) } /// Operation family trait for cumulative operations pub(crate) trait CumulativeOpFamily: Send + Sync + 'static { type CumulativeOp: CumulativeOp; } /// Trait for cumulative operations #[cube] pub(crate) trait CumulativeOp: 'static + Send + Sync { /// Execute a cumulative operation fn execute(lhs: C, rhs: C) -> C; /// Get the initial value for the accumulator fn init_value(first_element: C) -> C; } // Operation types struct SumOp; struct ProdOp; struct MaxOp; struct MinOp; // Implement CumulativeOpFamily for each operation impl CumulativeOpFamily for SumOp { type CumulativeOp = Self; } impl CumulativeOpFamily for ProdOp { type CumulativeOp = Self; } impl CumulativeOpFamily for MaxOp { type CumulativeOp = Self; } impl CumulativeOpFamily for MinOp { type CumulativeOp = Self; } // Implement CumulativeOp for each operation type #[cube] impl CumulativeOp for SumOp { fn execute(lhs: N, rhs: N) -> N { lhs + rhs } fn init_value(_first_element: N) -> N { N::zero() } } #[cube] impl CumulativeOp for ProdOp { fn execute(lhs: N, rhs: N) -> N { lhs * rhs } fn init_value(_first_element: N) -> N { N::from_int(1) } } #[cube] impl CumulativeOp for MaxOp { fn execute(lhs: N, rhs: N) -> N { max(lhs, rhs) } fn init_value(first_element: N) -> N { first_element } } #[cube] impl CumulativeOp for MinOp { fn execute(lhs: N, rhs: N) -> N { min(lhs, rhs) } fn init_value(first_element: N) -> N { first_element } } /// Generic cumulative operation kernel /// /// # Limitations /// /// This is a **naive sequential implementation** along the cumulative dimension: /// - Each output element sequentially reads all previous elements along the dimension /// - Computational complexity: O(n^2) memory reads where n is the size of the cumulative dimension /// - **Performance:** Suitable for small tensors or small dimensions. For large tensors, /// performance will degrade significantly compared to an optimized parallel scan algorithm. /// /// # TODO /// /// Implement an efficient GPU-optimized parallel scan algorithm. #[cube(launch_unchecked, address_type = "dynamic")] fn cumulative_kernel( input: &Tensor, output: &mut LinearView, shape: Sequence>, #[comptime] dim: usize, #[define(C)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { terminate!(); } let rank = comptime![shape.len()]; let dim_stride = input.stride(dim); let mut remainder = ABSOLUTE_POS; let mut offset = 0; let mut dim_idx = 0; #[unroll] for i in 0..shape.len() { let i = comptime![rank - i - 1]; let (rem, local_idx) = shape.index(i).div_mod(remainder); remainder = rem; if i == dim { dim_idx = local_idx; } else { offset += local_idx * input.stride(i); } } // Read first element let first_read_idx = offset + dim_idx * dim_stride; let first_elem = input[first_read_idx]; // Initialize accumulator let mut result = O::CumulativeOp::::init_value(first_elem); // Accumulate values for i in 0..=dim_idx { let read_idx = offset + i * dim_stride; result = O::CumulativeOp::::execute(result, input[read_idx]); } output[ABSOLUTE_POS] = result; } /// Compute the cumulative sum along a dimension pub fn cumsum(input: CubeTensor, dim: usize) -> CubeTensor { cumulative_op::(input, dim) } /// Compute the cumulative product along a dimension pub fn cumprod(input: CubeTensor, dim: usize) -> CubeTensor { cumulative_op::(input, dim) } /// Compute the cumulative minimum along a dimension pub fn cummin(input: CubeTensor, dim: usize) -> CubeTensor { cumulative_op::(input, dim) } /// Compute the cumulative maximum along a dimension pub fn cummax(input: CubeTensor, dim: usize) -> CubeTensor { cumulative_op::(input, dim) } /// Generic cumulative operation function fn cumulative_op( input: CubeTensor, dim: usize, ) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype); let num_elems = output.meta.num_elements(); let working_units = num_elems; let cube_dim = CubeDim::new(&client, working_units); let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim); let shape = shape_divmod(&input); unsafe { cumulative_kernel::launch_unchecked::( &client, cube_count, cube_dim, address_type!(input, output), input.into_tensor_arg(), output.clone().into_linear_view(), shape, dim, output.dtype.into(), ); } output } ================================================ FILE: crates/burn-cubecl/src/ops/qtensor.rs ================================================ use burn_backend::{ Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorMetadata, TensorPrimitive, ops::QTensorOps, quantization::{ QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue, QuantizationParametersPrimitive, params_shape, }, tensor::{Device, FloatElem, FloatTensor, IntTensor, QuantizedTensor}, }; use burn_std::Metadata; use cubecl::server::{MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutStrategy}; use cubecl::{e2m1x2, quant::scheme::QuantStore}; use crate::{ CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement, kernel::{self, matmul::MatmulStrategy}, tensor::{CubeTensor, QParams}, }; use super::{into_data, permute, swap_dims}; /// Create a quantized tensor with packed values (u32). fn new_qtensor_optimized( data: Bytes, shape: impl Into, scheme: QuantScheme, device: &R::Device, ) -> CubeTensor { new_qtensor(data, shape, scheme, device, MemoryLayoutStrategy::Optimized) } /// Create a quantized tensor with packed values (u32). fn new_qtensor( data: Bytes, shape: impl Into, scheme: QuantScheme, device: &R::Device, kind: MemoryLayoutStrategy, ) -> CubeTensor { new_quantized(shape, scheme, device, Some(data), kind) } /// Create an empty quantized tensor. pub fn empty_qtensor_optimized( shape: impl Into, scheme: QuantScheme, device: &R::Device, ) -> CubeTensor { empty_qtensor(shape, scheme, device, MemoryLayoutStrategy::Optimized) } /// Create an empty quantized tensor. pub fn empty_qtensor( shape: impl Into, scheme: QuantScheme, device: &R::Device, kind: MemoryLayoutStrategy, ) -> CubeTensor { new_quantized(shape, scheme, device, None, kind) } fn new_quantized( shape: impl Into, scheme: QuantScheme, device: &R::Device, data: Option, alloc_kind: MemoryLayoutStrategy, ) -> CubeTensor { let client = R::client(device); let shape: Shape = shape.into(); let mut shape_value: Shape = shape.clone(); let rank = shape.rank(); let shape_last = shape[rank - 1]; let num_quants = scheme.num_quants(); let data_size = match scheme.store { QuantStore::PackedU32(_) => { if !shape_last.is_multiple_of(num_quants) { panic!("Can't store in u32") } shape_value[rank - 1] = shape_last.div_ceil(num_quants); size_of::() } QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => { size_of::() } QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 => { panic!("Can't store native sub-byte values") } }, QuantStore::PackedNative(_) => match scheme.value { QuantValue::E2M1 => size_of::(), other => panic!("{other:?} doesn't support native packing"), }, }; let scales_dtype = match scheme.param { QuantParam::F32 => DType::F32, QuantParam::F16 => DType::F16, QuantParam::BF16 => DType::BF16, // Represented by U8 and reinterpreted in the kernel QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8, }; let scales_shape = params_shape(&shape, scheme.level); let data_desc = MemoryLayoutDescriptor::new(alloc_kind, shape_value.clone(), data_size); let scales_desc = MemoryLayoutDescriptor::new(alloc_kind, scales_shape.clone(), scales_dtype.size()); let mut tensors = match data { Some(data) => { let num_bytes = shape_value.num_elements() * data_size; match data.split(num_bytes) { Ok((bytes_data, bytes_scales)) => client .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]), Err((data, _)) => client.create_tensors_from_slices(vec![ (data_desc, &data[..num_bytes]), (scales_desc, &data[num_bytes..]), ]), } } None => client.empty_tensors(vec![data_desc, scales_desc]), }; let MemoryLayout { memory: scales_handle, strides: scales_strides, } = tensors.remove(1); let MemoryLayout { memory, strides } = tensors.remove(0); let scales = QParamTensor { offset_start: scales_handle.offset_start.unwrap_or(0) as usize, offset_end: scales_handle.offset_end.unwrap_or(0) as usize, metadata: Metadata::new(scales_shape, scales_strides), dtype: scales_dtype, }; let qparams = QParams { scales }; CubeTensor::new_quantized( client, memory, shape, device.clone(), strides, DType::QFloat(scheme), qparams, ) } impl QTensorOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { DType::QFloat(scheme) => match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1, .. } => { // TensorData quantized representation is the same, with multiple quantized values // packed into u32 and quantization parameters appended to the bytes new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device) } }, _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", data.dtype ), } } // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor) fn quantize( tensor: FloatTensor, scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { kernel::quantization::quantize(tensor, scheme, qparams.scales) } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { kernel::quantization::dequantize(tensor, FloatElem::::dtype()) } fn q_device(tensor: &QuantizedTensor) -> Device { tensor.device.clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { super::to_device(tensor, device) } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { super::q_reshape(tensor, shape) } async fn q_into_data(tensor: QuantizedTensor) -> Result { if tensor.qparams.is_none() { return into_data(tensor).await; } let (shape, dtype) = (tensor.shape(), tensor.dtype); let (values, params) = tensor.quantized_handles().unwrap(); let mut data_values = into_data(values).await?; let data_params = into_data(params).await?; data_values.bytes.extend_from_byte_slice(&data_params.bytes); Ok(TensorData { bytes: data_values.bytes, shape, dtype, }) } fn q_swap_dims( tensor: QuantizedTensor, dim1: usize, dim2: usize, ) -> QuantizedTensor { swap_dims(tensor, dim1, dim2) } fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { permute(tensor, axes) } fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_gather( _dim: usize, _tensor: QuantizedTensor, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_select( _tensor: QuantizedTensor, _dim: usize, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_slice(_tensor: QuantizedTensor, _slices: &[Slice]) -> QuantizedTensor { unimplemented!() } fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { let (propagation, scheme) = match (&lhs, &rhs) { (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()), (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()), _ => unreachable!(), }; // Inherit precision for mixed inputs, default to `FloatElem` for fully quantized. let out_dtype = match (&lhs, &rhs) { (TensorPrimitive::Float(lhs), _) => lhs.dtype, (_, TensorPrimitive::Float(rhs)) => rhs.dtype, _ => F::dtype(), }; let (_lhs_dtype, lhs) = match lhs { TensorPrimitive::Float(lhs) => (lhs.dtype, lhs), TensorPrimitive::QFloat(lhs) => (out_dtype, lhs), }; let (_rhs_dtype, rhs) = match rhs { TensorPrimitive::Float(rhs) => (rhs.dtype, rhs), TensorPrimitive::QFloat(rhs) => (out_dtype, rhs), }; let out = kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap(); match propagation { QuantPropagation::Propagate => { TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme)) } QuantPropagation::Inhibit => TensorPrimitive::Float(out), } } } ================================================ FILE: crates/burn-cubecl/src/ops/tensor.rs ================================================ use super::{expand, numeric, permute, unfold}; use crate::CubeBackend; use crate::element::bool_dtype; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::unary_basic::BasicFloatUnaryKind; use crate::kernel::{ self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic, }; use crate::{CubeRuntime, FloatElement, IntElement}; use crate::{ element::BoolElement, kernel::matmul::{MatmulStrategy, matmul}, }; use burn_backend::ops::GridSampleOptions; use burn_backend::tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use burn_backend::{Backend, ExecutionError, Scalar}; use burn_backend::{DType, ElementConversion, FloatDType, Slice}; use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps}; use cubecl::prelude::*; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; impl FloatTensorOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(data), fields(?data.shape, ?data.dtype) ))] fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { match data.dtype { DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device), _ => unimplemented!("Unsupported dtype for `float_from_data`"), } } fn float_random( shape: Shape, distribution: Distribution, device: &Device, ) -> FloatTensor { let dtype = FloatElem::::dtype(); match distribution { Distribution::Default => random_uniform(shape, device, 0., 1., dtype), Distribution::Uniform(low, high) => { random_uniform(shape, device, low.elem(), high.elem(), dtype) } Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype), Distribution::Normal(mean, std) => { random_normal(shape, device, mean.elem(), std.elem(), dtype) } } } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype) ))] async fn float_into_data(tensor: FloatTensor) -> Result { super::into_data(tensor).await } fn float_device(tensor: &FloatTensor) -> Device { tensor.device.clone() } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype) ))] fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { super::to_device(tensor, device) } fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let dtype = dtype.into(); super::empty(shape, device, dtype) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::add(lhs, rhs) } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let dtype = lhs.dtype; numeric::add_scalar(lhs, InputScalar::new(rhs, dtype)) } fn float_zeros(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let dtype = dtype.into(); numeric::zeros(device.clone(), shape, dtype) } fn float_full( shape: Shape, fill_value: Scalar, device: &R::Device, dtype: FloatDType, ) -> FloatTensor { let dtype: DType = dtype.into(); let client = R::client(device); numeric::full_device_dtype( client, shape, device.clone(), InputScalar::new(fill_value, dtype), dtype, ) } fn float_ones(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let dtype = dtype.into(); numeric::ones(device.clone(), shape, dtype) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::sub(lhs, rhs) } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let dtype = lhs.dtype; numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype)) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::mul(lhs, rhs) } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let dtype = lhs.dtype; numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype)) } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::div(lhs, rhs) } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let dtype = lhs.dtype; numeric::div_scalar(lhs, InputScalar::new(rhs, dtype)) } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::remainder(lhs, rhs) } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let dtype = lhs.dtype; numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype)) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let dtype = lhs.dtype; matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap() } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { kernel::cross(lhs, rhs, dim) } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { super::swap_dims(tensor, dim1, dim2) } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { super::reshape(tensor, shape) } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { kernel::gather(dim, tensor, indices) } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { kernel::scatter(dim, tensor, indices, value, false) } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { kernel::select(tensor, dim, indices) } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { kernel::select_assign(tensor, dim, indices, value, false) } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { // Check if all steps are 1 let all_steps_one = slices.iter().all(|info| info.step == 1); if all_steps_one { // Use optimized slice for step=1 let simple_ranges: Vec> = slices .iter() .enumerate() .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) } else { // Use slice with steps kernel kernel::slice_with_steps(tensor, slices) } } fn float_slice_assign( tensor: FloatTensor, ranges: &[Slice], value: FloatTensor, ) -> FloatTensor { kernel::slice_assign(tensor, ranges, value) } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { kernel::mask_where_auto(tensor, mask, value, bool_dtype::()) } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { let dtype = tensor.dtype; kernel::mask_fill_auto( tensor, mask, InputScalar::new(value, dtype), bool_dtype::(), ) } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { kernel::equal(lhs, rhs, bool_dtype::()) } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { kernel::greater(lhs, rhs, bool_dtype::()) } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { kernel::greater_equal(lhs, rhs, bool_dtype::()) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { kernel::lower(lhs, rhs, bool_dtype::()) } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { kernel::lower_equal(lhs, rhs, bool_dtype::()) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let dtype = lhs.dtype; kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::()) } fn float_sum(tensor: FloatTensor) -> FloatTensor { reduce::sum_fallback(tensor, Default::default()).unwrap() } fn float_max(tensor: FloatTensor) -> FloatTensor { reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap() } fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Max, ) .unwrap() } fn float_min(tensor: FloatTensor) -> FloatTensor { reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap() } fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Min, ) .unwrap() } fn float_max_abs(tensor: FloatTensor) -> FloatTensor { reduce::reduce( tensor, None, Default::default(), ReduceOperationConfig::MaxAbs, ) .unwrap() } fn float_max_abs_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::MaxAbs, ) .unwrap() } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Sum, ) .unwrap() } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Mean, ) .unwrap() } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { numeric::cumsum(tensor, dim) } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { numeric::cumprod(tensor, dim) } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { numeric::cummin(tensor, dim) } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { numeric::cummax(tensor, dim) } fn float_prod(tensor: FloatTensor) -> FloatTensor { reduce::reduce( tensor, None, Default::default(), ReduceOperationConfig::Prod, ) .unwrap() } fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce::reduce_dim( tensor, None, dim, Default::default(), ReduceOperationConfig::Prod, ) .unwrap() } fn float_exp(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Exp) } fn float_log(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Log) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Log1p) } fn float_powf_scalar_impl(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { struct Powf; #[cube] impl FloatUnaryOp for Powf { type Options = InputScalar; fn execute(input: Vector, options: &Self::Options) -> Vector { Vector::powf(input, Vector::new(options.get::())) } } impl FloatUnaryOpFamily for Powf { type Options = InputScalar; type Unary = Self; } let dtype = lhs.dtype; launch_unary_float::(lhs, |_| InputScalar::new(rhs, dtype)) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Sqrt) } fn float_abs(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Abs) } fn float_sign(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Sign) } fn float_cos(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Cos) } fn float_sin(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Sin) } fn float_tan(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Tan) } fn float_cosh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Cosh) } fn float_sinh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Sinh) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Tanh) } fn float_acos(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcCos) } fn float_acosh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcCosh) } fn float_asin(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcSin) } fn float_asinh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcSinh) } fn float_atan(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcTan) } fn float_atanh(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::ArcTanh) } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { crate::kernel::atan2::(lhs, rhs) } fn float_round(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Round) } fn float_floor(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Floor) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Ceil) } fn float_trunc(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Trunc) } fn float_erf(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Erf) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, Some(::IntElem::dtype()), dim, Default::default(), ReduceOperationConfig::ArgMax, ) .unwrap() } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { reduce::reduce_dim( tensor, Some(::IntElem::dtype()), dim, Default::default(), ReduceOperationConfig::ArgMin, ) .unwrap() } fn float_into_int(tensor: FloatTensor) -> IntTensor { kernel::cast(tensor, I::dtype()) } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { let dtype = tensor.dtype; kernel::clamp( tensor, InputScalar::new(min, dtype), InputScalar::new(max, dtype), ) } fn float_recip(tensor: FloatTensor) -> FloatTensor { unary_basic::launch::(tensor, |_| BasicFloatUnaryKind::Recip) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { kernel::repeat_dim(tensor, dim, times) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { numeric::pow(lhs, rhs) } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { permute(tensor, axes) } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { expand(tensor, shape) } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { kernel::flip(tensor, axes, bool_dtype::()) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { kernel::cast(tensor, dtype.into()) } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { unfold(tensor, dim, size, step) } fn float_is_nan(tensor: FloatTensor) -> BoolTensor { kernel::is_nan(tensor, bool_dtype::()) } fn float_is_inf(tensor: FloatTensor) -> BoolTensor { kernel::is_inf(tensor, bool_dtype::()) } fn float_grid_sample_2d( tensor: FloatTensor, grid: FloatTensor, options: GridSampleOptions, ) -> FloatTensor { kernel::grid_sample::grid_sample(tensor, grid, options) } } ================================================ FILE: crates/burn-cubecl/src/ops/transaction.rs ================================================ use burn_backend::{ DType, TensorData, backend::ExecutionError, ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData}, }; use burn_std::{Shape, Strides}; use cubecl::server::{CopyDescriptor, Handle}; use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement}; impl TransactionOps for CubeBackend where R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement, { async fn tr_execute( transaction: TransactionPrimitive, ) -> Result { let mut client = None; enum Kind { Float, Int, Bool, } #[derive(new)] struct BindingData { index: usize, kind: Kind, handle: Option, shape: Shape, strides: Strides, dtype: DType, } let mut num_bindings = 0; let mut kinds = Vec::new(); for t in transaction.read_floats.into_iter() { if client.is_none() { client = Some(t.client.clone()); } let t = crate::kernel::into_contiguous_aligned(t); let binding = BindingData::new( num_bindings, Kind::Float, Some(t.handle.clone()), t.meta.shape.clone(), t.meta.strides.clone(), t.dtype, ); kinds.push(binding); num_bindings += 1; } for t in transaction.read_ints.into_iter() { if client.is_none() { client = Some(t.client.clone()); } let t = crate::kernel::into_contiguous_aligned(t); let binding = BindingData::new( num_bindings, Kind::Int, Some(t.handle.clone()), t.meta.shape.clone(), t.meta.strides.clone(), t.dtype, ); kinds.push(binding); num_bindings += 1; } for t in transaction.read_bools.into_iter() { if client.is_none() { client = Some(t.client.clone()); } let t = crate::kernel::into_contiguous_aligned(t); let binding = BindingData::new( num_bindings, Kind::Bool, Some(t.handle.clone()), t.meta.shape.clone(), t.meta.strides.clone(), t.dtype, ); kinds.push(binding); num_bindings += 1; } let client = client.unwrap(); let bindings = kinds .iter_mut() .map(|b| { CopyDescriptor::new( b.handle.take().unwrap().binding(), b.shape.clone(), b.strides.clone(), b.dtype.size(), ) }) .collect(); let mut data: Vec> = client .read_tensor_async(bindings) .await .map_err(|err| ExecutionError::WithContext { reason: format!("{err:?}"), })? .into_iter() .map(Some) .collect::>>(); let mut result = TransactionPrimitiveData::default(); for binding in kinds { let bytes = data.get_mut(binding.index).unwrap().take().unwrap(); let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype); match binding.kind { Kind::Float => { result.read_floats.push(t_data); } Kind::Int => { result.read_ints.push(t_data); } Kind::Bool => { result.read_bools.push(t_data); } } } Ok(result) } } ================================================ FILE: crates/burn-cubecl/src/template/base.rs ================================================ use super::SourceTemplate; use crate::{CubeRuntime, element::CubeElement, tensor::CubeTensor}; use cubecl::{CompilationError, Compiler, CubeTask, prelude::*}; /// Kernel source to create a [source](SourceTemplate) pub trait KernelSource: Send + 'static + Sync { /// Convert to [source](SourceTemplate) fn source(&self) -> SourceTemplate; /// Identifier for the kernel, used for caching kernel compilation. fn id(&self) -> KernelId; } #[derive(new)] /// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask). pub struct SourceKernel { kernel_source: K, cube_dim: CubeDim, } impl CubeTask for SourceKernel { fn compile( &self, _compiler: &mut C, _options: &C::CompilationOptions, _mode: ExecutionMode, _address_type: StorageType, ) -> Result, CompilationError> { let source_template = self.kernel_source.source(); let source = source_template.complete(); Ok(CompiledKernel { entrypoint_name: "main".to_string(), debug_name: Some(core::any::type_name::()), source, cube_dim: self.cube_dim, debug_info: None, repr: None, }) } } impl KernelMetadata for SourceKernel { fn id(&self) -> KernelId { self.kernel_source.id() } fn address_type(&self) -> StorageType { u32::as_type_native_unchecked().storage_type() } } /// Generates kernel source code by replacing some information using templating. #[macro_export] macro_rules! kernel_source { ( $struct:ident, $file:expr ) => { /// Generated kernel from a source file. #[derive(new)] pub struct $struct; impl $struct { fn source(&self) -> $crate::template::SourceTemplate { $crate::template::SourceTemplate::new(include_str!($file)) } } }; } /// Create a vector containing the dimension, strides and shape of tensors. /// /// # Example /// /// With two tensors (lhs, rhs) /// /// | Indexes | Value | /// |:------------------------:|:-----------:| /// | 0..1 | D | /// | 1..D + 1 | lhs strides | /// | (D + 1)..(2 * D + 1) | rhs strides | /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | pub fn build_info(tensors: &[&CubeTensor]) -> Vec { let ndims = tensors[0].meta.num_dims(); let mut info: Vec = vec![0; tensors.len() * 2 * ndims + 1]; info[0] = ndims as u32; let mut current = 1; for tensor in tensors.iter() { for d in 0..ndims { info[current] = tensor.meta.strides()[d] as u32; current += 1; } } for tensor in tensors.iter() { for d in 0..ndims { info[current] = tensor.meta.shape()[d] as u32; current += 1; } } info } ================================================ FILE: crates/burn-cubecl/src/template/mod.rs ================================================ mod base; pub use base::*; mod source; pub use source::*; ================================================ FILE: crates/burn-cubecl/src/template/source.rs ================================================ use std::collections::HashMap; /// Kernel source code abstraction allowing for templating. /// /// The templates can have text placeholders in the form {{ label }}. /// They will be updated with their proper value when `generate` is called. #[derive(Debug)] pub struct SourceTemplate { items: HashMap, templates: Vec, } impl SourceTemplate { /// Create a new source template. pub fn new(template: S) -> Self where S: Into, { Self { items: HashMap::new(), templates: vec![template.into()], } } /// Register the value for a placeholder item. /// /// # Notes /// /// The value can't have placeholders, since it would require recursive templating with /// possibly circular dependencies. If you want to add a value that has some /// placeholders, consider adding a new template to the source using /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can /// register the function call instead. pub fn register(mut self, name: Name, value: Value) -> Self where Name: Into, Value: Into, { self.items.insert(name.into(), value.into()); self } /// Add a new template. pub fn add_template(mut self, template: S) -> Self where S: Into, { self.templates.push(template.into()); self } /// Complete the template and returns the source code. pub fn complete(mut self) -> String { let mut source = self.templates.remove(0); for s in self.templates.into_iter() { source.push_str(&s); } let template = text_placeholder::Template::new(&source); let mut context = HashMap::new(); for (key, value) in self.items.iter() { context.insert(key.as_str(), value.as_str()); } template.fill_with_hashmap(&context) } } ================================================ FILE: crates/burn-cubecl/src/tensor/base.rs ================================================ use crate::CubeRuntime; use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric}; use burn_backend::quantization::QuantScheme; use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata}; use burn_std::{Metadata, strides, tensor::is_contiguous}; use cubecl::server::Handle; use cubecl::std::tensor::TensorHandle; use cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch}; use cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch}; use cubecl::{ prelude::{TensorBinding, *}, std::tensor::layout::linear::LinearViewLayout, }; use std::marker::PhantomData; use super::QParams; /// The basic tensor primitive struct. pub struct CubeTensor { /// Compute client for the [runtime](CubeRuntime). pub client: ComputeClient, /// The buffer where the data are stored. pub handle: Handle, /// The metadata of the tensor. pub meta: Box, /// The device of the tensor. pub device: R::Device, /// The datatype of the tensor. pub dtype: DType, /// Runtime quantization parameters, if applicable pub qparams: Option, } impl From> for TensorHandle { fn from(val: CubeTensor) -> Self { TensorHandle::new( val.handle.clone(), val.meta.shape().clone(), val.meta.strides().clone(), val.dtype, ) } } impl cubecl::tune::AutotuneOutput for CubeTensor { #[cfg(feature = "autotune-checks")] fn check_equivalence(&self, other: Self) { use crate::ops::into_data_sync; use burn_backend::Tolerance; let expected = into_data_sync::(self.clone()); let actual = into_data_sync::(other); expected.assert_approx_eq::(&actual, Tolerance::permissive()); } } // TODO: Needed to cleanup leaves tensor. // // Maybe not needed when fusion is activated, since we have a detector there. // We could rely on basic GC strategy when not using fusion. // // impl Drop for CubeTensor { // fn drop(&mut self) { // todo!() // } // } impl core::fmt::Debug for CubeTensor where R: CubeRuntime, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}", self.meta.shape(), self.device, self.meta.strides(), self.dtype.name(), R::name(&self.client), )) } } impl Clone for CubeTensor where R: CubeRuntime, { fn clone(&self) -> Self { Self { client: self.client.clone(), handle: self.handle.clone(), meta: self.meta.clone(), device: self.device.clone(), dtype: self.dtype, qparams: self.qparams.clone(), } } } impl TensorMetadata for CubeTensor { fn dtype(&self) -> DType { self.dtype } fn shape(&self) -> Shape { self.meta.shape().clone() } fn rank(&self) -> usize { self.meta.rank() } } impl QTensorPrimitive for CubeTensor { fn scheme(&self) -> &QuantScheme { if let DType::QFloat(scheme) = &self.dtype { scheme } else { panic!( "Quantization scheme is not valid for dtype {:?}", self.dtype, ) } } } impl CubeTensor where R: CubeRuntime, { /// Create a new standard tensor pub fn new( client: ComputeClient, handle: Handle, metadata: Metadata, device: R::Device, dtype: DType, ) -> Self { CubeTensor { client, handle, meta: Box::new(metadata), device, dtype, qparams: None, } } /// Create a new tensor with a contiguous memory layout. pub fn new_contiguous( client: ComputeClient, device: R::Device, shape: Shape, handle: Handle, dtype: DType, ) -> Self { let ndims = shape.num_dims(); let mut strides = strides![0; ndims]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { strides[index] = current; current *= val; }); Self { client, handle, meta: Box::new(Metadata::new(shape, strides)), device, dtype, qparams: None, } } /// Change the context of the current tensor and return the newly transferred tensor. pub fn to_client(&self, client: ComputeClient, device: R::Device) -> Self { let desc = self.handle.clone().copy_descriptor( self.meta.shape().clone(), self.meta.strides().clone(), self.elem_size(), ); let handle = self.client.to_client_tensor(desc, &client); Self { client, handle, meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())), device, dtype: self.dtype, qparams: self.qparams.clone(), } } /// Return the reference to a tensor handle. pub fn binding(self) -> TensorBinding { TensorBinding { handle: self.handle.binding(), strides: self.meta.strides, shape: self.meta.shape, runtime: PhantomData, } } /// Returns the element size of this tensor pub fn elem_size(&self) -> usize { self.dtype.size() } /// Return the reference to a tensor argument. pub fn into_tensor_arg(self) -> TensorArg { self.binding().into_tensor_arg() } /// Return the reference to an array argument. pub fn into_array_arg(self) -> ArrayArg { self.into_tensor_arg().into_array_arg() } /// Returns a reference to the aliased tensor argument. pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg { TensorArg::Alias { input_pos, strides: self.meta.strides().clone(), shape: self.meta.shape().clone(), } } /// Return a linear view of this tensor. pub fn into_linear_view(self) -> LinearViewLaunch { let layout = LinearViewLayoutLaunch::new(); let buffer = self.into_tensor_arg(); LinearViewLaunch::new_tensor::(buffer, layout) } /// Return an aliased linear view of this tensor pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch { let layout = LinearViewLayoutLaunch::new(); let buffer = self.as_tensor_alias(input_pos); LinearViewLaunch::new_tensor::(buffer, layout) } /// Return a linear view broadcast to the reference tensor's shape pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch { let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape()); let buffer = self.into_tensor_arg(); LinearViewLaunch::new_tensor::(buffer, layout) } /// Returns the address type required to index this tensor pub fn required_address_type(&self) -> AddressType { match self.try_scheme() { Some(scheme) => { let len = self.handle.size() as usize * 8 / scheme.size_bits_value(); AddressType::from_len(len) } None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()), } } /// Return the `QuantScheme` if present pub fn try_scheme(&self) -> Option<&QuantScheme> { match &self.dtype { DType::QFloat(scheme) => Some(scheme), _ => None, } } pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool { if !self.handle.can_mut() || !self.is_nonoverlapping() { return false; } let ndims = self.meta.num_dims(); for i in 0..ndims { let shape_lhs = self.meta.shape()[i]; let shape_rhs = rhs.meta.shape()[i]; // Output tensor will be different from the mutable tensor. if shape_lhs < shape_rhs { return false; } } true } /// Copy the current tensor. pub fn copy(&self) -> Self { struct Copy; #[cube] impl NumericUnaryOp for Copy { type Options = (); fn execute(input: Vector, _options: &Self::Options) -> Vector { input } } impl NumericUnaryOpFamily for Copy { type Options = (); type Unary = Self; } let tensor = self.clone(); launch_unary_numeric::(tensor, |_| ()) } /// Check if the tensor is safe to mutate. pub fn can_mut(&self) -> bool { self.handle.can_mut() } /// Assert that both tensors are on the same device. pub fn assert_is_on_same_device(&self, other: &Self) { if self.device != other.device { panic!( "Both tensors should be on the same device {:?} != {:?}", self.device, other.device ); } } /// Check if the current tensor is contiguous. /// /// A tensor is contiguous if the elements are stored in memory /// if the strides in non-increasing order and the /// strides at position k is equal to the product of the shapes /// at all positions greater than k. However, all axes with a shape of 1 are ignored. pub fn is_contiguous(&self) -> bool { is_contiguous(self.meta.shape(), self.meta.strides()) } /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory /// regions within the shape). pub fn is_contiguous_buffer(&self) -> bool { self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize } /// Checks if the tensor is non-overlapping (can be safely written to). pub fn is_nonoverlapping(&self) -> bool { let shape = self.meta.shape(); let strides = self.meta.strides(); if strides.contains(&0) { return false; } let rank = self.rank(); if rank > 1 { let mut dims = shape.iter().zip(strides.iter()).collect::>(); dims.sort_by_key(|(_, stride)| **stride); let mut max_offset = 0; for (shape, stride) in dims.into_iter() { if *stride <= max_offset && *shape != 1 { return false; } max_offset += (*shape - 1) * *stride; } } true } } #[cfg(test)] mod tests { use super::*; #[test] fn is_contiguous_non_increasing() { assert!(is_contiguous(&[3, 1], &[1, 1])); } #[test] fn is_contiguous_basic() { assert!(is_contiguous(&[32, 32], &[32, 1])); } #[test] fn is_contiguous_permuted() { assert!(!is_contiguous(&[32, 32], &[1, 32])); } #[test] fn is_contiguous_slice() { assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1])); } #[test] fn is_contiguous_4d_positive() { assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1])); } #[test] fn is_contiguous_4d_negative() { assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1])); } /// Based on a bug encountered in interpolate_1d #[test] fn is_contiguous_4d_unit_shape() { assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8])); } } ================================================ FILE: crates/burn-cubecl/src/tensor/mod.rs ================================================ mod base; mod quantization; pub use base::*; pub use quantization::*; ================================================ FILE: crates/burn-cubecl/src/tensor/quantization.rs ================================================ use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor}; use burn_std::{Metadata, Strides}; use cubecl::quant::scheme::{QuantStore, QuantValue}; use cubecl::{client::ComputeClient, server::Handle}; use crate::CubeRuntime; use super::CubeTensor; /// Runtime parameters for quantization. Can be used to construct a scales handle from the base /// tensor handle. pub type QParams = burn_backend::quantization::QParams; impl CubeTensor { /// Create a new quantized tensor pub fn new_quantized( client: ComputeClient, handle: Handle, shape: Shape, device: R::Device, strides: Strides, dtype: DType, qparams: QParams, ) -> Self { CubeTensor { client, handle, meta: Box::new(Metadata::new(shape, strides)), device, dtype, qparams: Some(qparams), } } /// Returns the two tensors: (values, params) for a quantized tensor. /// For the values, native types that aren't supported as a normal `DType` will be returned /// as an unsigned integer tensor representing the bits. Should be reconstructed using `from_bits` /// in kernels. pub fn quantized_handles(&self) -> Option<(CubeTensor, CubeTensor)> { let params = self.scales()?; let scheme = match self.dtype { DType::QFloat(sc) => sc, _ => return None, }; let values = match scheme.store { QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S => CubeTensor { client: self.client.clone(), handle: self.handle.clone(), meta: self.meta.clone(), device: self.device.clone(), dtype: DType::I8, qparams: None, }, QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor { client: self.client.clone(), handle: self.handle.clone(), meta: self.meta.clone(), device: self.device.clone(), dtype: DType::U8, qparams: None, }, QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 => { panic!("Can't store native sub-byte values") } }, QuantStore::PackedU32(packed_dim) => { let packed_dim = self.rank() - packed_dim - 1; let mut shape = self.shape(); shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants()); CubeTensor { client: self.client.clone(), handle: self.handle.clone(), meta: Box::new(Metadata::new(shape, self.meta.strides.clone())), device: self.device.clone(), dtype: DType::U32, qparams: None, } } QuantStore::PackedNative(packed_dim) => match scheme.value { QuantValue::E2M1 => { let packed_dim = self.rank() - packed_dim - 1; let mut shape = self.shape(); shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants()); CubeTensor { client: self.client.clone(), handle: self.handle.clone(), meta: Box::new(Metadata::new(shape, self.meta.strides.clone())), device: self.device.clone(), dtype: DType::U8, qparams: None, } } other => panic!("{other:?} doesn't support native packing"), }, }; Some((values, params)) } /// Construct a separate tensor for the quantization scales, if present pub fn scales(&self) -> Option> { let qparams = self.qparams.as_ref()?; let mut handle = self.handle.clone(); handle.offset_start = Some(qparams.scales.offset_start as u64); handle.offset_end = Some(qparams.scales.offset_end as u64); Some(CubeTensor::new( self.client.clone(), handle, qparams.scales.metadata.clone(), self.device.clone(), qparams.scales.dtype, )) } } ================================================ FILE: crates/burn-cubecl/src/tune_key.rs ================================================ use crate::kernel::{ conv::{ConvAutotuneKey, ConvTranspose2dAutotuneKey}, reduce::SumAutotuneKey, }; use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; use std::fmt::Display; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] /// Key for all autotune-enabled operations pub enum CubeAutotuneKey { /// Key for sum operations Sum(SumAutotuneKey), /// Key for convolution operations Conv(ConvAutotuneKey), /// Key for transpose convolution operations ConvTranspose(ConvTranspose2dAutotuneKey), } impl Display for CubeAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { CubeAutotuneKey::Sum(reduce_key) => std::fmt::Debug::fmt(&reduce_key, f), CubeAutotuneKey::Conv(conv_key) => std::fmt::Debug::fmt(&conv_key, f), CubeAutotuneKey::ConvTranspose(conv_key) => std::fmt::Debug::fmt(&conv_key, f), } } } impl AutotuneKey for CubeAutotuneKey {} ================================================ FILE: crates/burn-cubecl-fusion/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Provide optimizations that can be used with cubecl based backends." documentation = "https://docs.rs/burn-cubecl-fusion" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu"] license.workspace = true name = "burn-cubecl-fusion" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl-fusion" version.workspace = true [lints] workspace = true [features] default = ["autotune", "std", "cubecl/default", "burn-fusion/default"] autotune = [] autotune-checks = ["cubecl/autotune-checks", "burn-backend", "half"] doc = ["default"] std = ["cubecl/std", "burn-backend?/std", "burn-fusion/std"] tracing = [ "cubecl/tracing", "burn-std/tracing", "burn-backend/tracing", "burn-fusion/tracing", ] [dependencies] burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", default-features = false } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", features = [ "cubecl", ] } cubecl = { workspace = true } cubek = { workspace = true, features = [ "matmul", "reduce", "quantization", "stdlib", ] } half = { workspace = true, optional = true } # Only for `TensorData` with autotune-checks burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false, optional = true } derive-new = { workspace = true } serde = { workspace = true } [dev-dependencies] cubecl = { workspace = true, features = ["test-runtime"] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-cubecl-fusion/README.md ================================================ # Burn CubeCl Fusion Provide optimizations that can be used with [cubecl](../burn-cubecl) based backends. ================================================ FILE: crates/burn-cubecl-fusion/src/base.rs ================================================ use burn_fusion::stream::Context; use burn_std::{DType, Shape, Strides, quantization::QParamTensor, strides}; use cubecl::quant::scheme::{QuantParam, QuantScheme}; use cubecl::{ Runtime, client::ComputeClient, ir::AddressType, prelude::{TensorArg, TensorBinding}, }; use std::marker::PhantomData; /// Defines a fallback operation when fusion isn't possible. pub trait FallbackOperation: Send + Sync { /// Executes the fallback procedure. fn run(&self, context: &mut Context<'_, CubeFusionHandle>); } /// Runtime parameters for quantization. Can be used to construct a scales handle from the base /// tensor handle. pub type QParams = burn_std::quantization::QParams; /// Handle to be used when fusing operations. pub struct CubeFusionHandle { /// Compute client for jit. pub client: ComputeClient, /// The buffer where the data are stored. pub handle: cubecl::server::Handle, /// The device of the current tensor. pub device: R::Device, /// The element type of the tensor. pub dtype: DType, /// The strides of the tensor. pub strides: Strides, /// Quantization runtime parameters, if applicable pub qparams: Option, } impl core::fmt::Debug for CubeFusionHandle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "CubeFusionHandle {{ device: {:?}, runtime: {}}}", self.device, R::name(&self.client), )) } } impl Clone for CubeFusionHandle { fn clone(&self) -> Self { Self { client: self.client.clone(), handle: self.handle.clone(), device: self.device.clone(), strides: self.strides.clone(), dtype: self.dtype, qparams: self.qparams.clone(), } } } unsafe impl Send for CubeFusionHandle {} unsafe impl Sync for CubeFusionHandle {} impl CubeFusionHandle { /// Return the reference to a tensor handle. pub fn binding(self, shape: Shape) -> TensorBinding { TensorBinding { handle: self.handle.binding(), strides: self.strides.clone(), shape, runtime: PhantomData, } } pub fn required_address_type(&self) -> AddressType { match self.dtype { DType::QFloat(scheme) => { let len = self.handle.size() as usize * 8 / scheme.size_bits_value(); AddressType::from_len(len) } _ => AddressType::from_len(self.handle.size() as usize / self.dtype.size()), } } /// Return the reference to a tensor argument. pub fn into_tensor_arg(self, shape: Shape) -> TensorArg { let handle = self.binding(shape); handle.into_tensor_arg() } /// Construct a separate tensor for the quantization scales, if present pub fn params(&self, scheme: QuantScheme) -> Option { let qparams = self.qparams.as_ref()?; let mut handle = self.handle.clone(); handle.offset_start = Some(qparams.scales.offset_start as u64); handle.offset_end = Some(qparams.scales.offset_end as u64); Some(Self { client: self.client.clone(), handle, device: self.device.clone(), dtype: match scheme.param { QuantParam::F32 => DType::F32, QuantParam::F16 => DType::F16, QuantParam::BF16 => DType::BF16, QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!("Not yet supported"), }, strides: qparams.scales.metadata.strides().clone(), qparams: None, }) } } pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides { let mut strides = strides![0; shape.len()]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { strides[index] = current; current *= val; }); strides } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/base.rs ================================================ use cubecl::{define_scalar, define_size}; define_scalar!(pub DynElem); define_size!(pub DynSize); ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/io.rs ================================================ //! This module declares input-output primitives to read and write values during kernel expansion. use crate::engine::codegen::{DynElem, DynSize}; use super::{ir::*, tensor::GlobalTensor}; use burn_std::quantization::QuantScheme; use cubecl::quant::scheme::QuantLevel; use cubecl::{ intrinsic, ir::{ManagedVariable, Variable}, prelude::*, std::{FastDivmod, tensor::View}, }; use cubek::quantization::layout::{BlockScaledLayout, PerTensorLayout, ScalesLayout}; use serde::{Deserialize, Serialize}; /// Define how a tensor might be transformed at runtime. #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum Transform { /// A reshape operation has been registered on a tensor. /// /// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the /// new shape for the current tensor. Reshape(Vec), /// Two axes have been swapped on a tensor. /// /// The enum entry contains those two axes. SwapDims(usize, usize), } /// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive. /// /// # Notes /// /// The [global arguments](GlobalArgs) for both inputs and outputs as well as the /// [local arguments](LocalArgs) need to be passed to this function. /// /// This is because the [argument](FuseArg) might point to a global input, output or local variable /// created during kernel expansion. #[cube] pub fn read( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, ) -> Vector { set_polyfill_typed::, DynElem, DynSize>(); match arg { FuseArg::Input(pos, _precision, layout) => { let global = inputs.tensors.index(pos); let vector_size = global.tensor.vector_size(); if comptime![!global.broadcasted && vector_size != config.width] { read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None) } else { read_input(inputs, locals, pos, ref_pos, layout, config, None) } } FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => { Vector::cast_from(outputs.variables.read(key)) } FuseArg::Output(pos, _precision, layout) => { read_output(inputs, outputs, locals, pos, ref_pos, layout, config) } FuseArg::BlockLocal { pos, ty } => match comptime![ty] { FuseType::F64 => Vector::cast_from(locals.l_f64.find(pos)), FuseType::F32 | FuseType::Flex32 => Vector::cast_from(locals.l_f32.find(pos)), FuseType::F16 => Vector::cast_from(locals.l_f16.find(pos)), FuseType::BF16 => Vector::cast_from(locals.l_bf16.find(pos)), FuseType::U64 => Vector::cast_from(locals.l_u64.find(pos)), FuseType::U32 => Vector::cast_from(locals.l_u32.find(pos)), FuseType::U16 => Vector::cast_from(locals.l_u16.find(pos)), FuseType::U8 => Vector::cast_from(locals.l_u8.find(pos)), FuseType::I64 => Vector::cast_from(locals.l_i64.find(pos)), FuseType::I32 => Vector::cast_from(locals.l_i32.find(pos)), FuseType::I16 => Vector::cast_from(locals.l_i16.find(pos)), FuseType::I8 => Vector::cast_from(locals.l_i8.find(pos)), }, FuseArg::Scalar(..) => { let scalar = read_scalar::(inputs, arg); Vector::new(scalar) } FuseArg::ScalarShape(_) => { let scalar = read_scalar_shape(inputs, arg); Vector::cast_from(scalar) } FuseArg::Literal(val, _precision) => Vector::new(from_const_int::(val)), FuseArg::InputReshaped { original, shape, broadcasted, } => match comptime![original.as_ref().clone()] { FuseArg::Input(pos, _precision, layout) => { let global = inputs.tensors.index(pos); let vector_size = global.tensor.vector_size(); if comptime![!broadcasted && vector_size != config.width] { read_input_aligned( inputs, locals, pos, ref_pos, layout, config, comptime![Some(Transform::Reshape(shape))], ) } else { read_input( inputs, locals, pos, ref_pos, layout, config, comptime![Some(Transform::Reshape(shape))], ) } } _ => comptime![panic!("Only input can be reshaped")], }, FuseArg::InputSwapDims { original, dims, broadcasted, } => match comptime![original.as_ref().clone()] { FuseArg::Input(pos, _precision, layout) => { let global = inputs.tensors.index(pos); let vector_size = global.tensor.vector_size(); if comptime![!broadcasted && vector_size != config.width] { read_input_aligned( inputs, locals, pos, ref_pos, layout, config, comptime![Some(Transform::SwapDims(dims.0, dims.1))], ) } else { read_input( inputs, locals, pos, ref_pos, layout, config, comptime![Some(Transform::SwapDims(dims.0, dims.1))], ) } } _ => comptime![panic!("Only input can be swapped dims")], }, } } /// Computes the offset for the current global tensor with a quantized layout. /// /// The offset can be used to fetch the correct data from the quantized tensor as if it was in a /// linear contiguous format. #[cube] fn index_offset_with_quant_layout( tensor: &GlobalTensor, locals: &LocalArgs, index: usize, #[comptime] rank: usize, #[comptime] scheme: QuantScheme, ) -> usize { let (start, end) = (0, rank - 1); let num_quants = scheme.num_quants(); let offset_ref = index * locals.ref_vector_size; let mut offset = 0; #[unroll] for i in start..end { let ogwl = offset_ref / locals.ref_strides[i]; offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i); } // Handle packed representation in last dim let ogwl = offset_ref / locals.ref_strides[end]; let shape_last = tensor.tensor.shape(end).div_ceil(num_quants); let stride_last = tensor.tensor.stride(end); offset += (ogwl.div_ceil(num_quants)) % shape_last * stride_last; offset / tensor.tensor.vector_size() } /// Reads a global quantized tensor at the given position. /// /// # Notes /// /// The values returned in the [Vector] are not dequantized. #[cube] pub fn read_quantized( inputs: &GlobalArgs, locals: &LocalArgs, ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, #[comptime] scheme: QuantScheme, ) -> Vector { match arg { FuseArg::Input(pos, _precision, _layout) => { set_polyfill_typed::, DynElem, DynSize>(); let global = inputs.tensors.index(pos); let offset = index_offset_with_quant_layout(global, locals, ref_pos, config.rank, scheme); let val = global.tensor[offset]; Vector::cast_from(val) } _ => panic!("Not supported"), } } /// Reads a global scalar. #[cube] pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> C { match arg { FuseArg::Scalar(pos, _precision) => { let scalar = inputs.scalars.index(pos); scalar.get::() } _ => comptime![panic!("Not a scalar")], } } /// Reads a global scalar that is used as a reshape position. #[cube] pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize { match arg { FuseArg::ScalarShape(pos) => inputs.reshapes[pos], _ => comptime![panic!("Not a scalar shape")], } } /// Reads an input tensor. #[cube] pub fn read_input( inputs: &GlobalArgs, locals: &LocalArgs, #[comptime] pos: usize, ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> Vector { set_polyfill_typed::, DynElem, DynSize>(); let tensor = inputs.tensors.index(pos); let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, transform), }; Vector::cast_from(tensor.tensor[offset]) } /// Returns a slice of data in the asked precision of the input tensor at the given position. #[cube] pub fn read_input_window( inputs: &GlobalArgs, #[comptime] pos: usize, start: usize, end: usize, ) -> Slice { set_polyfill_typed::(); let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.slice(start, end); slice.downcast() } /// Returns the input as a slice. #[cube] pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice { set_polyfill_typed::(); let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.to_slice(); slice.downcast() } /// Returns the input tensor as a quantized scale view. #[cube] pub fn input_as_scales_view( inputs: &GlobalArgs, #[comptime] pos: usize, #[comptime] tensor_pos: usize, #[comptime] level: QuantLevel, #[comptime] config: &FuseBlockConfig, ) -> View { set_polyfill_typed::, DynElem, DynSize>(); let tensor = inputs.tensors.index(tensor_pos); let scales = inputs.tensors.index(pos); let tensor_len = tensor.tensor.len(); let rank = config.rank; let layout = match level { QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)), QuantLevel::Block(block_size) => { let block_size = comptime![block_size.to_dim_vec(rank)]; let mut tensor_shape = Sequence::new(); let mut scales_strides = Sequence::new(); #[unroll] for i in 0..rank { tensor_shape.push(FastDivmod::new_Fallback(tensor.tensor.shape(i))); scales_strides.push(scales.tensor.stride(i)); } let vector_size = scales.tensor.vector_size(); let layout = BlockScaledLayout::new( tensor_shape, tensor_len, scales_strides, block_size, vector_size, ); ScalesLayout::new_BlockScaled(layout) } }; View::new::, usize>(&scales.tensor.to_slice().downcast(), layout) } /// Reads the input tensor aligned. #[cube] pub fn read_input_aligned( inputs: &GlobalArgs, locals: &LocalArgs, #[comptime] pos: usize, ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> Vector { let mut result = Vector::::empty(); let tensor = inputs.tensors.index(pos); match transform.clone() { Some(Transform::Reshape(shape)) => { // Very brute force, not really efficient, but not easy to optimize and not a very // frequent workflow. let ref_pos = ref_pos * config.width; #[unroll] for i in 0..config.width { let index = reshaped_index( inputs, locals, ref_pos + i, config.rank, comptime![shape.clone()], ); let index = reshaped_index_to_original_index(&tensor.tensor, index, config.rank); result[i] = C::cast_from(tensor.tensor[index][0]) } } Some(Transform::SwapDims(dim1, dim2)) => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))]; let stride = tensor.tensor.stride(i); #[unroll] for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } } None => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); let stride = tensor.tensor.stride(config.rank - 1); #[unroll] for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } } } result } /// Computes the offset of the given [GlobalTensor] at on the reference position with a linear /// layout. #[cube] pub fn get_offset_aligned( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> usize { match layout { LayoutInfo::SameAsRef | LayoutInfo::IsRef => { (ref_pos * locals.ref_vector_size) / tensor.tensor.vector_size() } LayoutInfo::Unknown => get_offset( inputs, locals, tensor, ref_pos, None, config, comptime!(transform.clone()), ), } } /// Reads an output tensor. #[cube] pub fn read_output( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] pos: usize, ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, ) -> Vector { let tensor = outputs.tensors.index(pos); let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, None), }; Vector::cast_from(tensor.tensor[offset]) } #[cube] /// Write the given value at the [arg](Arg) position. pub fn write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, ref_pos: usize, value: Vector, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, ) { set_polyfill_typed::, DynElem, DynSize>(); match arg { FuseArg::Output(pos, _, layout) => { let tensor = outputs.tensors.index(pos); let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, LayoutInfo::Unknown => { get_offset(inputs, locals, tensor, ref_pos, None, config, None) } }; let tensor = outputs.tensors.index_mut(pos); let value = Vector::cast_from(value); tensor.tensor[offset] = value; } FuseArg::BlockLocal { .. } => write_scalar::(locals, value, arg), FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => { outputs.variables.write(key, Vector::cast_from(value)) } _ => comptime![panic!("Can't write into inputs and scalars")], } } #[cube] /// Write the given value at the [arg](Arg) position. pub fn write_scalar( locals: &mut LocalArgs, value: Vector, #[comptime] arg: FuseArg, ) { match arg { FuseArg::BlockLocal { pos, ty } => match comptime![ty] { FuseType::F64 => locals.l_f64.insert(pos, Vector::cast_from(value)), FuseType::F32 | FuseType::Flex32 => locals.l_f32.insert(pos, Vector::cast_from(value)), FuseType::F16 => locals.l_f16.insert(pos, Vector::cast_from(value)), FuseType::BF16 => locals.l_bf16.insert(pos, Vector::cast_from(value)), FuseType::U64 => locals.l_u64.insert(pos, Vector::cast_from(value)), FuseType::U32 => locals.l_u32.insert(pos, Vector::cast_from(value)), FuseType::U16 => locals.l_u16.insert(pos, Vector::cast_from(value)), FuseType::U8 => locals.l_u8.insert(pos, Vector::cast_from(value)), FuseType::I64 => locals.l_i64.insert(pos, Vector::cast_from(value)), FuseType::I32 => locals.l_i32.insert(pos, Vector::cast_from(value)), FuseType::I16 => locals.l_i16.insert(pos, Vector::cast_from(value)), FuseType::I8 => locals.l_i8.insert(pos, Vector::cast_from(value)), }, _ => comptime![panic!("Can't write into something else than scalars")], } } #[cube] pub(crate) fn global_offset( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, index: usize, #[comptime] arg: FuseArg, #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, ) -> usize { match arg { FuseArg::Input(pos, _precision, _layout) => { let tensor = inputs.tensors.index(pos); get_offset(inputs, locals, tensor, index, range, config, None) } FuseArg::Output(pos, _precision, _layout) => { let tensor = outputs.tensors.index(pos); get_offset(inputs, locals, tensor, index, range, config, None) } _ => panic!("Only input and output tensors have global offset."), } } #[cube] fn get_offset( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, ref_pos: usize, #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> usize { index_offset_with_layout( inputs, tensor, locals, ref_pos, range, config.rank, transform, ) } #[cube] /// Gets the vector size for a global tensor. pub fn global_vector_size( global: &GlobalArgs, #[comptime] pos: usize, ) -> comptime_type!(VectorSize) { let tensor = global.tensors.index(pos); tensor.tensor.vector_size() } #[cube] /// Gets the rank for a global tensor. pub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.rank() } #[cube] /// Gets the length for a global tensor. pub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.len() } #[cube] /// Gets the buffer length for a global tensor. pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.buffer_len() } #[cube] /// Gets the reference tensor length. pub fn ref_len( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, ) -> usize { match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_len(inputs, index), FuseArg::Output(index, _, _) => global_len(outputs, index), _ => panic!("Invalid concrete ref layout."), }, RefLayout::Virtual(..) => num_elements(locals, config), } } #[cube] /// Gets the reference buffer tensor length. pub fn ref_buffer_len( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, ) -> usize { match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_buffer_len(inputs, index), FuseArg::Output(index, _, _) => global_buffer_len(outputs, index), _ => panic!("Invalid concrete ref layout."), }, RefLayout::Virtual(VirtualLayout::SwapDims(arg, ..)) => match arg { FuseArg::Input(index, _, _) => global_buffer_len(inputs, index), FuseArg::Output(index, _, _) => global_buffer_len(outputs, index), _ => panic!("Invalid concrete ref layout."), }, RefLayout::Virtual(VirtualLayout::Reshaped { .. }) => num_elements(locals, config), RefLayout::Virtual(VirtualLayout::Shape(..)) => num_elements(locals, config), RefLayout::Virtual(VirtualLayout::Runtime { .. }) => num_elements(locals, config), } } #[cube] /// Gets the reference number of elements. pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize { let mut length = 1; for i in 0..config.rank { length *= locals.ref_shape[i]; } length } #[cube] /// Gets the reference axis shape. pub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize { locals.ref_shape[axis] } #[cube] /// Gets the reference axis stride. pub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize { locals.ref_strides[axis] } #[cube] /// Gets the reference vector size. pub fn ref_vector_size(locals: &LocalArgs) -> comptime_type!(VectorSize) { comptime![locals.ref_vector_size] } #[cube] /// Gets the given tensor axis shape. pub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.shape(axis) } #[cube] /// Gets the given tensor axis stride. pub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.stride(dim) } #[cube] fn index_offset_with_layout( inputs: &GlobalArgs, tensor: &GlobalTensor, locals: &LocalArgs, index: usize, #[comptime] range: Option<(usize, usize)>, #[comptime] rank: usize, #[comptime] transform: Option, ) -> usize { match comptime![transform.clone()] { Some(Transform::Reshape(shape)) => { comptime![assert!( range.is_none(), "Can't get a range on a reshaped tensor." )]; let index = index * locals.ref_vector_size; let index = reshaped_index(inputs, locals, index, rank, shape); reshaped_index_to_original_index(&tensor.tensor, index, rank) } Some(Transform::SwapDims(dim1, dim2)) => { let (start, end) = comptime! {match range { Some(range) => range, None => (0, rank), }}; let offset_ref = index * locals.ref_vector_size; let mut offset = 0; #[unroll] for i in start..end { let index = comptime![swap_dims_transform(i, (dim1, dim2))]; let ogwl = offset_ref / locals.ref_strides[i]; offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index); } offset / tensor.tensor.vector_size() } None => { let (start, end) = comptime! {match range { Some(range) => range, None => (0, rank), }}; let offset_ref = index * locals.ref_vector_size; let mut offset = 0; #[unroll] for i in start..end { let ogwl = offset_ref / locals.ref_strides[i]; offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i); } offset / tensor.tensor.vector_size() } } } pub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize { if i == dims.0 { dims.1 } else if i == dims.1 { dims.0 } else { i } } #[cube] #[allow(clippy::clone_on_copy)] /// The index the input tensor would be at if it was contiguous. fn reshaped_index( inputs: &GlobalArgs, locals: &LocalArgs, index: usize, #[comptime] rank: usize, #[comptime] shape: Vec, ) -> usize { let mut offset = 0; let mut stride_curr = 1; #[unroll] for r in 0..rank { let i = reverse_index(rank, r).comptime(); let arg = shape[i].clone(); let shape_i = read_scalar_shape(inputs, arg); let ogwl = index / locals.ref_strides[i]; offset += ogwl % shape_i * stride_curr; stride_curr *= shape_i; } offset } #[allow(unreachable_code)] #[cube] #[allow(clippy::clone_on_copy)] fn reshaped_index_to_original_index( original: &Tensor>, index_reshaped: usize, #[comptime] rank: usize, ) -> usize { let mut remaining = index_reshaped; let mut offset = 0; #[unroll] for r in 0..rank { let i = reverse_index(rank, r); let shape = original.shape(i); let stride = original.stride(i); let coordinate = remaining % shape; remaining /= shape; offset += coordinate * stride; } offset / original.vector_size() } #[cube] #[allow(unused_variables)] pub(crate) fn reverse_index( #[comptime] rank: usize, #[comptime] iter: usize, ) -> comptime_type!(usize) { rank - iter - 1 } /// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. #[allow(unused_variables)] #[cube] fn from_const_int(#[comptime] value: usize) -> C { intrinsic!(|scope| { ManagedVariable::Plain(Variable::constant(value.into(), C::as_type(scope))).into() }) } #[cube] #[allow(clippy::extra_unused_type_parameters)] pub(crate) fn set_polyfill_typed() { intrinsic!(|scope| { let elem_type = C::as_type(scope); set_polyfill::expand::(scope, elem_type); }) } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/ir.rs ================================================ use super::tensor::GlobalTensor; use crate::engine::codegen::{DynElem, DynSize}; use burn_std::{ BoolStore, DType, Shape, Strides, bf16, f16, quantization::{QuantScheme, QuantStore, QuantValue}, strides, }; use core::fmt::Display; use cubecl::{ ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind}, prelude::*, }; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to a [fuse operation](FuseOp). pub enum FuseArg { /// A readonly input tensor. Input(usize, FuseType, LayoutInfo), /// A readwrite output tensor. Output(usize, FuseType, LayoutInfo), /// A temporary local variable within a single [block](FuseBlockConfig). BlockLocal { /// The position of the current variable relative to all local variables within a single block. pos: usize, /// The type of the current variable. ty: FuseType, }, /// A variable shared between multiple [block](FuseBlockConfig) that must have a compatible /// scope. MultiBlockLocal(MultiBlockPos, FuseType), /// A variable shared between multiple [blocks](FuseBlockConfig) within a global accessible /// scope. MultiBlockGlobal(MultiBlockPos, FuseType), /// A global scalar. Scalar(usize, FuseType), /// A global scalar used in a reshape operation. /// /// This is not a scalar defined by a user for computation, but a scalar defined as part of /// a reshape operation. ScalarShape(usize), /// Only constant that can be encoded into an u32 can be used as literal. Literal(usize, FuseType), /// A readonly input tensor that is reshaped. InputReshaped { original: Box, shape: Vec, broadcasted: bool, }, /// A readonly input tensor with swapped dimensions. InputSwapDims { original: Box, dims: (usize, usize), broadcasted: bool, }, } /// Metadata of a variable shared between blocks. #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct MultiBlockPos { /// The block position in all blocks included in a fused trace. pub block_pos: usize, /// The [FuseArg::BlockLocal] position in the block where the variable is first initialized. pub block_local_pos: usize, } #[derive( CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, )] /// Layout information. pub enum LayoutInfo { /// The layout if the same as the reference. SameAsRef, /// The reference layout. IsRef, /// The layout if unknown. Unknown, } impl FuseArg { pub fn precision(&self) -> FuseType { *match self { FuseArg::Input(_, p, _) => p, FuseArg::BlockLocal { ty, .. } => ty, FuseArg::MultiBlockLocal(_, p) => p, FuseArg::MultiBlockGlobal(_, p) => p, FuseArg::Output(_, p, _) => p, FuseArg::Scalar(_, p) => p, FuseArg::Literal(_, p) => p, FuseArg::ScalarShape(_) => return FuseType::U32, FuseArg::InputReshaped { original, .. } => return original.precision(), FuseArg::InputSwapDims { original, .. } => return original.precision(), } } } impl CubeType for FuseArg { type ExpandType = Self; } impl IntoMut for FuseArg { fn into_mut(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for FuseArg { fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } impl CubeDebug for FuseArg {} #[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Operations that can be executed and fused automatically using a fuse-on-read and/or /// fuse-on-write strategy. pub enum FuseOp { Add(BinaryFuseArgs), Sub(BinaryFuseArgs), Mul(BinaryFuseArgs), Div(BinaryFuseArgs), Powf(BinaryFuseArgs), Abs(UnaryFuseArgs), Exp(UnaryFuseArgs), Log(UnaryFuseArgs), Log1p(UnaryFuseArgs), Cos(UnaryFuseArgs), Sin(UnaryFuseArgs), Tanh(UnaryFuseArgs), Erf(UnaryFuseArgs), Sqrt(UnaryFuseArgs), Recip(UnaryFuseArgs), Assign(UnaryFuseArgs), Equal(BinaryFuseArgs), Lower(BinaryFuseArgs), Greater(BinaryFuseArgs), LowerEqual(BinaryFuseArgs), Rem(BinaryFuseArgs), GreaterEqual(BinaryFuseArgs), Clamp { input: FuseArg, min: FuseArg, max: FuseArg, out: FuseArg, }, ConditionalAssign { cond: FuseArg, lhs: FuseArg, rhs: FuseArg, out: FuseArg, }, Gather { input: FuseArg, indices: FuseArg, output: FuseArg, dim: usize, }, Select { input: FuseArg, indices: FuseArg, output: FuseArg, dim: usize, }, Dequantize { values: FuseArg, params: FuseArg, output: FuseArg, scheme: QuantSchemeFuse, }, } impl Display for FuseOp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { FuseOp::Add(args) => write!(f, "{} = {} + {}", args.out, args.lhs, args.rhs), FuseOp::Sub(args) => write!(f, "{} = {} - {}", args.out, args.lhs, args.rhs), FuseOp::Mul(args) => write!(f, "{} = {} * {}", args.out, args.lhs, args.rhs), FuseOp::Div(args) => write!(f, "{} = {} / {}", args.out, args.lhs, args.rhs), FuseOp::Powf(args) => write!(f, "{} = powf({}, {})", args.out, args.lhs, args.rhs), FuseOp::Abs(args) => write!(f, "{} = abs({})", args.out, args.input), FuseOp::Exp(args) => write!(f, "{} = exp({})", args.out, args.input), FuseOp::Log(args) => write!(f, "{} = log({})", args.out, args.input), FuseOp::Log1p(args) => write!(f, "{} = log1p({})", args.out, args.input), FuseOp::Cos(args) => write!(f, "{} = cos({})", args.out, args.input), FuseOp::Sin(args) => write!(f, "{} = sin({})", args.out, args.input), FuseOp::Tanh(args) => write!(f, "{} = tanh({})", args.out, args.input), FuseOp::Erf(args) => write!(f, "{} = erf({})", args.out, args.input), FuseOp::Sqrt(args) => write!(f, "{} = sqrt({})", args.out, args.input), FuseOp::Recip(args) => write!(f, "{} = recip({})", args.out, args.input), FuseOp::Assign(args) => write!(f, "{} = {}", args.out, args.input), FuseOp::Equal(args) => write!(f, "{} = {} == {}", args.out, args.lhs, args.rhs), FuseOp::Lower(args) => write!(f, "{} = {} < {}", args.out, args.lhs, args.rhs), FuseOp::Greater(args) => write!(f, "{} = {} > {}", args.out, args.lhs, args.rhs), FuseOp::LowerEqual(args) => write!(f, "{} = {} <= {}", args.out, args.lhs, args.rhs), FuseOp::Rem(args) => write!(f, "{} = {} % {}", args.out, args.lhs, args.rhs), FuseOp::GreaterEqual(args) => write!(f, "{} = {} >= {}", args.out, args.lhs, args.rhs), FuseOp::Clamp { input, min, max, out, } => write!(f, "{} = clamp({}, min={}, max={})", out, input, min, max), FuseOp::ConditionalAssign { cond, lhs, rhs, out, } => write!( f, "{} = select(cond={}, lhs={}, rhs={})", out, cond, lhs, rhs ), FuseOp::Gather { input, indices, output, dim, } => write!( f, "{} = gather(input={}, indices={}, dim={})", output, input, indices, dim ), FuseOp::Select { input, indices, output, dim, } => write!( f, "{} = select(input={}, indices={}, dim={})", output, input, indices, dim ), FuseOp::Dequantize { values, params, output, scheme: _, } => write!( f, "{} = dequantize(values={}, params={})", output, values, params ), } } } #[derive( CubeType, CubeLaunch, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, )] pub struct QuantSchemeFuse { #[cube(comptime)] pub(crate) scheme: QuantScheme, } impl FuseOp { /// Element type used for the computation. pub(crate) fn cmp_elem(&self) -> ElemType { match self { FuseOp::Add(op) => op.lhs.precision().into_elem(), FuseOp::Sub(op) => op.lhs.precision().into_elem(), FuseOp::Mul(op) => op.lhs.precision().into_elem(), FuseOp::Div(op) => op.lhs.precision().into_elem(), FuseOp::Powf(op) => op.lhs.precision().into_elem(), FuseOp::Abs(op) => op.out.precision().into_elem(), FuseOp::Exp(op) => op.out.precision().into_elem(), FuseOp::Log(op) => op.out.precision().into_elem(), FuseOp::Log1p(op) => op.out.precision().into_elem(), FuseOp::Cos(op) => op.out.precision().into_elem(), FuseOp::Sin(op) => op.out.precision().into_elem(), FuseOp::Tanh(op) => op.out.precision().into_elem(), FuseOp::Erf(op) => op.out.precision().into_elem(), FuseOp::Recip(op) => op.out.precision().into_elem(), FuseOp::Sqrt(op) => op.out.precision().into_elem(), FuseOp::Assign(op) => op.out.precision().into_elem(), FuseOp::Equal(op) => op.lhs.precision().into_elem(), FuseOp::Lower(op) => op.lhs.precision().into_elem(), FuseOp::Greater(op) => op.lhs.precision().into_elem(), FuseOp::LowerEqual(op) => op.lhs.precision().into_elem(), FuseOp::GreaterEqual(op) => op.lhs.precision().into_elem(), FuseOp::ConditionalAssign { out, .. } => out.precision().into_elem(), FuseOp::Gather { output, .. } => output.precision().into_elem(), FuseOp::Select { output, .. } => output.precision().into_elem(), FuseOp::Dequantize { output, .. } => output.precision().into_elem(), FuseOp::Rem(op) => op.out.precision().into_elem(), FuseOp::Clamp { out, .. } => out.precision().into_elem(), } } pub(crate) fn cmp_storage_ty(&self) -> StorageType { self.cmp_elem().into() } } #[derive(CubeType, CubeLaunch, Default, Clone)] /// Global arguments that are used for fusing [element wise operations](ElemTypewiseOp). pub struct GlobalArgs { /// Tensors that are stored in global memory. pub tensors: Sequence, /// Scalars that are stored in global memory. pub scalars: Sequence, /// To be used to perform reshape inside a fused kernel. pub reshapes: Sequence, /// When there are no metadata as a reference layout, we provide runtime shape/strides in this /// sequence instead. pub runtime_layouts: Sequence, /// Variables shared between blocks. pub variables: MultiBlockVariables, } impl GlobalArgsLaunch { pub fn required_address_type(&self) -> AddressType { self.tensors .values .iter() .map(|it| it.address_type) .max() .unwrap_or_default() } } /// Variables shared between blocks. #[derive(CubeType, Default, Clone)] pub struct MultiBlockVariables { variables: Registry>>>, } #[cube] impl MultiBlockVariables { /// Initializes the variable with the given key and vector size. /// /// # Notes /// /// The type of [`NumericExpand`] must be set before calling this function. pub fn init(&mut self, #[comptime] key: MultiBlockPos) { let mut registers = Registry::< usize, Registry>>, >::find_or_default::(&mut self.variables, key.block_pos); let cell = RuntimeCell::new(Vector::empty()); registers.insert(key.block_local_pos, cell); } /// Read the variable using the provided key. /// /// # Notes /// /// The variable must be initialized. pub fn read(&self, #[comptime] key: MultiBlockPos) -> Vector { let registers = self.variables.find(key.block_pos); let cell = registers.find(key.block_local_pos); cell.read() } /// Write to the variable using the provided key and value. /// /// # Notes /// /// The variable must be initialized. pub fn write(&mut self, #[comptime] key: MultiBlockPos, value: Vector) { let registers = self.variables.find(key.block_pos); // Try find for local(visibility) registers. let cell = registers.find(key.block_local_pos); cell.store(value); } } // Because we only create it DURING compilation, not as a real launch arg. unsafe impl Send for MultiBlockVariables {} unsafe impl Sync for MultiBlockVariables {} impl LaunchArg for MultiBlockVariables { type RuntimeArg = (); type CompilationArg = (); fn compilation_arg(_runtime_arg: &Self::RuntimeArg) -> Self::CompilationArg {} fn register(_arg: Self::RuntimeArg, _launcher: &mut KernelLauncher) {} fn expand( _arg: &Self::CompilationArg, _builder: &mut KernelBuilder, ) -> ::ExpandType { MultiBlockVariablesExpand { variables: Default::default(), } } } impl Default for GlobalArgsLaunch { fn default() -> Self { Self { tensors: Default::default(), scalars: Default::default(), reshapes: Default::default(), variables: Default::default(), runtime_layouts: Default::default(), _phantom_runtime: std::marker::PhantomData, } } } impl core::fmt::Debug for GlobalArgsLaunch { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "({:?})", self.tensors.values) } } impl GlobalArgsLaunch { /// Get the shape of the given [argument](Arg). /// /// # Panics /// /// If the argument doesn't have an handle. pub fn shape(&self, arg: &FuseArg) -> Shape { match self.resolve_arg(arg) { TensorArg::Handle { handle, .. } => handle.shape.clone(), TensorArg::Alias { .. } => panic!("Unsupported yet"), } } /// Shape used by the reference tensor. pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape { match ref_layout { RefLayout::Concrete(arg) => self.shape(arg), RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { let mut shape = self.shape(original); shape.swap(dims.0, dims.1); shape } VirtualLayout::Reshaped { reshape_pos, .. } => { let start = *reshape_pos * rank; let end = start + rank; self.reshapes.values[start..end].iter().copied().collect() } VirtualLayout::Shape(original, _) => self.shape(original), VirtualLayout::Runtime { pos } => { let start = (*pos * 2) * rank; let end = start + rank; self.runtime_layouts.values[start..end] .iter() .copied() .collect() } }, } } /// Get the strides of the given [argument](Arg). /// /// # Panics /// /// If the argument doesn't have an handle. pub fn strides(&self, arg: &FuseArg) -> Strides { match self.resolve_arg(arg) { TensorArg::Handle { handle, .. } => handle.strides.clone(), TensorArg::Alias { .. } => panic!("Unsupported yet"), } } pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides { match ref_layout { RefLayout::Concrete(arg) => self.strides(arg), // When not concrete, we operate on the contiguous layout. _ => { let shape = self.shape_ref(ref_layout, rank); let mut strides = strides![0; shape.len()]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { strides[index] = current; current *= val; }); strides } } } /// Get the vector size of the given [argument](Arg). /// /// # Panics /// /// If the argument doesn't have an handle. pub fn vector_size(&self, arg: &FuseArg) -> VectorSize { match arg { FuseArg::Input(pos, _, _) => self.tensors.values[*pos].ty.vector_size(), FuseArg::Output(pos, _, _) => self.tensors.values[*pos].ty.vector_size(), other => panic!("Arg not found: {other:?}"), } } /// Resolve the [argument](Arg) to a [tensor argument](TensorArg). /// /// # Panics /// /// If the argument isn't a global input or output tensor. pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg { match arg { FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor, FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor, other => panic!("Arg not found: {other:?}"), } } } #[derive(CubeType, Clone)] /// Keep track of all local variables that are used as argument in fused /// [element wise operations](ElemwiseOp). pub struct LocalArgs { pub l_f64: Registry>, pub l_f32: Registry>, pub l_f16: Registry>, pub l_bf16: Registry>, pub l_i64: Registry>, pub l_i32: Registry>, pub l_i16: Registry>, pub l_i8: Registry>, pub l_u64: Registry>, pub l_u32: Registry>, pub l_u16: Registry>, pub l_u8: Registry>, pub ref_shape: Slice, pub ref_strides: Slice, #[cube(comptime)] pub ref_vector_size: VectorSize, } #[cube] impl LocalArgs { /// Creates a new [LocalArgs] container. pub fn new( ref_shape: Slice, ref_strides: Slice, #[comptime] ref_vector_size: VectorSize, ) -> LocalArgs { LocalArgs { l_f64: Registry::>::new(), l_f32: Registry::>::new(), l_f16: Registry::>::new(), l_bf16: Registry::>::new(), l_i64: Registry::>::new(), l_i32: Registry::>::new(), l_i16: Registry::>::new(), l_i8: Registry::>::new(), l_u64: Registry::>::new(), l_u32: Registry::>::new(), l_u16: Registry::>::new(), l_u8: Registry::>::new(), ref_shape, ref_strides, ref_vector_size, } } } #[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Unary [element wise operation](ElemwiseOp) arguments. pub struct UnaryFuseArgs { pub input: FuseArg, pub out: FuseArg, } #[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Binary [element wise operation](ElemwiseOp) arguments. pub struct BinaryFuseArgs { pub lhs: FuseArg, pub rhs: FuseArg, pub out: FuseArg, } #[derive( CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, )] /// Precisions supported by [element wise operations](ElemwiseOp). /// /// This is a custom type instead of [ElemType] so it can implement [CubeType] /// and restricts the supported types for fusion. pub enum FuseType { F64, F32, Flex32, F16, BF16, I64, I32, I16, I8, U64, U32, U16, U8, } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Configuration that encapsulates all comptime information necessary for element wise fusion. pub struct FuseBlockConfig { pub rank: usize, pub ref_layout: RefLayout, pub ops: Vec, pub width: VectorSize, } impl FuseBlockConfig { pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) { for op in self.ops.iter() { op.multi_block_variables(registers); } } } impl FuseArg { pub fn multi_block_variable(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) { match self { FuseArg::MultiBlockGlobal(arg, fuse_type) // TODO: we need to init the multi-block local, but at some point we could avoid // that for performance (easier for the underlying compiler). | FuseArg::MultiBlockLocal(arg, fuse_type) => { registers.push((arg.clone(), fuse_type.into_storage_type())) } _ => {} }; } } impl FuseOp { pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) { match self { FuseOp::Add(binary_fuse_args) | FuseOp::Sub(binary_fuse_args) | FuseOp::Mul(binary_fuse_args) | FuseOp::Div(binary_fuse_args) | FuseOp::Powf(binary_fuse_args) | FuseOp::Equal(binary_fuse_args) | FuseOp::Lower(binary_fuse_args) | FuseOp::Greater(binary_fuse_args) | FuseOp::LowerEqual(binary_fuse_args) | FuseOp::Rem(binary_fuse_args) | FuseOp::GreaterEqual(binary_fuse_args) => { binary_fuse_args.lhs.multi_block_variable(registers); binary_fuse_args.rhs.multi_block_variable(registers); binary_fuse_args.out.multi_block_variable(registers); } FuseOp::Abs(unary_fuse_args) | FuseOp::Exp(unary_fuse_args) | FuseOp::Log(unary_fuse_args) | FuseOp::Log1p(unary_fuse_args) | FuseOp::Cos(unary_fuse_args) | FuseOp::Sin(unary_fuse_args) | FuseOp::Tanh(unary_fuse_args) | FuseOp::Erf(unary_fuse_args) | FuseOp::Sqrt(unary_fuse_args) | FuseOp::Recip(unary_fuse_args) | FuseOp::Assign(unary_fuse_args) => { unary_fuse_args.input.multi_block_variable(registers); unary_fuse_args.out.multi_block_variable(registers); } FuseOp::Clamp { input, min, max, out, } => { input.multi_block_variable(registers); min.multi_block_variable(registers); max.multi_block_variable(registers); out.multi_block_variable(registers); } FuseOp::ConditionalAssign { cond, lhs, rhs, out, } => { cond.multi_block_variable(registers); lhs.multi_block_variable(registers); rhs.multi_block_variable(registers); out.multi_block_variable(registers); } FuseOp::Gather { input, indices, output, dim: _, } => { input.multi_block_variable(registers); indices.multi_block_variable(registers); output.multi_block_variable(registers); } FuseOp::Select { input, indices, output, dim: _, } => { input.multi_block_variable(registers); indices.multi_block_variable(registers); output.multi_block_variable(registers); } FuseOp::Dequantize { values, params, output, scheme: _, } => { values.multi_block_variable(registers); params.multi_block_variable(registers); output.multi_block_variable(registers); } } } } #[cube] /// Initializes block variables, both globals and locals. pub fn multi_block_variables_init( #[comptime] block: &FuseBlockConfig, variables: &mut MultiBlockVariables, ) { let output = comptime! { let mut output = Vec::<(MultiBlockPos, StorageType)>::new(); block.multi_block_variables(&mut output); output }; #[unroll] for i in 0..comptime!(output.len()) { let (key, dtype) = comptime!(output.get(i).unwrap().clone()); set_polyfill::(comptime![Type::new(dtype).with_vector_size(block.width)]); variables.init(key); } } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// A reference layout determines how a fuse execution will access elements in tensors. /// /// It can either follow the same layout as a concrete tensor, or follow a virtual layout. pub enum RefLayout { Concrete(FuseArg), Virtual(VirtualLayout), } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// A virtual layout is always contiguous and retrieves its shape from either a reshaped tensor or a /// tensor with swap dimensions. pub enum VirtualLayout { /// Virtual tensor with the provided shape id and contiguous strides. Reshaped { reshape_pos: usize, vector_size: VectorSize, }, /// Virtual tensor with the same shape as the given input, but with swap dims and contiguous /// strides. SwapDims(FuseArg, (usize, usize)), /// Virtual tensor with the same shape as the given input, but with contiguous strides. Shape(FuseArg, usize), /// We don't have access to global metadata, they are passed as runtime values. Runtime { pos: usize }, } impl FuseArg { /// Adds layout information. /// /// It's going to impact how the input or output is read and written to. pub fn add_layout_info(&mut self, layout: LayoutInfo) { match self { FuseArg::Input(_, _, old) => { *old = layout; } FuseArg::Output(_, _, old) => { *old = layout; } _ => {} } } } impl RegistryQuery for FuseArg {} impl From for FuseType { fn from(value: ElemType) -> Self { match value { ElemType::Float(kind) => match kind { FloatKind::F16 => Self::F16, FloatKind::BF16 => Self::BF16, FloatKind::F32 => Self::F32, FloatKind::Flex32 => Self::Flex32, _ => panic!("Unsupported precision for fusion: {value}"), }, ElemType::Int(kind) => match kind { IntKind::I64 => Self::I64, IntKind::I32 => Self::I32, IntKind::I16 => Self::I16, IntKind::I8 => Self::I8, }, ElemType::UInt(kind) => match kind { UIntKind::U64 => Self::U64, UIntKind::U32 => Self::U32, UIntKind::U16 => Self::U16, UIntKind::U8 => Self::U8, }, ElemType::Bool => panic!("Bool should be encoded as u8 or u32"), } } } impl From for FuseType { fn from(value: StorageType) -> Self { value.elem_type().into() } } impl FuseType { /// Converts the [fused element type](FuseType) into the [cubecl element type](ElemType). pub fn into_elem(self) -> ElemType { match self { FuseType::F32 => ElemType::Float(FloatKind::F32), FuseType::Flex32 => ElemType::Float(FloatKind::Flex32), FuseType::F16 => ElemType::Float(FloatKind::F16), FuseType::BF16 => ElemType::Float(FloatKind::BF16), FuseType::I64 => ElemType::Int(IntKind::I64), FuseType::I32 => ElemType::Int(IntKind::I32), FuseType::I16 => ElemType::Int(IntKind::I16), FuseType::I8 => ElemType::Int(IntKind::I8), FuseType::U64 => ElemType::UInt(UIntKind::U64), FuseType::U32 => ElemType::UInt(UIntKind::U32), FuseType::U16 => ElemType::UInt(UIntKind::U16), FuseType::U8 => ElemType::UInt(UIntKind::U8), FuseType::F64 => ElemType::Float(FloatKind::F64), } } /// Convert the [fused element type](FuseType) into the [cubecl storage type](StorageType). pub fn into_storage_type(self) -> StorageType { self.into_elem().into() } /// Convert the [fused element type](FuseType) into the [cubecl type](Type) pub fn into_type(self, vector_size: VectorSize) -> Type { Type::new(self.into_storage_type()).with_vector_size(vector_size) } } impl From for FuseType { fn from(value: DType) -> Self { match value { DType::F32 => Self::F32, DType::Flex32 => Self::Flex32, DType::F16 => Self::F16, DType::BF16 => Self::BF16, DType::I64 => Self::I64, DType::I32 => Self::I32, DType::I16 => Self::I16, DType::I8 => Self::I8, DType::U64 => Self::U64, DType::U32 => Self::U32, DType::U16 => Self::U16, DType::U8 => Self::U8, DType::Bool(BoolStore::Native) => unimplemented!("Bool should be U8 or U32"), DType::Bool(BoolStore::U8) => Self::U8, DType::Bool(BoolStore::U32) => Self::U32, DType::F64 => Self::F64, DType::QFloat(scheme) => match scheme.store { QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S => Self::I8, QuantValue::E4M3 | QuantValue::E5M2 => { unimplemented!("Unsupported precision for fusion") } QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 => { panic!("Can't store native sub-byte values") } }, QuantStore::PackedU32(_) => Self::U32, QuantStore::PackedNative(_) => match scheme.value { QuantValue::E2M1 => unimplemented!("Unsupported precision for fusion"), other => panic!("{other:?} doesn't support native packing"), }, }, } } } impl Display for FuseArg { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { FuseArg::Input(pos, ..) => write!(f, "input({pos})"), FuseArg::Output(pos, ..) => write!(f, "output({pos})"), FuseArg::BlockLocal { pos, ty } => write!(f, "local({pos}, {ty:?})"), FuseArg::MultiBlockLocal(mbp, ..) => write!(f, "{mbp}"), FuseArg::MultiBlockGlobal(mbp, ..) => write!(f, "global_{mbp}"), FuseArg::Scalar(pos, ..) => write!(f, "scalar({pos})"), FuseArg::ScalarShape(pos) => write!(f, "scalar_shape({pos})"), FuseArg::Literal(val, ..) => write!(f, "literal_{val}"), FuseArg::InputReshaped { original, .. } => write!(f, "input_reshaped_{original}"), FuseArg::InputSwapDims { original, .. } => write!(f, "input_swap_dims_{original}"), } } } impl Display for MultiBlockPos { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "block_local({}-{})", self.block_pos, self.block_local_pos ) } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs ================================================ use super::{io::*, ir::*}; use burn_std::quantization::{QuantScheme, QuantStore, QuantValue}; use cubecl::{ ir::{ElemType, FloatKind, StorageType, UIntKind}, prelude::*, }; use cubek::quantization::{dequantize::dequantize_symmetric_packed_value_at, scheme::QuantMode}; #[cube] /// Fuse element-wise operations at the given write position. /// /// # Arguments /// /// - `inputs`: Contains all readonly global kernel arguments. /// - `outputs`: Contains all readwrite global kernel arguments. /// - `locals`: Contains all local variables defined during kernel expansion. /// - `write_pos`: The logical position the values are written to. /// - `write_values`: The explicit values to write at the given position. /// - `write_args`: The arguments associated to the `writes_values`. /// - `config`: The current [fuse block configuration](FuseBlockConfig). /// /// # Notes /// /// The function will start by writing `write_values`. pub fn fuse_on_write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, write_values: Registry>, #[comptime] write_args: Vec, #[comptime] config: &FuseBlockConfig, ) { comment!("Fuse on write begin"); // Write the values given as arguments. #[unroll] for i in 0..write_args.len() { let arg = comptime![write_args.get(i).unwrap().clone()]; let val = write_values.find(arg.clone()); write::(inputs, outputs, locals, write_pos, val, arg, config); } fuse(inputs, outputs, locals, write_pos, config); comment!("Fuse on write end"); } #[cube] /// Fuse element-wise operations at the given read position. /// /// # Arguments /// /// - `inputs`: Contains all readonly global kernel arguments. /// - `outputs`: Contains all readwrite global kernel arguments. /// - `locals`: Contains all local variables defined during kernel expansion. /// - `read_pos`: The logical position the values are read from. /// - `read_args`: The arguments associated to the `read_pos`. /// - `config`: The current [fuse block configuration](FuseBlockConfig). /// /// # Returns /// /// - A sequence of values associated to the given `read_args`. pub fn fuse_on_read( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, read_pos: usize, #[comptime] read_args: Sequence, #[comptime] config: &FuseBlockConfig, ) -> Sequence> { comment!("Fuse on read begin"); fuse(inputs, outputs, locals, read_pos, config); let mut output = Sequence::new(); #[unroll] for i in 0..read_args.len() { let arg = comptime![read_args.index(i).clone()]; let value = read::(inputs, outputs, locals, read_pos, arg, config); output.push(value); } comment!("Fuse on read end"); output } #[cube] /// Initializes [LocalArgs] given the input and output [arguments](GlobalArgs) with the [FuseBlockConfig]. /// /// # Notes /// /// The goal is to resolve and cache the reference shape and strides, as it is used in many /// different function during kernel expansion. pub fn init_locals( inputs: &GlobalArgs, outputs: &mut GlobalArgs, #[comptime] config: &FuseBlockConfig, ) -> LocalArgs { comment!("Init locals begin"); let mut ref_shape = Array::new(config.rank); let mut ref_strides = Array::new(config.rank); let locals = match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, ..) => { let layout = inputs.tensors.index(index); #[unroll] for i in 0..config.rank { ref_shape[i] = layout.tensor.shape(i); ref_strides[i] = layout.tensor.stride(i); } LocalArgs::new( ref_shape.to_slice(), ref_strides.to_slice(), layout.tensor.vector_size(), ) } FuseArg::Output(index, ..) => { let layout = outputs.tensors.index(index); #[unroll] for i in 0..config.rank { ref_shape[i] = layout.tensor.shape(i); ref_strides[i] = layout.tensor.stride(i); } LocalArgs::new( ref_shape.to_slice(), ref_strides.to_slice(), layout.tensor.vector_size(), ) } _ => comptime![panic!("Invalid concrete ref layout.")], }, RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { let reverse = reverse_index(config.rank, i); let swap = comptime![swap_dims_transform(reverse, dims)]; let shape = layout.tensor.shape(swap.clone()); ref_shape[reverse] = shape; ref_strides[reverse] = stride_curr; stride_curr *= ref_shape[comptime![reverse]]; } LocalArgs::new( ref_shape.to_slice(), ref_strides.to_slice(), layout.tensor.vector_size(), ) } VirtualLayout::Reshaped { reshape_pos, vector_size, } => { let mut stride_curr = 1; let start = reshape_pos * config.rank; #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { let reverse = reverse_index(config.rank, i); let arg = comptime![FuseArg::ScalarShape(start + reverse)]; let shape = read_scalar_shape(inputs, arg.clone()); ref_shape[comptime![reverse]] = shape; ref_strides[comptime![reverse]] = stride_curr; stride_curr *= ref_shape[comptime![reverse]]; } LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), vector_size) } VirtualLayout::Runtime { pos } => { let start_shape = (pos * 2) * config.rank; let start_strides = start_shape + config.rank; #[unroll] for i in 0..config.rank { let shape_index = start_shape + i; let strides_index = start_strides + i; ref_shape[i] = *inputs.runtime_layouts.index(shape_index); ref_strides[i] = *inputs.runtime_layouts.index(strides_index); } LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), config.width) } VirtualLayout::Shape(original, vector_size) => { let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { let reverse = reverse_index(config.rank, i); let shape = layout.tensor.shape(reverse); ref_shape[comptime![reverse]] = shape; ref_strides[comptime![reverse]] = stride_curr; stride_curr *= ref_shape[comptime![reverse]]; } LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), vector_size) } }, }; comment!("Init locals end"); locals } #[cube] /// Expands all [operations](FuseOp) registered in the [block config](FuseBlockConfig]. fn fuse( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, pos: usize, #[comptime] config: &FuseBlockConfig, ) { #[unroll] for index in 0..config.ops.len() { let op = config.ops[index].clone(); let define!(E) = op.cmp_storage_ty(); let size!(N) = config.width; match op { FuseOp::Add(op) => add::(inputs, outputs, locals, pos, op, config), FuseOp::Div(op) => div::(inputs, outputs, locals, pos, op, config), FuseOp::Sub(op) => sub::(inputs, outputs, locals, pos, op, config), FuseOp::Mul(op) => mul::(inputs, outputs, locals, pos, op, config), FuseOp::Powf(op) => powf::(inputs, outputs, locals, pos, op, config), FuseOp::Erf(op) => erf::(inputs, outputs, locals, pos, op, config), FuseOp::Sqrt(op) => sqrt::(inputs, outputs, locals, pos, op, config), FuseOp::Abs(op) => abs::(inputs, outputs, locals, pos, op, config), FuseOp::Log(op) => log::(inputs, outputs, locals, pos, op, config), FuseOp::Log1p(op) => log1p::(inputs, outputs, locals, pos, op, config), FuseOp::Recip(op) => recip::(inputs, outputs, locals, pos, op, config), FuseOp::Assign(op) => assign::(inputs, outputs, locals, pos, op, config), FuseOp::Exp(op) => exp::(inputs, outputs, locals, pos, op, config), FuseOp::Cos(op) => cos::(inputs, outputs, locals, pos, op, config), FuseOp::Sin(op) => sin::(inputs, outputs, locals, pos, op, config), FuseOp::Tanh(op) => tanh::(inputs, outputs, locals, pos, op, config), FuseOp::Equal(op) => equal::(inputs, outputs, locals, pos, op, config), FuseOp::Greater(op) => greater::(inputs, outputs, locals, pos, op, config), FuseOp::GreaterEqual(op) => { greater_equal::(inputs, outputs, locals, pos, op, config) } FuseOp::Lower(op) => lower::(inputs, outputs, locals, pos, op, config), FuseOp::LowerEqual(op) => lower_equal::(inputs, outputs, locals, pos, op, config), FuseOp::ConditionalAssign { cond, lhs, rhs, out, } => conditional_assign::( inputs, outputs, locals, pos, cond, lhs, rhs, out, config, ), FuseOp::Gather { input, indices, output, dim, } => gather::( inputs, outputs, locals, pos, dim, input, indices, output, config, ), FuseOp::Select { input, indices, output, dim, } => select_indices::( inputs, outputs, locals, pos, dim, input, indices, output, config, ), FuseOp::Dequantize { values, params, output, scheme, } => dequantize::( inputs, outputs, locals, pos, values, params, output, scheme.scheme, config, ), FuseOp::Rem(op) => rem::(inputs, outputs, locals, pos, op, config), FuseOp::Clamp { input, min, max, out, } => clamp::(inputs, outputs, locals, pos, input, min, max, out, config), } } } macro_rules! binary_op { ($ident:ident, $op:tt) => { #[cube] fn $ident( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); let result = lhs $op rhs; write::(inputs, outputs, locals, write_pos, result, op.out, config); } }; } macro_rules! binary_func { ($ident:ident, $func:expr, $c:tt) => { #[cube] fn $ident( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); let result = $func(lhs, rhs); write::(inputs, outputs, locals, write_pos, result, op.out, config); } }; } macro_rules! comparison_op { ($ident:ident, $op:tt) => { #[cube] fn $ident( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { let lhs = read::(inputs, outputs, &locals, write_pos, op.lhs, config); let rhs = read::(inputs, outputs, &locals, write_pos, op.rhs, config); let result = Vector::new(lhs $op rhs); write::(inputs, outputs, locals, write_pos, result, op.out, config); } }; } macro_rules! unary_func { ($ident:ident, $func:expr, $c:tt) => { #[cube] fn $ident( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { let input = read::(inputs, outputs, &locals, write_pos, op.input, config); let result = $func(input); write::(inputs, outputs, locals, write_pos, result, op.out, config); } }; } #[cube] fn assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { let input = read::(inputs, outputs, locals, write_pos, op.input, config); write::(inputs, outputs, locals, write_pos, input, op.out, config); } #[cube] fn gather( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, #[comptime] config: &FuseBlockConfig, ) { let vector_size = locals.ref_vector_size; let pos_input = comptime! { match input { FuseArg::Input(pos, ..) => pos, _ => panic!("Input tensor isn't an input"), } }; let pos_indices = comptime! { match indices { FuseArg::Input(pos, ..) => pos, _ => panic!("Indices tensor isn't an input"), } }; let stride_input_dim = global_stride(inputs, dim, pos_input); let mut index = 0; let mut result = Vector::::empty(); if comptime![dim > 0] { let index_before = global_offset( inputs, outputs, locals, write_pos, input.clone(), comptime![Some((0, dim))], config, ); index += index_before; } if comptime![dim + 1 < config.rank] { let index_after = global_offset( inputs, outputs, locals, write_pos, input, comptime![Some((dim + 1, config.rank))], config, ); index += index_after; } let index_offset = global_offset( inputs, outputs, locals, write_pos, indices, comptime![Some((0, config.rank))], config, ); if comptime![dim == config.rank - 1] { // Per-element indexing (along the dimension) #[unroll] for i in 0..vector_size { let offset = read_input::>( inputs, locals, pos_indices, index_offset + i, LayoutInfo::IsRef, config, None, ); let input = read_input::>( inputs, locals, pos_input, index + (offset[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, ); result[i] = input[0]; } } else { // Shared index for whole vector let stride_input_vector = global_stride(inputs, config.rank - 1, pos_input); let offset = read_input::>( inputs, locals, pos_indices, index_offset, LayoutInfo::IsRef, config, None, ); index += offset[0] as usize * stride_input_dim; #[unroll] for i in 0..vector_size { let input = read_input::>( inputs, locals, pos_input, index + i * stride_input_vector, LayoutInfo::IsRef, config, None, ); result[i] = input[0]; } } write::(inputs, outputs, locals, write_pos, result, output, config); } #[cube] fn select_indices( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, #[comptime] config: &FuseBlockConfig, ) { let (vector_size_ref, stride_dim_ref, shape_dim_ref) = ( locals.ref_vector_size, locals.ref_strides[dim], locals.ref_shape[dim], ); let pos_input = comptime! { match input { FuseArg::Input(pos, ..) => pos, _ => panic!("Input tensor isn't an input"), } }; let pos_indices = match indices { FuseArg::Input(pos, ..) => pos, _ => panic!("Indices tensor isn't an input"), }; let stride_input_dim = global_stride(inputs, dim, pos_input); let mut index = 0; let mut result = Vector::empty(); if comptime![dim != config.rank - 1] { // In this scenario the select is actually broadcasted along the axis we're working on. // // Therefore the same indices are used to fetch multiple entries in the input tensor. if comptime![dim > 0] { let index_before = global_offset( inputs, outputs, locals, write_pos, input.clone(), comptime![Some((0, dim))], config, ); index += index_before; } if comptime![dim + 1 < config.rank] { let index_after = global_offset( inputs, outputs, locals, write_pos, input.clone(), comptime![Some((dim + 1, config.rank))], config, ); index += index_after; } let stride_input_vector = global_stride(inputs, comptime![config.rank - 1], pos_input); let write_pos_input = write_pos * vector_size_ref; let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref; let offset_dim = read_input::>( inputs, locals, pos_indices, coordinate_dim, LayoutInfo::IsRef, config, None, ); index += offset_dim[0] as usize * stride_input_dim; #[unroll] for i in 0..vector_size_ref { let input = read_input::>( inputs, locals, pos_input, index + i * stride_input_vector, LayoutInfo::IsRef, config, None, ); result[i] = input[0]; } } else { // In this scenario the select is actually performed on the last dimension we're working on. // // Therefore we need to fetch multiple indices that correspond to different entries in the // input tensor. if comptime![dim > 0] { let index_before = global_offset( inputs, outputs, locals, write_pos, input.clone(), comptime![Some((0, dim))], config, ); index += index_before; } if comptime![dim + 1 < config.rank] { let index_after = global_offset( inputs, outputs, locals, write_pos, input, comptime![Some((dim + 1, config.rank))], config, ); index += index_after; } let write_pos_indices = write_pos * vector_size_ref; #[unroll] for i in 0..vector_size_ref { let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref; let offset_dim = read_input::>( inputs, locals, pos_indices, coordinate_dim, LayoutInfo::IsRef, config, None, ); let input = read_input::>( inputs, locals, pos_input, index + (offset_dim[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, ); result[i] = input[0]; } } write::(inputs, outputs, locals, write_pos, result, output, config); } #[cube] fn conditional_assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] cond: FuseArg, #[comptime] lhs: FuseArg, #[comptime] rhs: FuseArg, #[comptime] out: FuseArg, #[comptime] config: &FuseBlockConfig, ) { let cond = read::(inputs, outputs, locals, write_pos, cond, config); let lhs = read::(inputs, outputs, locals, write_pos, lhs, config); let rhs = read::(inputs, outputs, locals, write_pos, rhs, config); let result = select_many(cond, lhs, rhs); write::(inputs, outputs, locals, write_pos, result, out, config); } #[cube] fn clamp( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] input: FuseArg, #[comptime] min: FuseArg, #[comptime] max: FuseArg, #[comptime] out: FuseArg, #[comptime] config: &FuseBlockConfig, ) { let input = read::(inputs, outputs, locals, write_pos, input, config); let min = read::(inputs, outputs, locals, write_pos, min, config); let max = read::(inputs, outputs, locals, write_pos, max, config); let result = cubecl::prelude::clamp(input, min, max); write::(inputs, outputs, locals, write_pos, result, out, config); } #[cube] #[allow(clippy::explicit_counter_loop)] fn dequantize( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, write_pos: usize, #[comptime] input: FuseArg, #[comptime] scales: FuseArg, #[comptime] output: FuseArg, #[comptime] scheme: QuantScheme, #[comptime] config: &FuseBlockConfig, ) { comptime!(assert_eq!( scheme.mode, QuantMode::Symmetric, "Only symmetric quantization mode is supported." )); let quant_ty = comptime![match scheme.store { QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S => StorageType::Scalar(ElemType::UInt(UIntKind::U8)), QuantValue::E4M3 => StorageType::Scalar(ElemType::Float(FloatKind::E4M3)), QuantValue::E5M2 => StorageType::Scalar(ElemType::Float(FloatKind::E5M2)), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 => unreachable!("Can't store native sub-byte values"), }, QuantStore::PackedU32(_) => ElemType::UInt(UIntKind::U32).into(), QuantStore::PackedNative(_) => match scheme.value { QuantValue::E2M1 => StorageType::Packed(ElemType::Float(FloatKind::E4M3), 2), other => panic!("{other:?} doesn't support native packing"), }, }]; let param_ty = comptime![match scheme.param { cubecl::quant::scheme::QuantParam::F32 => StorageType::Scalar(ElemType::Float(FloatKind::F32)), cubecl::quant::scheme::QuantParam::F16 => StorageType::Scalar(ElemType::Float(FloatKind::F16)), cubecl::quant::scheme::QuantParam::BF16 => StorageType::Scalar(ElemType::Float(FloatKind::BF16)), cubecl::quant::scheme::QuantParam::UE8M0 => StorageType::Scalar(ElemType::Float(FloatKind::UE8M0)), cubecl::quant::scheme::QuantParam::UE4M3 => StorageType::Scalar(ElemType::Float(FloatKind::E4M3)), }]; let q_vector_size = N::value().comptime() / scheme.num_quants(); let define!(QStoreType) = quant_ty; let size!(QStoreSize) = q_vector_size; let define!(QParamType) = param_ty; let tensor_pos = comptime!(match input { FuseArg::Input(pos, _, _) => pos, _ => panic!("Not supported"), }); let pos = comptime!(match scales { FuseArg::Input(pos, ..) => pos, _ => unreachable!(""), }); let input = read_quantized::(inputs, locals, write_pos, input, config, scheme); let num_quants = scheme.num_quants(); let scales = input_as_scales_view::>(inputs, pos, tensor_pos, scheme.level, config); let result = dequantize_symmetric_packed_value_at::( write_pos * num_quants, input, &scales, scheme, ); let vector = if comptime!(q_vector_size == 1) { result[0] } else { let mut vector = Vector::empty(); #[unroll] for i in 0..q_vector_size { let value = result[i]; #[unroll] for j in 0..num_quants { let index = i * num_quants + j; vector[index] = value[j]; } } vector }; write::(inputs, outputs, locals, write_pos, vector, output, config); } binary_op!(add, +); binary_op!(mul, *); binary_op!(div, /); binary_op!(sub, -); comparison_op!(equal, ==); comparison_op!(greater, >); comparison_op!(greater_equal, >=); comparison_op!(lower, <); comparison_op!(lower_equal, <=); binary_func!(powf, Vector::::powf, Float); binary_func!(rem, Vector::::rem, Float); unary_func!(exp, Vector::::exp, Float); unary_func!(log, Vector::::ln, Float); unary_func!(log1p, Vector::::log1p, Float); unary_func!(sqrt, Vector::::sqrt, Float); unary_func!(cos, Vector::::cos, Float); unary_func!(sin, Vector::::sin, Float); unary_func!(tanh, Vector::::tanh, Float); unary_func!(erf, Vector::::erf, Float); unary_func!(recip, Vector::::recip, Float); unary_func!(abs, Vector::::abs, Numeric); ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/mod.rs ================================================ pub(crate) mod io; pub(crate) mod ir; pub(crate) mod kernel; pub(crate) mod tensor; pub(crate) mod view; mod base; pub(crate) use base::*; ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/tensor.rs ================================================ use crate::engine::codegen::{DynElem, DynSize}; use cubecl::{ir::Type, prelude::*}; use serde::{Deserialize, Serialize}; use std::hash::Hash; /// Represents a global tensor with the given [element type](ElemType). /// /// # Warning /// /// The `tensor` field type [Vector>] must be set using polyfill before /// use. #[derive(CubeType, Clone)] pub struct GlobalTensor { /// The global tensor type. pub tensor: Tensor>, /// The element type of the tensor. #[cube(comptime)] pub ty: Type, /// Whether the current tensor is logically broadcasted. #[cube(comptime)] pub broadcasted: bool, } // Everything below is to implement [LaunchArg]. #[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)] pub struct GlobalTensorCompilationArg { tensor: TensorCompilationArg, ty: Type, broadcasted: bool, } #[derive(new, Debug)] pub struct GlobalTensorArg { pub tensor: > as LaunchArg>::RuntimeArg, pub ty: Type, pub broadcasted: bool, pub address_type: AddressType, } impl LaunchArg for GlobalTensor { type RuntimeArg = GlobalTensorArg; type CompilationArg = GlobalTensorCompilationArg; fn compilation_arg(runtime_arg: &Self::RuntimeArg) -> Self::CompilationArg { let tensor = > as LaunchArg>::compilation_arg(&runtime_arg.tensor); GlobalTensorCompilationArg { tensor, ty: runtime_arg.ty, broadcasted: runtime_arg.broadcasted, } } fn register(arg: Self::RuntimeArg, launcher: &mut KernelLauncher) { launcher.register_tensor(arg.tensor, arg.ty); } fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> GlobalTensorExpand { let tensor = builder.input_tensor(arg.ty); GlobalTensorExpand { tensor: tensor.into(), ty: arg.ty, broadcasted: arg.broadcasted, } } fn expand_output( arg: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> GlobalTensorExpand { let tensor = match arg.tensor.inplace { Some(id) => builder.inplace_output(id), None => builder.output_tensor(arg.ty), }; GlobalTensorExpand { tensor: tensor.into(), ty: arg.ty, broadcasted: arg.broadcasted, } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/codegen/view.rs ================================================ use crate::engine::codegen::{DynElem, DynSize, io::set_polyfill_typed}; use super::{ io::{ Transform, global_buffer_len, global_vector_size, input_as_slice, read_input, read_input_window, ref_buffer_len, ref_len, }, ir::{FuseArg, FuseBlockConfig, GlobalArgs, LayoutInfo, LocalArgs}, kernel::fuse_on_write, }; use cubecl::{ CubeType, io::read_masked, ir::StorageType, prelude::{barrier::BarrierExpand, *}, std::tensor::{ ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand, layout::Coords1d, }, }; #[allow(dead_code, reason = "only used in expand")] #[derive(CubeType)] pub struct GlobalInput { inputs: GlobalArgs, locals: LocalArgs, #[cube(comptime)] pos: usize, #[cube(comptime)] ty: StorageType, #[cube(comptime)] layout: LayoutInfo, #[cube(comptime)] config: FuseBlockConfig, #[cube(comptime)] transform: Option, } #[cube] impl GlobalInput { pub fn new( inputs: &GlobalArgs, locals: &LocalArgs, #[comptime] arg: FuseArg, #[comptime] config: FuseBlockConfig, #[comptime] transform: Option, ) -> GlobalInput { let (pos, ty, layout) = comptime![match arg { FuseArg::Input(pos, prec, layout) => (pos, prec.into_storage_type(), layout), _ => unreachable!("Must be concrete input"), }]; GlobalInput { inputs: inputs.clone(), locals: locals.clone(), pos, ty, layout, config, transform, } } } impl ViewOperations for GlobalInput {} impl ViewOperationsExpand for GlobalInputExpand { #[allow(clippy::too_many_arguments)] fn __expand_read_method( &self, scope: &mut Scope, pos: NativeExpand, ) -> ::ExpandType { ViewOperationsExpand::::__expand_read_unchecked_method(self, scope, pos) } #[allow(clippy::too_many_arguments)] fn __expand_read_checked_method( &self, scope: &mut Scope, pos: NativeExpand, ) -> ::ExpandType { let zero = E::__expand_cast_from(scope, 0.into()); ViewOperationsExpand::::__expand_read_masked_method(self, scope, pos, zero) } #[allow(clippy::too_many_arguments)] fn __expand_read_masked_method( &self, scope: &mut Scope, pos: NativeExpand, value: ::ExpandType, ) -> ::ExpandType { let in_bounds = ViewOperationsExpand::::__expand_is_in_bounds_method( self, scope, pos.clone(), ); set_polyfill_typed::expand::(scope); let slice = input_as_slice::expand(scope, self.inputs.clone(), self.pos); read_masked::expand::(scope, in_bounds, slice, pos, value) } #[allow(clippy::too_many_arguments)] fn __expand_read_unchecked_method( &self, scope: &mut Scope, pos: NativeExpand, ) -> ::ExpandType { set_polyfill_typed::expand::(scope); let value = read_input::expand::( scope, self.inputs.clone(), self.locals.clone(), self.pos, pos, self.layout, self.config.clone(), self.transform.clone(), ); E::__expand_cast_from(scope, value) } #[allow(clippy::too_many_arguments)] fn __expand_to_linear_slice_method( &self, scope: &mut Scope, pos: NativeExpand, end: NativeExpand, ) -> SliceExpand { set_polyfill_typed::expand::(scope); let end = add::expand(scope, end.clone(), 1.into()); read_input_window::expand(scope, self.inputs.clone(), self.pos, pos, end) } #[allow(clippy::too_many_arguments)] fn __expand_tensor_map_load_method( &self, _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, _pos: NativeExpand, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] fn __expand_shape_method(&self, scope: &mut Scope) -> NativeExpand { global_buffer_len::expand(scope, self.inputs.clone(), self.pos) } #[allow(clippy::too_many_arguments)] fn __expand_is_in_bounds_method( &self, scope: &mut Scope, pos: NativeExpand, ) -> NativeExpand { let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos); lt::expand(scope, pos, buffer_len) } } impl Vectorized for GlobalInput {} impl VectorizedExpand for GlobalInputExpand { fn vector_size(&self) -> VectorSize { let mut temp_scope = Scope::root(false); global_vector_size::expand(&mut temp_scope, self.inputs.clone(), self.pos) } } #[allow(dead_code, reason = "only used in expand")] #[derive(CubeType)] pub struct FusedOutput { inputs: GlobalArgs, outputs: GlobalArgs, locals: LocalArgs, arg: FuseArg, #[cube(comptime)] config: FuseBlockConfig, } #[cube] impl FusedOutput { pub fn new( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, arg: FuseArg, #[comptime] config: FuseBlockConfig, ) -> Self { FusedOutput { inputs: inputs.clone(), outputs: outputs.clone(), locals: locals.clone(), arg, config, } } } impl ViewOperations for FusedOutput {} impl ViewOperationsExpand for FusedOutputExpand { #[allow(clippy::too_many_arguments)] fn __expand_read_method( &self, _scope: &mut Scope, _pos: NativeExpand, ) -> ::ExpandType { todo!() } #[allow(clippy::too_many_arguments)] fn __expand_read_checked_method( &self, _scope: &mut Scope, _pos: NativeExpand, ) -> ::ExpandType { todo!() } #[allow(clippy::too_many_arguments)] fn __expand_read_masked_method( &self, _scope: &mut Scope, _pos: NativeExpand, _value: ::ExpandType, ) -> ::ExpandType { todo!() } #[allow(clippy::too_many_arguments)] fn __expand_read_unchecked_method( &self, _scope: &mut Scope, _pos: NativeExpand, ) -> ::ExpandType { todo!() } #[allow(clippy::too_many_arguments)] fn __expand_to_linear_slice_method( &self, _scope: &mut Scope, _pos: NativeExpand, _size: NativeExpand, ) -> SliceExpand { todo!() } #[allow(clippy::too_many_arguments)] fn __expand_tensor_map_load_method( &self, _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, _pos: NativeExpand, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] fn __expand_shape_method(&self, scope: &mut Scope) -> NativeExpand { ref_len::expand( scope, self.inputs.clone(), self.outputs.clone(), self.locals.clone(), self.config.clone(), ) } #[allow(clippy::too_many_arguments)] fn __expand_is_in_bounds_method( &self, scope: &mut Scope, pos: NativeExpand, ) -> NativeExpand { let buffer_len = ref_buffer_len::expand( scope, self.inputs.clone(), self.outputs.clone(), self.locals.clone(), self.config.clone(), ); lt::expand(scope, pos, buffer_len) } } impl ViewOperationsMut for FusedOutput {} impl ViewOperationsMutExpand for FusedOutputExpand { #[allow(clippy::too_many_arguments)] fn __expand_write_method( &self, scope: &mut Scope, pos: NativeExpand, value: ::ExpandType, ) { let values = Registry::>::__expand_new(scope); let mut args = comptime![Vec::::new()]; let value = Vector::__expand_cast_from(scope, value); values .clone() .__expand_insert_method(scope, comptime![self.arg.clone()], value); comptime![args.push(self.arg.clone())]; fuse_on_write::expand( scope, self.inputs.clone(), self.outputs.clone(), self.locals.clone(), pos, values, args, self.config.clone(), ); } #[allow(clippy::too_many_arguments)] fn __expand_write_checked_method( &self, scope: &mut Scope, pos: NativeExpand, value: ::ExpandType, ) { let in_bounds = ViewOperationsExpand::::__expand_is_in_bounds_method( self, scope, pos.clone(), ); if_expand(scope, in_bounds, |scope| { ViewOperationsMutExpand::::__expand_write_method(self, scope, pos, value); }) } #[allow(clippy::too_many_arguments)] fn __expand_to_linear_slice_mut_method( &self, _scope: &mut Scope, _pos: NativeExpand, _size: NativeExpand, ) -> SliceExpand { todo!("Not yet supported") } #[allow(clippy::too_many_arguments)] fn __expand_tensor_map_store_method( &self, _scope: &mut Scope, _shared_memory: SliceExpand, _pos: NativeExpand, ) { panic!("Not a tensor map") } } impl Vectorized for FusedOutput {} impl VectorizedExpand for FusedOutputExpand { fn vector_size(&self) -> VectorSize { self.locals.ref_vector_size } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/fuser.rs ================================================ use super::{ codegen::ir::{BinaryFuseArgs, FuseArg, FuseOp, UnaryFuseArgs}, settings::FuseSettings, trace::{FuseTrace, TraceFuser, block::QuantInput}, }; use crate::engine::{codegen::ir::QuantSchemeFuse, scoring::Scoring}; use burn_fusion::{FuserProperties, FuserStatus, OperationFuser}; use burn_ir::{ BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr, TensorIr, UnaryOpIr, }; use burn_std::{DType, Shape}; use cubecl::ir::ElemType; /// The base operation fuser that can be used to fuse [all supported fuse operations](FuseOp). /// /// /// This fuser doesn't create a ready-to-execute kernel, but rather generates a /// [trace](FuseTrace) that be used with a [runner](super::trace::TraceRunner). /// /// Since this fuser supports fusing multiple blocks, you can fuse any compute-bound operations /// with the combination of fuse-on-read and fuse-on-write strategy. /// /// # Notes /// /// It is responsible to translate [OperationIr] into [FuseOp] and it uses the [TraceFuser] /// to actually fuse the [FuseOp] when possible. #[derive(Debug, Clone)] pub(crate) struct TraceOperationFuser { fuser: TryTraceFuser, scoring: Scoring, pub(crate) settings: FuseSettings, pub(crate) current_output_shape: Shape, status: FuserStatus, pub(crate) num_ops: usize, pub(crate) num_views: usize, pub(crate) max_bindings: u32, } impl TraceOperationFuser { /// Checks if the [operation](OperationIr) can be fused with the current fuser. pub(crate) fn can_fuse(&self, op: &OperationIr) -> bool { let len_previous = self.len(); let mut fuser_cloned = self.clone(); fuser_cloned.fuse(op); let len_after = fuser_cloned.len(); len_after > len_previous } } impl OperationFuser for TraceOperationFuser { fn fuse(&mut self, op: &OperationIr) { if let FuserStatus::Closed = self.status { return; } match op { OperationIr::Drop(tensor) => { if self.num_ops == 0 { self.status = FuserStatus::Closed; return; } self.fuser.fuser.fuse_dropped(tensor); } OperationIr::BaseFloat(ops) => { if !self.fuse_base(ops) { self.status = FuserStatus::Closed; return; } } OperationIr::BaseInt(ops) => { if !self.fuse_base(ops) { self.status = FuserStatus::Closed; return; } } OperationIr::Float(_dtype, ops) => { if !self.fuse_float(ops) { self.status = FuserStatus::Closed; return; } } OperationIr::NumericFloat(_dtype, ops) => { if !self.fuse_numeric(ops) { self.status = FuserStatus::Closed; return; } } OperationIr::NumericInt(_dtype, ops) => { if !self.fuse_numeric(ops) { self.status = FuserStatus::Closed; return; } } OperationIr::BaseBool(ops) => { if !self.fuse_base(ops) { self.status = FuserStatus::Closed; return; } } _ => { self.status = FuserStatus::Closed; return; } }; self.status = FuserStatus::Open; self.scoring.register(op); self.num_ops += 1; } fn finish(&mut self) -> FuseTrace { self.fuser.finish(self.current_output_shape.clone()) } fn len(&self) -> usize { self.num_ops } fn reset(&mut self) { self.num_ops = 0; self.scoring.reset(); self.num_views = 0; self.status = FuserStatus::Open; self.fuser = TryTraceFuser::new(self.max_bindings, self.settings); self.current_output_shape = Shape::new([]); } fn status(&self) -> FuserStatus { self.status } fn properties(&self) -> FuserProperties { let ready = self.num_ops > 0; let score = self .scoring .evaluate(&self.fuser.clone().finish(self.current_output_shape.clone())); FuserProperties { ready, score } } fn clone_dyn(&self) -> Box> { Box::new(self.clone()) } } impl TraceOperationFuser { /// Creates a new fuser. pub fn new(max_bindings: u32, settings: FuseSettings) -> Self { Self { fuser: TryTraceFuser::new(max_bindings, settings), settings, scoring: Scoring::default(), num_ops: 0, num_views: 0, max_bindings, current_output_shape: Shape::new([]), status: FuserStatus::Open, } } /// Closes the fuser. pub fn close(&mut self) { self.status = FuserStatus::Closed; } /// Declares an input tensor argument where the kernel is responsible to load. /// /// # Returns /// /// - The argument that maps to the tensor to be used during kernel expansion. pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg { self.fuser.fuser.input_unhandled(tensor) } /// Declares an input quantized tensor argument where the kernel is responsible to load. /// /// # Returns /// /// None if it's not possible to fuse a quantized tensor. Otherwise: /// /// - The argument that maps to the tensor values to be used during kernel expansion. /// - The argument that maps to the tensor params to be used during kernel expansion. pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> { self.fuser.fuser.input_quantized_unhandled(tensor) } /// Declares an output tensor argument where the kernel is responsible to write values. /// /// # Notes /// /// Normally you don't have to declare outputs explicitly before they are going to be /// fused based on the operations [fused](Self::fuse). /// /// # Returns /// /// - The argument that maps to the tensor to be used during kernel expansion. pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg { if self.current_output_shape.is_empty() { self.current_output_shape = tensor.shape.clone(); } else if self.current_output_shape.iter().sum::() < tensor.shape.iter().sum() { // The larguest shape win. self.current_output_shape = tensor.shape.clone(); } self.fuser.fuser.output_unhandled(tensor) } /// Closes the previous block and declares a new one. /// /// # Arguments /// /// - arguments: Tensors that are logical outputs of the current block and inputs of the following blocks. /// - settings: [FuseSettings] to be used by the next block. /// /// # Returns /// /// None if it's impossible to create a next block with the given arguments. Otherwise, the /// corresponding [arguments](Arg) to the given tensors are returned. pub fn next_block( &mut self, arguments: [&TensorIr; N], settings: FuseSettings, global: bool, ) -> [FuseArg; N] { let block_pos = self.fuser.fuser.num_previous_blocks(); let current_output_shape = core::mem::replace(&mut self.current_output_shape, Shape::new([])); self.fuser.fuser.next_block(current_output_shape, settings); self.settings = settings; self.status = FuserStatus::Open; arguments.map(|arg| self.fuser.fuser.block_local_input(arg, block_pos, global)) } /// Tag the [tensor](TensorIr) as received from a previous block. /// /// This will avoid reading the input again and instead use le local version when possible. pub fn block_local_input(&mut self, tensor: &TensorIr, block_pos: usize, global: bool) { self.fuser .fuser .block_local_input(tensor, block_pos, global); } fn fuse_base(&mut self, ops: &BaseOperationIr) -> bool { match ops { BaseOperationIr::Equal(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out }) }), BaseOperationIr::EqualElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out }) }), BaseOperationIr::Cast(desc) => { self.fuse_unary_op(&desc.input, &desc.out, |input, out| { FuseOp::Assign(UnaryFuseArgs { input, out }) }) } BaseOperationIr::SwapDims(desc) => { if !self.output_is_compatible(&desc.out) { return false; } if self.fuser.fuse(|fuser| { fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?; Some(()) }) { self.num_views += 1; true } else { false } } BaseOperationIr::Reshape(desc) => { if desc.input.shape == desc.out.shape { return self.fuse_unary_op(&desc.input, &desc.out, |input, out| { FuseOp::Assign(UnaryFuseArgs { input, out }) }); } if desc.input.shape.rank() > desc.out.shape.rank() { // Not yet supported. return false; } if !self.output_is_compatible(&desc.out) { return false; } if self.fuser.fuse(|fuser| { fuser.input_reshaped(&desc.input, &desc.out)?; Some(()) }) { self.num_views += 1; true } else { false } } BaseOperationIr::Ones(desc) => { if !self.output_is_compatible(&desc.out) { return false; } let elem: ElemType = desc.out.dtype.into(); let precision = elem.into(); let input = FuseArg::Literal(1, precision); self.fuser.fuse(|fuser| { let out = fuser.output(&desc.out)?; fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out })); Some(()) }) } BaseOperationIr::Zeros(desc) => { if !self.output_is_compatible(&desc.out) { return false; } let elem: ElemType = desc.out.dtype.into(); let precision = elem.into(); let input = FuseArg::Literal(0, precision); self.fuser.fuse(|fuser| { let out = fuser.output(&desc.out)?; fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out })); Some(()) }) } BaseOperationIr::Gather(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let input = build.input_indexed(&desc.tensor)?; let indices = build.input_indexed(&desc.indices)?; let output = build.output(&desc.out)?; build.fuse_operation(FuseOp::Gather { input, indices, output, dim: desc.dim, }); Some(()) }) } BaseOperationIr::Select(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let input = build.input_indexed(&desc.tensor)?; let indices = build.input_indexed(&desc.indices)?; let output = build.output(&desc.out)?; build.fuse_operation(FuseOp::Select { input, indices, output, dim: desc.dim, }); Some(()) }) } BaseOperationIr::MaskWhere(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let cond = build.input(&desc.mask)?; let rhs = build.input(&desc.tensor)?; let lhs = build.input(&desc.value)?; let out = build.output(&desc.out)?; build.fuse_operation(FuseOp::ConditionalAssign { cond, lhs, rhs, out, }); Some(()) }) } BaseOperationIr::MaskFill(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let cond = build.input(&desc.mask)?; let lhs = build.scalar(&desc.value, desc.out.dtype); let rhs = build.input(&desc.tensor)?; let out = build.output(&desc.out)?; build.fuse_operation(FuseOp::ConditionalAssign { cond, lhs, rhs, out, }); Some(()) }) } _ => false, } } fn fuse_float(&mut self, ops: &FloatOperationIr) -> bool { match ops { FloatOperationIr::Exp(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Exp(UnaryFuseArgs { input, out })) } FloatOperationIr::Log(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Log(UnaryFuseArgs { input, out })) } FloatOperationIr::Powf(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out }) }), FloatOperationIr::Log1p(desc) => self.fuse_unary_ops(desc, |input, out| { FuseOp::Log1p(UnaryFuseArgs { input, out }) }), FloatOperationIr::Cos(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Cos(UnaryFuseArgs { input, out })) } FloatOperationIr::Sin(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Sin(UnaryFuseArgs { input, out })) } FloatOperationIr::PowfScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out }) }), FloatOperationIr::Tanh(desc) => self.fuse_unary_ops(desc, |input, out| { FuseOp::Tanh(UnaryFuseArgs { input, out }) }), FloatOperationIr::Erf(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Erf(UnaryFuseArgs { input, out })) } FloatOperationIr::Sqrt(desc) => self.fuse_unary_ops(desc, |input, out| { FuseOp::Sqrt(UnaryFuseArgs { input, out }) }), FloatOperationIr::Recip(desc) => self.fuse_unary_ops(desc, |input, out| { FuseOp::Recip(UnaryFuseArgs { input, out }) }), FloatOperationIr::Dequantize(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let qinput = build.input_quantized(&desc.input)?; let out = build.output(&desc.out)?; match qinput { QuantInput::AlreadyDequantized { local } => { build.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input: local, out, })); } QuantInput::Quantized { values, params } => { build.fuse_operation(FuseOp::Dequantize { values, params, output: out, scheme: match desc.input.dtype { DType::QFloat(scheme) => QuantSchemeFuse { scheme }, _ => unreachable!("Should be a quant tensor."), }, }); } } Some(()) }) } _ => false, } } fn fuse_numeric(&mut self, op: &NumericOperationIr) -> bool { match op { NumericOperationIr::Add(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Add(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::AddScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Add(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Sub(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::SubScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Mul(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::MulScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Div(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Div(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::DivScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Div(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Abs(desc) => { self.fuse_unary_ops(desc, |input, out| FuseOp::Abs(UnaryFuseArgs { input, out })) } NumericOperationIr::Lower(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::LowerElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Greater(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::GreaterElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::LowerEqual(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::LowerEqualElem(desc) => self .fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::GreaterEqual(desc) => self .fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::GreaterEqualElem(desc) => self .fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Full(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let input = build.scalar(&desc.value, desc.out.dtype); let out = build.output(&desc.out)?; build.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out })); Some(()) }) } NumericOperationIr::Rem(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| { FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::RemScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| { FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out }) }), NumericOperationIr::Clamp(desc) => { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let input = build.input(&desc.tensor)?; let min = build.scalar(&desc.min, desc.out.dtype); let max = build.scalar(&desc.max, desc.out.dtype); let out = build.output(&desc.out)?; build.fuse_operation(FuseOp::Clamp { input, min, max, out, }); Some(()) }) } _ => false, } } fn fuse_binary_ops(&mut self, desc: &BinaryOpIr, func: Func) -> bool where Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp, { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let lhs = build.input(&desc.lhs)?; let rhs = build.input(&desc.rhs)?; let out = build.output(&desc.out)?; build.fuse_operation(func(lhs, rhs, out)); Some(()) }) } fn fuse_unary_ops(&mut self, desc: &UnaryOpIr, func: Func) -> bool where Func: Fn(FuseArg, FuseArg) -> FuseOp, { self.fuse_unary_op(&desc.input, &desc.out, func) } fn fuse_unary_op(&mut self, input: &TensorIr, out: &TensorIr, func: Func) -> bool where Func: Fn(FuseArg, FuseArg) -> FuseOp, { if !self.output_is_compatible(out) { return false; } self.fuser.fuse(|build| { let input = build.input(input)?; let out = build.output(out)?; build.fuse_operation(func(input, out)); Some(()) }) } fn fuse_scalar_ops(&mut self, desc: &ScalarOpIr, func: Func) -> bool where Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp, { if !self.output_is_compatible(&desc.out) { return false; } self.fuser.fuse(|build| { let elem = desc.lhs.dtype; let lhs = build.input(&desc.lhs)?; let rhs = build.scalar(&desc.rhs, elem); let out = build.output(&desc.out)?; build.fuse_operation(func(lhs, rhs, out)); Some(()) }) } fn output_is_compatible(&mut self, out: &TensorIr) -> bool { if self.current_output_shape.is_empty() { self.current_output_shape.clone_from(&out.shape); return true; } let rank = self.current_output_shape.len(); // Rank should be equal. if rank != out.shape.num_dims() { return false; } let mut updated = self.current_output_shape.clone(); let mut should_update = false; #[allow(clippy::needless_range_loop)] for i in 0..rank { let curr = self.current_output_shape[i]; let new = out.shape[i]; if curr == new { continue; } // Broadcast not enabled. if !self.settings.broadcast { return false; } // Broadcasted on new dim. if new == 0 { continue; } // Broadcasted on curr dim - update reference output shape. if curr == 0 && self.settings.output_shape_updates { should_update = true; updated[i] = new; continue; } return false; } if should_update { // For now forced to have exact shape. if updated != out.shape { return false; } self.current_output_shape.clone_from_slice(&out.shape); } true } } #[derive(Debug, Clone)] /// Builder wrapper to limit the number of bindings in generated kernels. struct TryTraceFuser { fuser: TraceFuser, max_bindings: u32, max_ops: u32, added_ops: bool, } impl TryTraceFuser { fn new(max_bindings: u32, settings: FuseSettings) -> Self { Self { fuser: TraceFuser::new(settings), max_bindings, // A good default, avoid errors with for loops over only memory // bound operations. max_ops: 64, added_ops: false, } } fn fuse(&mut self, add_ops: impl FnOnce(&mut TraceFuser) -> Option<()>) -> bool { if self.fuser.num_ops_fused() > self.max_ops { return false; } // Always allow the first operation to be added. if !self.added_ops { self.added_ops = true; if add_ops(&mut self.fuser).is_none() { return false; } return true; } let mut cloned = self.fuser.clone(); if add_ops(&mut cloned).is_none() { return false; } if cloned.estimate_bindings() > self.max_bindings { return false; } self.fuser = cloned; true } fn finish(&mut self, shape: Shape) -> FuseTrace { self.fuser.finish(shape) } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/base.rs ================================================ use crate::{ CubeFusionHandle, engine::{ launch::{ HandleInput, HandleOutput, LaunchPlan, executor::LaunchPlanExecutor, input::InputPlanner, output::OutputPlanner, runner::TraceRunner, vectorization::VectorizationPlanner, }, trace::{FuseTrace, TraceError, TuneOutput}, }, }; use burn_fusion::stream::Context; use cubecl::{Runtime, client::ComputeClient}; use std::marker::PhantomData; /// The launcher is responsible to launch a fused kernel using the [TraceRunner] and a [FuseTrace]. /// /// TODO: We can reuse the same launcher between runs and avoid a lot of allocation, by simply /// resetting the state. pub struct FuseTraceLauncher<'a, R: Runtime, Runner: TraceRunner> { trace: &'a FuseTrace, runner: &'a Runner, _runtime: PhantomData, } impl<'a, R: Runtime, Runner: TraceRunner> FuseTraceLauncher<'a, R, Runner> { /// Creates a new launcher. pub fn new(trace: &'a FuseTrace, runner: &'a Runner) -> Self { Self { trace, runner, _runtime: PhantomData, } } /// Launches the fuse kernel on the given device modifying the context. pub fn launch( &self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, CubeFusionHandle>, ) -> Result, TraceError> { let mut plan = LaunchPlan::new(&self.trace.blocks); InputPlanner::new(&self.trace.resources, &self.trace.blocks).run(context, &mut plan); OutputPlanner::new(&self.trace.resources, &self.trace.blocks) .run(client, device, context, &mut plan); VectorizationPlanner::new(&self.trace.resources, &self.trace.blocks).run( client, self.runner, context, &mut plan, ); match LaunchPlanExecutor::new(&self.trace.resources, &self.trace.blocks).execute::<_>( client, self.runner, context, plan, ) { Err(err) => { self.rollback(context, err.handles_input, err.handles_output); Err(err.error) } Ok(val) => Ok(val), } } fn rollback( &self, context: &mut Context<'_, CubeFusionHandle>, handle_inputs: Vec>, handle_outputs: Vec>, ) { for input in handle_inputs { match input { HandleInput::Normal(input) => { context .handles .register_handle(input.global_ir.id, input.handle_rollback()); } HandleInput::QuantValues(input) => { context .handles .register_handle(input.global_ir.id, input.handle); } HandleInput::QuantParams(_) => { // The scales are part of the quant data handle. } }; } for output in handle_outputs { if let HandleOutput::Owned { global_id, handle, .. } = output { context.handles.register_handle(global_id, handle); } } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/executor.rs ================================================ use super::{HandleInput, HandleOutput, LaunchPlan, ReferenceSelection}; use crate::engine::launch::runner::TraceRunner; use crate::engine::trace::{FuseResources, TensorView, TraceError, TuneOutput, block::FuseBlock}; use crate::{ CubeFusionHandle, engine::{ codegen::ir::{ FuseBlockConfig, FuseOp, FuseType, GlobalArgsLaunch, RefLayout, VirtualLayout, }, codegen::tensor::GlobalTensorArg, }, }; use burn_fusion::stream::{Context, ScalarId}; use burn_ir::ScalarIr; use cubecl::{ Runtime, client::ComputeClient, ir::{AddressType, Type}, prelude::{InputScalar, TensorArg}, }; use std::marker::PhantomData; /// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). pub struct LaunchPlanExecutor<'a, R: Runtime> { resources: &'a FuseResources, blocks: &'a Vec, _r: PhantomData, } #[derive(new, Debug)] pub struct ExecutionError> { pub error: TraceError, pub handles_input: Vec>, pub handles_output: Vec>, } impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { pub fn new(resources: &'a FuseResources, blocks: &'a Vec) -> Self { Self { resources, blocks, _r: PhantomData, } } pub fn execute>( self, client: &ComputeClient, runner: &Runner, context: &mut Context<'_, CubeFusionHandle>, plan: LaunchPlan<'a, R>, ) -> Result, ExecutionError> { let mut num_writes = 0; for b in plan.blocks.iter() { for writes in b.writes.values() { num_writes += writes.len(); } } #[cfg(feature = "autotune-checks")] let mut tune_output = TuneOutput::Checked { handles: std::collections::HashMap::new(), }; #[cfg(not(feature = "autotune-checks"))] let mut tune_output = TuneOutput::UnChecked(PhantomData); if num_writes == 0 { // Nothing to write, can skip execution. return Ok(tune_output); } let mut inputs = GlobalArgsLaunch::default(); let mut outputs = GlobalArgsLaunch::default(); register_inputs(plan.handle_inputs.clone(), &mut inputs); register_scalars( self.resources.scalars.iter(), self.resources.views.iter(), context, &mut inputs, ); register_outputs::(plan.handle_outputs.clone(), &mut outputs, &mut tune_output); for layout in plan.runtime_layouts { for s in layout.shape.iter() { inputs.runtime_layouts.push(*s); } for s in layout.strides.iter() { inputs.runtime_layouts.push(*s); } } let mut configs = Vec::with_capacity(plan.blocks.len()); for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), ReferenceSelection::VirtualShape { original, .. } => { RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width)) } ReferenceSelection::SwapDims { original, dims } => { RefLayout::Virtual(VirtualLayout::SwapDims(original, dims)) } ReferenceSelection::Reshaped { reshape_pos } => { RefLayout::Virtual(VirtualLayout::Reshaped { reshape_pos, vector_size: block_plan.width, }) } ReferenceSelection::Runtime { pos } => { RefLayout::Virtual(VirtualLayout::Runtime { pos }) } ReferenceSelection::Searching => { return Err(ExecutionError::new( TraceError::ReferenceNotFound, plan.handle_inputs, plan.handle_outputs, )); } }; let mut ops = Vec::::new(); for read_ops in block_plan.reads.into_values() { for op in read_ops { ops.push(op); } } for op in block.ops.iter() { ops.push(op.clone()); } for opsw in block_plan.writes.into_values() { for op in opsw { ops.push(op); } } let config = FuseBlockConfig { rank: plan.rank, ref_layout: reference, ops, width: block_plan.width, }; configs.push(config); } Runner::run(runner, client, inputs, outputs, &configs).map_err(|err| { ExecutionError::new( TraceError::RunnerError(err), plan.handle_inputs, plan.handle_outputs, ) })?; Ok(tune_output) } } fn register_inputs( handle_inputs: Vec>, inputs: &mut GlobalArgsLaunch, ) { for hi in handle_inputs { match hi { HandleInput::Normal(hi) => { let at = hi.handle.required_address_type(); let arg = hi.handle.into_tensor_arg(hi.global_ir.shape.clone()); inputs.tensors.push(GlobalTensorArg::new( arg, hi.precision.into_type(hi.vector_size), hi.broadcated, at, )); } HandleInput::QuantValues(hi) => { let at = hi.handle.required_address_type(); let arg = hi.handle.into_tensor_arg(hi.global_ir.shape.clone()); inputs.tensors.push(GlobalTensorArg::new( arg, hi.precision.into_type(hi.vector_size), false, at, )); } HandleInput::QuantParams(hi) => { let at = hi.handle.required_address_type(); let arg = hi.handle.into_tensor_arg(hi.shape.clone()); inputs.tensors.push(GlobalTensorArg::new( arg, hi.precision.into_type(1), false, at, )); } } } } fn register_outputs( handle_outputs: Vec>, outputs: &mut GlobalArgsLaunch, #[allow(unused_variables)] tune_output: &mut TuneOutput, ) { for item in handle_outputs { match item { HandleOutput::Alias { input_pos, precision, global_shape, strides, #[cfg(feature = "autotune-checks")] debug_info, } => { outputs.tensors.push(GlobalTensorArg::new( TensorArg::Alias { input_pos, strides, shape: global_shape, }, precision.into_type(1), false, AddressType::default(), )); #[cfg(feature = "autotune-checks")] if let TuneOutput::Checked { handles, .. } = tune_output { handles.insert( debug_info.relative_id, (debug_info.global_shape.clone(), debug_info.handle.clone()), ); } } HandleOutput::Owned { precision, handle, global_shape, vectorization: vector_size, #[cfg(feature = "autotune-checks")] relative_id, .. } => { let at = handle.required_address_type(); let arg = handle.into_tensor_arg(global_shape.clone()); let elem = precision.into_elem(); let ty = Type::new(elem.into()).with_vector_size(vector_size); #[cfg(feature = "autotune-checks")] if let TuneOutput::Checked { handles, .. } = tune_output { handles.insert(*relative_id, (global_shape.clone(), handle.clone())); } outputs .tensors .push(GlobalTensorArg::new(arg, ty, false, at)); } } } } fn register_scalars<'h, R: Runtime>( scalars: impl Iterator, views: impl DoubleEndedIterator, context: &mut Context<'_, CubeFusionHandle>, inputs: &mut GlobalArgsLaunch, ) { for (precision, id) in scalars { let dtype = precision.into_storage_type(); match context.scalars.get(&ScalarId { value: *id }) { Some(scalar) => match scalar { ScalarIr::Float(val) => inputs.scalars.push(InputScalar::new(*val, dtype)), ScalarIr::Int(val) => inputs.scalars.push(InputScalar::new(*val, dtype)), ScalarIr::UInt(val) => inputs.scalars.push(InputScalar::new(*val, dtype)), ScalarIr::Bool(val) => inputs.scalars.push(InputScalar::new(*val as u8, dtype)), }, None => panic!("Scalar ID not found"), } } for relative in views { if let TensorView::Reshape { reshaped, .. } = relative { let global = context.tensors.get(reshaped).unwrap(); for shape in global.shape.iter() { inputs.reshapes.push(*shape); } } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/input.rs ================================================ use super::{BlockPlan, HandleInput, InputReference}; use super::{LaunchPlan, NormalHandleInput, PotentialInplace}; use crate::CubeFusionHandle; use crate::engine::launch::{QuantParamsHandleInput, QuantValuesHandleInput}; use crate::engine::trace::block::FuseBlock; use crate::engine::trace::{FuseResources, RegisterTensor, TensorView}; use burn_fusion::stream::Context; use burn_ir::{TensorIr, TensorStatus}; use burn_std::quantization::params_shape; use cubecl::Runtime; use std::marker::PhantomData; /// Fetch and register [input handles](HandleInput). Also identifies potential inputs that /// can be used inplace and/or as the [reference layout](super::super::ir::RefLayout). pub struct InputPlanner<'a, R: Runtime> { resources: &'a FuseResources, blocks: &'a Vec, _r: PhantomData, } impl<'a, R: Runtime> InputPlanner<'a, R> { pub fn new(resources: &'a FuseResources, blocks: &'a Vec) -> Self { Self { resources, blocks, _r: PhantomData, } } pub fn run(self, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>) { for (pos, input) in self.resources.inputs.iter().enumerate() { match input { RegisterTensor::Normal(tensor_relative, precision) => { let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let handle = context .handles .get_handle(&tensor_global.id, &TensorStatus::ReadOnly); if let TensorStatus::ReadWrite = tensor_relative.status { plan.cleared.push(tensor_global.id); } let mut new_strides = handle.strides.clone(); self.analyze(plan, pos, tensor_relative, &handle); if tensor_global.shape.rank() < plan.rank { let num_elem: usize = tensor_global.shape.iter().product(); for _ in 0..(plan.rank - tensor_global.shape.rank()) { tensor_global.shape.insert(0, 1); new_strides.insert(0, num_elem); } } plan.handle_inputs .push(HandleInput::Normal(NormalHandleInput::new( tensor_global, tensor_relative, *precision, handle, new_strides, ))); } RegisterTensor::QuantValues(tensor_relative) => { let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let handle = context .handles .get_handle(&tensor_global.id, &TensorStatus::ReadOnly); let scheme = match tensor_relative.dtype { burn_std::DType::QFloat(scheme) => scheme, _ => unreachable!("Can't have quant data without QFloat"), }; let params = handle.params(scheme).unwrap(); let precision = tensor_relative.dtype.into(); let precision_scales = params.dtype.into(); let global_shape = tensor_global.shape.clone(); let shape_params = params_shape(&global_shape, scheme.level); plan.handle_inputs .push(HandleInput::QuantValues(QuantValuesHandleInput { relative_id: tensor_relative.id, global_ir: tensor_global, precision, handle, vector_size: 1, })); plan.handle_inputs .push(HandleInput::QuantParams(QuantParamsHandleInput { precision: precision_scales, handle: params, shape: shape_params, })); } RegisterTensor::QuantParams(_) => { // It is registered at the same time as quant data. // The order is important and the index in the vector as well, so that's why we // have QuantParams. } } } } fn analyze( &self, plan: &mut LaunchPlan<'a, R>, pos: usize, tensor_relative: &'a TensorIr, handle: &CubeFusionHandle, ) { if !self .resources .inputs_unhandled .contains(&tensor_relative.id) { let mut is_a_view = false; // For each view we try to see if it's not possible to set it as a reference input. for view in self.resources.views.iter() { for (block_plan, block) in plan.blocks.iter_mut().zip(self.blocks) { is_a_view = is_a_view || Self::analyze_view(pos, tensor_relative, block, block_plan, view); } } if !is_a_view { self.analyze_normal(plan, pos, tensor_relative, handle); } } } /// Analyzes if the given tensor can be used inplace in one of the block. fn analyze_normal( &self, plan: &mut LaunchPlan<'a, R>, pos: usize, tensor_relative: &'a TensorIr, handle: &CubeFusionHandle, ) { enum BlockInplaceSelection { Notinit, /// The block reads the input, and therefore can use it for inplace. Selected(usize), /// The same input is used in multiple blocks. Unavailable, } let mut block_inplace_selection = BlockInplaceSelection::Notinit; for (idx, block) in plan.blocks.iter().enumerate() { if block.reads.contains_key(&tensor_relative.id) { match block_inplace_selection { BlockInplaceSelection::Notinit => { block_inplace_selection = BlockInplaceSelection::Selected(idx); } BlockInplaceSelection::Selected(_) => { block_inplace_selection = BlockInplaceSelection::Unavailable; } BlockInplaceSelection::Unavailable => {} } } } if let BlockInplaceSelection::Selected(idx) = block_inplace_selection { if self.blocks[idx].shape_ref != tensor_relative.shape { return; } let block_plan = &mut plan.blocks[idx]; if tensor_relative.status == TensorStatus::ReadWrite { if self.blocks[idx].settings.inplace && handle.handle.can_mut() { block_plan.potential_inplaces.push(PotentialInplace { input_pos: pos, tensor_relative, strides: handle.strides.clone(), }); } // Inplace tensors are normally really good as the reference layout, since // it's normally better to be based on writes rather than on reads. block_plan.potential_reference_input = Some(InputReference::Normal { input_pos: pos }); } else { block_plan.potential_reference_input = Some(InputReference::Normal { input_pos: pos }); } } } /// Analyzes if the given tensor is also the view provided, and check if it can be used as the reference layout /// for the given block. fn analyze_view( pos: usize, tensor_relative: &'a TensorIr, block: &FuseBlock, block_plan: &mut BlockPlan<'a>, view: &TensorView, ) -> bool { match view { TensorView::Reshape { reshaped, original, reshape_pos, shape_relative, } => { if original == &tensor_relative.id || reshaped == &tensor_relative.id { if block_plan.potential_reference_input.is_none() && shape_relative == &block.shape_ref { block_plan.potential_reference_input = Some(InputReference::Reshaped { reshape_pos: *reshape_pos, }); } return true; } } TensorView::SwapDims { swapped, original, dims, .. } => { if swapped == &tensor_relative.id { return true; } if original == &tensor_relative.id { let shape = tensor_relative .shape .clone() .swapped(dims.0, dims.1) .unwrap(); if block_plan.potential_reference_input.is_none() && shape == block.shape_ref { block_plan.potential_reference_input = Some(InputReference::SwapDims { original_pos: pos, dims: *dims, }); } return true; } } }; false } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/mod.rs ================================================ pub(crate) mod executor; pub(crate) mod input; pub(crate) mod output; pub(crate) mod runner; pub(crate) mod vectorization; pub(crate) mod plan; pub use plan::*; mod base; pub use base::*; ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/output.rs ================================================ use super::{ super::codegen::ir::FuseType, BlockPlan, HandleOutput, InputReference, LaunchPlan, NormalHandleInput, ReferenceSelection, }; use crate::{ CubeFusionHandle, engine::{ codegen::ir::{FuseArg, FuseOp, LayoutInfo}, launch::HandleInput, settings::RefLayoutSetting, trace::{FuseResources, RegisterTensor, RuntimeLayout, TensorView, block::FuseBlock}, }, strides_dyn_rank, }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; use burn_std::Shape; use burn_std::{ Strides, tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action}, }; use cubecl::{Runtime, client::ComputeClient, ir::StorageType}; /// Create or reuse handles for the outputs. /// /// It is also responsible to select the reference tensor. pub struct OutputPlanner<'a, R: Runtime> { resources: &'a FuseResources, outputs_sorted: Vec>, handles: Vec>>, globals: Vec>, blocks: &'a Vec, } #[derive(Debug)] struct OutputSorted<'a> { pos_original: usize, precision: FuseType, tensor_relative: &'a TensorIr, } #[derive(Debug)] enum OutputKind { Normal, Inplace { /// The position in the potential inplace vector input_pos: usize, }, Transform(TensorView), } impl<'a, R: Runtime> OutputPlanner<'a, R> { pub fn new(resources: &'a FuseResources, blocks: &'a Vec) -> Self { let mut outputs_sorted: Vec<_> = resources .outputs .iter() .enumerate() .filter_map(|(pos, entry)| match entry { RegisterTensor::Normal(ir, p) => Some((pos, ir, p)), RegisterTensor::QuantValues(_) => None, RegisterTensor::QuantParams(_) => None, }) .map(|(pos, tensor, precision)| OutputSorted { pos_original: pos, precision: *precision, tensor_relative: tensor, }) .collect(); outputs_sorted.sort_by(|a, b| { let a_val: usize = a.tensor_relative.shape.iter().sum(); let b_val: usize = b.tensor_relative.shape.iter().sum(); b_val.cmp(&a_val) }); let mut handles = Vec::with_capacity(resources.outputs.len()); let mut globals = Vec::with_capacity(resources.outputs.len()); for _ in 0..resources.outputs.len() { handles.push(None); globals.push(None); } Self { resources, outputs_sorted, handles, globals, blocks, } } pub fn run( mut self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, ) { // So that we can borrow self during the iteration. let mut outputs = Vec::new(); core::mem::swap(&mut outputs, &mut self.outputs_sorted); for output in outputs.into_iter() { let tensor_global = context .tensors .get(&output.tensor_relative.id) .unwrap() .clone(); let strides = strides_dyn_rank(&tensor_global.shape); let (kind, block_idx) = self.output_kind(plan, &tensor_global, &output, &strides); match kind { OutputKind::Inplace { input_pos } => { self.inplace_output( context, plan, output, tensor_global, strides, input_pos, block_idx, ); } OutputKind::Normal => { self.normal_output( client, device, context, plan, output, tensor_global, strides, block_idx, ); } OutputKind::Transform(TensorView::Reshape { original, .. }) => { self.reshaped_output( client, device, context, plan, output, tensor_global, strides, original, block_idx, ); } OutputKind::Transform(TensorView::SwapDims { original, dims, .. }) => { self.swapped_dims_output( client, device, context, plan, output, tensor_global, original, dims, block_idx, ); } } } for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) { plan.handle_outputs.push(handle.unwrap()); plan.global_outputs.push(global.unwrap()); } for i in 0..plan.blocks.len() { if !plan.blocks[i].reference.is_found() { match self.blocks[i].settings.ref_layout { RefLayoutSetting::SameAsBlock { block_pos } => { plan.blocks[i].reference = plan.blocks[block_pos as usize].reference.clone(); } _ => { let new_runtime = Self::select_reference_from_inputs( &self.blocks[i], &mut plan.blocks[i], &plan.handle_inputs, ); if let Some(shape) = new_runtime { let pos = plan.runtime_layouts.len(); let mut shape_global = shape.clone(); for (i, s) in shape.iter().enumerate() { shape_global[i] = *context.shapes_relative2global.get(s).unwrap(); } let strides = strides_dyn_rank(&shape_global); plan.blocks[i].reference = ReferenceSelection::Runtime { pos }; plan.runtime_layouts.push(RuntimeLayout { shape: shape_global, strides, }); } } }; } else { Self::add_layout_info_inputs(&mut plan.blocks[i], &plan.handle_inputs); } } // Make sure dropped are correctly executed. for id in self.resources.dropped.iter() { if let Some(tensor_global) = context.tensors.get(id) { context.handles.remove_handle(tensor_global.id); } } for id in plan.cleared.drain(..) { context.handles.remove_handle(id); } } fn select_reference_from_inputs( block: &FuseBlock, block_plan: &mut BlockPlan<'_>, handle_inputs: &[HandleInput], ) -> Option { if let Some(input_ref) = block_plan.potential_reference_input.take() { match input_ref { InputReference::Normal { input_pos } => { let reference = handle_inputs .get(input_pos) .unwrap() .as_normal() .expect("Quant can't be used as inplace"); let set_ref_as_concrete = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::Concrete { layout: FuseArg::Input( input_pos, reference.precision, LayoutInfo::IsRef, ), shape: reference.global_ir.shape.clone(), strides: reference.handle.strides.clone(), }; }; let set_ref_as_virtual = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::VirtualShape { original: FuseArg::Input( input_pos, reference.precision, LayoutInfo::Unknown, ), shape: reference.global_ir.shape.clone(), strides: contiguous_strides(&reference.global_ir.shape), }; }; match block.settings.ref_layout { RefLayoutSetting::Any => set_ref_as_concrete(block_plan), RefLayoutSetting::SameAsBlock { .. } => { // Skip set ref. } RefLayoutSetting::OnlyContiguous => { if is_contiguous(&reference.global_ir.shape, &reference.handle.strides) { set_ref_as_concrete(block_plan) } else { set_ref_as_virtual(block_plan) } } } Self::add_layout_info_inputs(block_plan, handle_inputs); } InputReference::SwapDims { original_pos, dims } => { let reference = handle_inputs .get(original_pos) .unwrap() .as_normal() .expect("Quant can't be used in swap dims operation"); block_plan.reference = ReferenceSelection::SwapDims { original: FuseArg::Input( original_pos, reference.precision, LayoutInfo::Unknown, ), dims, }; } InputReference::Reshaped { reshape_pos } => { block_plan.reference = ReferenceSelection::Reshaped { reshape_pos }; } }; None } else { Some(block.shape_ref.clone()) } } fn add_layout_info_inputs(block: &mut BlockPlan<'_>, handle_inputs: &[HandleInput]) { for hi in handle_inputs.iter().filter_map(|h| match h { HandleInput::Normal(input) => Some(input), _ => None, }) { let (strides, shape) = match &block.reference { ReferenceSelection::Concrete { strides, shape, .. } | ReferenceSelection::VirtualShape { strides, shape, .. } => (strides, shape), _ => continue, }; if strides == &hi.handle.strides && shape == &hi.global_ir.shape && let Some(ops) = block.reads.get_mut(&hi.relative_id) { for op in ops.iter_mut() { if let FuseOp::Assign(op) = op { op.input.add_layout_info(LayoutInfo::SameAsRef); } } } } } fn output_kind( &self, plan: &mut LaunchPlan<'a, R>, tensor_global: &TensorIr, output: &OutputSorted, strides: &[usize], ) -> (OutputKind, usize) { let mut block_idx = None; for (i, block) in plan.blocks.iter().enumerate() { if block.writes.contains_key(&output.tensor_relative.id) { block_idx = Some(i); break; } } let block_idx = block_idx.unwrap(); if let Some(transform) = self.resources.views.iter().find(|v| match v { TensorView::Reshape { reshaped, .. } => reshaped == &output.tensor_relative.id, TensorView::SwapDims { swapped, .. } => swapped == &output.tensor_relative.id, }) { return (OutputKind::Transform(transform.clone()), block_idx); } let block = &plan.blocks[block_idx]; let kind = block .potential_inplaces .iter() .enumerate() .find(|(_pos, pi)| { pi.tensor_relative.dtype == tensor_global.dtype && pi.tensor_relative.shape == output.tensor_relative.shape && &*pi.strides == strides && block.reference.compatible_strides_for_inplace(strides) }) .map(|(pos, _)| OutputKind::Inplace { input_pos: pos }) .unwrap_or(OutputKind::Normal); (kind, block_idx) } #[allow(clippy::too_many_arguments)] fn inplace_output( &mut self, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, strides: Strides, input_index: usize, block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; let potential_inplace = block.potential_inplaces.remove(input_index); let handle_input = match plan.handle_inputs.get(potential_inplace.input_pos).unwrap() { HandleInput::Normal(handle) => handle, _ => { unreachable!("Quant tensor handle can't be used inplace yet.") } }; if !block.reference.is_found() && !matches!( self.blocks[block_idx].settings.ref_layout, RefLayoutSetting::SameAsBlock { .. } ) { let index_input = self .resources .inputs .get_index(potential_inplace.tensor_relative.id) .unwrap(); block.reference = ReferenceSelection::Concrete { layout: FuseArg::Input(index_input, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: handle_input.handle.strides.clone(), }; if let Some(ops) = block.reads.get_mut(&handle_input.relative_id) { for op in ops.iter_mut() { if let FuseOp::Assign(op) = op { op.input.add_layout_info(LayoutInfo::IsRef); break; }; } } if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) { for op in ops { if let FuseOp::Assign(op) = op { op.out.add_layout_info(LayoutInfo::IsRef); break; } } }; } else { // Already validated, necessary for correctness. if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) { for op in ops { if let FuseOp::Assign(op) = op { op.out.add_layout_info(LayoutInfo::SameAsRef); break; } } }; } context .handles .register_handle(tensor_global.id, handle_input.handle.clone()); self.handles[output.pos_original] = Some(HandleOutput::Alias { input_pos: potential_inplace.input_pos, precision: output.precision, global_shape: tensor_global.shape.clone(), strides, #[cfg(feature = "autotune-checks")] debug_info: super::HandleOutputAliasDebugInfo { relative_id: output.tensor_relative.id, handle: handle_input.handle.clone(), global_shape: tensor_global.shape.dims.clone(), }, }); self.globals[output.pos_original] = Some(tensor_global); } #[allow(clippy::too_many_arguments)] fn normal_output( &mut self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, strides: Strides, block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; if !block.reference.is_found() && self.blocks[block_idx].shape_ref == output.tensor_relative.shape && !matches!( self.blocks[block_idx].settings.ref_layout, RefLayoutSetting::SameAsBlock { .. } ) { block.reference = ReferenceSelection::Concrete { layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: strides.clone(), }; // Sometimes outputs that are manually handled don't have any write registered. if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) { for op in ops { if let FuseOp::Assign(op) = op { op.out.add_layout_info(LayoutInfo::IsRef); break; } } }; } else if let ReferenceSelection::Concrete { shape: ref_shape, strides: ref_strides, .. } = &block.reference && ref_strides == &strides && ref_shape == &tensor_global.shape && let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) { for op in ops { if let FuseOp::Assign(op) = op { op.out.add_layout_info(LayoutInfo::SameAsRef); break; } } }; let dtype = tensor_global.dtype; let size = tensor_global.shape.iter().product::() * StorageType::from(dtype).size(); let handle = CubeFusionHandle { client: client.clone(), handle: client.empty(size), device: device.clone(), strides, dtype, qparams: None, }; plan.rank = usize::max(tensor_global.shape.rank(), plan.rank); context .handles .register_handle(tensor_global.id, handle.clone()); self.handles[output.pos_original] = Some(HandleOutput::Owned { precision: output.precision, handle, global_shape: tensor_global.shape.clone(), global_id: tensor_global.id, relative_id: output.tensor_relative.id, vectorization: 1, }); self.globals[output.pos_original] = Some(tensor_global); } #[allow(clippy::too_many_arguments)] fn reshaped_output( &mut self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, strides: Strides, original: TensorId, block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original); let dtype = tensor_global.dtype; let action = reshape_action( &original_handle.global_ir.shape, &original_handle.handle.strides, &tensor_global.shape, ); let update = match action { ReshapeAction::UpdateStrides { strides } => Some(strides), ReshapeAction::NoChange => Some(original_handle.handle.strides.clone()), ReshapeAction::Recompute => None, }; match update { Some(strides) => { // We modify the metadata instead. remove_concrete_write(block, output.tensor_relative.id, output.pos_original); let handle = CubeFusionHandle { client: client.clone(), handle: original_handle.handle.handle.clone(), device: device.clone(), strides, dtype, qparams: original_handle.handle.qparams.clone(), }; context .handles .register_handle(tensor_global.id, handle.clone()); // IT will never be access, just a way to keep the original position working. self.handles[output.pos_original] = Some(HandleOutput::Alias { input_pos: pos_input, precision: output.precision, global_shape: tensor_global.shape.clone(), strides: handle.strides.clone(), #[cfg(feature = "autotune-checks")] debug_info: super::HandleOutputAliasDebugInfo { relative_id: output.tensor_relative.id, handle: handle.clone(), global_shape: tensor_global.shape.dims.clone(), }, }); self.globals[output.pos_original] = Some(tensor_global); } None => { self.normal_output( client, device, context, plan, output, tensor_global, strides, block_idx, ); } } } #[allow(clippy::too_many_arguments)] fn swapped_dims_output( &mut self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, original: TensorId, dims: (usize, usize), block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original); let dtype = tensor_global.dtype; // TODO: Check if we can also remove the read, if we have a dead partial graph. // // We modify the metadata instead. remove_concrete_write(block, output.tensor_relative.id, output.pos_original); let strides = original_handle.handle.strides.clone(); let mut handle = CubeFusionHandle { client: client.clone(), handle: original_handle.handle.handle.clone(), device: device.clone(), strides, dtype, qparams: original_handle.handle.qparams.clone(), }; handle.strides.swap(dims.0, dims.1); context .handles .register_handle(tensor_global.id, handle.clone()); // IT will never be access, just a way to keep the original position working. self.handles[output.pos_original] = Some(HandleOutput::Alias { input_pos: pos_input, precision: output.precision, global_shape: tensor_global.shape.clone(), strides: handle.strides.clone(), #[cfg(feature = "autotune-checks")] debug_info: super::HandleOutputAliasDebugInfo { relative_id: output.tensor_relative.id, handle: handle.clone(), global_shape: tensor_global.shape.dims.clone(), }, }); self.globals[output.pos_original] = Some(tensor_global); } fn find_child_input( handle_inputs: &[HandleInput], original: TensorId, ) -> (usize, &NormalHandleInput) { handle_inputs .iter() .enumerate() .find_map(|(pi, handle)| match handle { HandleInput::Normal(handle) => match handle.relative_id == original { true => Some((pi, handle)), false => None, }, _ => None, // Quant tensor can't be reshaped. }) .unwrap() } } fn remove_concrete_write(block: &mut BlockPlan, id: TensorId, output_pos: usize) { let ops = block.writes.remove(&id); if let Some(ops) = ops { let mut keep = Vec::with_capacity(ops.len()); for op in ops { if let FuseOp::Assign(args) = &op { if let FuseArg::Output(pos, ..) = args.out { if pos != output_pos { keep.push(op); } } else { keep.push(op); } } } block.writes.insert(id, keep); } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/plan.rs ================================================ use crate::{ CubeFusionHandle, engine::{ codegen::ir::{FuseArg, FuseOp, FuseType}, launch::vectorization::Vect, trace::{RuntimeLayout, block::FuseBlock}, }, }; use burn_ir::{TensorId, TensorIr}; use burn_std::{Shape, Strides}; use cubecl::{Runtime, ir::VectorSize}; use std::collections::BTreeMap; /// The `LaunchPlan` is responsible for aggregating all runtime information required /// to dispatch a fused kernel. /// /// It maps abstract IR tensors to memory handles, manages vectorization /// strategies, and tracks layout transformations. #[derive(Debug)] pub struct LaunchPlan<'a, R: Runtime> { /// The IR representation of tensors that are results of the fusion. pub global_outputs: Vec, /// Memory handles and metadata for all input tensors. pub handle_inputs: Vec>, /// Memory handles and metadata for all output tensors, including aliased inputs. pub handle_outputs: Vec>, /// The rank across all tensors in the plan. /// /// Smaller tensors are unsqueezed during launch. pub rank: usize, /// Detailed planning for each individual computation block within the fusion. pub blocks: Vec>, /// Mapping of tensor IDs to their specific vectorization factors. pub vectorizations: BTreeMap, /// Tensors that can be cleared or deallocated after this plan executes. pub cleared: Vec, /// Metadata for shapes and strides passed from the host when they cannot be /// inferred from input tensors (e.g., complex deep fusions). pub runtime_layouts: Vec, } /// Information regarding the execution of a specific block of operations within a fusion. #[derive(Debug)] pub struct BlockPlan<'a> { /// List of inputs that are candidates for in-place memory reuse within this block. pub potential_inplaces: Vec>, /// The input tensor chosen to define the iteration space, if any. pub potential_reference_input: Option, /// How the master layout is determined for this block. pub reference: ReferenceSelection, /// Mapping of tensor IDs to the read operations performed on them. pub reads: BTreeMap>, /// Mapping of tensor IDs to the write operations performed on them. pub writes: BTreeMap>, /// The width for the operations in this block. pub width: VectorSize, } /// Metadata for an input tensor being used as a reference for a block's layout. #[derive(Debug)] pub enum InputReference { /// Standard input at the specified position. Normal { input_pos: usize }, /// Input that has an axis swapped. SwapDims { original_pos: usize, dims: (usize, usize), }, /// Input that has been reshaped. Reshaped { reshape_pos: usize }, } /// Strategies for selecting the reference layout of a fused block. /// /// The reference layout determines how global indices are mapped to tensor coordinates. #[derive(Clone, Debug)] pub enum ReferenceSelection { /// The engine is still calculating the optimal reference. Searching, /// Layout from a normal tensor. Concrete { layout: FuseArg, shape: Shape, strides: Strides, }, /// Layout from a swapped dim tensor. SwapDims { original: FuseArg, dims: (usize, usize), }, /// Layout from a reshaped tensor. Reshaped { reshape_pos: usize }, /// Layout that has the shape of an input, but not its strides. VirtualShape { original: FuseArg, shape: Shape, strides: Strides, }, /// The layout is provided dynamically by the host at runtime. Runtime { pos: usize }, } impl LaunchPlan<'_, R> { /// Creates a new `LaunchPlan` from a slice of fusion blocks. /// /// Initializes blocks with default "Searching" references and calculates /// the initial max rank. pub fn new(fuse_blocks: &[FuseBlock]) -> Self { let mut rank = 0; let mut blocks = Vec::with_capacity(fuse_blocks.len()); for b in fuse_blocks.iter() { rank = usize::max(b.shape_ref.len(), rank); let block = BlockPlan { reference: ReferenceSelection::Searching, reads: b.reads.clone(), writes: b.writes.clone(), width: 0, potential_inplaces: Vec::new(), potential_reference_input: None, }; blocks.push(block); } LaunchPlan { global_outputs: Vec::new(), handle_inputs: Vec::new(), handle_outputs: Vec::new(), rank, blocks, vectorizations: Default::default(), cleared: Default::default(), runtime_layouts: Default::default(), } } } /// Debugging information for aliased handles when `autotune-checks` is enabled. #[cfg(feature = "autotune-checks")] #[derive(Debug)] pub struct HandleOutputAliasDebugInfo { pub handle: CubeFusionHandle, pub relative_id: TensorId, pub global_shape: Shape, } /// Represents the output of a fused kernel execution. #[derive(Debug, Clone)] #[allow(clippy::large_enum_variant)] pub enum HandleOutput { /// An output that reuses the memory of an input tensor (In-place). Alias { /// Index of the input handle being aliased. input_pos: usize, /// Data type precision. precision: FuseType, global_shape: Shape, strides: Strides, #[cfg(feature = "autotune-checks")] debug_info: HandleOutputAliasDebugInfo, }, /// An output that requires a newly allocated memory buffer. Owned { global_id: TensorId, relative_id: TensorId, precision: FuseType, handle: CubeFusionHandle, global_shape: Shape, vectorization: VectorSize, }, } /// A standard input handle with associated layout and vectorization metadata. #[derive(Debug, Clone)] pub struct NormalHandleInput { pub relative_id: TensorId, pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, pub vector_size: VectorSize, pub broadcated: bool, /// Stores the original strides of the handle for restoration during plan rollback. pub orig_strides: Strides, } /// An input handle containing values for a quantized tensor. #[derive(Debug, Clone)] pub struct QuantValuesHandleInput { pub relative_id: TensorId, pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, pub vector_size: VectorSize, } /// An input handle containing parameters (scales/offsets) for quantization. #[derive(Debug, Clone)] pub struct QuantParamsHandleInput { pub precision: FuseType, pub handle: CubeFusionHandle, pub shape: Shape, } /// Different types of inputs that can be passed to a fused kernel. #[derive(Debug, Clone)] pub enum HandleInput { Normal(NormalHandleInput), QuantValues(QuantValuesHandleInput), QuantParams(QuantParamsHandleInput), } impl HandleInput { /// Returns a reference to the inner `NormalHandleInput` if the variant matches. pub fn as_normal(&self) -> Option<&NormalHandleInput> { match self { HandleInput::Normal(normal) => Some(normal), _ => None, } } } impl NormalHandleInput { /// Creates a new `NormalHandleInput` tracking original strides. pub fn new( tensor_global: TensorIr, tensor_relative: &TensorIr, precision: FuseType, mut handle: CubeFusionHandle, mut strides: Strides, ) -> Self { // Swap current handle strides with provided strides to track the original state for rollback. core::mem::swap(&mut handle.strides, &mut strides); Self { precision, handle, relative_id: tensor_relative.id, global_ir: tensor_global, vector_size: 1, broadcated: false, orig_strides: strides, } } /// Restores the handle's original strides and returns the handle. /// /// Used when a plan is invalidated or needs to be rolled back. pub fn handle_rollback(mut self) -> CubeFusionHandle { core::mem::swap(&mut self.handle.strides, &mut self.orig_strides); self.handle } } /// A candidate for in-place optimization. #[derive(Debug)] pub struct PotentialInplace<'a> { /// Position of the input handle in the `handle_inputs` vector. pub input_pos: usize, /// Reference to the IR of the relative tensor. pub tensor_relative: &'a TensorIr, /// Current strides of the potential in-place candidate. pub strides: Strides, } impl ReferenceSelection { pub fn is_found(&self) -> bool { !matches!(self, Self::Searching) } pub fn compatible_strides_for_inplace(&self, strides_inplace: &[usize]) -> bool { match self { ReferenceSelection::Concrete { strides, .. } => &**strides == strides_inplace, _ => false, } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/runner.rs ================================================ use super::super::codegen::ir::{FuseBlockConfig, GlobalArgsLaunch}; use crate::{ CubeFusionHandle, engine::launch::{ LaunchPlan, vectorization::{Vect, vectorization_default}, }, }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; use cubecl::prelude::*; use std::collections::{BTreeMap, HashMap}; /// A trace runner is responsible for determining the vectorization factor as well as launching /// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) /// with provided [fuse block configs](FuseBlockConfig). pub trait TraceRunner: Vectorization { /// The error that might happen while running the trace. type Error; /// Run the trace with the given inputs and outputs. /// /// There is one [fuse config](FuseBlockConfig) for each [block](super::block::FuseBlock) registered /// in the [optimization builder](burn_fusion::OptimizationBuilder). fn run<'a>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, configs: &'a [FuseBlockConfig], ) -> Result<(), Self::Error>; } pub enum VectorizationHandle<'a, R: Runtime> { NormalInput(&'a CubeFusionHandle, &'a TensorIr), QuantValues(&'a CubeFusionHandle, &'a TensorIr), QuantParams, } impl<'a, R: Runtime> VectorizationHandle<'a, R> { /// Returns if the current vectorization handle is from the given tensor id. pub fn is_from_tensor(&self, id: TensorId) -> bool { match self { VectorizationHandle::NormalInput(_, tensor_ir) => tensor_ir.id == id, VectorizationHandle::QuantValues(_, tensor_ir) => tensor_ir.id == id, VectorizationHandle::QuantParams => false, } } } #[derive(Default)] pub struct VectorizationAxis { axis: HashMap, } impl VectorizationAxis { pub fn get usize>(&self, id: TensorId, default: F) -> usize { self.axis.get(&id).copied().unwrap_or_else(default) } pub fn insert(&mut self, id: TensorId, axis: usize) { self.axis.insert(id, axis); } } pub trait Vectorization { /// Returns the vectorization options. fn axis(&self, _plan: &LaunchPlan<'_, R>) -> VectorizationAxis { VectorizationAxis::default() } /// The vectorization factor for all inputs and outputs. #[allow(clippy::too_many_arguments)] fn vectorization<'a>( &self, _context: &Context<'_, CubeFusionHandle>, vectorizations: &mut BTreeMap, inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, swapped: impl Iterator, vector_sizes: &[VectorSize], max: VectorSize, axis: VectorizationAxis, ) { vectorization_default( vectorizations, inputs, outputs, reshaped, swapped, vector_sizes, &Default::default(), max, &axis, ) } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs ================================================ use crate::{ CubeFusionHandle, engine::launch::runner::{VectorizationAxis, VectorizationHandle}, }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; use cubecl::{Runtime, ir::VectorSize}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; #[derive(Debug, Clone, Copy)] pub enum Vect { Broadcasted, Aligned(VectorSize), } impl Vect { pub fn vector_size(&self) -> VectorSize { match self { Vect::Broadcasted => 1, Vect::Aligned(val) => *val, } } pub fn is_broadcast(&self) -> bool { matches!(self, Vect::Broadcasted) } } #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct VectorSizeOverrides { state: Option>>, default: Option>, } #[allow(unused)] impl VectorSizeOverrides { pub fn overrides(&mut self, tensor_id: &TensorId, vector_sizes: Vec) { let map = match &mut self.state { Some(val) => val, None => { self.state = Some(BTreeMap::new()); self.state.as_mut().unwrap() } }; map.insert(*tensor_id, vector_sizes); } pub fn overrides_default(&mut self, vector_sizes: Vec) { self.default = Some(vector_sizes); } pub fn mapping(&self, context: &Context<'_, CubeFusionHandle>) -> Self { match &self.state { Some(state) => { let mut state_new = BTreeMap::new(); for (k, v) in state.iter() { let global = context.tensors.get(k).unwrap(); state_new.insert(global.id, v.clone()); } Self { state: Some(state_new), default: self.default.clone(), } } None => Self { state: None, default: self.default.clone(), }, } } pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec> { let map = match &self.state { Some(val) => val, None => match &self.default { Some(val) => return Some(val), None => return None, }, }; match map.get(tensor_id) { Some(val) => Some(val), None => match &self.default { Some(val) => Some(val), None => None, }, } } } #[allow(clippy::too_many_arguments)] pub(crate) fn vectorization_default<'a, R: Runtime>( vectorizations: &mut BTreeMap, inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, swapped: impl Iterator, vector_sizes: &[VectorSize], overrides: &VectorSizeOverrides, max: VectorSize, axis: &VectorizationAxis, ) { let swapped: Vec<_> = swapped.collect(); for input in inputs { if let Some((s, o, mr, dims)) = swapped .iter() .find(|(_s, o, _mr, _dims)| input.is_from_tensor(o.id)) { let (handle, id) = match input { VectorizationHandle::NormalInput(handle, tensor_ir) => (handle, &tensor_ir.id), VectorizationHandle::QuantValues(..) => panic!("Can't be swapped"), VectorizationHandle::QuantParams => panic!("Can't be swapped"), }; let val = vectorization_swapped( handle, s, o, *mr, dims, max, axis, vector_sizes, overrides.tensor(id), ); multi_reads_vectorization_update(vectorizations, o.id, val); } else { match input { VectorizationHandle::NormalInput(handle, tensor_ir) => { let val = vectorization_input( handle, tensor_ir, axis, vector_sizes, overrides.tensor(&tensor_ir.id), ); vectorizations.insert(tensor_ir.id, val); } VectorizationHandle::QuantValues(handle, tensor_ir) => { let val = vectorization_input( handle, tensor_ir, axis, vector_sizes, overrides.tensor(&tensor_ir.id), ); let num_quants = match tensor_ir.dtype { burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(), _ => panic!(""), }; let val = match val { Vect::Broadcasted => Vect::Aligned(1), Vect::Aligned(val) => Vect::Aligned(val.div_ceil(num_quants)), }; vectorizations.insert(tensor_ir.id, val); } VectorizationHandle::QuantParams => { // Doesn't have vectorization for now. } }; } } for (reshaped, original, multi_reads) in reshaped { let val = vectorization_reshape( reshaped, original, multi_reads, axis, vector_sizes, max, overrides.tensor(&original.id), ); multi_reads_vectorization_update(vectorizations, original.id, val); } for tensor in outputs { let val = vectorization_output( tensor, axis, vector_sizes, max, overrides.tensor(&tensor.id), ); vectorizations.insert(tensor.id, val); } } fn multi_reads_vectorization_update( vectorizations: &mut BTreeMap, original: TensorId, vect: Vect, ) { if let Some(ori_vect) = vectorizations.get(&original).cloned() { match ori_vect { Vect::Broadcasted => { // keep the original as is. } Vect::Aligned(ori) => match vect { Vect::Broadcasted => { vectorizations.insert(original, Vect::Aligned(1)); } Vect::Aligned(new) => { let val = if new != ori { 1 } else { new }; vectorizations.insert(original, Vect::Aligned(val)); } }, }; } else { vectorizations.insert(original, vect); } } // The default version uses the last dimension as vectorization axis and assumes a // perpendicular contiguous vector. fn vectorization_input( handle: &CubeFusionHandle, desc: &TensorIr, axis: &VectorizationAxis, vector_sizes: &[VectorSize], overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || handle.strides.len() - 1); let shape_axis = desc.shape[axis]; if shape_axis == 1 { return Vect::Broadcasted; } // Last dimension strides should be 1, otherwise vecX won't be contiguous. if handle.strides[axis] != 1 { return Vect::Aligned(1); } let inner = |s: VectorSize| { // The last dimension should be a multiple of the vector size or broadcated. if shape_axis.is_multiple_of(s) { return Some(Vect::Aligned(s)); } None }; match overrides { Some(vals) => { for s in vals { if let Some(val) = inner(*s) { return val; } } } None => { for s in vector_sizes { if let Some(val) = inner(*s) { return val; } } } } Vect::Aligned(1) } fn vectorization_output( desc: &TensorIr, axis: &VectorizationAxis, vector_sizes: &[VectorSize], max: VectorSize, overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || desc.shape.rank() - 1); let inner = |s: VectorSize| { // The dimension should be a multiple of the vector size. if desc.shape[axis].is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } None }; match overrides { Some(val) => { for s in val { if let Some(val) = inner(*s) { return val; } } } None => { for s in vector_sizes { if let Some(val) = inner(*s) { return val; } } } } Vect::Aligned(1) } fn vectorization_reshape( reshaped: &TensorIr, original: &TensorIr, multi_reads: bool, axis: &VectorizationAxis, vector_sizes: &[VectorSize], max: VectorSize, overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1); let reshape_shape_axis = reshaped.shape[axis]; if !multi_reads && reshape_shape_axis == 1 { return Vect::Broadcasted; } // If the axis is not the last dim, didn't think of it, return Aligned(1) to be sure. if axis != reshaped.shape.rank() - 1 { return Vect::Aligned(1); } let original_shape_axis = original.shape[original.shape.rank() - 1]; if original_shape_axis != reshape_shape_axis { return Vect::Aligned(1); } let inner = |s: VectorSize| { if !multi_reads { // The last dimension should be a multiple of the vector size or broadcated. if reshape_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) } else { None } } else { // Since the original tensor must share the same vectorization factor as the // reshaped tensor, they must have compatible shapes when both are access // independently. if reshape_shape_axis.is_multiple_of(s) && original_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) } else { None } } }; match overrides { Some(val) => { for i in val { if let Some(vect) = inner(*i) { return vect; } } } None => { for s in vector_sizes { if let Some(vect) = inner(*s) { return vect; } } } } Vect::Aligned(1) } #[allow(clippy::too_many_arguments)] fn vectorization_swapped( handle: &CubeFusionHandle, swapped: &TensorIr, original: &TensorIr, multi_reads: bool, dims: &(usize, usize), max: VectorSize, axis: &VectorizationAxis, vector_sizes: &[VectorSize], overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(swapped.id, || swapped.shape.rank() - 1); let swapped_axis = swapped.shape[axis]; let shape_axis = original.shape[axis]; let axis_index = axis; let dim_index = if dims.0 == axis_index { dims.1 } else if dims.1 == axis_index { dims.0 } else { axis_index }; // Last dimension strides should be 1, otherwise vecX won't be contiguous. if multi_reads { if handle.strides[axis_index] != 1 { return Vect::Aligned(1); } if handle.strides[dim_index] != 1 { return Vect::Aligned(1); } } else if handle.strides[dim_index] != 1 { return Vect::Aligned(1); } if !multi_reads && swapped_axis == 1 { return Vect::Broadcasted; } let inner = |s: VectorSize| { // The last dimension should be a multiple of the vector size or broadcated. if multi_reads { if swapped_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } } else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } None }; match overrides { Some(val) => { for s in val { if let Some(val) = inner(*s) { return val; } } } None => { for s in vector_sizes { if let Some(val) = inner(*s) { return val; } } } } Vect::Aligned(1) } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/vectorization/mod.rs ================================================ mod base; mod planner; pub use base::*; pub use planner::*; ================================================ FILE: crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs ================================================ use super::{ super::{BlockPlan, HandleOutput, LaunchPlan}, Vect, }; use crate::{ CubeFusionHandle, engine::{ launch::{ HandleInput, runner::{Vectorization, VectorizationHandle}, }, settings::VectorizationSetting, trace::{FuseResources, TensorView, block::FuseBlock}, }, }; use burn_fusion::stream::Context; use burn_ir::TensorId; use cubecl::{ Runtime, client::ComputeClient, ir::{ElemType, StorageType, UIntKind}, }; use cubecl::{ ir::VectorSize, quant::scheme::{QuantScheme, QuantStore, QuantValue}, }; use std::marker::PhantomData; /// Select the best vectorization factor for each tensor handle. pub struct VectorizationPlanner<'a, R: Runtime> { resources: &'a FuseResources, blocks: &'a Vec, _r: PhantomData, } impl<'a, R: Runtime> VectorizationPlanner<'a, R> { pub fn new(resources: &'a FuseResources, blocks: &'a Vec) -> Self { Self { resources, blocks, _r: PhantomData, } } pub fn run>( self, client: &ComputeClient, runner: &Runner, context: &Context<'_, CubeFusionHandle>, plan: &mut LaunchPlan<'a, R>, ) { let has_multiple_read = |tensor: &TensorId| { let mut read_count = 0; for block in plan.blocks.iter() { read_count += block.reads.get(tensor).map(|a| a.len()).unwrap_or(0); } read_count > 1 }; let tensors_reshaped = self.resources.views.iter().filter_map(|view| match view { TensorView::Reshape { reshaped, original, .. } => Some(( context.tensors.get(reshaped).unwrap(), context.tensors.get(original).unwrap(), has_multiple_read(original), )), TensorView::SwapDims { .. } => None, }); let tensors_swapped = self.resources.views.iter().filter_map(|view| match view { TensorView::SwapDims { swapped, original, dims, .. } => Some(( context.tensors.get(swapped).unwrap(), context.tensors.get(original).unwrap(), has_multiple_read(original), dims, )), TensorView::Reshape { .. } => None, }); let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8); let mut quants_vector_sizes: Option> = None; for input in plan.handle_inputs.iter() { let elem: StorageType = match input { HandleInput::Normal(h) => h.global_ir.dtype.into(), HandleInput::QuantValues(handle) => match handle.global_ir.dtype { burn_std::DType::QFloat(scheme) => { vector_sizes_quants(client, &mut quants_vector_sizes, scheme); continue; } _ => panic!("Unable to retrieve the scheme for quantized values."), }, HandleInput::QuantParams(..) => continue, }; let elem_size = elem.size(); if ref_elem.1 >= elem_size { ref_elem = (elem, elem_size); } } for r in plan.global_outputs.iter() { let elem: StorageType = r.dtype.into(); let elem_size = elem.size(); if ref_elem.1 >= elem_size { ref_elem = (elem, elem_size); } } let filtered = plan .handle_inputs .iter() .map(|item| { item.as_normal() // Filter out indexed resources. .map(|item| !self.resources.indexed.contains_key(&item.relative_id)) .unwrap_or(true) }) .collect::>(); let vector_sizes = match quants_vector_sizes { // Quantization normally triggers higher vectorization than anything else, no need to // compare to ref elem. Some(vector_sizes) => vector_sizes, None => client .io_optimized_vector_sizes(ref_elem.0.size()) .collect::>(), }; let vectorization_axis = runner.axis(plan); runner.vectorization( context, &mut plan.vectorizations, plan.handle_inputs .iter() .enumerate() .filter_map(|(i, item)| { if filtered[i] { Some(match item { HandleInput::Normal(h) => { VectorizationHandle::NormalInput(&h.handle, &h.global_ir) } HandleInput::QuantValues(h) => { VectorizationHandle::QuantValues(&h.handle, &h.global_ir) } HandleInput::QuantParams(_) => VectorizationHandle::QuantParams, }) } else { None } }), plan.global_outputs.iter(), tensors_reshaped, tensors_swapped, &vector_sizes, u8::MAX as usize, vectorization_axis, ); for tensor in self.resources.indexed.keys() { let global = context.tensors.get(tensor).unwrap(); plan.vectorizations.insert(global.id, Vect::Aligned(1)); } let mut block_vectorization = Vec::with_capacity(self.blocks.len()); for _ in 0..self.blocks.len() { block_vectorization.push(Vec::new()); } for (input_pos, handle) in plan.handle_inputs.iter_mut().enumerate() { let (global_ir, relative_id) = match handle { HandleInput::Normal(h) => (&h.global_ir, &h.relative_id), HandleInput::QuantValues(h) => (&h.global_ir, &h.relative_id), HandleInput::QuantParams(_) => continue, }; let (vect, br) = match plan.vectorizations.get(&global_ir.id) { Some(v) => (v.vector_size(), v.is_broadcast()), None => panic!("No vectorization factor found for {:?}", global_ir.id), }; for (block_pos, block_plan) in plan.blocks.iter().enumerate() { if block_plan.reads.contains_key(relative_id) { block_vectorization[block_pos].push(BlockVectorization { action: VectorizationAction::Input(input_pos), potential: vect, broadcasted: br, }); } } } for (output_pos, handle) in plan.handle_outputs.iter().enumerate() { if let HandleOutput::Owned { global_id, relative_id, .. } = handle { for (block_pos, block_plan) in plan.blocks.iter().enumerate() { if block_plan.writes.contains_key(relative_id) { let vectorization = plan.vectorizations.get(global_id).unwrap().vector_size(); block_vectorization[block_pos].push(BlockVectorization { action: VectorizationAction::Output(output_pos), potential: vectorization, broadcasted: false, }); } } } } let mut previous_widths = Vec::with_capacity(block_vectorization.len()); // Unhandled inputs might not get included in any fused blocks for now. // // So we ensure they are vectorized by setting their vectorization before we set the // vectorizations in blocks. // // Unhandled Outputs are correctly vectorized, so this is only necessary for inputs. for input in self.resources.inputs_unhandled.iter() { let pos = self .resources .inputs .get_index(*input) .unwrap_or_else(|| self.resources.inputs.get_index_quant(*input).unwrap()); let input_global = context.tensors.get(input).unwrap(); match plan.vectorizations.get(&input_global.id).unwrap() { Vect::Aligned(vect) => { let handle = &mut plan.handle_inputs[pos]; match handle { HandleInput::Normal(handle) => { handle.vector_size = *vect; } HandleInput::QuantValues(handle) => { handle.vector_size = *vect; } HandleInput::QuantParams(_) => {} } } Vect::Broadcasted => {} } } for ((tmp, block_plan), block) in block_vectorization .into_iter() .zip(plan.blocks.iter_mut()) .zip(self.blocks) { match block.settings.vectorization { VectorizationSetting::Activated => { apply_vectorization_block( tmp, &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, u8::MAX as usize, ); } VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos } => { apply_vectorization_block( tmp, &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, previous_widths[block_pos], ); if block_plan.width == 0 { block_plan.width = previous_widths[block_pos]; } } VectorizationSetting::EqualThanPreviousBlock { block_pos } => { apply_vectorization_block( tmp, &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, previous_widths[block_pos], ); // Enforces the width. block_plan.width = previous_widths[block_pos]; } VectorizationSetting::Deactivated => { apply_vectorization_block( tmp, &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, 1, ); block_plan.width = 1; } } // When only virtual inputs/outputs are present for a block, we need to set a width. if block_plan.width == 0 { if let Some(w) = previous_widths.last() { block_plan.width = *w; } else { block_plan.width = 1; } } previous_widths.push(block_plan.width); } } } #[derive(Debug)] enum VectorizationAction { Input(usize), Output(usize), } #[derive(Debug)] struct BlockVectorization { action: VectorizationAction, potential: VectorSize, broadcasted: bool, } fn apply_vectorization_block( block_vectorization: Vec, inputs: &mut [HandleInput], outputs: &mut [HandleOutput], block_plan: &mut BlockPlan, max: VectorSize, ) { for item in block_vectorization { match item.action { VectorizationAction::Input(pos) => { let (vect, br) = if item.potential <= max { (item.potential, item.broadcasted) } else { (1, false) }; match &mut inputs[pos] { HandleInput::Normal(input) => { input.vector_size = vect; input.broadcated = br; } HandleInput::QuantValues(input) => { input.vector_size = vect; } HandleInput::QuantParams(_) => { // Not vectorized } } if block_plan.width < vect { block_plan.width = vect; } } VectorizationAction::Output(pos) => { if let HandleOutput::Owned { vectorization, .. } = &mut outputs[pos] { let vect = if item.potential <= max { item.potential } else { 1 }; *vectorization = vect; if block_plan.width < vect { block_plan.width = vect; } } } } } } fn vector_sizes_quants( client: &ComputeClient, quants_vector_sizes: &mut Option>, scheme: QuantScheme, ) { match scheme.store { QuantStore::Native => match scheme.value { // Type sizes are the same so just treat fp8/fp4x2 as i8 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => { let vector_sizes = client .io_optimized_vector_sizes(size_of::()) .collect::>(); match &quants_vector_sizes { Some(sizes) => { if sizes[0] < vector_sizes[0] { *quants_vector_sizes = Some(vector_sizes); } } None => { *quants_vector_sizes = Some(vector_sizes); } } } QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { unreachable!("Can't store native sub-byte values") } }, QuantStore::PackedU32(_) => { let mut vector_sizes = client .io_optimized_vector_sizes(size_of::()) .collect::>(); for val in vector_sizes.iter_mut() { *val *= scheme.num_quants(); } match &quants_vector_sizes { Some(sizes) => { if sizes[0] < vector_sizes[0] { let mut min = *vector_sizes.last().unwrap(); while min > 1 { min /= 2; vector_sizes.push(min); } *quants_vector_sizes = Some(vector_sizes); } } None => { *quants_vector_sizes = Some(vector_sizes); } } } QuantStore::PackedNative(_) => { panic!("Not yet supported") } }; } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/mod.rs ================================================ pub(crate) mod codegen; pub(crate) mod fuser; pub(crate) mod launch; pub(crate) mod scoring; pub(crate) mod settings; pub mod trace; ================================================ FILE: crates/burn-cubecl-fusion/src/engine/scoring.rs ================================================ use crate::engine::{ codegen::ir::{FuseArg, FuseOp, UnaryFuseArgs}, trace::FuseTrace, }; use burn_ir::OperationIr; #[derive(Debug, Clone, Default)] /// Tracks and evaluates the efficiency of operation fusion. pub struct Scoring { num_writes: usize, num_reads: usize, num_ops: usize, } impl Scoring { /// Resets the internal O counters. pub fn reset(&mut self) { self.num_writes = 0; self.num_reads = 0; self.num_ops = 0; } /// Registers an unfused operation to the score, counting its total potential I/O. pub fn register(&mut self, op: &OperationIr) { self.num_writes += op.outputs().count(); self.num_reads += op.inputs().count(); self.num_ops += 1; } /// Evaluates the efficiency of a fused trace by comparing its actual I/O /// against the registered unfused I/O. Returns the number of saved I/O operations. pub fn evaluate(&self, trace: &FuseTrace) -> u64 { let mut num_reads_fused = 0; let mut num_writes_fused = 0; let mut num_penalty = 0; for b in trace.blocks.iter() { // Count reads in block for (_, ops) in b.reads.iter() { let result = self.count_fused_io(ops, |args| &args.input); num_reads_fused += result.0; num_penalty += result.1; } // Count writes in block for (_, ops) in b.writes.iter() { let result = self.count_fused_io(ops, |args| &args.out); num_writes_fused += result.0; num_penalty += result.1; } } self.calculate_score(num_reads_fused, num_writes_fused, num_penalty) } fn calculate_score(&self, reads_fused: usize, writes_fused: usize, num_penalty: usize) -> u64 { // Those could be tweaked eventually. const FACTOR_IO: u64 = 100; const FACTOR_LAUNCH: u64 = 10; const FACTOR_PENALTY: u64 = 50; let num_fused = reads_fused + writes_fused; let num_unfused = self.num_reads + self.num_writes; let score_io = match num_fused >= num_unfused { true => 0, false => (num_unfused - num_fused) as u64 * FACTOR_IO, }; // We minus 1 since at least one kernel launch is necessary. let score_launch = self.num_ops.saturating_sub(1) as u64 * FACTOR_LAUNCH; let score_penalty = num_penalty as u64 * FACTOR_PENALTY; (score_io + score_launch).saturating_sub(score_penalty) } fn count_fused_io(&self, ops: &[FuseOp], arg_extractor: F) -> (usize, usize) where F: Fn(&UnaryFuseArgs) -> &FuseArg, { let mut num_io = 0; let mut penalty = 0; for op in ops.iter() { let FuseOp::Assign(args) = op else { unreachable!() }; let count_normal = matches!( arg_extractor(args), FuseArg::Input(..) | FuseArg::Output(..) ) as usize; let count_view = matches!( arg_extractor(args), FuseArg::InputReshaped { .. } | FuseArg::InputSwapDims { .. } ) as usize; num_io += count_normal + count_view; penalty += count_view; } (num_io, penalty) } } #[cfg(test)] #[allow(clippy::field_reassign_with_default)] mod tests { use super::*; #[test] fn test_scoring_io_savings() { let mut scoring = Scoring::default(); scoring.num_reads = 2; scoring.num_writes = 2; scoring.num_ops = 2; let score = scoring.calculate_score(1, 1, 0); assert_eq!(score, 210); } #[test] fn test_scoring_with_penalties() { let mut scoring = Scoring::default(); scoring.num_reads = 2; scoring.num_writes = 2; scoring.num_ops = 2; let score = scoring.calculate_score(1, 1, 1); assert_eq!(score, 160); } #[test] fn test_penalty_outweighs_benefit() { let mut scoring = Scoring::default(); scoring.num_reads = 1; scoring.num_writes = 1; scoring.num_ops = 2; let score = scoring.calculate_score(1, 1, 1); assert_eq!(score, 0); } #[test] fn test_scoring_no_ops() { let scoring = Scoring::default(); let score = scoring.calculate_score(0, 0, 0); assert_eq!(score, 0); } #[test] fn test_reset() { let mut scoring = Scoring { num_writes: 10, num_reads: 10, num_ops: 10, }; scoring.reset(); assert_eq!(scoring.num_writes, 0); assert_eq!(scoring.num_reads, 0); assert_eq!(scoring.num_ops, 0); } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/settings.rs ================================================ use serde::{Deserialize, Serialize}; /// Controls which operations can be fused. #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct FuseSettings { /// Enables broadcasting of shapes. pub broadcast: bool, /// Enables output shape updates. /// /// When broadcast is enabled, the output shape can become bigger after a fusion, /// therefore an update is needed. pub output_shape_updates: bool, /// Enables the reuse of input buffers. pub inplace: bool, /// Whether vectorization is enabled. pub vectorization: VectorizationSetting, /// How [reference layout](super::ir::RefLayout) selection is done. pub ref_layout: RefLayoutSetting, } impl Default for FuseSettings { fn default() -> Self { Self { broadcast: true, output_shape_updates: true, inplace: true, vectorization: VectorizationSetting::Activated, ref_layout: RefLayoutSetting::Any, } } } #[derive(Clone, Copy, Debug, Serialize, Deserialize)] /// How vectorization is handled during fusion. pub enum VectorizationSetting { /// The biggest vector_size possible will be used. Activated, /// Equivalent to using vector_size of one. Deactivated, /// This is a good setting when a block processes values calculated from a previous block. SmallerOrEqualThanPreviousBlock { block_pos: usize }, /// This is a good setting when a block processes values calculated from a previous block. EqualThanPreviousBlock { block_pos: usize }, } #[derive(Clone, Copy, Debug, Serialize, Deserialize)] /// Influence how the [reference layout](super::ir::RefLayout) selection is done. pub enum RefLayoutSetting { /// Any reference layout is allowed. Any, /// Only contiguous reference layout is allowed. /// /// Note that forcing a contiguous reference layout might reduce the opportunity of inplace /// fusion. OnlyContiguous, SameAsBlock { block_pos: u32, }, } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/trace/base.rs ================================================ use crate::engine::{ codegen::ir::{FuseArg, FuseType}, trace::block::FuseBlock, }; use burn_ir::{TensorId, TensorIr}; use burn_std::{Shape, Strides}; use cubecl::prelude::*; use serde::{Deserialize, Serialize}; use std::{ collections::{BTreeMap, HashSet}, marker::PhantomData, }; #[cfg(feature = "autotune-checks")] use crate::CubeFusionHandle; #[cfg(feature = "autotune-checks")] use burn_backend::TensorData; #[cfg(feature = "autotune-checks")] use std::collections::HashMap; #[derive(Clone, Serialize, Deserialize, Debug)] /// A trace contains all [blocks](FuseBlock) and the [resources](FuseResources) used by the /// kernel. pub struct FuseTrace { pub blocks: Vec, pub resources: FuseResources, } impl core::fmt::Display for FuseTrace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "FuseTrace")?; for b in self.blocks.iter() { writeln!(f, " - Block shape={:?}", b.shape_ref)?; for (tensor, ops) in b.reads.iter() { for op in ops.iter() { writeln!(f, " - {op} <== {tensor}")?; } } for op in b.ops.iter() { writeln!(f, " - {op}")?; } for (tensor, ops) in b.writes.iter() { for op in ops.iter() { writeln!(f, " - {op} <== {tensor}")?; } } } Ok(()) } } pub enum TuneOutput { UnChecked(PhantomData), #[cfg(feature = "autotune-checks")] Checked { handles: HashMap, CubeFusionHandle)>, }, } impl TuneOutput { #[allow(unused_variables)] pub fn merge(self, other: Self) -> Self { let mut result = self; match &mut result { TuneOutput::UnChecked(..) => {} #[cfg(feature = "autotune-checks")] TuneOutput::Checked { handles } => match other { TuneOutput::UnChecked(..) => {} TuneOutput::Checked { handles: o } => { for (k, v) in o.into_iter() { handles.insert(k, v); } } }, } result } } impl cubecl::tune::AutotuneOutput for TuneOutput { #[cfg(feature = "autotune-checks")] fn check_equivalence(&self, other: Self) { use burn_backend::Tolerance; use burn_std::DType; if let ( TuneOutput::Checked { handles: handles_ref, }, TuneOutput::Checked { handles }, ) = (self, &other) { let mut num_checked = 0; let mut num_handles = 0; for (id, (shape, handle)) in handles_ref.iter() { num_handles += 1; if let Some((shape_other, other)) = handles.get(id) { use burn_std::is_contiguous; use cubecl::std::tensor::into_contiguous_ref; let current_handle = if !is_contiguous(&shape, &handle.strides) { into_contiguous_ref::( &handle.client, &handle.as_handle_ref(&shape), handle.dtype.into(), ) .unwrap() .handle } else { handle.handle.clone() }; let other_handle = if !is_contiguous(&shape, &other.strides) { into_contiguous_ref::( &other.client, &other.as_handle_ref(&shape), other.dtype.into(), ) .unwrap() .handle } else { other.handle.clone() }; let data_ref = handle.client.read_one(current_handle); let data_other = other.client.read_one(other_handle); let data_ref = TensorData::from_bytes(data_ref, shape.clone(), handle.dtype); let data_other = TensorData::from_bytes(data_other, shape_other.clone(), handle.dtype); match handle.dtype { DType::F64 => { data_ref.assert_approx_eq::(&data_other, Tolerance::permissive()) } DType::F32 => { data_ref.assert_approx_eq::(&data_other, Tolerance::permissive()) } DType::F16 => data_ref .assert_approx_eq::(&data_other, Tolerance::permissive()), DType::BF16 => data_ref .assert_approx_eq::(&data_other, Tolerance::permissive()), _ => data_ref.assert_eq(&data_other, true), } num_checked += 1; } else { // Debug info for the tests. println!("No tensor found for {id:?}=>{shape:?}"); } } // At least one check is needed per output when there is an output. // // Some optimizations might write more outputs than needed, so it might be fined if // the number of handles is different, but at least one is required. // // An optimization might not create outputs if its dead code detection is triggered, // therefore avoiding useless computation. if num_handles > 0 { assert!(num_checked >= 1); } } } } #[derive(Clone, Serialize, Deserialize, Debug, Default)] /// Declare all resources used by the kernel, and potentially multiple [blocks](FuseBlock). /// /// # Notes /// /// Each block can't contain their own resources, since they are shared between blocks. The /// vectorization factor of one input tensor must be the same for all blocks. pub struct FuseResources { pub outputs: RegisteredTensors, pub inputs: RegisteredTensors, pub scalars: Vec<(FuseType, u64)>, // TODO: Making put a map of global registers. pub views: Vec, pub indexed: BTreeMap, pub inputs_unhandled: Vec, pub outputs_unhandled: Vec, pub num_reshaped: usize, /// Necessary to remove some entries from the context. pub dropped: HashSet, /// We know during fusion that we have to have those buffers has global. /// The pos here can be interpreted as GLOBAL pos where the output pos are locals. pub buffers: RegisteredTensors, /// Global registers available everywhere. /// /// TODO: Not all registers should be globals. pub registers: BTreeMap, } #[derive(Clone, Serialize, Deserialize, Debug)] pub struct RuntimeLayout { pub shape: Shape, pub strides: Strides, } impl Default for RuntimeLayout { fn default() -> Self { Self { shape: Shape::new([]), strides: Strides::new(&[]), } } } #[derive(Debug)] pub enum TraceError { ReferenceNotFound, RunnerError(Err), } #[derive(Clone, Serialize, Deserialize, Debug)] pub enum TensorView { Reshape { reshaped: TensorId, original: TensorId, reshape_pos: usize, shape_relative: Shape, }, SwapDims { swapped: TensorId, original: TensorId, dims: (usize, usize), }, } #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct RegisteredTensors { tensors: Vec, } #[derive(Clone, Serialize, Deserialize, Debug)] pub enum RegisterTensor { Normal(TensorIr, FuseType), QuantValues(TensorIr), QuantParams(TensorId), } impl RegisterTensor { pub fn as_normal_tensor(&self) -> Option<(&TensorIr, &FuseType)> { match self { RegisterTensor::Normal(tensor_ir, precision) => Some((tensor_ir, precision)), RegisterTensor::QuantValues(_) => None, RegisterTensor::QuantParams(_) => None, } } } impl RegisteredTensors { /// Iterate over all the registered tensors. pub fn iter(&self) -> impl Iterator { self.tensors.iter() } /// Consumes and iterate over all the registered tensors. pub fn into_iter(self) -> impl Iterator { self.tensors.into_iter() } /// Returns the number of tensors registered. pub fn len(&self) -> usize { self.tensors.len() } /// Retrieve the [tensor id](TensorId) at the given index. pub fn get_id(&self, index: usize) -> Option { self.tensors.get(index).map(|entry| match entry { RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id, RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id, RegisterTensor::QuantParams(tensor_id) => *tensor_id, }) } /// Doesn't return quantized tensor. pub fn get_index(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() .find(|(_pos, entry)| match entry { RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id, RegisterTensor::QuantValues(_) => false, RegisterTensor::QuantParams(_) => false, }) .map(|(pos, _)| pos) } /// Get the index of a quantized tensor. pub fn get_index_quant(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() .find(|(_pos, entry)| match entry { RegisterTensor::Normal(..) => false, RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id, RegisterTensor::QuantParams(_) => false, }) .map(|(pos, _)| pos) } /// Doesn't return quantized tensor. pub fn get(&self, tensor_id: TensorId) -> Option<(&TensorIr, &FuseType)> { self.tensors .iter() .find(|entry| match entry { RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id, RegisterTensor::QuantValues(_) => false, RegisterTensor::QuantParams(_) => false, }) .and_then(|entry| match entry { RegisterTensor::Normal(tensor_ir, fuse_precision) => { Some((tensor_ir, fuse_precision)) } RegisterTensor::QuantValues(_) => None, RegisterTensor::QuantParams(_) => None, }) } /// Insert a quantized tensor. /// /// It will return the positions for both the value tensor and param tensor. pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor, _ => false, }) { let values = old.0; let params = values + 1; return (values, params); } let params = RegisterTensor::QuantParams(tensor.id); let values = RegisterTensor::QuantValues(tensor); let pos_values = self.len(); self.tensors.push(values); let pos_params = self.len(); self.tensors.push(params); (pos_values, pos_params) } /// Insert a normal tensor with the given [precision](FusePrecision) in the current block. pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id, _ => false, }) { return old.0; } let value = RegisterTensor::Normal(tensor, precision); let pos = self.len(); self.tensors.push(value); pos } /// Update the already registered tensor with the given [tensor ir](TensorIr). /// /// # Notes /// /// This function only works with normal tensors, not quantized tensors. pub fn update(&mut self, tensor: &TensorIr) { if let Some(entry) = self.tensors.iter_mut().find(|entry| match entry { RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id, _ => false, }) && let RegisterTensor::Normal(tensor_ir, _) = entry { tensor_ir.status = tensor.status } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/trace/block.rs ================================================ use super::{FuseResources, RegisteredTensors, TensorView}; use crate::engine::{ codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo, MultiBlockPos, UnaryFuseArgs}, settings::FuseSettings, }; use burn_ir::{TensorId, TensorIr, TensorStatus}; use burn_std::{DType, Shape, quantization::QuantParam}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, btree_map::Entry}; #[derive(Clone, Serialize, Deserialize, Debug)] /// A block containing all [operations](FuseOp) as well as reads and writes for each tensor along /// with the [fusion settings](FuseSettings). pub struct FuseBlock { /// Contains the [fusion settings](FuseSettings) associated to the current block. pub settings: FuseSettings, /// Contains all the [operations](FuseOp) registered in the current block. pub ops: Vec, /// The reference shape of the current block. pub shape_ref: Shape, /// Contains all tensor inputs of the current block except for manually handled tensors. /// /// # Notes /// /// Some reads might not have read operations registered, such as dequantization, but it's /// important to be registered here for vectorization. Input tensors that are not /// registered here must be vectorized manually. pub reads: BTreeMap>, /// Contains all tensor outputs of the current block except for manually handled tensors. /// We can have multiple writes when the same variable is reused after in another block. pub writes: BTreeMap>, } #[derive(Clone, Debug)] /// It is responsible to build a [trace](FuseBlock). pub struct FuseBlockBuilder { pub settings: FuseSettings, locals: LocalVariablePool, pub ops: Vec, reads: BTreeMap>, // Only for global registers. writes: BTreeMap>, // Output declared in this block alone. outputs: RegisteredTensors, pub outputs_unhandled: Vec, pub local_inputs: BTreeMap, /// The reference shape used by this block. pub shape_ref: Shape, } #[derive(Debug)] /// How a quantized input can be read. pub enum QuantInput { /// If already dequantized, we cache the dequantization and returns the local variable /// corresponding to the float value. AlreadyDequantized { local: FuseArg }, /// Otherwise we return the information necessary to dequantize the tensor. Quantized { values: FuseArg, params: FuseArg }, } impl FuseBlockBuilder { pub fn new(settings: FuseSettings) -> Self { Self { settings, locals: Default::default(), ops: Default::default(), reads: Default::default(), writes: Default::default(), outputs: Default::default(), outputs_unhandled: Default::default(), local_inputs: Default::default(), shape_ref: Shape::new([]), } } /// Register an output tensor. pub fn output(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option { if resources.indexed.contains_key(&tensor.id) { return None; } if matches!(tensor.dtype, DType::QFloat(..)) { return None; } let precision = tensor.dtype.into(); let out = match self.locals.get(precision, tensor.id) { Some(local) => local, None => { let out = self.locals.create(precision, tensor.id); self.outputs.insert(precision, tensor.clone()); resources.outputs.insert(precision, tensor.clone()); out } }; Some(out) } /// Register an input tensor. pub fn multi_block_variable( &mut self, block_pos: usize, tensor: &TensorIr, global: bool, ) -> Option { let precision = tensor.dtype.into(); if let Some(val) = self.local_inputs.get(&tensor.id) { return Some(val.clone()); } let val = match self.locals.get(precision, tensor.id) { Some(val) => val, None => { return None; } }; let arg = if global { FuseArg::MultiBlockGlobal( MultiBlockPos { block_pos, block_local_pos: self.writes.len(), }, val.precision(), ) } else { FuseArg::MultiBlockLocal( MultiBlockPos { block_pos, block_local_pos: self.writes.len(), }, val.precision(), ) }; let ops = match self.writes.get_mut(&tensor.id) { Some(ops) => ops, None => { self.writes.insert(tensor.id, Vec::new()); self.writes.get_mut(&tensor.id).unwrap() } }; ops.push(FuseOp::Assign(UnaryFuseArgs { input: val, out: arg.clone(), })); Some(arg) } /// Register an input tensor. pub fn input(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option { if resources.indexed.contains_key(&tensor.id) { return None; } if matches!(tensor.dtype, DType::QFloat(..)) { return None; } let precision = tensor.dtype.into(); if let Some(val) = self.local_inputs.get(&tensor.id) { return Some(val.clone()); } let arg = match self.locals.get(precision, tensor.id) { Some(local) => { resources.inputs.update(tensor); local } None => { let input = if resources.outputs.get_index(tensor.id).is_some() { if let Some(val) = resources.registers.get(&tensor.id) { return Some(val.clone()); }; let pos = resources.buffers.insert(precision, tensor.clone()); FuseArg::Output(pos, precision, LayoutInfo::Unknown) } else { let pos = resources.inputs.insert(precision, tensor.clone()); FuseArg::Input(pos, precision, LayoutInfo::Unknown) }; let out = self.locals.create(precision, tensor.id); let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) { e.insert(Vec::with_capacity(1)); self.reads.get_mut(&tensor.id).unwrap() } else { self.reads.get_mut(&tensor.id).unwrap() }; reads.push(FuseOp::Assign(UnaryFuseArgs { input, out: out.clone(), })); out } }; Some(arg) } /// Register an input quantized tensor. pub fn input_quant( &mut self, tensor: &TensorIr, resources: &mut FuseResources, ) -> Option { if resources.indexed.contains_key(&tensor.id) { return None; } let precision = tensor.dtype.into(); let precision_scales = match tensor.dtype { DType::QFloat(scheme) => match scheme.param { QuantParam::F32 => FuseType::F32, QuantParam::F16 => FuseType::F16, QuantParam::BF16 => FuseType::BF16, QuantParam::UE8M0 | QuantParam::UE4M3 => { unimplemented!("Unsupported fuse precision"); } }, _ => return None, }; let arg = match self.locals.get(precision, tensor.id) { Some(local) => { resources.inputs.update(tensor); QuantInput::AlreadyDequantized { local } } None => { let (new_input, q_index) = resources.inputs.insert_quant(tensor.clone()); let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown); let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown); // Important to flag that there is a read, even if no operation is registered. if let Entry::Vacant(e) = self.reads.entry(tensor.id) { e.insert(Vec::new()); }; QuantInput::Quantized { values: input, params: scales, } } }; Some(arg) } /// Register an input with swapped dims. pub fn input_swap_dims( &mut self, tensor: &TensorIr, output: &TensorIr, dims: (usize, usize), resources: &mut FuseResources, ) -> Option { if matches!(tensor.dtype, DType::QFloat(..)) { return None; } let precision = tensor.dtype.into(); let input_index = match self.locals.get(precision, tensor.id) { Some(_) => { // Can't fused an already fused input. if resources.outputs.get(tensor.id).is_some() { return None; } match resources.inputs.get_index(tensor.id) { Some(index) => { resources.inputs.update(tensor); index } None => { return None; } } } None => resources.inputs.insert(precision, tensor.clone()), }; let out = self.output(output, resources)?; let original = FuseArg::Input(input_index, precision, LayoutInfo::Unknown); let broadcasted = output.shape[output.shape.rank() - 1] == 0; resources.views.push(TensorView::SwapDims { swapped: output.id, original: tensor.id, dims, }); let input = FuseArg::InputSwapDims { original: Box::new(original), dims, broadcasted, }; let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) { e.insert(Vec::with_capacity(1)); self.reads.get_mut(&tensor.id).unwrap() } else { self.reads.get_mut(&tensor.id).unwrap() }; reads.push(FuseOp::Assign(UnaryFuseArgs { input, out: out.clone(), })); Some(out) } /// Register an input that is reshaped. pub fn input_reshaped( &mut self, tensor: &TensorIr, output: &TensorIr, resources: &mut FuseResources, ) -> Option { if matches!(tensor.dtype, DType::QFloat(..)) { return None; } let precision = tensor.dtype.into(); let input_index = match self.locals.get(precision, tensor.id) { Some(_) => { // Can't fused an already fused input. if resources.outputs.get(tensor.id).is_some() { return None; } match resources.inputs.get_index(tensor.id) { Some(index) => { resources.inputs.update(tensor); index } None => { return None; } } } None => resources.inputs.insert(precision, tensor.clone()), }; let out = self.output(output, resources)?; let original = FuseArg::Input(input_index, precision, LayoutInfo::Unknown); let mut shape = Vec::new(); let index = resources.num_reshaped; resources.num_reshaped += 1; let rank = output.shape.rank(); for i in 0..output.shape.rank() { let id = index * rank + i; shape.push(FuseArg::ScalarShape(id)); } resources.views.push(TensorView::Reshape { reshaped: output.id, original: tensor.id, reshape_pos: index, shape_relative: output.shape.clone(), }); let input = FuseArg::InputReshaped { original: Box::new(original), shape, broadcasted: output.shape[rank - 1] == 0, }; let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) { e.insert(Vec::with_capacity(1)); self.reads.get_mut(&tensor.id).unwrap() } else { self.reads.get_mut(&tensor.id).unwrap() }; reads.push(FuseOp::Assign(UnaryFuseArgs { input, out: out.clone(), })); Some(out) } /// Build into a fuse block. pub fn build( &self, resources: &FuseResources, outputs: &mut RegisteredTensors, buffers: &mut Vec, ) -> FuseBlock { let ops = self.ops.clone(); let reads = self.reads.clone(); let tensor_writes = self.tensor_writes(resources, buffers); let mut writes = self.writes.clone(); for (tensor, precision) in tensor_writes .iter() .filter_map(|entry| entry.as_normal_tensor()) { if let Some(local) = self.locals.get_any_precision(tensor.id) { let out_index = outputs.insert(*precision, tensor.clone()); let ops = match writes.get_mut(&tensor.id) { Some(ops) => ops, None => { writes.insert(tensor.id, Vec::new()); writes.get_mut(&tensor.id).unwrap() } }; ops.push(FuseOp::Assign(UnaryFuseArgs { input: local, out: FuseArg::Output(out_index, *precision, LayoutInfo::Unknown), })); } } FuseBlock { settings: self.settings, ops, shape_ref: self.shape_ref.clone(), reads, writes, } } /// Return the tensor that needs to be written to. /// /// # Notes /// /// The buffers vector passed as input is only to track the intermediary buffer writes needed /// during execution. pub fn tensor_writes( &self, resources: &FuseResources, buffers: &mut Vec, ) -> RegisteredTensors { let mut result = RegisteredTensors::default(); // All tensors where their latest representation is not read write should be written to since they // are going to be used after the fused kernel by other operations. for output in self.outputs.iter() { if let Some((tensor, _precision)) = output.as_normal_tensor() { // We get the latest representation from the resources, not just this block. if let Some((tensor, precision)) = resources.outputs.get(tensor.id) { if !matches!(tensor.status, TensorStatus::ReadWrite) { result.insert(*precision, tensor.clone()); } else if resources.buffers.get(tensor.id).is_some() && !buffers.contains(&tensor.id) { result.insert(*precision, tensor.clone()); // We make sure we don't write multiple time in the same buffer, only the // earliest possible. buffers.push(tensor.id); } } } } result } } #[derive(Default, Clone, Debug)] pub struct LocalVariablePool { values: BTreeMap>, } impl LocalVariablePool { fn get(&self, precision: FuseType, tensor_id: TensorId) -> Option { if let Some(indexes) = self.values.get(&precision) && let Some(index) = indexes.get(&tensor_id) { return Some(FuseArg::BlockLocal { pos: *index, ty: precision, }); } None } fn get_any_precision(&self, tensor_id: TensorId) -> Option { for (precision, indexes) in self.values.iter() { if let Some(index) = indexes.get(&tensor_id) { return Some(FuseArg::BlockLocal { pos: *index, ty: *precision, }); } } None } fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg { if let Some(indexes) = self.values.get_mut(&precision) { let new_index = indexes.len(); indexes.insert(tensor_id, new_index); return FuseArg::BlockLocal { pos: new_index, ty: precision, }; } let new_index = 0; self.values .insert(precision, BTreeMap::from_iter([(tensor_id, new_index)])); FuseArg::BlockLocal { pos: new_index, ty: precision, } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/trace/fuser.rs ================================================ use super::{ super::{ codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo}, settings::FuseSettings, }, FuseResources, block::FuseBlockBuilder, }; use super::{FuseTrace, RegisteredTensors}; use crate::engine::trace::block::QuantInput; use burn_fusion::stream::ScalarId; use burn_ir::{ScalarIr, TensorIr}; use burn_std::{DType, Shape}; use cubecl::quant::scheme::QuantParam; #[derive(Clone, Debug)] /// It is responsible to create a [trace](FuseTrace) composed of multiple [blocks](super::block::FuseBlock). /// /// It mostly handles the [resources](KernelResources) needed by the generated fused kernel, and /// delegates most of the work to the [block builder](FuseBlockBuilder). pub struct TraceFuser { settings: FuseSettings, // The tensors returned by the block that don't need to be written to global memory. block_current: FuseBlockBuilder, blocks_previous: Vec, resources: FuseResources, } impl TraceFuser { /// Create a new trace builder with the given bool precision and [fuse settings](FuseSettings). pub fn new(settings: FuseSettings) -> Self { Self { settings, block_current: FuseBlockBuilder::new(settings), blocks_previous: Default::default(), resources: Default::default(), } } /// Get the number of blocks that are closed. pub fn num_previous_blocks(&self) -> usize { self.blocks_previous.len() } /// Tag a tensor as dropped. pub fn fuse_dropped(&mut self, tensor: &TensorIr) { self.resources.outputs.update(tensor); self.resources.inputs.update(tensor); self.resources.dropped.insert(tensor.id); } /// Register an operation. pub fn fuse_operation(&mut self, op: FuseOp) { self.block_current.ops.push(op); } /// The number of operations fused. pub fn num_ops_fused(&self) -> u32 { let mut num_ops_fused = 0; for block in self.blocks_previous.iter() { num_ops_fused += block.ops.len(); } num_ops_fused += self.block_current.ops.len(); num_ops_fused as u32 } /// Close the current block with the given reference shape and creates a new one with new [fusion settings](FuseSettings). pub fn next_block(&mut self, shape_ref: Shape, settings: FuseSettings) { let mut block_new = FuseBlockBuilder::new(settings); core::mem::swap(&mut self.block_current, &mut block_new); block_new.shape_ref = shape_ref; self.blocks_previous.push(block_new); self.settings = settings; } // Estimate how many bindings are in use right now. This can return more than the actual number // but should never return less. pub fn estimate_bindings(&self) -> u32 { let mut buffers = Vec::new(); let mut estimation = 1; // Metadata takes one. // We assume we are not going to write multiple times in the same output buffer. for b in self.blocks_previous.iter() { estimation += b.tensor_writes(&self.resources, &mut buffers).len() as u32; } estimation += self .block_current .tensor_writes(&self.resources, &mut buffers) .len() as u32; estimation += self.resources.inputs.len() as u32; // One buffer per scalar type for now. estimation += self.resources.scalars.len() as u32; estimation } /// Tag the [tensor](TensorIr) as received from a previous block. /// /// This will avoid reading the input again and instead use le local version when possible. pub fn block_local_input( &mut self, tensor: &TensorIr, block_pos: usize, global: bool, ) -> FuseArg { let block = &mut self.blocks_previous[block_pos]; let src_arg = match block.multi_block_variable(block_pos, tensor, global) { Some(val) => val, None => { // We try to read the input if not present. block.input(tensor, &mut self.resources); block .multi_block_variable(block_pos, tensor, global) .unwrap() } }; self.resources.outputs.update(tensor); if global { self.resources.registers.insert(tensor.id, src_arg.clone()); } self.block_current .local_inputs .insert(tensor.id, src_arg.clone()); src_arg } /// Register an output tensor that won't be automatically synced into global memory. /// /// It is therefore the responsibility of the operation to write the result to given tensor. pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg { let arg = self .output(tensor) .expect("Can't add a new output that is already used in an index operation"); self.resources.outputs_unhandled.push(arg.clone()); self.block_current.outputs_unhandled.push(arg.clone()); arg } /// Register an input tensor that won't be automatically read into a local variable. /// /// It is therefore the responsibility of the operation to read the given tensor. pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg { if self.resources.indexed.contains_key(&tensor.id) { panic!("Can't add a new input that is already used in an index operation"); } self.resources.outputs.update(tensor); let precision = tensor.dtype.into(); let new_input = self.resources.inputs.insert(precision, tensor.clone()); let arg = FuseArg::Input(new_input, precision, LayoutInfo::Unknown); self.resources.inputs_unhandled.push(tensor.id); arg } /// Register an input tensor. pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> { if self.resources.indexed.contains_key(&tensor.id) { panic!("Can't add a new input that is already used in an index operation"); } self.resources.outputs.update(tensor); let precision = tensor.dtype.into(); let precision_scales = match tensor.dtype { DType::QFloat(scheme) => match scheme.param { QuantParam::F32 => FuseType::F32, QuantParam::F16 => FuseType::F16, QuantParam::BF16 => FuseType::BF16, QuantParam::UE8M0 | QuantParam::UE4M3 => { unimplemented!("Unsupported fuse precision"); } }, _ => return None, }; let (new_input, q_index) = self.resources.inputs.insert_quant(tensor.clone()); let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown); let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown); self.resources.inputs_unhandled.push(tensor.id); Some((input, scales)) } /// Register an input tensor. pub fn input(&mut self, tensor: &TensorIr) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; } self.resources.outputs.update(tensor); self.block_current.input(tensor, &mut self.resources) } /// Register an input tensor. pub fn input_quantized(&mut self, tensor: &TensorIr) -> Option { self.resources.outputs.update(tensor); self.block_current.input_quant(tensor, &mut self.resources) } /// Register an output tensor. pub fn output(&mut self, tensor: &TensorIr) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; } self.block_current.output(tensor, &mut self.resources) } /// Register an input that will be accessed using custom indexing with no vectorization. pub fn input_indexed(&mut self, tensor: &TensorIr) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; } if let Some(val) = self.resources.indexed.get(&tensor.id) { self.resources.outputs.update(tensor); return Some(val.clone()); }; if self.resources.inputs.get(tensor.id).is_some() { return None; } if self.resources.outputs.get(tensor.id).is_some() { return None; } let input = self.input_unhandled(tensor); self.resources.indexed.insert(tensor.id, input.clone()); Some(input) } /// Register an input with swapped dims. pub fn input_swap_dims( &mut self, tensor: &TensorIr, output: &TensorIr, dims: (usize, usize), ) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; } self.resources.outputs.update(tensor); self.block_current .input_swap_dims(tensor, output, dims, &mut self.resources) } /// Register an input that is reshaped. pub fn input_reshaped(&mut self, tensor: &TensorIr, output: &TensorIr) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; } self.resources.outputs.update(tensor); self.block_current .input_reshaped(tensor, output, &mut self.resources) } /// Register a scalar value. pub fn scalar(&mut self, elem: &ScalarIr, dtype: DType) -> FuseArg { let precision = dtype.into(); let id = if let ScalarIr::UInt(value) = elem { ScalarId { value: *value } } else { unreachable!() // should always be u64 }; let new_index = self.resources.scalars.len(); self.resources.scalars.push((precision, id.value)); FuseArg::Scalar(new_index, precision) } /// Finish fusing and returns the created trace. pub fn finish(&mut self, shape_ref: Shape) -> FuseTrace { let mut resources = self.resources.clone(); let mut outputs = RegisteredTensors::default(); let mut buffers = Vec::new(); for tensor in resources.buffers.iter() { let (tensor, ty) = tensor.as_normal_tensor().unwrap(); outputs.insert(*ty, tensor.clone()); } let mut blocks = Vec::new(); let mut register_block = |block: &FuseBlockBuilder| { let block = block.build(&self.resources, &mut outputs, &mut buffers); blocks.push(block); }; for block in self.blocks_previous.iter() { register_block(block); } self.block_current.shape_ref = shape_ref; register_block(&self.block_current); // We update the output tensors registered to be the ones that are written to in global // memory. resources.outputs = outputs; FuseTrace { blocks, resources } } } ================================================ FILE: crates/burn-cubecl-fusion/src/engine/trace/mod.rs ================================================ pub(crate) mod block; mod base; mod fuser; pub use base::*; pub use fuser::*; ================================================ FILE: crates/burn-cubecl-fusion/src/lib.rs ================================================ #[macro_use] extern crate derive_new; pub mod optim; mod base; pub(crate) mod engine; pub(crate) mod tune; pub use base::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/base.rs ================================================ use crate::optim::{ elemwise::{ElemwiseOptimization, ElemwiseOptimizationState}, matmul::{MatmulOptimization, MatmulOptimizationState}, reduce::{ReduceOptimization, ReduceOptimizationState}, reduce_broadcasted::{ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationState}, }; use cubecl::Runtime; use serde::{Deserialize, Serialize}; /// Fusion optimization type for cubecl. /// /// More optimization variants should be added here. #[allow(clippy::large_enum_variant)] pub enum CubeOptimization { ElementWise(ElemwiseOptimization), Matmul(MatmulOptimization), Reduce(ReduceOptimization), ReduceBroadcasted(ReduceBroadcastedOptimization), } impl core::fmt::Debug for CubeOptimization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let value = self.to_opt_state(); f.write_fmt(format_args!("{value:?}")) } } impl CubeOptimization { /// Serializes the current optimization to its state. pub fn to_opt_state(&self) -> CubeOptimizationState { match self { Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()), Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()), Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()), Self::ReduceBroadcasted(value) => { CubeOptimizationState::ReduceBroadcasted(value.to_state()) } } } } impl burn_fusion::NumOperations for CubeOptimization { fn len(&self) -> usize { match self { Self::ElementWise(op) => op.num_ops_fused(), Self::Matmul(op) => op.num_ops_fused(), Self::Reduce(op) => op.num_ops_fused(), Self::ReduceBroadcasted(op) => op.num_ops_fused(), } } } /// Fusion optimization state type for cubecl. /// /// More optimization variants should be added here. #[allow(clippy::large_enum_variant)] #[derive(Serialize, Deserialize, Debug)] pub enum CubeOptimizationState { ElementWise(ElemwiseOptimizationState), Matmul(MatmulOptimizationState), Reduce(ReduceOptimizationState), ReduceBroadcasted(ReduceBroadcastedOptimizationState), } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/elemwise/fuser.rs ================================================ use super::optimization::ElemwiseOptimization; use crate::{ engine::{ fuser::TraceOperationFuser, settings::{FuseSettings, RefLayoutSetting, VectorizationSetting}, }, optim::CubeOptimization, }; use burn_fusion::OperationFuser; use burn_std::Shape; use cubecl::Runtime; /// Fuses element wise operations. pub struct ElementWiseFuser { fuser: TraceOperationFuser, device: R::Device, } impl Clone for ElementWiseFuser { fn clone(&self) -> Self { Self { fuser: self.fuser.clone(), device: self.device.clone(), } } } impl ElementWiseFuser { pub fn shape_id(&self) -> Shape { self.fuser.current_output_shape.clone() } pub fn new(device: R::Device) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware.max_bindings; Self { fuser: TraceOperationFuser::new( max_bindings, FuseSettings { broadcast: true, output_shape_updates: true, inplace: true, vectorization: VectorizationSetting::Activated, ref_layout: RefLayoutSetting::Any, }, ), device, } } } impl OperationFuser> for ElementWiseFuser { fn fuse(&mut self, operation: &burn_ir::OperationIr) { self.fuser.fuse(operation); } fn finish(&mut self) -> CubeOptimization { let client = R::client(&self.device); let trace = self.fuser.finish(); let elementwise = ElemwiseOptimization::new(trace, client, self.device.clone(), self.len()); CubeOptimization::ElementWise(elementwise) } fn reset(&mut self) { self.fuser.reset() } fn status(&self) -> burn_fusion::FuserStatus { self.fuser.status() } fn properties(&self) -> burn_fusion::FuserProperties { self.fuser.properties() } fn len(&self) -> usize { self.fuser.len() } fn clone_dyn(&self) -> Box>> { Box::new(self.clone()) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/elemwise/mod.rs ================================================ mod fuser; mod optimization; pub use fuser::*; pub use optimization::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs ================================================ use crate::{ CubeFusionHandle, engine::{ codegen::{ DynSize, io::ref_len, ir::{ FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsLaunch, RefLayout, multi_block_variables_init, }, kernel::{fuse_on_write, init_locals}, }, launch::{ FuseTraceLauncher, runner::{TraceRunner, Vectorization}, }, trace::FuseTrace, }, }; use burn_fusion::stream::Context; use cubecl::{CubeDim, calculate_cube_count_elemwise, client::ComputeClient, prelude::*}; use serde::{Deserialize, Serialize}; #[derive(new)] /// Fuse element wise operations into a single kernel. pub struct ElemwiseOptimization { pub(crate) trace: FuseTrace, client: ComputeClient, device: R::Device, len: usize, } #[derive(Serialize, Deserialize, Debug)] /// State for the [elemwise optimization](ElemwiseOptimization). pub struct ElemwiseOptimizationState { trace: FuseTrace, len: usize, } impl ElemwiseOptimization { /// Execute the optimization. pub fn execute(&self, context: &mut Context<'_, CubeFusionHandle>) { let launcher = FuseTraceLauncher::new(&self.trace, &ElemwiseRunner); match launcher.launch(&self.client, &self.device, context) { Ok(_) => (), Err(err) => { panic!("{err:?} - {:?}", self.trace); } } } /// Number of element wise operations fused. pub fn num_ops_fused(&self) -> usize { self.len } /// Create an optimization from its [state](ElemwiseOptimizationState). pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self { Self { trace: state.trace, len: state.len, client: R::client(device), device: device.clone(), } } /// Convert the optimization to its [state](ElemwiseOptimizationState). pub fn to_state(&self) -> ElemwiseOptimizationState { ElemwiseOptimizationState { trace: self.trace.clone(), len: self.len, } } } pub struct ElemwiseRunner; impl Vectorization for ElemwiseRunner {} impl TraceRunner for ElemwiseRunner { type Error = LaunchError; // No error possible fn run<'a>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, configs: &[FuseBlockConfig], ) -> Result<(), Self::Error> { let config = &configs[0]; let shape = match &config.ref_layout { RefLayout::Concrete(arg) => match arg { FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank), FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank), _ => panic!("Invalid concreate ref layout"), }, RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank), }; let working_units = shape.iter().product::() / config.width; let cube_dim = CubeDim::new(client, working_units); let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim); let address_type = inputs .required_address_type() .max(outputs.required_address_type()); unsafe { elemwise_fuse::launch_unchecked( client, cube_count, cube_dim, address_type, inputs, outputs, config.clone(), ); }; Ok(()) } } #[cube(launch_unchecked, address_type = "dynamic")] fn elemwise_fuse( inputs: &GlobalArgs, outputs: &mut GlobalArgs, #[comptime] config: &FuseBlockConfig, ) { // We write no values for this fusion. let values = Registry::>::new(); let args = comptime![Vec::::new()]; let pos = ABSOLUTE_POS; multi_block_variables_init(config, &mut outputs.variables); let mut locals = init_locals(inputs, outputs, config); let length = ref_len(inputs, outputs, &locals, config); if pos < length { fuse_on_write::(inputs, outputs, &mut locals, pos, values, args, config) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/matmul/args.rs ================================================ use crate::engine::codegen::{ io::ref_vector_size, ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, LocalArgs, multi_block_variables_init}, kernel::init_locals, view::{FusedOutput, GlobalInput, GlobalInputExpand}, }; use cubecl::{ intrinsic, prelude::*, quant::scheme::{QuantLevel, QuantScheme}, std::{ FastDivmod, quant::{ RunWithQuantType, view::{QuantizedView, run_with_quant_type}, }, tensor::{ View, ViewExpand, layout::{Coords1d, Coords2d, VirtualLayout}, }, }, }; use cubek::{ matmul::{ components::global::memory::{ BatchLayout, BlockScaledLayout, GlobalLayout, GlobalLayoutConfig, GlobalLayoutExpand, GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout, }, launch::{BatchedCoords, MatmulArgs}, }, std::MatrixLayout, }; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; #[derive(Clone)] pub struct FusedMatmulArgs; #[derive(CubeLaunch, CubeType)] pub struct FusedMatmulInput { global: GlobalArgs, #[cube(comptime)] config: FuseBlockConfig, #[cube(comptime)] a: MatmulArg, #[cube(comptime)] b: MatmulArg, #[cube(comptime)] c: Option, #[cube(comptime)] out: FuseArg, } #[cube] impl MatmulArgs for FusedMatmulArgs { type Output = GlobalArgs; type Input = FusedMatmulInput; type State = FusedMatmulState; type Config = (); fn init_state( inputs: &Self::Input, outputs: &mut Self::Output, _config: (), #[comptime] lhs_layout_config: GlobalLayoutConfig, #[comptime] rhs_layout_config: GlobalLayoutConfig, #[comptime] out_layout_config: GlobalLayoutConfig, ) -> Self::State { multi_block_variables_init(&inputs.config, &mut outputs.variables); let mut locals = init_locals(&inputs.global, outputs, &inputs.config); let rank = comptime![inputs.config.rank]; let mut batch_shape = Sequence::new(); let mut batch_strides_out = Sequence::new(); #[unroll] for i in 0..rank - 2 { batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32)); batch_strides_out.push(locals.ref_strides[i]); } let batch_lhs = input_batch_layout( &inputs.global, &batch_shape, comptime![inputs.a.clone()], comptime![inputs.config.clone()], ); let batch_rhs = input_batch_layout( &inputs.global, &batch_shape, comptime![inputs.b.clone()], comptime![inputs.config.clone()], ); let batch_acc = match comptime![inputs.c.clone()] { Some(c) => ComptimeOption::Some(input_batch_layout( &inputs.global, &batch_shape, comptime![c], comptime![inputs.config.clone()], )), None => ComptimeOption::new_None(), }; let batch_out = BatchLayout::new(batch_strides_out, batch_shape.clone()); FusedMatmulState::new( inputs, outputs, &mut locals, batch_lhs, batch_rhs, batch_acc, VirtualLayout::new::(batch_out), batch_shape, &inputs.config, lhs_layout_config, rhs_layout_config, out_layout_config, ) } fn view_lhs( state: &Self::State, ) -> View { global_view( &state.inputs, &state.locals, &state.batch_shape, comptime![state.a.clone()], comptime![state.config.clone()], state.lhs_layout_config, ) } fn batch_lhs( state: &Self::State, batch: usize, ) -> usize { state.a_batch.to_source_pos(batch) } fn view_rhs( state: &Self::State, ) -> View { global_view( &state.inputs, &state.locals, &state.batch_shape, comptime![state.b.clone()], comptime![state.config.clone()], comptime![state.rhs_layout_config], ) } fn batch_rhs( state: &Self::State, batch: usize, ) -> usize { state.b_batch.to_source_pos(batch) } fn view_acc( state: &Self::State, ) -> ComptimeOption> { match comptime![state.c.clone()] { Some(c) => { let view = global_view( &state.inputs, &state.locals, &state.batch_shape, c, comptime![state.config.clone()], comptime![state.out_layout_config], ); ComptimeOption::Some(view) } None => ComptimeOption::new_None(), } } fn batch_acc( state: &Self::State, batch: usize, ) -> usize { #[comptime] match state.c_batch { ComptimeOption::Some(c_batch) => c_batch.to_source_pos(batch), ComptimeOption::None => batch, } } fn view_out( state: &mut Self::State, ) -> View { let rank = comptime![state.config.rank]; let shape_row = state.locals.ref_shape[rank - 2] as u32; let shape_col = state.locals.ref_shape[rank - 1] as u32; let stride_row = state.locals.ref_strides[rank - 2]; let stride_col = state.locals.ref_strides[rank - 1]; let layout = GlobalLayout::new( VirtualLayout::new::(NoopLayout::new()), shape_row, shape_col, stride_row, stride_col, ref_vector_size(&state.locals), 1u32, state.out_layout_config, ); let mut buffer = FusedOutput::new( &state.inputs, &mut state.outputs, &mut state.locals, comptime![state.out.clone()], comptime![state.config.clone()], ); View::new_mut::(&mut buffer, layout) } fn batch_out( state: &Self::State, batch: usize, ) -> usize { state.out_batch.to_source_pos(batch) } fn runtime_config( _state: &Self::State, ) { } } #[cube] #[allow(clippy::missing_transmute_annotations)] fn global_view( inputs: &GlobalArgs, locals: &LocalArgs, batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, #[comptime] layout_config: GlobalLayoutConfig, ) -> View { let rank = comptime![config.rank]; let data = comptime![arg.data().clone()]; let data_tensor = match comptime![data.clone()] { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), _ => panic!("Input must be concrete"), }; let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32; let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32; let mut packing = comptime![1]; if arg.scheme().is_some() { let scheme = arg.scheme().unwrap(); let num_quants = scheme.num_quants() as u32; comptime![packing = num_quants]; match comptime![layout_config.matrix_layout] { MatrixLayout::RowMajor => shape_col *= num_quants, MatrixLayout::ColMajor => shape_row *= num_quants, }; } let shape = (shape_row, shape_col); // Noop for normal inputs because batch offset is cached, quantized uses logical batches let batch_layout = match comptime![arg.clone()] { MatmulArg::Normal(_) => VirtualLayout::new::(NoopLayout::new()), MatmulArg::Quantized { data, .. } => { let data_arg = comptime![MatmulArg::Normal(data)]; input_batch_layout(inputs, batch_shape, data_arg, comptime![config.clone()]) } }; let data_layout = global_layout( inputs, shape, batch_layout, arg.data().clone(), config.clone(), data_tensor.tensor.vector_size(), layout_config, packing, ); let data_buf = GlobalInput::new(inputs, locals, data, comptime![config.clone()], None); match comptime![arg.clone()] { MatmulArg::Normal(_) => View::new::(&data_buf, data_layout), MatmulArg::Quantized { scales, scheme, .. } => { let scales_layout = match comptime![scheme.level] { QuantLevel::Tensor => GlobalScaleLayout::new_PerTensor(shape), QuantLevel::Block(block_size) => { let block_size = comptime![block_size.as_dim::<2>()]; let scales_arg = comptime![MatmulArg::Normal(scales.clone())]; let batch_layout = input_batch_layout( inputs, batch_shape, scales_arg, comptime![config.clone()], ); let scales_layout = global_layout( inputs, shape, batch_layout, comptime![scales.clone()], comptime![config.clone()], 1usize, layout_config, 1u32, ); GlobalScaleLayout::new_BlockScaled(BlockScaledLayout::new( shape, scales_layout, comptime![(block_size[0] as u32, block_size[1] as u32)], )) } }; let scales_buf = GlobalInput::new(inputs, locals, scales, config, None); // Redefine because of `Numeric` bound, kinda hacky but I can't figure out a way to // assert `Vector::Scalar: Numeric` let define!(T) = storage_type_of::(); let view = create_quant_view_dynamic::( data_buf, data_layout, scales_buf, scales_layout, scheme, ); // Safety: should be fine since `Vector` is guaranteed equal to `E` comptime![unsafe { core::mem::transmute(view) }] } } } #[cube] fn input_batch_layout( inputs: &GlobalArgs, batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, ) -> VirtualLayout { let rank = comptime![config.rank]; match comptime![arg.clone()] { MatmulArg::Normal(arg) => { let data_tensor = match comptime![arg.clone()] { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), _ => panic!("Input must be concrete"), }; let mut batch_strides = Sequence::new(); #[unroll] for i in 0..rank - 2 { let shape = data_tensor.tensor.shape(i); let stride = select(shape == 1, 0, data_tensor.tensor.stride(i)); batch_strides.push(stride); } VirtualLayout::new::(BatchLayout::new(batch_strides, batch_shape.clone())) } MatmulArg::Quantized { .. } => VirtualLayout::new::(NoopLayout::new()), } } #[cube] fn global_layout( inputs: &GlobalArgs, shape: Coords2d, batch_layout: VirtualLayout, #[comptime] arg: FuseArg, #[comptime] config: FuseBlockConfig, #[comptime] vector_size: VectorSize, #[comptime] layout_config: GlobalLayoutConfig, #[comptime] packing: u32, ) -> GlobalLayout { let rank = comptime![config.rank]; let data_tensor = match comptime![arg.clone()] { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), _ => panic!("Input must be concrete"), }; let (shape_row, shape_col) = shape; let stride_row = data_tensor.tensor.stride(rank - 2); let stride_col = data_tensor.tensor.stride(rank - 1); GlobalLayout::new( batch_layout, shape_row, shape_col, stride_row, stride_col, vector_size, packing, layout_config, ) } struct CreateQuantView<'a, E: Numeric, N: Size> { scope: &'a mut Scope, data_buf: GlobalInputExpand, data_layout: GlobalLayoutExpand, scales_buf: GlobalInputExpand, scales_layout: GlobalScaleLayoutExpand, scheme: QuantScheme, _ty: PhantomData<(E, N)>, } impl<'a, E: Numeric, N: Size> RunWithQuantType for CreateQuantView<'a, E, N> { type Output = ViewExpand, BatchedCoords>; fn execute(self) -> Self::Output { create_quant_view::expand::( self.scope, self.data_buf, self.data_layout, self.scales_buf, self.scales_layout, self.scheme, ) } } #[cube] #[allow(unused)] fn create_quant_view_dynamic( data_buf: GlobalInput, data_layout: GlobalLayout, scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, ) -> View, BatchedCoords> { intrinsic!(|scope| { let func = CreateQuantView { scope, data_buf, data_layout, scales_buf, scales_layout, scheme, _ty: PhantomData, }; run_with_quant_type(func, scheme) }) } #[cube] fn create_quant_view( data_buf: GlobalInput, data_layout: GlobalLayout, scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, ) -> View, BatchedCoords> { let size!(NQ) = N::value().comptime() / scheme.num_quants(); let data_view: View, BatchedCoords> = View::new::(&data_buf, data_layout); let scales_view: View = View::new::(&scales_buf, scales_layout); QuantizedView::new(data_view, scales_view, scheme).view() } #[derive(CubeType)] pub struct FusedMatmulState { inputs: GlobalArgs, outputs: GlobalArgs, locals: LocalArgs, a_batch: VirtualLayout, b_batch: VirtualLayout, c_batch: ComptimeOption>, out_batch: VirtualLayout, #[cube(comptime)] config: FuseBlockConfig, #[cube(comptime)] a: MatmulArg, #[cube(comptime)] b: MatmulArg, #[cube(comptime)] c: Option, #[cube(comptime)] out: FuseArg, #[cube(comptime)] lhs_layout_config: GlobalLayoutConfig, #[cube(comptime)] rhs_layout_config: GlobalLayoutConfig, #[cube(comptime)] out_layout_config: GlobalLayoutConfig, batch_shape: Sequence>, } #[cube] impl FusedMatmulState { #[allow(clippy::too_many_arguments)] pub fn new( inputs: &FusedMatmulInput, outputs: &mut GlobalArgs, locals: &mut LocalArgs, a_batch: VirtualLayout, b_batch: VirtualLayout, c_batch: ComptimeOption>, out_batch: VirtualLayout, batch_shape: Sequence>, #[comptime] config: &FuseBlockConfig, #[comptime] lhs_layout_config: GlobalLayoutConfig, #[comptime] rhs_layout_config: GlobalLayoutConfig, #[comptime] out_layout_config: GlobalLayoutConfig, ) -> FusedMatmulState { FusedMatmulState { inputs: inputs.global.clone(), outputs: outputs.clone(), config: comptime![config.clone()], locals: locals.clone(), a_batch, b_batch, c_batch, out_batch, a: comptime![inputs.a.clone()], b: comptime![inputs.b.clone()], c: comptime![inputs.c.clone()], out: comptime![inputs.out.clone()], lhs_layout_config, rhs_layout_config, out_layout_config, batch_shape, } } } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to a matmul operation. pub enum MatmulArg { Normal(FuseArg), Quantized { data: FuseArg, scales: FuseArg, precision: FuseType, scheme: QuantScheme, }, } impl MatmulArg { pub fn data(&self) -> &FuseArg { match self { MatmulArg::Normal(arg) => arg, MatmulArg::Quantized { data, .. } => data, } } pub fn scheme(&self) -> Option<&QuantScheme> { match self { MatmulArg::Normal(_) => None, MatmulArg::Quantized { scheme, .. } => Some(scheme), } } pub fn precision(&self) -> FuseType { match self { MatmulArg::Normal(arg) => arg.precision(), MatmulArg::Quantized { precision, .. } => *precision, } } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/matmul/fuser.rs ================================================ use super::optimization::{FusedMatmul, MatmulOptimization}; use crate::{ engine::{fuser::TraceOperationFuser, settings::FuseSettings}, optim::CubeOptimization, optim::matmul::args::MatmulArg, }; use burn_fusion::{FuserStatus, OperationFuser}; use burn_ir::{FloatOperationIr, OperationIr}; use burn_std::DType; use cubecl::Runtime; /// Fused element wise operations that are normally memory bound. pub struct MatmulFuser { fuser: TraceOperationFuser, fuser_fallback: TraceOperationFuser, device: R::Device, matmul: Option, } impl Clone for MatmulFuser { fn clone(&self) -> Self { Self { fuser: self.fuser.clone(), fuser_fallback: self.fuser_fallback.clone(), device: self.device.clone(), matmul: self.matmul.clone(), } } } impl MatmulFuser { pub fn new(device: R::Device) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware.max_bindings; let settings_matmul = FuseSettings { output_shape_updates: false, ..Default::default() }; let settings_fallback = FuseSettings::default(); Self { fuser: TraceOperationFuser::new(max_bindings, settings_matmul), fuser_fallback: TraceOperationFuser::new(max_bindings, settings_fallback), device, matmul: None, } } } impl OperationFuser> for MatmulFuser { fn fuse(&mut self, operation: &OperationIr) { if let FuserStatus::Closed = self.fuser.status() { return; } if self.matmul.is_none() { if let OperationIr::Float(_, FloatOperationIr::Matmul(op)) = operation { // Precision shouldn't be hardcoded but I don't know how to get float precision of the backend let lhs = match op.lhs.dtype { DType::QFloat(scheme) => { let (data, scales) = self.fuser.input_quantized_unhandled(&op.lhs).unwrap(); MatmulArg::Quantized { data, scales, precision: op.out.dtype.into(), scheme, } } _ => MatmulArg::Normal(self.fuser.input_unhandled(&op.lhs)), }; let rhs = match op.rhs.dtype { DType::QFloat(scheme) => { let (data, scales) = self.fuser.input_quantized_unhandled(&op.rhs).unwrap(); MatmulArg::Quantized { data, scales, precision: op.out.dtype.into(), scheme, } } _ => MatmulArg::Normal(self.fuser.input_unhandled(&op.rhs)), }; let out = self.fuser.output_unhandled(&op.out); self.matmul = Some(FusedMatmul::new( lhs, rhs, out, op.clone().into(), Default::default(), )); } else { self.fuser.close(); self.fuser_fallback.close(); } } else { let can_register = self.fuser.can_fuse(operation) && self.fuser_fallback.can_fuse(operation); match can_register { true => { self.fuser.fuse(operation); self.fuser_fallback.fuse(operation); } false => { self.fuser.close(); self.fuser_fallback.close(); } }; } } fn finish(&mut self) -> CubeOptimization { let client = R::client(&self.device); let trace = self.fuser.finish(); let trace_fallback = self.fuser_fallback.finish(); let matmul = MatmulOptimization::new( trace, trace_fallback, client, self.device.clone(), self.len(), self.matmul.as_ref().unwrap().clone(), ); CubeOptimization::Matmul(matmul) } fn reset(&mut self) { self.fuser.reset(); self.fuser_fallback.reset(); self.matmul = None; } fn status(&self) -> burn_fusion::FuserStatus { self.fuser.status() } fn properties(&self) -> burn_fusion::FuserProperties { self.fuser.properties() } fn len(&self) -> usize { // Matmul operation isn't registered in the fuser self.fuser.len() + 1 } fn clone_dyn(&self) -> Box>> { Box::new(self.clone()) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/matmul/mod.rs ================================================ mod fuser; mod optimization; pub(crate) mod args; pub(crate) mod tune; pub use fuser::*; pub use optimization::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs ================================================ use super::args::FusedMatmulInputLaunch; #[cfg(feature = "autotune")] use super::tune::fused_matmul_autotune; use crate::{ CubeFusionHandle, FallbackOperation, engine::{ codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout}, launch::{ FuseTraceLauncher, HandleInput, LaunchPlan, runner::{TraceRunner, Vectorization, VectorizationAxis}, }, trace::{FuseTrace, TraceError, TuneOutput}, }, optim::{ elemwise::ElemwiseRunner, matmul::args::{FusedMatmulArgs, MatmulArg}, }, }; use burn_fusion::stream::Context; use burn_ir::BinaryOpIr; use cubecl::{ client::ComputeClient, prelude::*, std::tensor::{MatrixBatchLayout, matrix_batch_layout}, }; use cubek::{ matmul::{ components::tile::{cmma::CmmaMatmul, mma::MmaMatmul}, definition::{ MatmulElems, MatmulGlobalElems, MatmulProblem, MatmulSetupError, MatmulVectorSizes, }, launch::launch_kernel_virtual, routines::{ BlueprintStrategy, Routine, double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs}, double_unit::DoubleUnitAlgorithm, ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs}, simple::{SimpleAlgorithm, SimpleArgs}, simple_unit::SimpleUnitAlgorithm, vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm}, }, }, std::MatrixLayout, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; /// Fuse matmul operation followed by elemwise operations into a single kernel. pub struct MatmulOptimization { pub(crate) info: Arc>, } pub struct MatmulOptimizationTuneArg { pub(crate) info: Arc>, pub(crate) fallback: Box>, } pub(crate) struct MatmulOptimizationInfo { trace: FuseTrace, trace_fallback: FuseTrace, pub(crate) client: ComputeClient, pub(crate) device: R::Device, pub(crate) len: usize, pub(crate) matmul: FusedMatmul, } #[derive(Serialize, Deserialize, Debug)] /// State for the [matrix optimization](MatmulOptimizationState). pub struct MatmulOptimizationState { trace: FuseTrace, trace_fallback: FuseTrace, matmul: FusedMatmul, len: usize, } impl MatmulOptimizationInfo { /// Returns the number of output buffers added by fusion. pub fn num_output_buffers(&self) -> usize { self.trace_fallback.resources.outputs.len() } /// Number of operations fused. pub fn num_ops_fused(&self) -> usize { self.len } } impl MatmulOptimizationTuneArg { pub(crate) fn execute_fused( &self, context: &mut Context<'_, CubeFusionHandle>, selector: FusedMatmulSelector, ) -> Result, TraceError> { let launch = FusedMatmulLaunch::new(&self.info.matmul, selector); let launcher = FuseTraceLauncher::new(&self.info.trace, &launch); launcher.launch(&self.info.client, &self.info.device, context) } pub fn execute_fallback( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> TuneOutput { self.fallback.run(context); #[cfg(feature = "autotune-checks")] let mut output = TuneOutput::Checked { handles: Default::default(), }; #[cfg(not(feature = "autotune-checks"))] let output = TuneOutput::UnChecked(core::marker::PhantomData); #[cfg(feature = "autotune-checks")] if let TuneOutput::Checked { handles } = &mut output { let out_desc = context.tensors.get(&self.info.matmul.op.out.id).unwrap(); let handle_out = context .handles .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly); handles.insert( self.info.matmul.op.out.id, (out_desc.shape.dims.clone(), handle_out.clone()), ); } let launcher = FuseTraceLauncher::new(&self.info.trace_fallback, &ElemwiseRunner); let output_write = launcher .launch(&self.info.client, &self.info.device, context) .unwrap(); output.merge(output_write) } } impl MatmulOptimization { pub fn new( trace: FuseTrace, trace_fallback: FuseTrace, client: ComputeClient, device: R::Device, len: usize, matmul: FusedMatmul, ) -> Self { let info = MatmulOptimizationInfo { trace, trace_fallback, client, device, len, matmul, }; Self { info: Arc::new(info), } } /// Execute the optimization. pub fn execute( &mut self, context: &mut Context<'_, CubeFusionHandle>, fallback: impl FnOnce(usize) -> Box>, ) { // The index of the fallback matmul is always 0. let fallback = fallback(0); let arg = MatmulOptimizationTuneArg { info: self.info.clone(), fallback, }; #[cfg(feature = "autotune")] fused_matmul_autotune::(arg, context); #[cfg(not(feature = "autotune"))] if arg .execute_fused(context, FusedMatmulSelector::default()) .is_err() { arg.execute_fallback(context); } } /// Number of operations fused. pub fn num_ops_fused(&self) -> usize { self.info.num_ops_fused() } /// Create an optimization from its [state](MatmulOptimizationState). pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self { let info = MatmulOptimizationInfo { trace: state.trace, trace_fallback: state.trace_fallback, len: state.len, client: R::client(device), device: device.clone(), matmul: state.matmul.clone(), }; Self { info: Arc::new(info), } } /// Convert the optimization to its [state](MatmulOptimizationState). pub fn to_state(&self) -> MatmulOptimizationState { MatmulOptimizationState { trace: self.info.trace.clone(), trace_fallback: self.info.trace_fallback.clone(), matmul: self.info.matmul.clone(), len: self.info.len, } } } #[derive(Clone, Copy, Serialize, Deserialize, Debug)] pub enum FusedMatmulSelector { Simple { multi_rows: bool, tile_matmul: AcceleratedTileKind, }, DoubleBuffering { specialized: bool, tile_matmul: AcceleratedTileKind, }, OrderedDoubleBuffering { tile_matmul: AcceleratedTileKind, }, SimpleVecMat, DoubleVecMat, SimpleUnit, DoubleUnit, } impl FusedMatmulSelector { /// Not efficient, but only called once when initializing the tunables. pub fn name(&self) -> String { let name = match self { FusedMatmulSelector::Simple { multi_rows, tile_matmul, } => match multi_rows { false => format!("simple_{tile_matmul:?}"), true => format!("simple_multirows_{tile_matmul:?}"), }, FusedMatmulSelector::DoubleBuffering { specialized, tile_matmul, } => match specialized { false => format!("double_buffering_{tile_matmul:?}"), true => format!("double_buffering_specialized_{tile_matmul:?}"), }, FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => { format!("double_buffering_ordered_{tile_matmul:?}").to_lowercase() } FusedMatmulSelector::SimpleVecMat => "simple_vec_mat".into(), FusedMatmulSelector::DoubleVecMat => "double_buffering_vec_mat".into(), FusedMatmulSelector::SimpleUnit => "simple_unit".into(), FusedMatmulSelector::DoubleUnit => "double_buffering_unit".into(), }; format!("fused_{name}") } } impl Default for FusedMatmulSelector { fn default() -> Self { FusedMatmulSelector::Simple { multi_rows: false, tile_matmul: AcceleratedTileKind::Cmma, } } } #[derive(new, Clone, Serialize, Deserialize, Debug)] pub struct FusedMatmul { pub(crate) lhs: MatmulArg, pub(crate) rhs: MatmulArg, out: FuseArg, pub(crate) op: BinaryOpIr, pub(crate) selector: FusedMatmulSelector, } #[derive(new)] pub struct FusedMatmulLaunch<'a> { pub(crate) matmul: &'a FusedMatmul, pub(crate) selector: FusedMatmulSelector, } #[derive(Debug)] pub enum FusedMatmulError { LaunchError(MatmulSetupError), InvalidInput(&'static str), } impl From for FusedMatmulError { fn from(value: MatmulSetupError) -> Self { Self::LaunchError(value) } } impl<'a, R: Runtime> Vectorization for FusedMatmulLaunch<'a> { fn axis(&self, plan: &LaunchPlan<'_, R>) -> VectorizationAxis { let lhs_id = self.matmul.op.lhs.id; let rhs_id = self.matmul.op.rhs.id; let mut tensor_lhs = None; let mut tensor_rhs = None; for input in plan.handle_inputs.iter() { match input { HandleInput::Normal(input) => { if input.relative_id == lhs_id { tensor_lhs = Some((input.global_ir.id, &input.handle.strides)); } if input.relative_id == rhs_id { tensor_rhs = Some((input.global_ir.id, &input.handle.strides)); } } HandleInput::QuantValues(input) => { if input.relative_id == lhs_id { tensor_lhs = Some((input.global_ir.id, &input.handle.strides)); } if input.relative_id == rhs_id { tensor_rhs = Some((input.global_ir.id, &input.handle.strides)); } } HandleInput::QuantParams(_) => {} } } let (lhs_id_global, lhs_strides) = tensor_lhs.unwrap(); let (rhs_id_global, rhs_strides) = tensor_rhs.unwrap(); let mut axis = VectorizationAxis::default(); if let MatrixBatchLayout::MildlyPermuted { transposed, .. } = matrix_batch_layout(lhs_strides, self.matmul.lhs.scheme()) && transposed { axis.insert(lhs_id_global, lhs_strides.len() - 2); } if let MatrixBatchLayout::MildlyPermuted { transposed, .. } = matrix_batch_layout(rhs_strides, self.matmul.rhs.scheme()) && transposed { axis.insert(rhs_id_global, rhs_strides.len() - 2); } axis } } impl TraceRunner for FusedMatmulLaunch<'_> { type Error = FusedMatmulError; fn run<'a>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, configs: &'a [FuseBlockConfig], ) -> Result<(), FusedMatmulError> { let global_elems = MatmulGlobalElems { lhs: self.matmul.lhs.precision().into_storage_type(), rhs: self.matmul.rhs.precision().into_storage_type(), out: self.matmul.out.precision().into_storage_type(), }; let dtypes = MatmulElems::from_globals(&global_elems); self.matmul_fused(client, inputs, outputs, &configs[0], dtypes) } } #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] /// Which tile matmul to use for accelerated algorithms pub enum AcceleratedTileKind { #[default] Cmma, Mma, } macro_rules! with_tile_kind { ($kind: expr, $T: ident, $launch: expr) => { match $kind { AcceleratedTileKind::Cmma => { type $T = CmmaMatmul; ($launch)() } AcceleratedTileKind::Mma => { type $T = MmaMatmul; ($launch)() } } }; } impl FusedMatmulLaunch<'_> { fn matmul_fused<'a, R: Runtime>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, config: &'a FuseBlockConfig, dtypes: MatmulElems, ) -> Result<(), FusedMatmulError> { let lhs_shape = inputs.shape(self.matmul.lhs.data()); let rhs_shape = inputs.shape(self.matmul.rhs.data()); let out_shape = outputs.shape_ref(&config.ref_layout, config.rank); let lhs_strides = inputs.strides(self.matmul.lhs.data()); let lhs_scheme = self.matmul.lhs.scheme(); let rhs_strides = inputs.strides(self.matmul.rhs.data()); let rhs_scheme = self.matmul.rhs.scheme(); if matrix_batch_layout(&lhs_strides, lhs_scheme) == MatrixBatchLayout::HighlyPermuted { return Err(FusedMatmulError::InvalidInput( "Lhs needs to be contiguous, but can't when fusing.", )); } if matrix_batch_layout(&rhs_strides, rhs_scheme) == MatrixBatchLayout::HighlyPermuted { return Err(FusedMatmulError::InvalidInput( "Rhs needs to be contiguous, but can't when fusing.", )); } let mut vector_sizes = MatmulVectorSizes { lhs: inputs.vector_size(self.matmul.lhs.data()), rhs: inputs.vector_size(self.matmul.rhs.data()), out: match &config.ref_layout { RefLayout::Concrete(arg) => match arg { FuseArg::Input(..) => inputs.vector_size(arg), FuseArg::Output(..) => outputs.vector_size(arg), _ => panic!("Invalid ref layout"), }, RefLayout::Virtual(_) => 1, }, }; let address_type = inputs .required_address_type() .max(outputs.required_address_type()); if vector_sizes.out == 1 && (vector_sizes.lhs > 1 || vector_sizes.rhs > 1) { return Err(FusedMatmulError::InvalidInput( "Output vector size of 1 removes the gain from fusion", )); } if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs { vector_sizes.lhs *= scheme.num_quants(); } if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs { vector_sizes.rhs *= scheme.num_quants(); } let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape); let problem = MatmulProblem::from_shapes_and_strides( lhs_shape, rhs_shape, out_shape, lhs_strides, rhs_strides, out_strides, dtypes.as_global_elems(), address_type, self.matmul.lhs.scheme(), self.matmul.rhs.scheme(), )?; match self.selector { FusedMatmulSelector::Simple { multi_rows, tile_matmul, } => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::< R, SimpleAlgorithm, >( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &BlueprintStrategy::Inferred(SimpleArgs { multi_rows }), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), }), FusedMatmulSelector::DoubleBuffering { specialized, tile_matmul, } => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::< R, CyclicDoubleBufferingAlgorithm, >( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized }), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), }), FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => { let row_count = match self.matmul.lhs.precision() { FuseType::F16 | FuseType::BF16 => 8, _ => 4, }; with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::< R, OrderedDoubleBufferingAlgorithm, >( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &BlueprintStrategy::Inferred(OrderedSelectionArgs { row_count: Some(row_count), rows_per_plane: Some(2), partition_k: Some(2), }), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), }) } FusedMatmulSelector::SimpleUnit => { match launch_inner_fix_dtype::( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &Default::default(), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } FusedMatmulSelector::DoubleUnit => { match launch_inner_fix_dtype::( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &Default::default(), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } FusedMatmulSelector::SimpleVecMat => { match launch_inner_fix_dtype::( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &Default::default(), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } FusedMatmulSelector::DoubleVecMat => { match launch_inner_fix_dtype::( client, FusedMatmulInputLaunch::new( inputs, config.clone(), self.matmul.lhs.clone(), self.matmul.rhs.clone(), None, self.matmul.out.clone(), ), outputs, problem, vector_sizes, &Default::default(), ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), } } } } } fn launch_inner_fix_dtype>( client: &ComputeClient, input: FusedMatmulInputLaunch, output: GlobalArgsLaunch, problem: MatmulProblem, vector_sizes: MatmulVectorSizes, blueprint_strategy: &BlueprintStrategy<(), A>, ) -> Result<(), MatmulSetupError> { launch_kernel_virtual::( client, input, output, (), problem, vector_sizes, blueprint_strategy, ) } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/matmul/tune.rs ================================================ use super::optimization::MatmulOptimizationTuneArg; use crate::{ CubeFusionHandle, engine::trace::TuneOutput, optim::matmul::{AcceleratedTileKind, FusedMatmulSelector}, tune::{TuneContext, TuneInput}, }; use burn_fusion::stream::Context; use cubecl::{ AutotuneKey, CubeTuneId, Runtime, tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}, }; use cubek::matmul::{ definition::MatmulKind, launch::{MatmulAutotuneKey, MatmulGlobalScale, should_tune_double_buffering}, }; use serde::{Deserialize, Serialize}; #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] pub struct FusedMatmulAutotuneKey { matmul_key: MatmulAutotuneKey, #[autotune(anchor)] num_out_buffers: usize, #[autotune(anchor)] num_ops: usize, } /// Executes autotune on matmul operations pub fn fused_matmul_autotune( optimization: MatmulOptimizationTuneArg, context: &mut Context>, ) { static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 3; const PRIORITY_HIGH: i8 = 2; const PRIORITY_MEDIUM: i8 = 1; const PRIORITY_MIN: i8 = 0; let cmma = TuneGroup::::new("cmma", |key| { if matches!( key.matmul_key.analysis.kind, MatmulKind::General // Those variants are just because the unit alternatives aren't very good yet. | MatmulKind::VecMat | MatmulKind::MatVec ) { PRIORITY_MAX } else { PRIORITY_MEDIUM } }); let mma = TuneGroup::::new("mma", |key| { if matches!( key.matmul_key.analysis.kind, // General is usually bad, but I think shapes like 16x8196 would be classed as // general and are very good with MMA // Should highly degenerated matrices that aren't VecMat have their own class? MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec ) { PRIORITY_MAX } else { PRIORITY_MEDIUM } }); let odd = TuneGroup::::new("odd", |key| { if key.matmul_key.definition.lhs_pow2_factor == 0 || key.matmul_key.definition.rhs_pow2_factor == 0 { PRIORITY_MAX } else { PRIORITY_MIN } }); let unit = TuneGroup::::new("unit", |key| { if !matches!(key.matmul_key.analysis.kind, MatmulKind::General) || matches!( key.matmul_key.analysis.scale_global, MatmulGlobalScale::Small ) { PRIORITY_MAX } else { PRIORITY_MIN } }); fn double_buffering_priority(key: &FusedMatmulAutotuneKey, max: i8, min: i8) -> i8 { if should_tune_double_buffering(key.num_out_buffers > 1, &key.matmul_key) { max } else { min } } let mut set = TunableSet::new(create_key::, input_gen::) .with(Tunable::new("fused_matmul_fallback", tune_fallback::)); // First one should always work. // Unit matmuls for (selector, double_buf) in [ (FusedMatmulSelector::SimpleUnit, false), (FusedMatmulSelector::DoubleUnit, true), (FusedMatmulSelector::SimpleVecMat, false), (FusedMatmulSelector::DoubleVecMat, true), ] { set = set.with( Tunable::new(selector.name(), move |input| { tune_fused::(input, selector) }) .group(&unit, move |key| match double_buf { true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH), false => PRIORITY_MAX, }), ); } // Accelerated matmuls for (tile_matmul, group) in [ (AcceleratedTileKind::Cmma, &cmma), (AcceleratedTileKind::Mma, &mma), ] { for (selector, double_buf, extra_group) in [ ( FusedMatmulSelector::Simple { multi_rows: false, tile_matmul, }, false, None, ), ( FusedMatmulSelector::Simple { multi_rows: true, tile_matmul, }, false, None, ), ( FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul }, true, None, ), ( FusedMatmulSelector::DoubleBuffering { specialized: false, tile_matmul, }, true, None, ), ( FusedMatmulSelector::DoubleBuffering { specialized: true, tile_matmul, }, true, Some(&odd), ), ] { let mut tunable = Tunable::new(selector.name(), move |input| { tune_fused::(input, selector) }) .group(group, move |key| match double_buf { true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH), false => PRIORITY_MAX, }); if let Some(group) = extra_group { tunable = tunable.group(group, |_| PRIORITY_MAX); } set = set.with(tunable); } } set }); TUNER.execute( &CubeTuneId::new(&optimization.info.client, &optimization.info.device), &optimization.info.client.clone(), tunables, TuneInput::new(context, optimization), ); } pub(crate) fn create_key( input: &TuneInput>, ) -> FusedMatmulAutotuneKey { let opt = input.optimization(); let context = match input.context() { TuneContext::Original(context) => context, TuneContext::Fork(_) => panic!("Not supported when generating key"), }; let lhs = context.tensors.get(&opt.info.matmul.op.lhs.id).unwrap(); let rhs = context.tensors.get(&opt.info.matmul.op.rhs.id).unwrap(); let out = context.tensors.get(&opt.info.matmul.op.out.id).unwrap(); let lhs_strides = context .handles .get_handle(&lhs.id, &burn_ir::TensorStatus::ReadOnly) .strides .clone(); let rhs_strides = context .handles .get_handle(&rhs.id, &burn_ir::TensorStatus::ReadOnly) .strides .clone(); let key = MatmulAutotuneKey::generate( &opt.info.client, &lhs.shape, &rhs.shape, &lhs_strides, &rhs_strides, lhs.dtype.into(), rhs.dtype.into(), out.dtype.into(), opt.info.matmul.lhs.scheme(), opt.info.matmul.rhs.scheme(), ); FusedMatmulAutotuneKey::new(key, opt.info.num_output_buffers(), opt.info.num_ops_fused()) } fn input_gen( _key: &FusedMatmulAutotuneKey, input: &TuneInput>, ) -> TuneInput> { input.clone() } fn tune_fused( input: TuneInput>, selector: FusedMatmulSelector, ) -> Result, String> { let optimization = input.optimization(); let context = input.context(); match context { TuneContext::Original(context) => match optimization.execute_fused(context, selector) { Ok(out) => Ok(out), Err(_) => { return tune_fallback::(input); } }, TuneContext::Fork(mut context_owned) => { optimization.execute_fused(&mut context_owned.as_context(), selector) } } .map_err(|e| format!("{e:?}")) } fn tune_fallback( input: TuneInput>, ) -> Result, String> { let optimization = input.optimization(); let context = input.context(); Ok(match context { TuneContext::Original(context) => optimization.execute_fallback(context), TuneContext::Fork(mut context_owned) => { optimization.execute_fallback(&mut context_owned.as_context()) } }) } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/mod.rs ================================================ pub mod elemwise; pub mod matmul; pub mod reduce; pub mod reduce_broadcasted; mod base; pub use base::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce/args.rs ================================================ use crate::engine::codegen::{ io::{ref_buffer_len, ref_len, ref_shape, ref_stride, ref_vector_size}, ir::{FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsExpand, LocalArgs, LocalArgsExpand}, kernel::{fuse_on_read, fuse_on_write, init_locals}, }; use cubecl::prelude::*; use cubek::reduce::components::args::{ReduceArgs, ReduceDType}; #[derive(Clone)] pub struct FusedReduceArgs; #[derive(CubeType, CubeLaunch)] pub struct FusedReduceInput { pub global: GlobalArgs, #[cube(comptime)] pub config: FuseBlockConfig, #[cube(comptime)] pub arg: FuseArg, } #[derive(CubeType, CubeLaunch)] pub struct FusedReduceOutput { pub global: GlobalArgs, #[cube(comptime)] pub config: FuseBlockConfig, #[cube(comptime)] pub arg: FuseArg, } pub struct FusedReduceState { inputs: *const GlobalArgs, outputs: *mut GlobalArgs, locals_on_read: *mut LocalArgs, locals_on_write: *mut LocalArgs, config_on_read: FuseBlockConfig, config_on_write: FuseBlockConfig, // TODO: Should be a list when multiple blocks are there. input: FuseArg, out: FuseArg, } #[derive(Clone)] pub struct FusedReduceStateExpand { inputs: GlobalArgsExpand, outputs: GlobalArgsExpand, locals_on_read: LocalArgsExpand, locals_on_write: LocalArgsExpand, config_on_read: FuseBlockConfig, config_on_write: FuseBlockConfig, input: FuseArg, out: FuseArg, } #[cube] impl ReduceArgs for FusedReduceArgs { type Input = FusedReduceInput; type Output = FusedReduceOutput; type State = FusedReduceState; fn init_state( input: &Self::Input, output: &mut Self::Output, ) -> Self::State

{ let mut locals_read = init_locals(&input.global, &mut output.global, &input.config); let mut locals_write = init_locals(&input.global, &mut output.global, &output.config); // TODO Add stuff from previous blocks to the local of each block. FusedReduceState::new(input, output, &mut locals_read, &mut locals_write) } fn read_input( state: &Self::State

, index: usize, ) -> Vector { let value = fuse_on_read::( unsafe { &(*state.inputs) }, unsafe { &mut (*state.outputs) }, unsafe { &mut (*state.locals_on_read) }, index, comptime! { let mut sequence = Sequence::new(); // TODO: Register local arguments from previous blocks. sequence.push(state.input.clone()); sequence }, &state.config_on_read, )[0]; value } fn read_output( _state: &Self::State

, _index: usize, ) -> Vector { Vector::empty() } fn write_output( state: &mut Self::State

, index: usize, value: Vector, ) { let mut values = Registry::>::new(); let mut args = comptime![Vec::::new()]; values.insert(comptime![state.out.clone()], value); comptime![args.push(state.out.clone())]; fuse_on_write( unsafe { &(*state.inputs) }, unsafe { &mut (*state.outputs) }, unsafe { &mut (*state.locals_on_write) }, index, values, args, &state.config_on_write, ); } fn len_input(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, unsafe { &(*state.locals_on_read) }, &state.config_on_read, ) } fn len_output(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, unsafe { &(*state.locals_on_write) }, &state.config_on_write, ) } fn buffer_len_input(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, unsafe { &(*state.locals_on_read) }, &state.config_on_read, ) } fn buffer_len_output(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, unsafe { &(*state.locals_on_write) }, &state.config_on_write, ) } fn rank_input(state: &Self::State

) -> usize { state.config_on_read.rank.runtime() } fn rank_output(state: &Self::State

) -> usize { state.config_on_write.rank.runtime() } fn shape_input(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_read) }, dim) } fn shape_output(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_write) }, dim) } fn stride_input(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_read) }, dim) } fn stride_output(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_write) }, dim) } fn vector_size_input(state: &Self::State

) -> comptime_type!(VectorSize) { ref_vector_size(unsafe { &(*state.locals_on_read) }) } fn vector_size_output(state: &Self::State

) -> comptime_type!(VectorSize) { ref_vector_size(unsafe { &(*state.locals_on_write) }) } } #[cube] impl FusedReduceState { pub fn new( inputs: &FusedReduceInput, outputs: &mut FusedReduceOutput, locals_on_read: &mut LocalArgs, locals_on_write: &mut LocalArgs, ) -> FusedReduceState { FusedReduceState { inputs: &inputs.global, outputs: &mut outputs.global, locals_on_read, locals_on_write, config_on_read: comptime![inputs.config.clone()], config_on_write: comptime![outputs.config.clone()], input: comptime![inputs.arg.clone()], out: comptime![outputs.arg.clone()], } } } impl CubeType for FusedReduceState { type ExpandType = FusedReduceStateExpand; } impl IntoMut for FusedReduceStateExpand { fn into_mut(self, _context: &mut Scope) -> Self { self } } impl CubeDebug for FusedReduceStateExpand {} ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce/fuser.rs ================================================ use super::{ ReduceSettings, optimization::{FusedReduce, ReduceInstruction, ReduceOptimization}, }; use crate::{ engine::{ codegen::ir::FuseType, fuser::TraceOperationFuser, settings::{FuseSettings, RefLayoutSetting, VectorizationSetting}, }, optim::CubeOptimization, }; use burn_fusion::{FuserStatus, OperationFuser}; use burn_ir::{NumericOperationIr, OperationIr, ReduceDimOpIr}; use burn_std::Shape; use cubecl::Runtime; /// Fuses element wise operations around a reduce operation. pub struct ReduceFuser { pub(crate) fuser: TraceOperationFuser, pub(crate) fuser_read_fallback: TraceOperationFuser, fuser_write_fallback: TraceOperationFuser, settings_write: FuseSettings, pub(crate) device: R::Device, pub(crate) reduce: Option, settings: ReduceSettings, } impl Clone for ReduceFuser { fn clone(&self) -> Self { Self { fuser: self.fuser.clone(), fuser_read_fallback: self.fuser_read_fallback.clone(), fuser_write_fallback: self.fuser_write_fallback.clone(), settings_write: self.settings_write, device: self.device.clone(), reduce: self.reduce.clone(), settings: self.settings, } } } #[derive(Debug)] pub enum ReduceFuserInfo { FusedReduce { shape_input_id: Shape, axis: usize }, FusedElemwise { shape_id: Shape }, } impl ReduceFuser { pub fn new(device: R::Device, settings: ReduceSettings) -> Self { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware.max_bindings; let settings_read = FuseSettings { // Inplace would work, but not when we have a concrete output to write too. inplace: true, ref_layout: RefLayoutSetting::OnlyContiguous, broadcast: false, output_shape_updates: true, vectorization: VectorizationSetting::Activated, }; let settings_write = FuseSettings { inplace: false, output_shape_updates: false, vectorization: VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos: 0 }, broadcast: false, ref_layout: RefLayoutSetting::OnlyContiguous, }; let settings_fallback = FuseSettings::default(); Self { fuser: TraceOperationFuser::new(max_bindings, settings_read), fuser_read_fallback: TraceOperationFuser::new(max_bindings, settings_fallback), fuser_write_fallback: TraceOperationFuser::new(max_bindings, settings_fallback), settings_write, device, reduce: None, settings, } } pub fn reduce_info(&self) -> ReduceFuserInfo { match &self.reduce { Some(reduce) => { let shape_input_id = reduce.op.input.shape.clone(); let axis = reduce.axis; ReduceFuserInfo::FusedReduce { shape_input_id, axis, } } None => { let shape_id = self.fuser_read_fallback.current_output_shape.clone(); ReduceFuserInfo::FusedElemwise { shape_id } } } } fn on_reduce(&mut self, op: &ReduceDimOpIr, inst: ReduceInstruction) { // TODO: Fix: we need to have fuse-on-read with an identity block. // // if self.fuser.num_ops == 0 && false { // self.fuser.current_output_shape = op.input.shape.dims.clone(); // } else if self.fuser.current_output_shape != op.input.shape.dims { if self.fuser.current_output_shape != op.input.shape { self.fuser.close(); self.fuser_read_fallback.close(); return; } let [input] = self .fuser .next_block([&op.input], self.settings_write, false); let output = self.fuser.output_unhandled(&op.out); let axis = op.axis; let fuse_on_write_activated = match self.settings { ReduceSettings::Always => true, // We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise // vectorization is impossible. Only [VectorizationMode::Perpendicular] supports vectorization. // // We could still fuse some output operations, but it would probably lead to worse performance. ReduceSettings::OnlyParallel => axis != op.input.shape.rank() - 1, ReduceSettings::Never => false, }; if !fuse_on_write_activated { self.fuser.close(); } let acc = match inst { ReduceInstruction::Mean | ReduceInstruction::Prod | ReduceInstruction::Sum => { match input.precision() { FuseType::F16 | FuseType::BF16 => FuseType::F32, FuseType::I16 | FuseType::I8 => FuseType::I32, FuseType::U16 | FuseType::U8 => FuseType::U32, _ => input.precision(), } } _ => input.precision(), }; self.reduce = Some(FusedReduce { input, output, acc, axis, op: op.clone(), use_planes: false, shared: false, inst, }); self.fuser_read_fallback.close(); } fn on_elemwise_read(&mut self, operation: &OperationIr) { let can_register = self.fuser.can_fuse(operation) && self.fuser_read_fallback.can_fuse(operation); match can_register { true => { self.fuser.fuse(operation); self.fuser_read_fallback.fuse(operation); } false => { self.fuser.close(); self.fuser_read_fallback.close(); } }; } fn on_elemwise_write(&mut self, operation: &OperationIr) { let can_register = self.fuser.can_fuse(operation) && self.fuser_write_fallback.can_fuse(operation); match can_register { true => { self.fuser.fuse(operation); self.fuser_write_fallback.fuse(operation); } false => { self.fuser.close(); self.fuser_write_fallback.close(); } }; } } impl OperationFuser> for ReduceFuser { fn fuse(&mut self, operation: &OperationIr) { if let FuserStatus::Closed = self.fuser.status() { return; } if self.reduce.is_none() { if let OperationIr::NumericFloat(_, op) = operation { match op { NumericOperationIr::SumDim(op) => { self.on_reduce(op, ReduceInstruction::Sum); } NumericOperationIr::MeanDim(op) => { self.on_reduce(op, ReduceInstruction::Mean); } NumericOperationIr::ProdDim(op) => { self.on_reduce(op, ReduceInstruction::Prod); } NumericOperationIr::ArgMax(op) => { self.on_reduce(op, ReduceInstruction::ArgMax); } NumericOperationIr::ArgMin(op) => { self.on_reduce(op, ReduceInstruction::ArgMin); } NumericOperationIr::MinDim(op) => { self.on_reduce(op, ReduceInstruction::Min); } NumericOperationIr::MaxDim(op) => { self.on_reduce(op, ReduceInstruction::Max); } NumericOperationIr::MaxAbsDim(op) => { self.on_reduce(op, ReduceInstruction::MaxAbs); } _ => { self.on_elemwise_read(operation); } }; } else if let OperationIr::NumericInt(_, op) = operation { match op { NumericOperationIr::SumDim(op) => { self.on_reduce(op, ReduceInstruction::Sum); } NumericOperationIr::MeanDim(op) => { self.on_reduce(op, ReduceInstruction::Mean); } NumericOperationIr::ProdDim(op) => { self.on_reduce(op, ReduceInstruction::Prod); } NumericOperationIr::ArgMax(op) => { self.on_reduce(op, ReduceInstruction::ArgMax); } NumericOperationIr::ArgMin(op) => { self.on_reduce(op, ReduceInstruction::ArgMin); } NumericOperationIr::MinDim(op) => { self.on_reduce(op, ReduceInstruction::Min); } NumericOperationIr::MaxDim(op) => { self.on_reduce(op, ReduceInstruction::Max); } NumericOperationIr::MaxAbsDim(op) => { self.on_reduce(op, ReduceInstruction::MaxAbs); } _ => { self.on_elemwise_read(operation); } }; } else { self.on_elemwise_read(operation); } } else { self.on_elemwise_write(operation); } } fn finish(&mut self) -> CubeOptimization { let client = R::client(&self.device); let trace = self.fuser.finish(); let trace_read_fallback = self.fuser_read_fallback.finish(); let trace_write_fallback = self.fuser_write_fallback.finish(); let fuse_reduce = self.reduce.as_ref().unwrap(); let reduce = ReduceOptimization::new( trace, trace_read_fallback, trace_write_fallback, client, self.device.clone(), self.len(), self.fuser_read_fallback.len(), fuse_reduce.clone(), self.settings, ); CubeOptimization::Reduce(reduce) } fn reset(&mut self) { self.fuser.reset(); self.fuser_read_fallback.reset(); self.fuser_write_fallback.reset(); self.reduce = None; } fn status(&self) -> burn_fusion::FuserStatus { self.fuser.status() } fn properties(&self) -> burn_fusion::FuserProperties { let mut properties = self.fuser.properties(); properties.ready = self.reduce.is_some(); properties } fn len(&self) -> usize { self.fuser.len() + if self.reduce.is_some() { 1 } else { 0 } } fn clone_dyn(&self) -> Box>> { Box::new(self.clone()) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce/mod.rs ================================================ mod fuser; mod optimization; pub(crate) mod args; pub(crate) mod tune; pub use fuser::*; pub use optimization::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs ================================================ use super::args::{ FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch, }; #[cfg(feature = "autotune")] use super::tune::fused_reduce_autotune; use crate::{ CubeFusionHandle, FallbackOperation, engine::{ codegen::ir::{ FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout, multi_block_variables_init, }, launch::{ FuseTraceLauncher, runner::{TraceRunner, Vectorization}, }, trace::{FuseTrace, TraceError, TuneOutput}, }, optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs}, }; use burn_fusion::stream::Context; use burn_ir::ReduceDimOpIr; use burn_std::DType; use cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*}; use cubek::reduce::{ ReduceDtypes, ReduceError, VectorizationMode, components::instructions::ReduceOperationConfig, init_tensors, launch::{RoutineStrategy, reduce_kernel_virtual}, routines::{ ReduceBlueprint, ReduceLaunchSettings, ReduceProblem, ReduceVectorSettings, Routine, cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine, }, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; #[cfg(not(feature = "autotune"))] use cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy}; pub struct ReduceOptimization { pub(crate) info: Arc>, } pub(crate) struct ReduceOptimizationInfo { pub(crate) trace: FuseTrace, trace_read_fallback: FuseTrace, trace_write_fallback: FuseTrace, pub(crate) client: ComputeClient, pub(crate) device: R::Device, pub(crate) len: usize, pub(crate) len_read: usize, pub(crate) reduce: FusedReduce, settings: ReduceSettings, } impl ReduceOptimizationInfo { pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self { let client = R::client(device); Self { trace: state.trace, trace_read_fallback: state.trace_read_fallback, trace_write_fallback: state.trace_write_fallback, client, device: device.clone(), len: state.len, len_read: state.len_read, reduce: state.reduce, settings: state.settings, } } pub fn to_state(&self) -> ReduceOptimizationState { ReduceOptimizationState { trace: self.trace.clone(), trace_read_fallback: self.trace_read_fallback.clone(), trace_write_fallback: self.trace_write_fallback.clone(), len: self.len, len_read: self.len_read, reduce: self.reduce.clone(), settings: self.settings, } } } #[derive(Serialize, Deserialize, Copy, Clone)] pub enum ReduceSettings { Always, /// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise /// vectorization is impossible. Only [VectorizationMode::Perpendicular] supports vectorization. /// /// We could still fuse some output operations, but it would probably lead to worse performance. OnlyParallel, Never, } pub(crate) struct ReduceOptimizationTuneArg { pub(crate) info: Arc>, pub(crate) fallback: Arc>>, } impl Clone for ReduceOptimizationTuneArg { fn clone(&self) -> Self { Self { info: self.info.clone(), fallback: self.fallback.clone(), } } } #[derive(Clone, Copy, Serialize, Deserialize, Debug)] pub enum ReduceInstruction { ArgMax, ArgMin, Mean, Prod, Sum, Max, Min, MaxAbs, } pub trait ReduceFallbackFn: Send + Sync { fn run(&self, context: &mut Context<'_, CubeFusionHandle>); } #[derive(Serialize, Deserialize)] pub struct ReduceOptimizationState { pub(crate) trace: FuseTrace, pub(crate) trace_read_fallback: FuseTrace, pub(crate) trace_write_fallback: FuseTrace, pub(crate) reduce: FusedReduce, pub(crate) len: usize, pub(crate) len_read: usize, pub(crate) settings: ReduceSettings, } impl core::fmt::Debug for ReduceOptimizationState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "{{ len_read: {}, len_total: {} }}", self.len_read, self.len )) } } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct FusedReduce { pub(crate) input: FuseArg, pub(crate) output: FuseArg, pub(crate) acc: FuseType, pub(crate) axis: usize, pub(crate) op: ReduceDimOpIr, pub(crate) use_planes: bool, pub(crate) shared: bool, pub(crate) inst: ReduceInstruction, } #[derive(new)] pub struct FusedReduceLaunch<'a> { reduce: &'a FusedReduce, strategy: RoutineStrategy, } #[derive(Debug)] pub enum FusedReduceError { Reduce(ReduceError), InvalidSelection(Box<&'static str>), InvalidInput, } impl From for FusedReduceError { fn from(value: ReduceError) -> Self { Self::Reduce(value) } } impl ReduceOptimizationTuneArg { pub fn execute_fused( &self, context: &mut Context<'_, CubeFusionHandle>, strategy: RoutineStrategy, ) -> Result, TraceError> { let launch = FusedReduceLaunch::new(&self.info.reduce, strategy); let launcher = FuseTraceLauncher::new(&self.info.trace, &launch); launcher.launch(&self.info.client, &self.info.device, context) } pub fn execute_fallback( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> TuneOutput { let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner); #[allow(unused_mut)] // It is used when `autotune-checks` is activated. let mut output_read = launcher .launch(&self.info.client, &self.info.device, context) .unwrap(); self.fallback.run(context); #[cfg(feature = "autotune-checks")] if let TuneOutput::Checked { handles } = &mut output_read { let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap(); let handle_out = context .handles .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly); handles.insert( self.info.reduce.op.out.id, (out_desc.shape.dims.clone(), handle_out.clone()), ); } let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner); let output_write = launcher .launch(&self.info.client, &self.info.device, context) .unwrap(); output_read.merge(output_write) } } #[allow(clippy::too_many_arguments)] impl ReduceOptimization { pub fn new( trace: FuseTrace, trace_read_fallback: FuseTrace, trace_write_fallback: FuseTrace, client: ComputeClient, device: R::Device, len: usize, len_read: usize, reduce: FusedReduce, settings: ReduceSettings, ) -> Self { let info = ReduceOptimizationInfo { trace, trace_read_fallback, trace_write_fallback, client, device, len, len_read, reduce, settings, }; Self { info: Arc::new(info), } } /// Execute the optimization. pub fn execute( &mut self, context: &mut Context<'_, CubeFusionHandle>, fallback: impl FnOnce(usize) -> Box>, ) { // The index of the fallback reduce is the number of ops fused as read. let fallback = fallback(self.info.len_read); let arg = ReduceOptimizationTuneArg { info: self.info.clone(), fallback: Arc::new(fallback), }; #[cfg(feature = "autotune")] fused_reduce_autotune::(arg, context); #[cfg(not(feature = "autotune"))] if arg .execute_fused( context, RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)), ) .is_err() { arg.execute_fallback(context); } } pub fn num_output_buffers(&self) -> usize { self.info.trace_read_fallback.resources.outputs.len() } pub fn to_state(&self) -> ReduceOptimizationState { ReduceOptimizationState { trace: self.info.trace.clone(), trace_read_fallback: self.info.trace_read_fallback.clone(), trace_write_fallback: self.info.trace_write_fallback.clone(), reduce: self.info.reduce.clone(), len: self.info.len, len_read: self.info.len_read, settings: self.info.settings, } } pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self { let client = R::client(device); let info = ReduceOptimizationInfo { trace: state.trace, trace_read_fallback: state.trace_read_fallback, trace_write_fallback: state.trace_write_fallback, reduce: state.reduce, len: state.len, len_read: state.len_read, client, device: device.clone(), settings: state.settings, }; Self { info: Arc::new(info), } } /// Returns the number of output buffers added by fusion. pub fn num_ops_fused(&self) -> usize { self.info.len } } // TODO: Implement better vectorization here. impl Vectorization for FusedReduceLaunch<'_> {} impl TraceRunner for FusedReduceLaunch<'_> { type Error = FusedReduceError; fn run<'a>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, configs: &'a [FuseBlockConfig], ) -> Result<(), FusedReduceError> { let [config_read, config_write] = [&configs[0], &configs[1]]; let shape = match &config_read.ref_layout { RefLayout::Concrete(FuseArg::Output(..)) => { outputs.shape_ref(&config_read.ref_layout, config_read.rank) } _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank), }; let reduce_count: usize = shape .iter() .enumerate() .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s }) .product(); let vectorization_mode = match self.reduce.axis == config_read.rank - 1 { true => VectorizationMode::Parallel, false => VectorizationMode::Perpendicular, }; let address_type = inputs .required_address_type() .max(outputs.required_address_type()); let settings = ReduceVectorSettings { vectorization_mode, vector_size_input: config_read.width, vector_size_output: config_write.width, }; let problem = ReduceProblem { vector_size: shape[self.reduce.axis], vector_count: reduce_count, axis: self.reduce.axis, dtypes: ReduceDtypes { input: self.reduce.op.input.dtype.into(), output: self.reduce.op.out.dtype.into(), accumulation: self.reduce.acc.into_elem().into(), }, address_type, }; let (blueprint, settings) = match self.strategy.clone() { RoutineStrategy::Unit(strategy) => { let routine = UnitRoutine; routine.prepare(client, problem, settings, strategy)? } RoutineStrategy::Plane(strategy) => { let routine = PlaneRoutine; routine.prepare(client, problem, settings, strategy)? } RoutineStrategy::Cube(strategy) => { let routine = CubeRoutine; routine.prepare(client, problem, settings, strategy)? } }; let kwargs = ReduceKwArgs { client, inputs, outputs, axis: self.reduce.axis, config_fuse_read: config_read.clone(), config_fuse_write: config_write.clone(), input: self.reduce.input.clone(), output: self.reduce.output.clone(), blueprint, settings, }; let result = launch_reduce_mixed_precision( kwargs, self.reduce.inst, self.reduce.op.input.dtype, self.reduce.op.out.dtype, DType::from(self.reduce.acc.into_elem()), ); match result { Ok(_) => Ok(()), Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))), } } } struct ReduceKwArgs<'b, Run: Runtime> { client: &'b ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, axis: usize, blueprint: ReduceBlueprint, settings: ReduceLaunchSettings, config_fuse_read: FuseBlockConfig, config_fuse_write: FuseBlockConfig, input: FuseArg, output: FuseArg, } fn launch_reduce_mixed_precision( kwargs: ReduceKwArgs<'_, Run>, instruction: ReduceInstruction, dtype_input: DType, dtype_output: DType, dtype_acc: DType, ) -> Result<(), LaunchError> { let config = match instruction { ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax, ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin, ReduceInstruction::Prod => ReduceOperationConfig::Prod, ReduceInstruction::Mean => ReduceOperationConfig::Mean, ReduceInstruction::Sum => ReduceOperationConfig::Sum, ReduceInstruction::Max => ReduceOperationConfig::Max, ReduceInstruction::Min => ReduceOperationConfig::Min, ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs, }; launch_reduce::(kwargs, config, dtype_input, dtype_output, dtype_acc) } fn launch_reduce( kwargs: ReduceKwArgs<'_, Run>, inst: ReduceOperationConfig, dtype_input: DType, dtype_output: DType, dtype_acc: DType, ) -> Result<(), LaunchError> { unsafe { reduce_kernel_fused::launch_unchecked::( kwargs.client, kwargs.settings.cube_count, kwargs.settings.cube_dim, kwargs.settings.address_type, kwargs.config_fuse_read.width, kwargs.config_fuse_write.width, FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input), FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output), kwargs.axis, kwargs.blueprint, inst, dtype_input.into(), dtype_output.into(), dtype_acc.into(), ) }; Ok(()) } #[cube(launch_unchecked, address_type = "dynamic")] pub fn reduce_kernel_fused( input: &FusedReduceInput, output: &mut FusedReduceOutput, axis_reduce: usize, #[comptime] blueprint: ReduceBlueprint, #[comptime] config: ReduceOperationConfig, #[define(In)] _input_dtype: StorageType, #[define(Out)] _output_dtype: StorageType, #[define(Acc)] _acc_dtype: StorageType, ) { multi_block_variables_init(&input.config, &mut output.global.variables); multi_block_variables_init(&output.config, &mut output.global.variables); let (input, mut output) = init_tensors::(input, output); reduce_kernel_virtual::( &input, &mut output, axis_reduce, blueprint, config, ); } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce/tune.rs ================================================ use super::optimization::ReduceOptimizationTuneArg; use crate::{ CubeFusionHandle, engine::trace::TuneOutput, tune::{TuneContext, TuneInput}, }; use burn_fusion::stream::Context; use cubecl::{ AutotuneKey, CubeTuneId, Runtime, tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}, }; use cubek::reduce::{ launch::{RoutineStrategy, tune_key::ReduceAutotuneKey}, routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy}, }; use serde::{Deserialize, Serialize}; /// Autotune key for standard fused reduction operations. /// /// Records metadata about the fusion graph (IO and ops) alongside /// the core reduction parameters to ensure stable kernel selection. #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] pub struct FusedReduceAutotuneKey { reduce_key: ReduceAutotuneKey, #[autotune(anchor)] fuse_num_reads: usize, #[autotune(anchor)] fuse_num_writes: usize, #[autotune(anchor)] fuse_num_ops: usize, } /// Executes autotuning for fused reduction operations. /// /// This tuner evaluates different hardware-specific strategies (Plane, Cube, Unit) /// and assigns priorities based on the `vector_count` of the reduction. pub fn fused_reduce_autotune( arg: ReduceOptimizationTuneArg, context: &mut Context>, ) { static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 2; const PRIORITY_MIN: i8 = 1; let mut set = TunableSet::new(create_key::, input_gen::); let group = TuneGroup::::new("fused_reduce", |_key| PRIORITY_MAX); // Fallback implementation for robustness. set = set.with(Tunable::new("fused_reduce_fallback", tune_fallback::)); // Define properties to categorize hardware strategies. enum ReduceProps { GreatWithLowReduceCount, GreatWithHighReduceCount, Balanced, } let strategies = [ ( "fused_unit", RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)), ReduceProps::GreatWithHighReduceCount, ), ( "fused_plane", RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy { independent: true, })), ReduceProps::Balanced, ), ( "fused_cube", RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy { // Two steps reduction doesn't work with fuse-on-write, we can't activate plane // when using the cube algo. use_planes: false, })), ReduceProps::GreatWithLowReduceCount, ), ]; for (name, strategy, props) in strategies { let tunable = Tunable::new(name, move |input| tune_reduce::(input, &strategy)) .group(&group, move |key| match props { ReduceProps::GreatWithLowReduceCount => { if key.reduce_key.vector_count < 128 { PRIORITY_MAX } else { PRIORITY_MIN } } ReduceProps::GreatWithHighReduceCount => { if key.reduce_key.vector_count > 64 { PRIORITY_MAX } else { PRIORITY_MIN } } ReduceProps::Balanced => PRIORITY_MAX, }); set = set.with(tunable); } set }); TUNER.execute( &CubeTuneId::new(&arg.info.client, &arg.info.device), &arg.info.client.clone(), tunables, TuneInput::new(context, arg), ); } /// Creates the autotune key by extracting tensor metadata and fusion block statistics. pub(crate) fn create_key( input: &TuneInput>, ) -> FusedReduceAutotuneKey { let opt = input.optimization(); let context = match input.context() { TuneContext::Original(context) => context, TuneContext::Fork(_) => panic!("Forked context not supported for key generation"), }; let input_tensor = context.tensors.get(&opt.info.reduce.op.input.id).unwrap(); let out_tensor = context.tensors.get(&opt.info.reduce.op.out.id).unwrap(); let acc = opt.info.reduce.acc.into_elem(); let key = ReduceAutotuneKey::generate( input_tensor.dtype.into(), out_tensor.dtype.into(), acc, &input_tensor.shape, opt.info.reduce.axis == input_tensor.shape.rank() - 1, opt.info.reduce.axis, ); // Assume the fusion contains at least a read and a write block. let read_block = &opt.info.trace.blocks[0]; let write_block = &opt.info.trace.blocks[1]; FusedReduceAutotuneKey::new( key, read_block.reads.len() + write_block.reads.len(), read_block.writes.len() + write_block.writes.len(), read_block.ops.len() + write_block.ops.len(), ) } /// Identity generator for tuning inputs. fn input_gen( _key: &FusedReduceAutotuneKey, input: &TuneInput>, ) -> TuneInput> { input.clone() } /// Executes a fused reduction optimization. fn tune_reduce( input: TuneInput>, strategy: &RoutineStrategy, ) -> Result, String> { let optimization = input.optimization(); match input.context() { TuneContext::Original(context) => optimization.execute_fused(context, strategy.clone()), TuneContext::Fork(mut context_owned) => { optimization.execute_fused(&mut context_owned.as_context(), strategy.clone()) } } .map_err(|e| format!("{e:?}")) } /// Executes the fallback path for a reduction optimization. fn tune_fallback( input: TuneInput>, ) -> Result, String> { let optimization = input.optimization(); match input.context() { TuneContext::Original(context) => optimization.execute_fallback(context), TuneContext::Fork(mut context_owned) => { optimization.execute_fallback(&mut context_owned.as_context()) } }; Ok(TuneOutput::UnChecked(std::marker::PhantomData)) } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/base.rs ================================================ use crate::optim::{ CubeOptimization, reduce::{ReduceFuser, ReduceFuserInfo, ReduceSettings}, reduce_broadcasted::{ ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationInfo, fuser::{ block::{ReduceBlockFuser, ReduceBlockFusionAnalysis, ReduceBroadcastedStatus}, full::ReduceBroadcastedFullFuser, full_analyzer::FullFuserAnalyzer, }, }, }; use burn_fusion::{FuserProperties, FuserStatus, OperationFuser}; use burn_ir::OperationIr; use cubecl::Runtime; use std::sync::Arc; /// Fuses element wise operations around a reduce operation. pub struct ReduceBroadcastedFuser { blocks: Vec>, fuser_default: ReduceFuser, num_ops: usize, state: ReduceBroadcastedStatus, max_bindings: u32, } impl Clone for ReduceBroadcastedFuser { fn clone(&self) -> Self { Self { blocks: self.blocks.clone(), fuser_default: self.fuser_default.clone(), num_ops: self.num_ops, state: self.state.clone(), max_bindings: self.max_bindings, } } } impl ReduceBroadcastedFuser { pub fn new(device: R::Device) -> Self { let fuser = ReduceFuser::new(device, ReduceSettings::Always); let max_bindings = fuser.fuser.max_bindings; let block = ReduceBlockFuser::new(fuser.clone()); Self { blocks: vec![block], fuser_default: fuser, num_ops: 0, state: ReduceBroadcastedStatus::Starting, max_bindings, } } } impl OperationFuser> for ReduceBroadcastedFuser { fn fuse(&mut self, operation: &OperationIr) { if matches!( &self.state, ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort ) { return; } let block = self.blocks.last_mut().unwrap(); let analyze = block.analyze(operation, &self.state, &self.fuser_default); let info = match analyze { ReduceBlockFusionAnalysis::Accept => { block.fuse(operation); self.num_ops += 1; block.fuser.reduce_info() } ReduceBlockFusionAnalysis::Refuse => { self.state = ReduceBroadcastedStatus::Closed; return; } ReduceBlockFusionAnalysis::NewBlockRequired => { let info = block.fuser.reduce_info(); let mut block = ReduceBlockFuser::new(self.fuser_default.clone()); block.fuse(operation); self.num_ops += 1; self.blocks.push(block); info } }; match info { ReduceFuserInfo::FusedReduce { shape_input_id, axis, } => { // Only support last axis for now. if axis != shape_input_id.len() - 1 { self.state = ReduceBroadcastedStatus::Abort; } else { self.state = ReduceBroadcastedStatus::Init { shape_id: shape_input_id, axis, }; } } ReduceFuserInfo::FusedElemwise { .. } => {} } } fn finish(&mut self) -> CubeOptimization { let analyzer = FullFuserAnalyzer::new(&self.blocks); let mut full = ReduceBroadcastedFullFuser::new(self.max_bindings, analyzer); let mut num_ops = 0; let fallbacks = self .blocks .iter_mut() .map(|block| block.finish(&mut num_ops, &mut full)) .collect::>(); let broadcasted = Arc::new(full.finish()); let info = Arc::new(ReduceBroadcastedOptimizationInfo { fallbacks, broadcasted, }); CubeOptimization::ReduceBroadcasted(ReduceBroadcastedOptimization { info, num_ops }) } fn reset(&mut self) { let block = ReduceBlockFuser::new(self.fuser_default.clone()); self.blocks = vec![block]; self.num_ops = 0; self.state = ReduceBroadcastedStatus::Starting; } fn status(&self) -> FuserStatus { match self.state { ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort => { return FuserStatus::Closed; } _ => {} }; let fuser = self.blocks.last().unwrap(); fuser.fuser.status() } fn properties(&self) -> FuserProperties { let ready = match self.state { ReduceBroadcastedStatus::Starting | ReduceBroadcastedStatus::Abort => false, ReduceBroadcastedStatus::Closed => { if self.blocks.len() == 1 { !self.blocks[0].is_elemwise() } else { true } } _ => true, }; let mut props = FuserProperties { score: 0, ready }; for block in self.blocks.iter() { let p = block.properties(); props.score += p.score; props.ready = p.ready && props.ready; } props } fn len(&self) -> usize { self.num_ops } fn clone_dyn(&self) -> Box>> { Box::new(self.clone()) } } #[cfg(test)] mod tests { use burn_ir::{ BaseOperationIr, BinaryOpIr, CreationOpIr, ReduceDimOpIr, TensorId, TensorIr, TensorStatus, }; use burn_std::{DType, Shape}; use super::*; type Run = cubecl::TestRuntime; #[test] fn reduce_broadcast_workflow_1() { let device: ::Device = Default::default(); let mut fuser = ReduceBroadcastedFuser::::new(device); let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite); let (tensor2_out, tensor2) = tensor(1, &[1, 0], TensorStatus::ReadWrite); fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones( CreationOpIr { out: tensor1_out }, ))); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr { input: tensor1, out: tensor2_out, axis: 1, }), )); let status = fuser.status(); assert_eq!(2, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); // An existing tensor let (_tensor3_out, tensor3) = tensor(2, &[1, 0], TensorStatus::ReadWrite); // A new tensor let (tensor4_out, tensor4) = tensor(3, &[1, 0], TensorStatus::ReadWrite); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::Add(BinaryOpIr { lhs: tensor2, rhs: tensor3, out: tensor4_out, }), )); let status = fuser.status(); assert_eq!(3, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); // An existing tensor let (_tensor5_out, tensor5) = tensor(4, &[1, 2], TensorStatus::ReadWrite); // A new tensor let (tensor6_out, tensor6) = tensor(5, &[1, 2], TensorStatus::ReadWrite); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::Add(BinaryOpIr { lhs: tensor4, rhs: tensor5, out: tensor6_out, }), )); let status = fuser.status(); assert_eq!(4, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); let (tensor7_out, _tensor7) = tensor(6, &[1, 0], TensorStatus::ReadWrite); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr { input: tensor6, out: tensor7_out, axis: 1, }), )); assert_eq!(5, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); let _optimization = fuser.finish(); } #[test] fn reduce_broadcast_workflow_2() { let device: ::Device = Default::default(); let mut fuser = ReduceBroadcastedFuser::::new(device); let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite); // An existing tensor let (_tensor2_out, mut tensor2) = tensor(2, &[1, 2], TensorStatus::ReadOnly); let (tensor3_out, tensor3) = tensor(3, &[1, 2], TensorStatus::ReadWrite); // First reduce output let (tensor4_out, tensor4) = tensor(1, &[1, 0], TensorStatus::ReadWrite); fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones( CreationOpIr { out: tensor1_out }, ))); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::Add(BinaryOpIr { lhs: tensor1, rhs: tensor2.clone(), out: tensor3_out, }), )); fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr { input: tensor3, out: tensor4_out, axis: 1, }), )); let status = fuser.status(); assert_eq!(3, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); // A new tensor let (tensor5_out, _tensor5) = tensor(5, &[1, 2], TensorStatus::ReadWrite); // Last time we use tensor2. tensor2.status = TensorStatus::ReadWrite; fuser.fuse(&OperationIr::NumericFloat( DType::F32, burn_ir::NumericOperationIr::Add(BinaryOpIr { lhs: tensor4, rhs: tensor2, out: tensor5_out, }), )); let status = fuser.status(); assert_eq!(4, fuser.len()); assert_eq!(status, FuserStatus::Open); assert!(fuser.properties().ready,); let _optimization = fuser.finish(); } fn tensor(id: u64, shape: &[usize], status: TensorStatus) -> (TensorIr, TensorIr) { let tensor = TensorIr { id: TensorId::new(id), shape: Shape::from(shape), status: TensorStatus::NotInit, dtype: DType::F32, }; let mut tensor_init = tensor.clone(); tensor_init.status = status; (tensor, tensor_init) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/block.rs ================================================ use crate::optim::{ CubeOptimization, elemwise::ElemwiseOptimization, reduce::{FusedReduce, ReduceFuser, ReduceFuserInfo}, reduce_broadcasted::{ReduceBlockOptimInfo, fuser::full::ReduceBroadcastedFullFuser}, }; use burn_fusion::{FuserProperties, OperationFuser}; use burn_ir::OperationIr; use burn_std::Shape; use cubecl::Runtime; use std::sync::Arc; /// Responsible for fusing a single reduce block or elementwise block. /// /// When the block kind is reduce, it supports fuse-on-read and fuse-on-write fusion. /// Broadcasting isn't supported; another block should handle it instead. pub struct ReduceBlockFuser { /// We use [ReduceFuser] for both elementwise and reduce blocks, keeping only the /// fuse-on-read trace if the block is tagged as elementwise. /// /// # Notes /// /// A single elementwise block can only exist at the end of a full [ReduceBlockFuser], /// otherwise the optimization will be included in the reduce fusion block. pub fuser: ReduceFuser, pub(crate) ops: Vec, pub(crate) kind: ReduceBlockKind, } /// The current state of the fusion process. #[derive(Debug, Clone)] pub enum ReduceBroadcastedStatus { /// Fusion is starting; no reduction has been fused yet. Starting, /// Fusion is initialized with at least one reduce operation. /// /// # Notes /// /// Subsequent reduce operations must be compatible with the previous reduction to fuse. Init { shape_id: Shape, axis: usize }, /// No more operations can be fused. Closed, /// Invalid axis. Abort, } /// The [ReduceBlockFuser] capacity to accept an [OperationIr]. #[derive(Clone, Copy, Debug)] pub enum ReduceBlockFusionAnalysis { /// The operation can be fused; call [ReduceBlockFuser::fuse()]. Accept, /// The operation cannot be fused; the optimization should close. Refuse, /// The operation can be fused, but requires a new block. NewBlockRequired, } impl ReduceBlockFuser { /// Creates a new block. pub fn new(fuser: ReduceFuser) -> Self { Self { fuser: fuser.clone(), ops: Vec::new(), kind: ReduceBlockKind::Elemwise, } } /// Returns true if this is an elementwise fuser. pub fn is_elemwise(&self) -> bool { matches!(self.kind, ReduceBlockKind::Elemwise) } /// Analyzes if fusion is possible within this block. pub fn analyze( &self, op: &OperationIr, status: &ReduceBroadcastedStatus, default_node: &ReduceFuser, ) -> ReduceBlockFusionAnalysis { let mut fuser_try = self.fuser.clone(); let before = fuser_try.len(); fuser_try.fuse(op); let after = fuser_try.len(); if after > before { return ReduceBlockFusionAnalysis::Accept; } // Can't create a new block if the previous one was not a reduction. if self.fuser.reduce.is_none() { return ReduceBlockFusionAnalysis::Refuse; } let mut fuser_try = default_node.clone(); let before = fuser_try.len(); fuser_try.fuse(op); let after = fuser_try.len(); if after > before { let info = fuser_try.reduce_info(); return match (info, status) { ( ReduceFuserInfo::FusedReduce { shape_input_id, axis, }, ReduceBroadcastedStatus::Init { shape_id, axis: axis_init, }, ) => { if shape_id == &shape_input_id && axis_init == &axis { ReduceBlockFusionAnalysis::NewBlockRequired } else { ReduceBlockFusionAnalysis::Refuse } } ( ReduceFuserInfo::FusedElemwise { shape_id }, ReduceBroadcastedStatus::Init { shape_id: shape_init, .. }, ) => { if &shape_id == shape_init { ReduceBlockFusionAnalysis::NewBlockRequired } else { ReduceBlockFusionAnalysis::Refuse } } _ => ReduceBlockFusionAnalysis::Refuse, }; } ReduceBlockFusionAnalysis::Refuse } /// Fuses an operation within this block. /// /// # Warning /// /// Ensure [Self::analyze()] is called before this function to confirm the operation is accepted. pub fn fuse(&mut self, op: &OperationIr) { self.fuser.fuse(op); self.ops.push(op.clone()); // Update the kind if a reduction is introduced to an elementwise block. if let (Some(reduce), ReduceBlockKind::Elemwise) = (&self.fuser.reduce, &self.kind) { self.kind = ReduceBlockKind::Reduce { ops_index: self.ops.len() - 1, reduce: Box::new(reduce.clone()), }; } } /// Computes the fuser properties. pub fn properties(&self) -> FuserProperties { let mut properties = self.fuser.properties(); if let ReduceBlockKind::Elemwise = &self.kind { // Elementwise traces are always ready to run. properties.ready = true; } properties } pub fn finish( &mut self, num_ops: &mut usize, full: &mut ReduceBroadcastedFullFuser, ) -> ReduceBlockOptimInfo { full.register(self); match &self.kind { ReduceBlockKind::Elemwise => { let len = self.fuser.fuser_read_fallback.len(); let device = self.fuser.device.clone(); *num_ops += len; let trace = self.fuser.fuser_read_fallback.finish(); let client = R::client(&device); let elementwise = ElemwiseOptimization::new(trace, client, device, len); ReduceBlockOptimInfo::Elemwise(Arc::new(elementwise)) } ReduceBlockKind::Reduce { .. } => { *num_ops += self.fuser.len(); let optim = self.fuser.finish(); let info = match optim { CubeOptimization::Reduce(optim) => optim.info, _ => unreachable!("Expected Reduce optimization"), }; ReduceBlockOptimInfo::Reduce(info) } } } } #[derive(Clone, Debug)] pub enum ReduceBlockKind { Elemwise, Reduce { ops_index: usize, reduce: Box, }, } impl Clone for ReduceBlockFuser { fn clone(&self) -> Self { Self { fuser: self.fuser.clone(), ops: self.ops.clone(), kind: self.kind.clone(), } } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/full.rs ================================================ use crate::{ engine::{ fuser::TraceOperationFuser, settings::{FuseSettings, RefLayoutSetting, VectorizationSetting}, }, optim::{ reduce::{FusedReduce, ReduceInstruction}, reduce_broadcasted::{ ReduceBroadcastedInfo, fuser::{ block::{ReduceBlockFuser, ReduceBlockKind}, full_analyzer::FullFuserAnalyzer, }, launch::ReduceBroadcastedFuseBlock, }, }, }; use burn_fusion::OperationFuser; use cubecl::Runtime; use cubek::reduce::components::instructions::ReduceOperationConfig; /// Responsible for fusing a single trace for all operations involved in this optimization. pub struct ReduceBroadcastedFullFuser { pub(crate) fuser: TraceOperationFuser, analyzer: FullFuserAnalyzer, blocks: Vec, settings_read: FuseSettings, settings_write: FuseSettings, } impl ReduceBroadcastedFullFuser { /// Creates a new fuser with the given settings. pub fn new(max_bindings: u32, analyzer: FullFuserAnalyzer) -> Self { let settings_read = FuseSettings { output_shape_updates: true, broadcast: true, inplace: false, ref_layout: RefLayoutSetting::OnlyContiguous, vectorization: VectorizationSetting::Activated, }; let settings_write = FuseSettings { output_shape_updates: false, inplace: false, broadcast: false, ref_layout: RefLayoutSetting::OnlyContiguous, // Deactivated for now, but would be cool to support vectorization of the output. vectorization: VectorizationSetting::Deactivated, }; let fuser = TraceOperationFuser::new(max_bindings, settings_read); Self { fuser, blocks: Vec::new(), settings_write, settings_read, analyzer, } } /// Finishes fusing all blocks. pub fn finish(mut self) -> ReduceBroadcastedInfo { let mut reduce_axis = 0; let mut blocks = Vec::new(); for block in self.blocks.iter() { match block { ReduceBlockKind::Elemwise => {} ReduceBlockKind::Reduce { reduce, .. } => { let config = match reduce.inst { ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax, ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin, ReduceInstruction::Prod => ReduceOperationConfig::Prod, ReduceInstruction::Mean => ReduceOperationConfig::Mean, ReduceInstruction::Sum => ReduceOperationConfig::Sum, ReduceInstruction::Max => ReduceOperationConfig::Max, ReduceInstruction::Min => ReduceOperationConfig::Min, ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs, }; let block = ReduceBroadcastedFuseBlock { op: config, input: reduce.input.clone(), output: reduce.output.clone(), }; reduce_axis = reduce.axis; blocks.push(block); } } } let trace = self.fuser.finish(); ReduceBroadcastedInfo { blocks, trace, reduce_axis, } } /// Registers a [ReduceBlockFuser] to build the trace. pub fn register(&mut self, block: &ReduceBlockFuser) { // Helper to close previous blocks if necessary if !self.fuser.is_empty() { let mut settings = self.settings_read; settings.vectorization = VectorizationSetting::EqualThanPreviousBlock { block_pos: 0 }; settings.ref_layout = RefLayoutSetting::SameAsBlock { block_pos: 0 }; self.fuser.next_block([], settings, false); let analysis = self.analyzer.retrieve_next(); for (tensor, block_pos) in analysis.inputs { self.fuser.block_local_input(&tensor, block_pos, false); } } match &block.kind { ReduceBlockKind::Elemwise => { for op in &block.ops { self.fuser.fuse(op); } self.blocks.push(ReduceBlockKind::Elemwise); } ReduceBlockKind::Reduce { ops_index, reduce } => { for op in &block.ops[0..*ops_index] { self.fuser.fuse(op); } let [input] = self .fuser .next_block([&reduce.op.input], self.settings_write, false); let output = self.fuser.output_unhandled(&reduce.op.out); let analysis = self.analyzer.retrieve_next(); // Can be broadcasted so the generated buffer can be global. for (tensor, block_pos) in analysis.inputs { self.fuser.block_local_input(&tensor, block_pos, false); } let fused_reduce = FusedReduce { input, output, acc: reduce.acc, axis: reduce.axis, op: reduce.op.clone(), use_planes: reduce.use_planes, shared: reduce.shared, inst: reduce.inst, }; self.blocks.push(ReduceBlockKind::Reduce { ops_index: *ops_index, reduce: Box::new(fused_reduce), }); for op in &block.ops[*ops_index + 1..block.ops.len()] { self.fuser.fuse(op); } } } } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/full_analyzer.rs ================================================ use super::block::ReduceBlockKind; use crate::optim::reduce_broadcasted::fuser::block::ReduceBlockFuser; use burn_ir::{TensorId, TensorIr}; use cubecl::Runtime; use std::collections::BTreeMap; #[derive(Debug)] pub struct FullFuserAnalyzer { // We need to know the block id of which we can reuse the read local input. analyses: Vec>, } impl FullFuserAnalyzer { pub fn new(blocks: &[ReduceBlockFuser]) -> Self { let mut state = AnalysisState::default(); for block in blocks.iter() { for (pos, op) in block.ops.iter().enumerate() { let potential_from_previous_blocks = op.inputs(); let potential_to_next_blocks = op.outputs(); match &block.kind { ReduceBlockKind::Elemwise => { state.register( potential_from_previous_blocks, potential_to_next_blocks, BlockKind::Full, ); } ReduceBlockKind::Reduce { ops_index, .. } => { if pos < *ops_index { state.register( potential_from_previous_blocks, potential_to_next_blocks, BlockKind::Full, ); } else if pos > *ops_index { state.register( potential_from_previous_blocks, potential_to_next_blocks, BlockKind::Single, ); } else { state.next_block(); } } } } state.next_block(); } // First one is never called. state.analyses.remove(0); Self { analyses: state.analyses, } } pub fn retrieve_next(&mut self) -> FullFuserAnalysis { let inputs = self.analyses.remove(0); FullFuserAnalysis { inputs } } } #[derive(Debug)] pub struct FullFuserAnalysis { /// The tensor received from a previous block. pub inputs: Vec<(TensorIr, usize)>, } #[derive(Default)] struct AnalysisState { /// That pool contains tensors that are available in the fuse-on-write part of a reduce, not /// broadcasted. available_from_previous_single: BTreeMap, /// That pool contains tensors that are available in the fuse-on-read of a reduce and the /// element-wise broadcasted part available_from_previous_full: BTreeMap, block_data: Vec<(TensorIr, usize)>, analyses: Vec>, current_full: Vec, current_single: Vec, } enum BlockKind { Full, Single, } impl AnalysisState { fn next_block(&mut self) { let block_pos = self.analyses.len(); let data = core::mem::take(&mut self.block_data); self.analyses.push(data); // Makes the current tensor reads available for the next block. for p in self.current_single.drain(..) { // We need to keep the earliest block position. self.available_from_previous_single .entry(p.id) .or_insert(block_pos); } for p in self.current_full.drain(..) { // We need to keep the earliest block position. self.available_from_previous_full .entry(p.id) .or_insert(block_pos); } } fn register<'a>( &mut self, potential_from_previous_blocks: impl Iterator, potential_to_next_blocks: impl Iterator, kind: BlockKind, ) { match kind { BlockKind::Full => { for potential in potential_from_previous_blocks { // We can't since it's not in the same scope. // // TODO: Find a way to merge multiple reduce loops. // // if let Some(block_pos) = self.available_from_previous_full.get(&potential.id) { // self.block_data.push((potential.clone(), *block_pos)); // } // We can since it's a broadcast. if let Some(block_pos) = self.available_from_previous_single.get(&potential.id) { self.block_data.push((potential.clone(), *block_pos)); } // Can reuse the read. self.current_full.push(potential.clone()); } for p in potential_to_next_blocks { self.current_full.push(p.clone()); } } BlockKind::Single => { for potential in potential_from_previous_blocks { if let Some(block_pos) = self.available_from_previous_single.get(&potential.id) { self.block_data.push((potential.clone(), *block_pos)); } // Can reuse the read. self.current_single.push(potential.clone()); } for p in potential_to_next_blocks { self.current_single.push(p.clone()); } } } } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/mod.rs ================================================ mod base; mod block; mod full; mod full_analyzer; pub use base::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/launch.rs ================================================ use crate::{ engine::{ codegen::ir::{FuseArg, FuseBlockConfig, GlobalArgsLaunch, RefLayout}, launch::runner::{TraceRunner, Vectorization}, }, optim::reduce_broadcasted::unit::{ ElemwiseFuseBlockLaunch, ReduceFuseBlockLaunch, reduce_kernel_broadcasted, }, }; use cubecl::{ Runtime, ir::{ElemType, FloatKind, StorageType}, prelude::*, server::LaunchError, }; use cubek::reduce::{ ReduceDtypes, VectorizationMode, components::instructions::ReduceOperationConfig, launch::RoutineStrategy, routines::{ BlueprintStrategy, GlobalReduceBlueprint, ReduceProblem, ReduceVectorSettings, Routine, unit::{UnitRoutine, UnitStrategy}, }, }; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ReduceBroadcastedFuseBlock { pub(crate) op: ReduceOperationConfig, pub(crate) input: FuseArg, pub(crate) output: FuseArg, } #[derive(new)] pub struct FusedReduceBroadcastedLaunch<'a> { blocks: &'a Vec, reduce_axis: usize, // TODO: Support multiple strategies. _strategy: RoutineStrategy, } impl Vectorization for FusedReduceBroadcastedLaunch<'_> {} impl TraceRunner for FusedReduceBroadcastedLaunch<'_> { type Error = LaunchError; fn run<'a>( &'a self, client: &'a ComputeClient, inputs: GlobalArgsLaunch, outputs: GlobalArgsLaunch, configs: &'a [FuseBlockConfig], ) -> Result<(), Self::Error> { let routine = UnitRoutine; let first_config = &configs[0]; let shape = match &first_config.ref_layout { RefLayout::Concrete(FuseArg::Output(..)) => { outputs.shape_ref(&first_config.ref_layout, first_config.rank) } _ => inputs.shape_ref(&first_config.ref_layout, first_config.rank), }; let vector_size = shape[self.reduce_axis]; let vector_count = shape.iter().product::() / vector_size; let address_type = inputs .required_address_type() .max(outputs.required_address_type()); let (blueprint, settings) = routine .prepare::( client, ReduceProblem { vector_size, vector_count, axis: self.reduce_axis, dtypes: ReduceDtypes { input: StorageType::Scalar(ElemType::Float(FloatKind::F32)), output: StorageType::Scalar(ElemType::Float(FloatKind::F32)), accumulation: StorageType::Scalar(ElemType::Float(FloatKind::F32)), }, address_type, }, ReduceVectorSettings { vectorization_mode: VectorizationMode::Parallel, vector_size_input: first_config.width, vector_size_output: 1, }, BlueprintStrategy::Inferred(UnitStrategy), ) .unwrap(); assert_eq!(blueprint.vectorization_mode, VectorizationMode::Parallel); let mut blocks = SequenceArg::new(); let mut index = 0; for block in self.blocks { let arg = ReduceFuseBlockLaunch::new( block.op, configs[index].clone(), configs[index + 1].clone(), block.input.clone(), block.output.clone(), match blueprint.global { GlobalReduceBlueprint::Unit(bpt) => bpt, _ => panic!(), }, ); index += 2; blocks.push(arg); } let block_end = match configs.len() > index { true => ComptimeOptionArgs::Some(ElemwiseFuseBlockLaunch::new( configs.last().cloned().unwrap(), )), false => ComptimeOptionArgs::None, }; // TODO: Ensure parallel is selected. unsafe { reduce_kernel_broadcasted::launch_unchecked::( client, settings.cube_count, settings.cube_dim, settings.address_type, inputs, outputs, self.reduce_axis, blocks, block_end, ); } Ok(()) } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/mod.rs ================================================ mod fuser; mod optimization; pub(crate) mod launch; pub(crate) mod tune; pub(crate) mod unit; pub use fuser::*; pub use optimization::*; ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/optimization.rs ================================================ #[cfg(feature = "autotune")] use crate::optim::reduce::tune::fused_reduce_autotune; use crate::{ CubeFusionHandle, FallbackOperation, engine::{ launch::FuseTraceLauncher, trace::{FuseTrace, TraceError, TuneOutput}, }, optim::{ elemwise::{ElemwiseOptimization, ElemwiseOptimizationState}, reduce::{ReduceOptimizationInfo, ReduceOptimizationState, ReduceOptimizationTuneArg}, reduce_broadcasted::{ launch::{FusedReduceBroadcastedLaunch, ReduceBroadcastedFuseBlock}, tune::fused_broadcasted_reduce_autotune, }, }, }; use burn_fusion::stream::Context; use cubecl::{Runtime, prelude::*}; use cubek::reduce::launch::RoutineStrategy; use serde::{Deserialize, Serialize}; use std::sync::Arc; pub struct ReduceBroadcastedOptimization { pub(crate) info: Arc>, pub(crate) num_ops: usize, } pub(crate) struct ReduceBroadcastedOptimizationInfo { pub(crate) fallbacks: Vec>, pub(crate) broadcasted: Arc, } #[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct ReduceBroadcastedInfo { pub(crate) blocks: Vec, pub(crate) trace: FuseTrace, pub(crate) reduce_axis: usize, } pub(crate) enum ReduceBlockOptimInfo { Reduce(Arc>), Elemwise(Arc>), } impl ReduceBlockOptimInfo { pub fn from_state(device: &R::Device, state: ReduceBlockState) -> Self { match state { ReduceBlockState::Reduce(state) => { Self::Reduce(Arc::new(ReduceOptimizationInfo::from_state(device, state))) } ReduceBlockState::Elemwise(state) => { Self::Elemwise(Arc::new(ElemwiseOptimization::from_state(device, state))) } } } pub fn to_state(&self) -> ReduceBlockState { match self { Self::Reduce(info) => ReduceBlockState::Reduce(info.to_state()), Self::Elemwise(info) => ReduceBlockState::Elemwise(info.to_state()), } } } pub(crate) struct ReduceBroadcastedOptimizationTuneArg { pub(crate) fallbacks: Vec>, pub(crate) broadcasted: Arc, pub(crate) client: ComputeClient, pub(crate) device: R::Device, } pub(crate) enum ReduceBlockOptimArg { Reduce(ReduceOptimizationTuneArg), Elemwise(Arc>), } impl ReduceBlockOptimArg { pub fn execute_fallback( &self, context: &mut Context<'_, CubeFusionHandle>, ) -> Option> { match self { ReduceBlockOptimArg::Reduce(reduce) => { #[cfg(feature = "autotune")] { fused_reduce_autotune::(reduce.clone(), context); None } #[cfg(not(feature = "autotune"))] Some(reduce.execute_fallback(context)) } ReduceBlockOptimArg::Elemwise(elem) => { elem.execute(context); None } } } } #[derive(Serialize, Deserialize, Debug)] pub struct ReduceBroadcastedOptimizationState { fallbacks: Vec, broadcasted: ReduceBroadcastedInfo, num_ops: usize, } #[derive(Serialize, Deserialize, Debug)] #[allow(clippy::large_enum_variant)] // Only for serialization. pub enum ReduceBlockState { Reduce(ReduceOptimizationState), Elemwise(ElemwiseOptimizationState), } impl ReduceBroadcastedOptimizationTuneArg { pub fn execute_fused( &self, context: &mut Context<'_, CubeFusionHandle>, strategy: RoutineStrategy, ) -> Result, TraceError> { let launch = FusedReduceBroadcastedLaunch::new( &self.broadcasted.blocks, self.broadcasted.reduce_axis, strategy, ); let launcher = FuseTraceLauncher::new(&self.broadcasted.trace, &launch); launcher .launch(&self.client, &self.device, context) .map_err(|err| TraceError::RunnerError(format!("{:?}", err))) } pub fn execute_fallback(&self, context: &mut Context<'_, CubeFusionHandle>) { for fallback in self.fallbacks.iter() { fallback.execute_fallback(context); } } } #[allow(clippy::too_many_arguments)] impl ReduceBroadcastedOptimization { /// Execute the optimization. pub fn execute( &mut self, context: &mut Context<'_, CubeFusionHandle>, fallback: impl Fn(usize) -> Box>, ) { let mut current_index = 0; let mut client = None; let mut device = None; let fallbacks = self .info .fallbacks .iter() .map(|info| { match info { ReduceBlockOptimInfo::Reduce(info) => { // The index of the fallback reduce is the number of ops fused as read. let fallback = fallback(current_index + info.len_read); client = Some(info.client.clone()); device = Some(info.device.clone()); let arg = ReduceOptimizationTuneArg { info: info.clone(), fallback: Arc::new(fallback), }; current_index += info.len; ReduceBlockOptimArg::Reduce(arg) } ReduceBlockOptimInfo::Elemwise(op) => ReduceBlockOptimArg::Elemwise(op.clone()), } }) .collect(); let arg = ReduceBroadcastedOptimizationTuneArg { fallbacks, client: client.unwrap(), device: device.unwrap(), broadcasted: self.info.broadcasted.clone(), }; #[cfg(feature = "autotune")] fused_broadcasted_reduce_autotune::(arg, context); #[cfg(not(feature = "autotune"))] arg.execute_fallback(context); } pub fn to_state(&self) -> ReduceBroadcastedOptimizationState { ReduceBroadcastedOptimizationState { fallbacks: self .info .fallbacks .iter() .map(|info| info.to_state()) .collect(), broadcasted: self.info.broadcasted.as_ref().clone(), num_ops: self.num_ops, } } pub fn from_state(device: &R::Device, state: ReduceBroadcastedOptimizationState) -> Self { Self { info: Arc::new(ReduceBroadcastedOptimizationInfo { fallbacks: state .fallbacks .into_iter() .map(|state| ReduceBlockOptimInfo::from_state(device, state)) .collect(), broadcasted: Arc::new(state.broadcasted), }), num_ops: state.num_ops, } } /// Returns the number of output buffers added by fusion. pub fn num_ops_fused(&self) -> usize { self.num_ops } } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/tune.rs ================================================ use super::optimization::ReduceBroadcastedOptimizationTuneArg; use crate::{ CubeFusionHandle, engine::trace::TuneOutput, optim::{reduce::ReduceOptimizationInfo, reduce_broadcasted::ReduceBlockOptimArg}, tune::{TuneContext, TuneInput}, }; use burn_fusion::stream::Context; use cubecl::{ AutotuneKey, CubeTuneId, Runtime, tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}, }; use cubek::reduce::{ launch::{RoutineStrategy, tune_key::ReduceAutotuneKey}, routines::{BlueprintStrategy, unit::UnitStrategy}, }; use serde::{Deserialize, Serialize}; /// Autotune key for fused broadcasted reduction operations. /// /// Captures the characteristics of the fusion (reads, writes, ops) to ensure /// the best kernel is selected for specific fused graph shapes. #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] pub struct FusedBroadcastedReduceAutotuneKey { reduce_key: ReduceAutotuneKey, #[autotune(anchor)] fuse_num_reads: usize, #[autotune(anchor)] fuse_num_writes: usize, #[autotune(anchor)] fuse_num_ops: usize, fuse_num_blocks: usize, } /// Executes the autotuning process for fused reduction operations. /// /// This function initializes a local tuner and attempts multiple strategies /// (fallback vs. unit strategy) to find the most efficient execution path. pub fn fused_broadcasted_reduce_autotune( arg: ReduceBroadcastedOptimizationTuneArg, context: &mut Context>, ) { static TUNER: LocalTuner = local_tuner!(); let tunables = TUNER.init(|| { const PRIORITY_MAX: i8 = 2; let mut set = TunableSet::new(create_key::, input_gen::); let group = TuneGroup::::new( "fused_reduce_broadcasted", |_key| PRIORITY_MAX, ); // Standard fallback implementation - guaranteed to work. set = set.with(Tunable::new( "fused_reduce_broadcasted_fallback", tune_fallback::, )); // Specialized unit strategy for fused reductions. set = set.with( Tunable::new("fused_reduce_broadcasted_unit", move |input| { tune_reduce::( input, &RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)), ) }) .group(&group, |_| PRIORITY_MAX), ); set }); TUNER.execute( &CubeTuneId::new(&arg.client, &arg.device), &arg.client.clone(), tunables, TuneInput::new(context, arg), ); } /// Generates the autotune key based on the current optimization context and trace blocks. pub(crate) fn create_key( input: &TuneInput>, ) -> FusedBroadcastedReduceAutotuneKey { let opt = input.optimization(); let context = match input.context() { TuneContext::Original(context) => context, TuneContext::Fork(_) => unreachable!("Forked context not supported for key generation"), }; // The fusion must start with a reduction block to be valid here. let info = match &opt.fallbacks[0] { ReduceBlockOptimArg::Reduce(reduce) => &reduce.info, ReduceBlockOptimArg::Elemwise(_) => { unreachable!("Fusion must start with a reduction block") } }; let key = generate_reduce_autotune_key(info, context); // Sum up complexity metrics across all blocks in the fused trace. let (mut num_reads, mut num_writes, mut num_ops) = (0, 0, 0); for block in opt.broadcasted.trace.blocks.iter() { num_reads += block.reads.len(); num_writes += block.writes.len(); num_ops += block.ops.len(); } FusedBroadcastedReduceAutotuneKey::new( key, num_reads, num_writes, num_ops, info.trace.blocks.len(), ) } /// Helper to generate the base reduction key (shapes, types, axes). fn generate_reduce_autotune_key( info: &ReduceOptimizationInfo, context: &Context>, ) -> ReduceAutotuneKey { let input = context.tensors.get(&info.reduce.op.input.id).unwrap(); let out = context.tensors.get(&info.reduce.op.out.id).unwrap(); let acc = info.reduce.acc.into_elem(); ReduceAutotuneKey::generate( input.dtype.into(), out.dtype.into(), acc, &input.shape, info.reduce.axis == input.shape.rank() - 1, // Is it the last dimension? info.reduce.axis, ) } /// Simple input generator that clones the input for the tuner. fn input_gen( _key: &FusedBroadcastedReduceAutotuneKey, input: &TuneInput>, ) -> TuneInput> { input.clone() } /// Executes a fused reduction using a specific routine strategy. fn tune_reduce( input: TuneInput>, strategy: &RoutineStrategy, ) -> Result, String> { let optimization = input.optimization(); match input.context() { TuneContext::Original(context) => optimization.execute_fused(context, strategy.clone()), TuneContext::Fork(mut context_owned) => { optimization.execute_fused(&mut context_owned.as_context(), strategy.clone()) } } .map_err(|e| format!("{e:?}")) } /// Executes the fallback implementation for the reduction. fn tune_fallback( input: TuneInput>, ) -> Result, String> { let optimization = input.optimization(); match input.context() { TuneContext::Original(context) => optimization.execute_fallback(context), TuneContext::Fork(mut context_owned) => { optimization.execute_fallback(&mut context_owned.as_context()) } }; // Fallback is often used as a baseline, returning unchecked output. Ok(TuneOutput::UnChecked(std::marker::PhantomData)) } ================================================ FILE: crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/unit.rs ================================================ use crate::{ engine::codegen::{ ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, multi_block_variables_init}, kernel::{fuse_on_write, init_locals}, }, optim::reduce::args::{FusedReduceArgs, FusedReduceInput, FusedReduceOutput}, }; use cubecl::{Runtime, define_size, prelude::*, std::tensor::r#virtual::VirtualTensor}; use cubek::reduce::{ ReduceInstruction, ReducePrecision, VectorizationMode, components::{ args::NumericLine, global::unit::GlobalFullUnitReduce, instructions::{ReduceOperation, ReduceOperationConfig}, }, init_tensors, routines::UnitReduceBlueprint, }; /// A configuration block for a reduction operation within a fused kernel. /// /// This struct holds all the compile-time information needed to perform a /// reduction, including the operation type (Sum, Max, etc.) and the layout /// configuration for both input and output. #[derive(CubeType, CubeLaunch, Clone)] pub struct ReduceFuseBlock { #[cube(comptime)] op: ReduceOperationConfig, #[cube(comptime)] config_input: FuseBlockConfig, #[cube(comptime)] config_output: FuseBlockConfig, #[cube(comptime)] input: FuseArg, #[cube(comptime)] output: FuseArg, #[cube(comptime)] blueprint: UnitReduceBlueprint, } /// A configuration block for an elementwise operation that follows a reduction. #[derive(CubeType, CubeLaunch, Clone)] pub struct ElemwiseFuseBlock { #[cube(comptime)] config: FuseBlockConfig, } /// The entry point for a broadcasted reduction kernel. /// /// This kernel initializes local variables for multiple reduction blocks and then /// executes the reduction sequence. /// /// # Arguments /// /// * `inputs` - Global arguments containing input tensor handles. /// * `outputs` - Global arguments containing output tensor handles. /// * `reduce_axis` - The dimension along which the reduction is performed. /// * `blocks` - A sequence of reduction operations to execute. /// * `block_end` - An optional elementwise block to execute after reductions are complete. #[cube(launch_unchecked, address_type = "dynamic")] pub fn reduce_kernel_broadcasted( inputs: &GlobalArgs, outputs: &mut GlobalArgs, reduce_axis: usize, blocks: Sequence, block_end: ComptimeOption, ) { #[unroll] for i in 0..blocks.len() { let block = blocks.index(i); multi_block_variables_init(&block.config_input, &mut outputs.variables); multi_block_variables_init(&block.config_output, &mut outputs.variables); } reduce_many(inputs, outputs, reduce_axis, blocks, block_end); } define_scalar!(In); define_scalar!(Acc); define_scalar!(Out); define_size!(InSize); define_size!(OutSize); /// Configures the precision polyfills for the reduction based on the block's `FuseType`. #[cube] fn set_polyfill_block(block: &ReduceFuseBlock) { let input_precision = comptime!(block.input.precision()); let output_precision = comptime!(block.output.precision()); let acc_precision = comptime!(match input_precision { FuseType::F64 => FuseType::F64, FuseType::F32 => FuseType::F32, FuseType::Flex32 => FuseType::F32, FuseType::F16 => FuseType::F32, FuseType::BF16 => FuseType::F32, FuseType::I64 => FuseType::I64, FuseType::I32 => FuseType::I32, FuseType::I16 => FuseType::I32, FuseType::I8 => FuseType::I32, FuseType::U64 => FuseType::U64, FuseType::U32 => FuseType::U32, FuseType::U16 => FuseType::U32, FuseType::U8 => FuseType::U32, }); set_polyfill::(comptime!( input_precision.into_type(block.config_input.width) )); set_polyfill::(comptime!( output_precision.into_type(block.config_output.width) )); set_polyfill::(comptime!(acc_precision.into_type(block.config_input.width))); } /// Internal logic for executing a sequence of reduction blocks followed by an optional /// trailing elementwise block. #[cube] #[allow(clippy::clone_on_copy)] fn reduce_many( inputs: &GlobalArgs, outputs: &mut GlobalArgs, reduce_axis: usize, blocks: Sequence, block_end: ComptimeOption, ) { let mut axis_size = 0; #[unroll] for i in 0..blocks.len() { let block = blocks.index(i); let input = FusedReduceInput { global: inputs.clone(), config: comptime!(block.config_input.clone()), arg: comptime!(block.input.clone()), }; let global = outputs.clone(); let config = comptime!(block.config_output.clone()); let arg = comptime!(block.output.clone()); let mut output = FusedReduceOutput { global, config, arg, }; set_polyfill_block(block); let (input, mut output) = init_tensors::(&input, &mut output); axis_size = reduce_step::<(In, InSize, Acc), (Out, OutSize), ReduceOperation>( &input, &mut output, reduce_axis, block.op, comptime!(block.blueprint.clone()), ); } #[comptime] if let ComptimeOption::Some(block) = block_end { let global_index = ABSOLUTE_POS; let width = block.config.width; let num_iter = axis_size / width; let size!(N) = width; for i in 0..num_iter { // Register block local inputs. let values = Registry::>::new(); let args = comptime![Vec::::new()]; let index = global_index * num_iter + i; let mut locals = init_locals(inputs, outputs, &block.config); fuse_on_write::( inputs, outputs, &mut locals, index, values, args, &block.config.clone(), ) } } } #[cube] /// Executes a single reduction step using a specified instruction and blueprint. /// /// Returns the size of the axis that was reduced. fn reduce_step>( input: &VirtualTensor, output: &mut VirtualTensor, reduce_axis: usize, #[comptime] config: I::Config, #[comptime] blueprint: UnitReduceBlueprint, ) -> usize { let inst = I::from_config(config); let axis_size = input.shape(reduce_axis); GlobalFullUnitReduce::execute::( input, output, reduce_axis, &inst, VectorizationMode::Parallel, comptime!(blueprint), ); axis_size } ================================================ FILE: crates/burn-cubecl-fusion/src/tune.rs ================================================ use crate::CubeFusionHandle; use burn_fusion::stream::{Context, ContextOwned}; use cubecl::Runtime; use std::sync::Arc; /// Fusion context used when tuning kernels. /// /// Either the original context is returned or a fork of the original. /// The fork is only given when performing autotuning, and not when actually performing the /// operation. pub enum TuneContext<'a, R: Runtime> { Original(&'a mut Context<'a, CubeFusionHandle>), Fork(Box>>), } /// Fusion input wrapper containing the context and the optimization. /// /// # Safety /// /// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions /// are made based on its behavior. pub struct TuneInput { context: UnsafeTuneContext, optimization: Arc, } /// Unsafe wrapper around the context. /// /// # Safety /// /// The wrapper removes the context lifetime. /// /// For it to be correct, the context must not be used after the invocation of the /// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are /// tuned using a cloned version of the input; therefore, a fork of the context will be used to find /// the best kernel to use, which can be async. enum UnsafeTuneContext { Original(*mut Context<'static, CubeFusionHandle>), Fork(Box>>), } unsafe impl Send for UnsafeTuneContext {} unsafe impl Send for TuneInput {} impl TuneInput { /// Create a new autotune input from the [context](Context) and an optimization. pub fn new(context: &mut Context>, optimization: O) -> Self { let context = UnsafeTuneContext::new(context); Self { context, optimization: Arc::new(optimization), } } /// Retrieve the [autotune context](TuneContext) for the current input. pub fn context(&self) -> TuneContext<'static, R> { self.context.get() } /// Retrieve the optimization for the current input. pub fn optimization(&self) -> &O { &self.optimization } } impl UnsafeTuneContext { fn new(context: &mut Context<'_, CubeFusionHandle>) -> Self { let ptr = core::ptr::from_mut(context); // It is necessary for the lifetime. #[allow(clippy::unnecessary_cast)] Self::Original(ptr as *mut Context<'static, _>) } fn get(&self) -> TuneContext<'static, R> { match self { UnsafeTuneContext::Original(ptr) => { TuneContext::Original(unsafe { ptr.as_mut().unwrap() }) } UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())), } } } impl Clone for TuneInput { fn clone(&self) -> Self { Self { context: self.context.clone(), optimization: self.optimization.clone(), } } } impl Clone for UnsafeTuneContext { fn clone(&self) -> Self { let context = match self { UnsafeTuneContext::Original(ptr) => { let context: &mut Context<'static, CubeFusionHandle> = unsafe { ptr.as_mut().unwrap() }; context.fork() } UnsafeTuneContext::Fork(context) => context.fork(), }; UnsafeTuneContext::Fork(Box::new(context)) } } ================================================ FILE: crates/burn-cuda/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "CUDA backend for the Burn framework" documentation = "https://docs.rs/burn-cuda" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu", "cuda"] license.workspace = true name = "burn-cuda" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" version.workspace = true [lints] workspace = true [features] autotune = ["burn-cubecl/autotune"] autotune-checks = ["burn-cubecl/autotune-checks"] default = ["std", "fusion", "autotune", "burn-cubecl/default", "cubecl/default"] doc = ["burn-cubecl/doc"] fusion = ["burn-fusion", "burn-cubecl/fusion"] std = ["burn-cubecl/std", "cubecl/std"] tracing = [ "burn-backend/tracing", "burn-cubecl/tracing", "burn-fusion?/tracing", "cubecl/tracing", ] [dependencies] burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false, features = [ "cubecl-cuda", ] } cubecl = { workspace = true, features = ["cuda"] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-cuda/README.md ================================================ # Burn CUDA Backend [Burn](https://github.com/tracel-ai/burn) CUDA backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-cuda.svg)](https://crates.io/crates/burn-cuda) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-cuda/blob/master/README.md) This crate provides a CUDA backend for [Burn](https://github.com/tracel-ai/burn) using the [cubecl](https://github.com/tracel-ai/cubecl.git) and [cudarc](https://github.com/coreylowman/cudarc.git) crates. ## Usage Example ```rust #[cfg(feature = "cuda")] mod cuda { use burn_autodiff::Autodiff; use burn_cuda::{Cuda, CudaDevice}; use mnist::training; pub fn run() { let device = CudaDevice::default(); training::run::>>(device); } } ``` ## Dependencies Requires CUDA 12.x to be installed and on the `PATH`. ================================================ FILE: crates/burn-cuda/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] extern crate alloc; use burn_cubecl::CubeBackend; pub use cubecl::cuda::CudaDevice; use cubecl::cuda::CudaRuntime; #[cfg(not(feature = "fusion"))] pub type Cuda = CubeBackend; #[cfg(feature = "fusion")] pub type Cuda = burn_fusion::Fusion>; #[cfg(all(test, not(target_os = "macos")))] mod tests { use super::*; use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive}; use burn_cubecl::tensor::CubeTensor; #[test] fn should_support_dtypes() { type B = Cuda; let device = Default::default(); assert!(B::supports_dtype(&device, DType::F32)); assert!(B::supports_dtype(&device, DType::Flex32)); assert!(B::supports_dtype(&device, DType::F16)); assert!(B::supports_dtype(&device, DType::BF16)); assert!(B::supports_dtype(&device, DType::I64)); assert!(B::supports_dtype(&device, DType::I32)); assert!(B::supports_dtype(&device, DType::I16)); assert!(B::supports_dtype(&device, DType::I8)); assert!(B::supports_dtype(&device, DType::U64)); assert!(B::supports_dtype(&device, DType::U32)); assert!(B::supports_dtype(&device, DType::U16)); assert!(B::supports_dtype(&device, DType::U8)); assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); assert!(B::supports_dtype( &device, DType::QFloat(CubeTensor::::default_scheme()) )); // Currently not registered in supported types assert!(!B::supports_dtype(&device, DType::F64)); } } ================================================ FILE: crates/burn-dataset/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Library with simple dataset APIs for creating ML data pipelines" documentation = "https://docs.rs/burn-dataset" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-dataset" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dataset" version.workspace = true [lints] workspace = true [features] default = ["sqlite-bundled"] doc = ["default"] tracing = [ "burn-std/tracing", ] audio = ["hound"] builtin-sources = ["vision", "dep:tar", "nlp"] fake = ["dep:fake"] network = ["dep:burn-std"] sqlite = ["__sqlite-shared", "dep:rusqlite"] sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"] vision = ["dep:flate2", "dep:globwalk", "dep:image", "network"] nlp = ["dep:zip", "dep:encoding_rs"] # internal __sqlite-shared = [ "dep:r2d2", "dep:r2d2_sqlite", "dep:serde_rusqlite", "dep:image", "dep:gix-tempfile", ] dataframe = ["dep:polars", "dep:planus"] [dependencies] burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, features = [ "network", ] } csv = { workspace = true } derive-new = { workspace = true } dirs = { workspace = true } fake = { workspace = true, optional = true } flate2 = { workspace = true, optional = true } gix-tempfile = { workspace = true, optional = true } globwalk = { workspace = true, optional = true } hound = { workspace = true, optional = true } image = { workspace = true, optional = true } planus = { workspace = true, optional = true } encoding_rs = { workspace = true, optional = true } polars = { workspace = true, optional = true } r2d2 = { workspace = true, optional = true } r2d2_sqlite = { workspace = true, optional = true } rand = { workspace = true, features = ["std", "sys_rng"] } zip = { workspace = true, optional = true } rmp-serde = { workspace = true } rusqlite = { workspace = true, optional = true } sanitize-filename = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } serde_json = { workspace = true, features = ["std"] } serde_rusqlite = { workspace = true, optional = true } strum = { workspace = true } tar = { workspace = true, optional = true } tempfile = { workspace = true } thiserror = { workspace = true } [dev-dependencies] fake = { workspace = true } rayon = { workspace = true } rstest = { workspace = true } [package.metadata.cargo-udeps.ignore] normal = ["strum", "strum_macros"] [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-dataset/README.md ================================================ # Burn Dataset > [Burn](https://github.com/tracel-ai/burn) dataset library [![Current Crates.io Version](https://img.shields.io/crates/v/burn-dataset.svg)](https://crates.io/crates/burn-dataset) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-dataset/blob/master/README.md) The Burn Dataset library is designed to streamline your machine learning (ML) data pipeline creation process. It offers a variety of dataset implementations, transformation functions, and data sources. ## Feature Flags - `audio` - enables audio dataset (SpeechCommandsDataset). Run the following example to try it out: ```shell cargo run --example speech_commands --features audio ``` ================================================ FILE: crates/burn-dataset/examples/hf_dataset.rs ================================================ use burn_dataset::HuggingfaceDatasetLoader; use burn_dataset::SqliteDataset; use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] struct MnistItemRaw { pub _image_bytes: Vec, pub _label: usize, } fn main() { // There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script, // In this cases you must enable trusting remote code execution if you want to use it. let _train_ds: SqliteDataset = HuggingfaceDatasetLoader::new("mnist") .with_trust_remote_code(true) .dataset("train") .unwrap(); // However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main let _train_ds: SqliteDataset = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf") .dataset("train") .unwrap(); } ================================================ FILE: crates/burn-dataset/examples/speech_commands.rs ================================================ #[cfg(feature = "audio")] use burn_dataset::{Dataset, audio::SpeechCommandsDataset}; #[cfg(feature = "audio")] fn speech_command() { let index: usize = 4835; let test = SpeechCommandsDataset::test(); let item = test.get(index).unwrap(); println!("Item: {:?}", item); println!("Item Length: {:?}", item.audio_samples.len()); println!("Label: {}", item.label); assert_eq!(test.len(), 4890); assert_eq!(item.label.to_string(), "Yes"); assert_eq!(item.sample_rate, 16000); assert_eq!(item.audio_samples.len(), 16000); } fn main() { #[cfg(feature = "audio")] speech_command() } ================================================ FILE: crates/burn-dataset/src/audio/mod.rs ================================================ mod speech_commands; pub use speech_commands::*; ================================================ FILE: crates/burn-dataset/src/audio/speech_commands.rs ================================================ use crate::{ Dataset, HuggingfaceDatasetLoader, SqliteDataset, transform::{Mapper, MapperDataset}, }; use hound::WavReader; use serde::{Deserialize, Serialize}; use strum::{Display, EnumCount, FromRepr}; type MappedDataset = MapperDataset, ConvertSamples, SpeechItemRaw>; /// Enum representing speech command classes in the Speech Commands dataset. /// Class names are based on the Speech Commands dataset from Huggingface. /// See [speech_commands](https://huggingface.co/datasets/speech_commands) /// for more information. #[allow(missing_docs)] #[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)] pub enum SpeechCommandClass { // Target command words Yes = 0, No = 1, Up = 2, Down = 3, Left = 4, Right = 5, On = 6, Off = 7, Stop = 8, Go = 9, Zero = 10, One = 11, Two = 12, Three = 13, Four = 14, Five = 15, Six = 16, Seven = 17, Eight = 18, Nine = 19, // Non-target words that can be grouped into "Other" Bed = 20, Bird = 21, Cat = 22, Dog = 23, Happy = 24, House = 25, Marvin = 26, Sheila = 27, Tree = 28, Wow = 29, // Commands from v2 dataset, that can be grouped into "Other" Backward = 30, Forward = 31, Follow = 32, Learn = 33, Visual = 34, // Background noise Silence = 35, // Other miscellaneous words Other = 36, } /// Struct containing raw speech data returned from a database. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItemRaw { /// Audio file bytes. pub audio_bytes: Vec, /// Label index. pub label: usize, /// Indicates if the label is unknown. pub is_unknown: bool, } /// Speech item with audio samples and label. /// /// The audio samples are floats in the range [-1.0, 1.0]. /// The sample rate is in Hz. /// The label is the class index (see [SpeechCommandClass]). /// To convert to usize simply use `as usize`. To convert label to string use `.to_string()`. /// /// The original label is also stored in the `label_original` field for debugging and remapping if needed. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItem { /// Audio samples in the range [-1.0, 1.0]. pub audio_samples: Vec, /// The sample rate of the audio. pub sample_rate: usize, /// The label of the audio. pub label: SpeechCommandClass, } /// Speech Commands dataset from Huggingface v0.02. /// See [Speech Commands dataset](https://huggingface.co/datasets/speech_commands). /// /// The data is downloaded from Huggingface and stored in a SQLite database (3.0 GB). /// The dataset contains 99,720 audio samples of 2,607 people saying 35 different words. /// /// NOTE: The most samples are under 1 second long but there are some with pure background noise that /// need splitting into shorter segmants. /// /// The labels are 20 target words, silence and other words. /// /// The dataset is split into 3 parts: /// - train: 84,848 audio files /// - test: 4,890 audio files /// - validation: 9,982 audio files pub struct SpeechCommandsDataset { dataset: MappedDataset, } impl SpeechCommandsDataset { /// Create a new dataset with the given split. pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("speech_commands") .with_subset("v0.02") .dataset(split) .unwrap(); let dataset = MapperDataset::new(dataset, ConvertSamples); Self { dataset } } /// Create a new dataset with the train split. pub fn train() -> Self { Self::new("train") } /// Create a new dataset with the test split. pub fn test() -> Self { Self::new("test") } /// Create a new dataset with the validation split. pub fn validation() -> Self { Self::new("validation") } /// Returns the number of classes in the dataset pub fn num_classes() -> usize { SpeechCommandClass::COUNT } } impl Dataset for SpeechCommandsDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } /// Mapper converting audio bytes into audio samples and the label to enum class. struct ConvertSamples; impl ConvertSamples { /// Convert label to enum class. fn to_speechcommandclass(label: usize) -> SpeechCommandClass { SpeechCommandClass::from_repr(label).unwrap() } /// Convert audio bytes into samples of floats [-1.0, 1.0]. fn to_audiosamples(bytes: &Vec) -> (Vec, usize) { let reader = WavReader::new(bytes.as_slice()).unwrap(); let spec = reader.spec(); // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample). let max_value = (1 << (spec.bits_per_sample - 1)) as f32; // The sample rate of the audio. let sample_rate = spec.sample_rate as usize; // Convert the audio samples to floats [-1.0, 1.0]. let audio_samples: Vec = reader .into_samples::() .filter_map(Result::ok) .map(|sample| sample as f32 / max_value) .collect(); (audio_samples, sample_rate) } } impl Mapper for ConvertSamples { /// Convert audio bytes into samples of floats [-1.0, 1.0] /// and the label to enum class with the target word, other and silence classes. fn map(&self, item: &SpeechItemRaw) -> SpeechItem { let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes); // Convert the label to enum class, with the target words, other and silence classes. let label = Self::to_speechcommandclass(item.label); SpeechItem { audio_samples, sample_rate, label, } } } ================================================ FILE: crates/burn-dataset/src/dataset/base.rs ================================================ use std::sync::Arc; use crate::DatasetIterator; /// The dataset trait defines a basic collection of items with a predefined size. pub trait Dataset: Send + Sync { /// Gets the item at the given index. fn get(&self, index: usize) -> Option; /// Gets the number of items in the dataset. fn len(&self) -> usize; /// Checks if the dataset is empty. fn is_empty(&self) -> bool { self.len() == 0 } /// Returns an iterator over the dataset. fn iter(&self) -> DatasetIterator<'_, I> where Self: Sized, { DatasetIterator::new(self) } } impl Dataset for Arc where D: Dataset, { fn get(&self, index: usize) -> Option { self.as_ref().get(index) } fn len(&self) -> usize { self.as_ref().len() } } impl Dataset for Arc> { fn get(&self, index: usize) -> Option { self.as_ref().get(index) } fn len(&self) -> usize { self.as_ref().len() } } impl Dataset for Box where D: Dataset, { fn get(&self, index: usize) -> Option { self.as_ref().get(index) } fn len(&self) -> usize { self.as_ref().len() } } impl Dataset for Box> { fn get(&self, index: usize) -> Option { self.as_ref().get(index) } fn len(&self) -> usize { self.as_ref().len() } } ================================================ FILE: crates/burn-dataset/src/dataset/dataframe.rs ================================================ use std::marker::PhantomData; use crate::Dataset; use polars::frame::row::Row; use polars::prelude::*; use serde::de::DeserializeSeed; use serde::{ Deserialize, de::{self, DeserializeOwned, Deserializer, SeqAccess, Visitor}, forward_to_deserialize_any, }; /// Error type for DataframeDataset #[derive(thiserror::Error, Debug)] pub enum DataframeDatasetError { /// Error occurred during deserialization or other operations #[error("{0}")] Other(String), } impl de::Error for DataframeDatasetError { fn custom(msg: T) -> Self { DataframeDatasetError::Other(msg.to_string()) } } /// Dataset implementation for Polars DataFrame /// /// This struct provides a way to access data from a Polars DataFrame /// as if it were a Dataset of type I. pub struct DataframeDataset { df: DataFrame, len: usize, column_name_mapping: Vec, phantom: PhantomData, } impl DataframeDataset where I: Clone + Send + Sync + DeserializeOwned, { /// Create a new DataframeDataset from a Polars DataFrame /// /// # Arguments /// /// * `df` - A Polars DataFrame /// /// # Returns /// /// A Result containing the new DataframeDataset or a DataframeDatasetError pub fn new(df: DataFrame) -> Result { let len = df.height(); let field_names = extract_field_names::(); let column_name_mapping = field_names .iter() .map(|name| { df.schema() .try_get_full(name) .expect("Corresponding column should exist in the DataFrame") .0 }) .collect::>(); Ok(DataframeDataset { df, len, column_name_mapping, phantom: PhantomData, }) } } impl Dataset for DataframeDataset where I: Clone + Send + Sync + DeserializeOwned, { /// Get an item from the dataset at the specified index /// /// # Arguments /// /// * `index` - The index of the item to retrieve /// /// # Returns /// /// An Option containing the item if it exists, or None if it doesn't fn get(&self, index: usize) -> Option { let row = self.df.get_row(index).ok()?; let mut deserializer = RowDeserializer::new(&row, &self.column_name_mapping); I::deserialize(&mut deserializer).ok() } /// Get the length of the dataset fn len(&self) -> usize { self.len } /// Check if the dataset is empty fn is_empty(&self) -> bool { self.len == 0 } } /// A deserializer for Polars DataFrame rows struct RowDeserializer<'a> { row: &'a Row<'a>, column_name_mapping: &'a Vec, index: usize, } impl<'a> RowDeserializer<'a> { /// Create a new RowDeserializer /// /// # Arguments /// /// * `row` - A reference to a Polars DataFrame row /// * `column_name_mapping` - A reference to a vector mapping field names to column indices fn new(row: &'a Row, column_name_mapping: &'a Vec) -> RowDeserializer<'a> { RowDeserializer { row, column_name_mapping, index: 0, } } } impl<'de, 'a> Deserializer<'de> for &'a mut RowDeserializer<'a> { type Error = DataframeDatasetError; fn deserialize_any(self, visitor: V) -> Result where V: Visitor<'de>, { let i = self.column_name_mapping[self.index]; let value = &self.row.0[i]; match value { AnyValue::Null => visitor.visit_none(), AnyValue::Boolean(b) => visitor.visit_bool(*b), AnyValue::Int8(i) => visitor.visit_i8(*i), AnyValue::Int16(i) => visitor.visit_i16(*i), AnyValue::Int32(i) => visitor.visit_i32(*i), AnyValue::Int64(i) => visitor.visit_i64(*i), AnyValue::UInt8(i) => visitor.visit_u8(*i), AnyValue::UInt16(i) => visitor.visit_u16(*i), AnyValue::UInt32(i) => visitor.visit_u32(*i), AnyValue::UInt64(i) => visitor.visit_u64(*i), AnyValue::Float32(f) => visitor.visit_f32(*f), AnyValue::Float64(f) => visitor.visit_f64(*f), AnyValue::Date(i) => visitor.visit_i32(*i), AnyValue::String(s) => visitor.visit_string(s.to_string()), AnyValue::Binary(b) => { visitor.visit_seq(de::value::SeqDeserializer::new(b.iter().copied())) } AnyValue::Time(t) => visitor.visit_i64(*t), ty => Err(DataframeDatasetError::Other( format!("Unsupported type: {ty:?}").to_string(), )), } } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { visitor.visit_seq(self) } forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct map enum identifier ignored_any } } impl<'de, 'a> SeqAccess<'de> for RowDeserializer<'a> { type Error = DataframeDatasetError; fn next_element_seed(&mut self, seed: T) -> Result, DataframeDatasetError> where T: DeserializeSeed<'de>, { if self.index >= self.row.0.len() { return Ok(None); } let mut deserializer = RowDeserializer { row: self.row, column_name_mapping: self.column_name_mapping, index: self.index, }; self.index += 1; seed.deserialize(&mut deserializer).map(Some) } } struct FieldExtractor { fields: Vec<&'static str>, } impl<'de> Deserializer<'de> for &mut FieldExtractor { type Error = de::value::Error; fn deserialize_any(self, _visitor: V) -> core::result::Result where V: Visitor<'de>, { Err(de::Error::custom("Field extractor")) } fn deserialize_struct( self, _name: &'static str, fields: &'static [&'static str], _visitor: V, ) -> core::result::Result where V: Visitor<'de>, { self.fields.extend_from_slice(fields); Err(de::Error::custom("Field extractor")) } forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct map enum identifier ignored_any } } /// Extract field names from a type T that implements Deserialize /// /// # Returns /// /// A vector of field names as static string slices fn extract_field_names<'de, T>() -> Vec<&'static str> where T: Deserialize<'de>, { let mut extractor = FieldExtractor { fields: Vec::new() }; let _ = T::deserialize(&mut extractor); extractor.fields } #[cfg(test)] mod tests { use polars::prelude::*; use serde::Deserialize; use super::*; #[derive(Clone, Debug, Deserialize, PartialEq)] struct TestData { int32: i32, bool: bool, float64: f64, string: String, int16: i16, uint32: u32, uint64: u64, float32: f32, int64: i64, int8: i8, binary: Vec, } fn create_test_dataframe() -> DataFrame { let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]); let s1 = Column::new("bool".into(), &[true, false, true]); let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]); let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]); let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]); let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]); let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]); let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]); let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]); let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]); let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]; let s13 = Column::new("binary".into(), binary_data); DataFrame::new_infer_height(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap() } #[test] fn test_dataframe_dataset_creation() { let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df); assert!(dataset.is_ok()); } #[test] fn test_dataframe_dataset_length() { let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df).unwrap(); assert_eq!(dataset.len(), 3); assert!(!dataset.is_empty()); } #[test] fn test_dataframe_dataset_get() { let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df).unwrap(); let expected_items = vec![ TestData { int32: 1, bool: true, float64: 1.1, string: "Boo".to_string(), int16: 1, uint32: 1, uint64: 1, float32: 1.1, int64: 1, int8: 1, binary: vec![1, 2, 3], }, TestData { int32: 2, bool: false, float64: 2.2, string: "Boo2".to_string(), int16: 2, uint32: 2, uint64: 2, float32: 2.2, int64: 2, int8: 2, binary: vec![4, 5, 6], }, TestData { int32: 3, bool: true, float64: 3.3, string: "Boo3".to_string(), int16: 3, uint32: 3, uint64: 3, float32: 3.3, int64: 3, int8: 3, binary: vec![7, 8, 9], }, ]; for (index, expected_item) in expected_items.iter().enumerate() { let item = dataset.get(index).unwrap(); assert_eq!(&item, expected_item); } } #[test] fn test_dataframe_dataset_out_of_bounds() { let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df).unwrap(); assert!(dataset.get(3).is_none()); } #[test] fn test_dataframe_dataset() { let df = create_test_dataframe(); let dataset: DataframeDataset = DataframeDataset::new(df).unwrap(); assert_eq!(dataset.len(), 3); assert!(!dataset.is_empty()); let item = dataset.get(1).unwrap(); assert_eq!( item, TestData { int32: 2, bool: false, float64: 2.2, string: "Boo2".to_string(), int16: 2, uint32: 2, uint64: 2, float32: 2.2, int64: 2, int8: 2, binary: vec![4, 5, 6], } ); let item = dataset.get(2).unwrap(); assert_eq!( item, TestData { int32: 3, bool: true, float64: 3.3, string: "Boo3".to_string(), int16: 3, uint32: 3, uint64: 3, float32: 3.3, int64: 3, int8: 3, binary: vec![7, 8, 9], } ); } #[test] #[should_panic = "Corresponding column should exist in the DataFrame: SchemaFieldNotFound(ErrString(\"non_existent\"))"] fn test_non_existing_struct_fields() { #[derive(Clone, Debug, Deserialize, PartialEq)] struct PartialTestData { int32: i32, bool: bool, non_existent: String, } let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df); assert!(dataset.is_err()); if let Err(e) = dataset { assert!(matches!(e, DataframeDatasetError::Other(_))); } } #[test] fn test_partial_table() { #[derive(Clone, Debug, Deserialize, PartialEq)] struct PartialTestData { int32: i32, bool: bool, string: String, } let df = create_test_dataframe(); let dataset = DataframeDataset::::new(df).unwrap(); assert_eq!(dataset.len(), 3); assert!(!dataset.is_empty()); let item = dataset.get(1).unwrap(); assert_eq!( item, PartialTestData { int32: 2, bool: false, string: "Boo2".to_string(), } ); let item = dataset.get(2).unwrap(); assert_eq!( item, PartialTestData { int32: 3, bool: true, string: "Boo3".to_string(), } ); } } ================================================ FILE: crates/burn-dataset/src/dataset/fake.rs ================================================ use crate::{Dataset, DatasetIterator, InMemDataset}; use fake::{Dummy, Fake, Faker}; /// Dataset filled with fake items generated from the [fake](fake) crate. pub struct FakeDataset { dataset: InMemDataset, } impl> FakeDataset { /// Create a new fake dataset with the given size. pub fn new(size: usize) -> Self { let mut items = Vec::with_capacity(size); for _ in 0..size { items.push(Faker.fake()); } let dataset = InMemDataset::new(items); Self { dataset } } } impl Dataset for FakeDataset { fn iter(&self) -> DatasetIterator<'_, I> { DatasetIterator::new(self) } fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } fn is_empty(&self) -> bool { self.dataset.is_empty() } } ================================================ FILE: crates/burn-dataset/src/dataset/in_memory.rs ================================================ use std::{ fs::File, io::{BufRead, BufReader}, path::Path, }; use serde::de::DeserializeOwned; use crate::Dataset; /// Dataset where all items are stored in ram. pub struct InMemDataset { items: Vec, } impl InMemDataset { /// Creates a new in memory dataset from the given items. pub fn new(items: Vec) -> Self { InMemDataset { items } } } impl Dataset for InMemDataset where I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { self.items.get(index).cloned() } fn len(&self) -> usize { self.items.len() } } impl InMemDataset where I: Clone + DeserializeOwned, { /// Create from a dataset. All items are loaded in memory. pub fn from_dataset(dataset: &impl Dataset) -> Self { let items: Vec = dataset.iter().collect(); Self::new(items) } /// Create from a json rows file (one json per line). /// /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html) pub fn from_json_rows>(path: P) -> Result { let file = File::open(path)?; let reader = BufReader::new(file); let mut items = Vec::new(); for line in reader.lines() { let item = serde_json::from_str(line.unwrap().as_str()).unwrap(); items.push(item); } let dataset = Self::new(items); Ok(dataset) } /// Create from a csv file. /// /// The provided `csv::ReaderBuilder` can be configured to fit your csv format. /// /// The supported field types are: String, integer, float, and bool. /// /// See: /// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde) /// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records) pub fn from_csv>( path: P, builder: &csv::ReaderBuilder, ) -> Result { let mut rdr = builder.from_path(path)?; let mut items = Vec::new(); for result in rdr.deserialize() { let item: I = result?; items.push(item); } let dataset = Self::new(items); Ok(dataset) } } #[cfg(test)] mod tests { use super::*; use crate::{SqliteDataset, test_data}; use rstest::{fixture, rstest}; use serde::{Deserialize, Serialize}; const DB_FILE: &str = "tests/data/sqlite-dataset.db"; const JSON_FILE: &str = "tests/data/dataset.json"; const CSV_FILE: &str = "tests/data/dataset.csv"; const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv"; type SqlDs = SqliteDataset; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Sample { column_str: String, column_bytes: Vec, column_int: i64, column_bool: bool, column_float: f64, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct SampleCsv { column_str: String, column_int: i64, column_bool: bool, column_float: f64, } #[fixture] fn train_dataset() -> SqlDs { SqliteDataset::from_db_file(DB_FILE, "train").unwrap() } #[rstest] pub fn from_dataset(train_dataset: SqlDs) { let dataset = InMemDataset::from_dataset(&train_dataset); let non_existing_record_index: usize = 10; let record_index: usize = 0; assert_eq!(train_dataset.get(non_existing_record_index), None); assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1"); } #[test] pub fn from_json_rows() { let dataset = InMemDataset::::from_json_rows(JSON_FILE).unwrap(); let non_existing_record_index: usize = 10; let record_index: usize = 1; assert_eq!(dataset.get(non_existing_record_index), None); assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); assert!(!dataset.get(record_index).unwrap().column_bool); } #[test] pub fn from_csv_rows() { let rdr = csv::ReaderBuilder::new(); let dataset = InMemDataset::::from_csv(CSV_FILE, &rdr).unwrap(); let non_existing_record_index: usize = 10; let record_index: usize = 1; assert_eq!(dataset.get(non_existing_record_index), None); assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); assert_eq!(dataset.get(record_index).unwrap().column_int, 1); assert!(!dataset.get(record_index).unwrap().column_bool); assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); } #[test] pub fn from_csv_rows_fmt() { let mut rdr = csv::ReaderBuilder::new(); let rdr = rdr.delimiter(b' ').has_headers(false); let dataset = InMemDataset::::from_csv(CSV_FMT_FILE, rdr).unwrap(); let non_existing_record_index: usize = 10; let record_index: usize = 1; assert_eq!(dataset.get(non_existing_record_index), None); assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2"); assert_eq!(dataset.get(record_index).unwrap().column_int, 1); assert!(!dataset.get(record_index).unwrap().column_bool); assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0); } #[test] pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() { let items_original = test_data::string_items(); let dataset = InMemDataset::new(items_original.clone()); let items: Vec = dataset.iter().collect(); assert_eq!(items_original, items); } } ================================================ FILE: crates/burn-dataset/src/dataset/iterator.rs ================================================ use crate::dataset::Dataset; use std::iter::Iterator; /// Dataset iterator. pub struct DatasetIterator<'a, I> { current: usize, dataset: &'a dyn Dataset, } impl<'a, I> DatasetIterator<'a, I> { /// Creates a new dataset iterator. pub fn new(dataset: &'a D) -> Self where D: Dataset, { DatasetIterator { current: 0, dataset, } } } impl Iterator for DatasetIterator<'_, I> { type Item = I; fn next(&mut self) -> Option { let item = self.dataset.get(self.current); self.current += 1; item } } ================================================ FILE: crates/burn-dataset/src/dataset/mod.rs ================================================ mod base; mod in_memory; mod iterator; pub use base::*; pub use in_memory::*; pub use iterator::*; #[cfg(any(test, feature = "fake"))] mod fake; #[cfg(any(test, feature = "fake"))] pub use self::fake::*; #[cfg(feature = "dataframe")] mod dataframe; #[cfg(feature = "dataframe")] pub use dataframe::*; #[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))] pub use sqlite::*; #[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))] mod sqlite; ================================================ FILE: crates/burn-dataset/src/dataset/sqlite.rs ================================================ use std::{ collections::HashSet, fs, io, marker::PhantomData, path::{Path, PathBuf}, sync::{Arc, RwLock}, }; use crate::Dataset; use gix_tempfile::{ AutoRemove, ContainingDirectory, Handle, handle::{Writable, persist}, }; use r2d2::{Pool, PooledConnection}; use r2d2_sqlite::{ SqliteConnectionManager, rusqlite::{OpenFlags, OptionalExtension}, }; use sanitize_filename::sanitize; use serde::{Serialize, de::DeserializeOwned}; use serde_rusqlite::{columns_from_statement, from_row_with_columns}; /// Result type for the sqlite dataset. pub type Result = core::result::Result; /// Sqlite dataset error. #[derive(thiserror::Error, Debug)] pub enum SqliteDatasetError { /// IO related error. #[error("IO error: {0}")] Io(#[from] io::Error), /// Sql related error. #[error("Sql error: {0}")] Sql(#[from] serde_rusqlite::rusqlite::Error), /// Serde related error. #[error("Serde error: {0}")] Serde(#[from] rmp_serde::encode::Error), /// The database file already exists error. #[error("Overwrite flag is set to false and the database file already exists: {0}")] FileExists(PathBuf), /// Error when creating the connection pool. #[error("Failed to create connection pool: {0}")] ConnectionPool(#[from] r2d2::Error), /// Error when persisting the temporary database file. #[error("Could not persist the temporary database file: {0}")] PersistDbFile(#[from] persist::Error), /// Any other error. #[error("{0}")] Other(&'static str), } impl From<&'static str> for SqliteDatasetError { fn from(s: &'static str) -> Self { SqliteDatasetError::Other(s) } } /// This struct represents a dataset where all items are stored in an SQLite database. /// Each instance of this struct corresponds to a specific table within the SQLite database, /// and allows for interaction with the data stored in the table in a structured and typed manner. /// /// The SQLite database must contain a table with the same name as the `split` field. This table should /// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id` /// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1. /// /// Table columns can be represented in two ways: /// /// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table /// should match the field names of the `I` struct. The field names can be a subset of column names and /// can be in any order. /// /// For the supported field types, refer to: /// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite) /// - [SQLite data types](https://www.sqlite.org/datatype3.html) /// /// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table /// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields /// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using /// [MessagePack](https://msgpack.org/). /// /// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate /// method to read the data from the table. #[derive(Debug)] pub struct SqliteDataset { db_file: PathBuf, split: String, conn_pool: Pool, columns: Vec, len: usize, select_statement: String, row_serialized: bool, phantom: PhantomData, } impl SqliteDataset { /// Initializes a `SqliteDataset` from a SQLite database file and a split name. pub fn from_db_file>(db_file: P, split: &str) -> Result { // Create a connection pool let conn_pool = create_conn_pool(&db_file, false)?; // Determine how the table is stored let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?; // Create a select statement and save it let select_statement = if row_serialized { format!("select item from {split} where row_id = ?") } else { format!("select * from {split} where row_id = ?") }; // Save the column names and the number of rows let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?; Ok(SqliteDataset { db_file: db_file.as_ref().to_path_buf(), split: split.to_string(), conn_pool, columns, len, select_statement, row_serialized, phantom: PhantomData, }) } /// Returns true if table has two columns: row_id (integer) and item (blob). /// /// This is used to determine if the table is row serialized or not. fn check_if_row_serialized( conn_pool: &Pool, split: &str, ) -> Result { // This struct is used to store the column name and type struct Column { name: String, ty: String, } const COLUMN_NAME: usize = 1; const COLUMN_TYPE: usize = 2; let sql_statement = format!("PRAGMA table_info({split})"); let conn = conn_pool.get()?; let mut stmt = conn.prepare(sql_statement.as_str())?; let column_iter = stmt.query_map([], |row| { Ok(Column { name: row .get::(COLUMN_NAME) .unwrap() .to_lowercase(), ty: row .get::(COLUMN_TYPE) .unwrap() .to_lowercase(), }) })?; let mut columns: Vec = vec![]; for column in column_iter { columns.push(column?); } if columns.len() != 2 { Ok(false) } else { // Check if the column names and types match the expected values Ok(columns[0].name == "row_id" && columns[0].ty == "integer" && columns[1].name == "item" && columns[1].ty == "blob") } } /// Get the database file name. pub fn db_file(&self) -> PathBuf { self.db_file.clone() } /// Get the split name. pub fn split(&self) -> &str { self.split.as_str() } } impl Dataset for SqliteDataset where I: Clone + Send + Sync + DeserializeOwned, { /// Get an item from the dataset. fn get(&self, index: usize) -> Option { // Row ids start with 1 (one) and index starts with 0 (zero) let row_id = index + 1; // Get a connection from the pool let connection = self.conn_pool.get().unwrap(); let mut statement = connection.prepare(self.select_statement.as_str()).unwrap(); if self.row_serialized { // Fetch with a single column `item` and deserialize it with MessagePack statement .query_row([row_id], |row| { // Deserialize item (blob) with MessagePack (rmp-serde) Ok( rmp_serde::from_slice::(row.get_ref(0).unwrap().as_blob().unwrap()) .unwrap(), ) }) .optional() //Converts Error (not found) to None .unwrap() } else { // Fetch a row with multiple columns and deserialize it serde_rusqlite statement .query_row([row_id], |row| { // Deserialize the row with serde_rusqlite Ok(from_row_with_columns::(row, &self.columns).unwrap()) }) .optional() //Converts Error (not found) to None .unwrap() } } /// Return the number of rows in the dataset. fn len(&self) -> usize { self.len } } /// Fetch the column names and the number of rows from the database. fn fetch_columns_and_len( conn_pool: &Pool, select_statement: &str, split: &str, ) -> Result<(Vec, usize)> { // Save the column names let connection = conn_pool.get()?; let statement = connection.prepare(select_statement)?; let columns = columns_from_statement(&statement); // Count the number of rows and save it as len // // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables. // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id, // which corresponds to the number of rows in the table. // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps. // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index. let mut statement = connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?; let len = statement.query_row([], |row| { let len: usize = row.get(0)?; Ok(len) })?; Ok((columns, len)) } /// Helper function to create a connection pool fn create_conn_pool>( db_file: P, write: bool, ) -> Result> { let sqlite_flags = if write { OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE } else { OpenFlags::SQLITE_OPEN_READ_ONLY }; let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags); Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool) } /// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets. /// It consists of an optional name, a database file path, and a base directory for storage. #[derive(Clone, Debug)] pub struct SqliteDatasetStorage { name: Option, db_file: Option, base_dir: Option, } impl SqliteDatasetStorage { /// Creates a new instance of `SqliteDatasetStorage` using a dataset name. /// /// # Arguments /// /// * `name` - A string slice that holds the name of the dataset. pub fn from_name(name: &str) -> Self { SqliteDatasetStorage { name: Some(name.to_string()), db_file: None, base_dir: None, } } /// Creates a new instance of `SqliteDatasetStorage` using a database file path. /// /// # Arguments /// /// * `db_file` - A reference to the Path that represents the database file path. pub fn from_file>(db_file: P) -> Self { SqliteDatasetStorage { name: None, db_file: Some(db_file.as_ref().to_path_buf()), base_dir: None, } } /// Sets the base directory for storing the dataset. /// /// # Arguments /// /// * `base_dir` - A string slice that represents the base directory. pub fn with_base_dir>(mut self, base_dir: P) -> Self { self.base_dir = Some(base_dir.as_ref().to_path_buf()); self } /// Checks if the database file exists in the given path. /// /// # Returns /// /// * A boolean value indicating whether the file exists or not. pub fn exists(&self) -> bool { self.db_file().exists() } /// Fetches the database file path. /// /// # Returns /// /// * A `PathBuf` instance representing the file path. pub fn db_file(&self) -> PathBuf { match &self.db_file { Some(db_file) => db_file.clone(), None => { let name = sanitize(self.name.as_ref().expect("Name is not set")); Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db")) } } } /// Determines the base directory for storing the dataset. /// /// # Arguments /// /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory. /// /// # Returns /// /// * A `PathBuf` instance representing the base directory. pub fn base_dir(base_dir: Option) -> PathBuf { match base_dir { Some(base_dir) => base_dir, None => dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset"), } } /// Provides a writer instance for the SQLite dataset. /// /// # Arguments /// /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. /// /// # Returns /// /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. pub fn writer(&self, overwrite: bool) -> Result> where I: Clone + Send + Sync + Serialize + DeserializeOwned, { SqliteDatasetWriter::new(self.db_file(), overwrite) } /// Provides a reader instance for the SQLite dataset. /// /// # Arguments /// /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test"). /// /// # Returns /// /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise. pub fn reader(&self, split: &str) -> Result> where I: Clone + Send + Sync + Serialize + DeserializeOwned, { if !self.exists() { panic!("The database file does not exist"); } SqliteDataset::from_db_file(self.db_file(), split) } } /// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets. /// It retains the current writer's state and its database connection. /// /// Being thread-safe, this writer can be concurrently used across multiple threads. /// /// Typical applications include: /// /// - Generation of a new dataset /// - Storage of preprocessed data or metadata /// - Enlargement of a dataset's item count post preprocessing #[derive(Debug)] pub struct SqliteDatasetWriter { db_file: PathBuf, db_file_tmp: Option>, splits: Arc>>, overwrite: bool, conn_pool: Option>, is_completed: Arc>, phantom: PhantomData, } impl SqliteDatasetWriter where I: Clone + Send + Sync + Serialize + DeserializeOwned, { /// Creates a new instance of `SqliteDatasetWriter`. /// /// # Arguments /// /// * `db_file` - A reference to the Path that represents the database file path. /// * `overwrite` - A boolean indicating if the existing database file should be overwritten. /// /// # Returns /// /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise. pub fn new>(db_file: P, overwrite: bool) -> Result { let writer = Self { db_file: db_file.as_ref().to_path_buf(), db_file_tmp: None, splits: Arc::new(RwLock::new(HashSet::new())), overwrite, conn_pool: None, is_completed: Arc::new(RwLock::new(false)), phantom: PhantomData, }; writer.init() } /// Initializes the dataset writer by creating the database file, tables, and connection pool. /// /// # Returns /// /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise. fn init(mut self) -> Result { // Remove the db file if it already exists if self.db_file.exists() { if self.overwrite { fs::remove_file(&self.db_file)?; } else { return Err(SqliteDatasetError::FileExists(self.db_file)); } } // Create the database file directory if it does not exist let db_file_dir = self .db_file .parent() .ok_or("Unable to get parent directory")?; if !db_file_dir.exists() { fs::create_dir_all(db_file_dir)?; } // Create a temp database file name as {base_dir}/{name}.db.tmp let mut db_file_tmp = self.db_file.clone(); db_file_tmp.set_extension("db.tmp"); if db_file_tmp.exists() { fs::remove_file(&db_file_tmp)?; } // Create the temp database file and wrap it with a gix_tempfile::Handle // This will ensure that the temp file is deleted when the writer is dropped // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this) gix_tempfile::signal::setup(Default::default()); self.db_file_tmp = Some(gix_tempfile::writable_at( &db_file_tmp, ContainingDirectory::Exists, AutoRemove::Tempfile, )?); let conn_pool = create_conn_pool(db_file_tmp, true)?; self.conn_pool = Some(conn_pool); Ok(self) } /// Serializes and writes an item to the database. The item is written to the table for the /// specified split. If the table does not exist, it is created. If the table exists, the item /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/) /// /// # Arguments /// /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test"). /// * `item` - A reference to the item to be written to the database. /// /// # Returns /// /// * A `Result` containing the index of the inserted row if successful, an error otherwise. pub fn write(&self, split: &str, item: &I) -> Result { // Acquire the read lock (wont't block other reads) let is_completed = self.is_completed.read().unwrap(); // If the writer is completed, return an error if *is_completed { return Err(SqliteDatasetError::Other( "Cannot save to a completed dataset writer", )); } // create the table for the split if it does not exist if !self.splits.read().unwrap().contains(split) { self.create_table(split)?; } // Get a connection from the pool let conn_pool = self.conn_pool.as_ref().unwrap(); let conn = conn_pool.get()?; // Serialize the item using MessagePack let serialized_item = rmp_serde::to_vec(item)?; // Turn off the synchronous and journal mode for speed up // We are sacrificing durability for speed but it's okay because // we always recreate the dataset if it is not completed. pragma_update_with_error_handling(&conn, "synchronous", "OFF")?; pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?; // Insert the serialized item into the database let insert_statement = format!("insert into {split} (item) values (?)"); conn.execute(insert_statement.as_str(), [serialized_item])?; // Get the primary key of the last inserted row and convert to index (row_id-1) let index = (conn.last_insert_rowid() - 1) as usize; Ok(index) } /// Marks the dataset as completed and persists the temporary database file. pub fn set_completed(&mut self) -> Result<()> { let mut is_completed = self.is_completed.write().unwrap(); // Force close the connection pool // This is required on Windows platform where the connection pool prevents // from persisting the db by renaming the temp file. if let Some(pool) = self.conn_pool.take() { std::mem::drop(pool); } // Rename the database file from tmp to db let _file_result = self .db_file_tmp .take() // take ownership of the temporary file and set to None .unwrap() // unwrap the temporary file .persist(&self.db_file)? .ok_or("Unable to persist the database file")?; *is_completed = true; Ok(()) } /// Creates table for the data split. /// /// Note: call is idempotent and thread-safe. /// /// # Arguments /// /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test"). /// /// # Returns /// /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise. /// /// TODO (@antimora): add support creating a table with columns corresponding to the item fields fn create_table(&self, split: &str) -> Result<()> { // Check if the split already exists if self.splits.read().unwrap().contains(split) { return Ok(()); } let conn_pool = self.conn_pool.as_ref().unwrap(); let connection = conn_pool.get()?; let create_table_statement = format!( "create table if not exists {split} (row_id integer primary key autoincrement not \ null, item blob not null)" ); connection.execute(create_table_statement.as_str(), [])?; // Add the split to the splits self.splits.write().unwrap().insert(split.to_string()); Ok(()) } } /// Runs a pragma update and ignores the `ExecuteReturnedResults` error. /// /// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error /// and can be ignored. This function runs the pragma update and ignores the error if it is /// `ExecuteReturnedResults`. fn pragma_update_with_error_handling( conn: &PooledConnection, setting: &str, value: &str, ) -> Result<()> { let result = conn.pragma_update(None, setting, value); if let Err(error) = result && error != rusqlite::Error::ExecuteReturnedResults { return Err(SqliteDatasetError::Sql(error)); } Ok(()) } #[cfg(test)] mod tests { use rayon::prelude::*; use rstest::{fixture, rstest}; use serde::{Deserialize, Serialize}; use tempfile::{NamedTempFile, TempDir, tempdir}; use super::*; type SqlDs = SqliteDataset; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Sample { column_str: String, column_bytes: Vec, column_int: i64, column_bool: bool, column_float: f64, } #[fixture] fn train_dataset() -> SqlDs { SqliteDataset::::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap() } #[rstest] pub fn len(train_dataset: SqlDs) { assert_eq!(train_dataset.len(), 2); } #[rstest] pub fn get_some(train_dataset: SqlDs) { let item = train_dataset.get(0).unwrap(); assert_eq!(item.column_str, "HI1"); assert_eq!(item.column_bytes, vec![55, 231, 159]); assert_eq!(item.column_int, 1); assert!(item.column_bool); assert_eq!(item.column_float, 1.0); } #[rstest] pub fn get_none(train_dataset: SqlDs) { assert_eq!(train_dataset.get(10), None); } #[rstest] pub fn multi_thread(train_dataset: SqlDs) { let indices: Vec = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1]; let results: Vec> = indices.par_iter().map(|&i| train_dataset.get(i)).collect(); let mut match_count = 0; for (_index, result) in indices.iter().zip(results.iter()) { if let Some(_val) = result { match_count += 1 } } assert_eq!(match_count, 5); } #[test] fn sqlite_dataset_storage() { // Test with non-existing file let storage = SqliteDatasetStorage::from_file("non-existing.db"); assert!(!storage.exists()); // Test with non-existing name let storage = SqliteDatasetStorage::from_name("non-existing.db"); assert!(!storage.exists()); // Test with existing file let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db"); assert!(storage.exists()); let result = storage.reader::("train"); assert!(result.is_ok()); let train = result.unwrap(); assert_eq!(train.len(), 2); // Test get writer let temp_file = NamedTempFile::new().unwrap(); let storage = SqliteDatasetStorage::from_file(temp_file.path()); assert!(storage.exists()); let result = storage.writer::(true); assert!(result.is_ok()); } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Complex { column_str: String, column_bytes: Vec, column_int: i64, column_bool: bool, column_float: f64, column_complex: Vec>>, } /// Create a temporary directory. #[fixture] fn tmp_dir() -> TempDir { // Create a TempDir. This object will be automatically // deleted when it goes out of scope. tempdir().unwrap() } type Writer = SqliteDatasetWriter; /// Create a SqliteDatasetWriter with a temporary directory. /// Make sure to return the temporary directory so that it is not deleted. #[fixture] fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) { let temp_dir_str = tmp_dir.path(); let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str); let overwrite = true; let result = storage.writer::(overwrite); assert!(result.is_ok()); let writer = result.unwrap(); (writer, tmp_dir) } #[test] fn test_new() { // Test that the constructor works with overwrite = true let test_path = NamedTempFile::new().unwrap(); let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); assert!(!test_path.path().exists()); // Test that the constructor works with overwrite = false let test_path = NamedTempFile::new().unwrap(); let result = SqliteDatasetWriter::::new(&test_path, false); assert!(result.is_err()); // Test that the constructor works with no existing file let temp = NamedTempFile::new().unwrap(); let test_path = temp.path().to_path_buf(); assert!(temp.close().is_ok()); assert!(!test_path.exists()); let _writer = SqliteDatasetWriter::::new(&test_path, true).unwrap(); assert!(!test_path.exists()); } #[rstest] pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) { // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) let (writer, _tmp_dir) = writer_fixture; assert!(writer.overwrite); assert!(!writer.db_file.exists()); let new_item = Complex { column_str: "HI1".to_string(), column_bytes: vec![1_u8, 2, 3], column_int: 0, column_bool: true, column_float: 1.0, column_complex: vec![vec![vec![[1, 23_u8, 3]]]], }; let index = writer.write("train", &new_item).unwrap(); assert_eq!(index, 0); let mut writer = writer; writer.set_completed().expect("Failed to set completed"); assert!(writer.db_file.exists()); assert!(writer.db_file_tmp.is_none()); let result = writer.write("train", &new_item); // Should fail because the writer is completed assert!(result.is_err()); let dataset = SqliteDataset::::from_db_file(writer.db_file, "train").unwrap(); let fetched_item = dataset.get(0).unwrap(); assert_eq!(fetched_item, new_item); assert_eq!(dataset.len(), 1); } #[rstest] pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) { // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope) let (writer, _tmp_dir) = writer_fixture; let writer = Arc::new(writer); let record_count = 20; let splits = ["train", "test"]; (0..record_count).into_par_iter().for_each(|index: i64| { let thread_id: std::thread::ThreadId = std::thread::current().id(); let sample = Complex { column_str: format!("test_{thread_id:?}_{index}"), column_bytes: vec![index as u8, 2, 3], column_int: index, column_bool: true, column_float: 1.0, column_complex: vec![vec![vec![[1, index as u8, 3]]]], }; // half for train and half for test let split = splits[index as usize % 2]; let _index = writer.write(split, &sample).unwrap(); }); let mut writer = Arc::try_unwrap(writer).unwrap(); writer .set_completed() .expect("Should set completed successfully"); let train = SqliteDataset::::from_db_file(writer.db_file.clone(), "train").unwrap(); let test = SqliteDataset::::from_db_file(writer.db_file, "test").unwrap(); assert_eq!(train.len(), record_count as usize / 2); assert_eq!(test.len(), record_count as usize / 2); } } ================================================ FILE: crates/burn-dataset/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! # Burn Dataset //! //! Burn Dataset is a library for creating and loading datasets. #[macro_use] extern crate derive_new; extern crate alloc; extern crate dirs; /// Sources for datasets. pub mod source; pub mod transform; /// Audio datasets. #[cfg(feature = "audio")] pub mod audio; /// Vision datasets. #[cfg(feature = "vision")] pub mod vision; /// Natural language processing datasets. #[cfg(feature = "nlp")] pub mod nlp; /// Network dataset utilities. #[cfg(feature = "network")] pub mod network { pub use burn_std::network::*; } mod dataset; pub use dataset::*; #[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))] pub use source::huggingface::downloader::*; #[cfg(test)] mod test_data { pub fn string_items() -> Vec { vec![ "1 Item".to_string(), "2 Items".to_string(), "3 Items".to_string(), "4 Items".to_string(), ] } } ================================================ FILE: crates/burn-dataset/src/nlp/ag_news.rs ================================================ //! AG NEWS Dataset Module //! //! This module provides functionality for loading the AG NEWS text classification dataset. //! AG NEWS is a collection of news articles categorized into different topics. //! The dataset is split into training (120,000 articles) and test (7,600 articles) sets. //! //! ## Dataset Details //! - **Classes**: 4 categories (World, Sports, Business, Sci/Tech) //! - **AG NEWS mirror**: [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83) //! - **License**: [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE) //! //! ## Usage Example //! ```rust //! use burn_dataset::nlp::AgNewsDataset; //! //! // Create an AG NEWS dataset accessor //! let dataset = AgNewsDataset::new(); //! //! // Access training and test sets //! let train_dataset = dataset.train(); //! let test_dataset = dataset.test(); //! ``` use std::{path::PathBuf, sync::Mutex}; use flate2::read::GzDecoder; use serde::{Deserialize, Serialize}; use tar::Archive; use crate::InMemDataset; use crate::network::downloader; /// AG NEWS mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83). /// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). const AG_NEWS_URL: &str = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz"; /// Represents an item in the AG NEWS dataset. /// /// Each item contains a label, title, and content of a news article. #[derive(Deserialize, Serialize, Debug, Clone)] pub struct AgNewsItem { /// The category label of the news article. pub label: String, /// The title of the news article. pub title: String, /// The content/body of the news article. pub content: String, } /// AG NEWS dataset accessor. /// /// This struct provides convenient access to the AG NEWS text classification dataset. /// It automatically downloads (if not already downloaded), extracts, and loads the datasets. /// /// The dataset is split into training (120,000 articles) and test (7,600 articles) sets. pub struct AgNewsDataset { agnews_dir: PathBuf, } /// AG NEWS dataset download lock. /// /// This lock ensures that only one thread downloads the AG NEWS dataset at a time. static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(()); impl AgNewsDataset { /// Creates a new AG NEWS dataset accessor. /// /// This will download and extract the dataset if it's not already present. pub fn new() -> Self { Self { agnews_dir: Self::download(), } } /// Downloads and extracts the AG NEWS dataset. /// /// # Returns /// Path to the directory containing the extracted dataset. fn download() -> PathBuf { // Acquire the lock. This will block if another thread already holds the lock. let _lock = DOWNLOAD_LOCK.lock().unwrap(); // Dataset files are stored in the burn-dataset cache directory let cache_dir = dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset"); // AG NEWS dataset directory let agnews_dir = cache_dir.join("ag_news_csv"); // AG NEWS dataset url let url = AG_NEWS_URL; // AG NEWS dataset archive filename let filename = "ag_news_csv.tgz"; // Check for already downloaded content if !agnews_dir.exists() { // Download gzip file let bytes = downloader::download_file_as_bytes(url, filename); // Decode gzip file content and unpack archive let gz_buffer = GzDecoder::new(&bytes[..]); let mut archive = Archive::new(gz_buffer); archive.unpack(cache_dir).unwrap(); } agnews_dir } /// Parses a CSV file into an in-memory dataset. /// /// # Arguments /// * `file_path` - Path to the CSV file to parse. /// /// # Returns /// An `InMemDataset` containing the parsed data. fn parse_csv(file_path: &str) -> InMemDataset { let mut rdr = csv::ReaderBuilder::new(); let rdr = rdr.has_headers(false); InMemDataset::from_csv(file_path, &rdr).expect("Failed to parse CSV file") } /// Gets the training dataset. /// /// # Returns /// An `InMemDataset` instance containing 120,000 training articles. pub fn train(&self) -> InMemDataset { let file_path = self.agnews_dir.join("train.csv"); Self::parse_csv(file_path.to_str().unwrap()) } /// Gets the test dataset. /// /// # Returns /// An `InMemDataset` instance containing 7,600 test articles. pub fn test(&self) -> InMemDataset { let file_path = self.agnews_dir.join("test.csv"); Self::parse_csv(file_path.to_str().unwrap()) } } #[cfg(test)] mod tests { use super::*; use crate::Dataset; // AG NEWS dataset train and test dataset lengths const TRAIN_DATASET_LEN: usize = 120000; const TEST_DATASET_LEN: usize = 7600; #[test] fn test_agnews_download() { let agnews_dir = AgNewsDataset::download(); assert!(agnews_dir.exists()); } #[test] fn test_agnews_len() { let agnews = AgNewsDataset::new(); let train_dataset = agnews.train(); let test_dataset = agnews.test(); assert_eq!(train_dataset.len(), TRAIN_DATASET_LEN); assert_eq!(test_dataset.len(), TEST_DATASET_LEN); } #[test] fn test_agnews_first_and_last_item() { let agnews = AgNewsDataset::new(); // Test the first and the last item in training dataset let train_dataset = agnews.train(); let first_item = train_dataset.get(0).unwrap(); let last_item = train_dataset.get(train_dataset.len() - 1).unwrap(); assert!(compare_item(&first_item, &("3".to_string(), "Wall St. Bears Claw Back Into the Black (Reuters)".to_string(), "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.".to_string()))); assert!(compare_item( &last_item, &( "2".to_string(), "Nets get Carter from Raptors".to_string(), "INDIANAPOLIS -- All-Star Vince Carter was traded by the Toronto Raptors to the New Jersey Nets for Alonzo Mourning, Eric Williams, Aaron Williams, and a pair of first-round draft picks yesterday.".to_string() ) )); // Test the first and the last item in test dataset let test_dataset = agnews.test(); let first_item = test_dataset.get(0).unwrap(); let last_item = test_dataset.get(test_dataset.len() - 1).unwrap(); assert!(compare_item( &first_item, &( "3".to_string(), "Fears for T N pension after talks".to_string(), "Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.".to_string() ) )); assert!(compare_item( &last_item, &( "3".to_string(), "EBay gets into rentals".to_string(), "EBay plans to buy the apartment and home rental service Rent.com for \\$415 million, adding to its already exhaustive breadth of offerings.".to_string() ) )); } fn compare_item(item: &AgNewsItem, target: &(String, String, String)) -> bool { item.label == target.0 && item.title == target.1 && item.content == target.2 } } ================================================ FILE: crates/burn-dataset/src/nlp/mod.rs ================================================ #[cfg(feature = "builtin-sources")] mod ag_news; mod text_folder; #[cfg(feature = "builtin-sources")] pub use ag_news::*; pub use text_folder::*; ================================================ FILE: crates/burn-dataset/src/nlp/text_folder.rs ================================================ use crate::transform::{Mapper, MapperDataset}; use crate::{Dataset, InMemDataset}; use encoding_rs::{GB18030, GBK, UTF_8, UTF_16BE, UTF_16LE}; use globwalk::{self, DirEntry}; use std::collections::{HashMap, HashSet}; use std::fs; use std::io::Read; use std::path::{Path, PathBuf}; use thiserror::Error; const SUPPORTED_FILES: [&str; 1] = ["txt"]; /// Text data type. #[derive(Debug, Clone, PartialEq)] pub struct TextData { /// The text content. pub text: String, /// Original text source. pub text_path: String, } /// Text dataset item. #[derive(Debug, Clone, PartialEq)] pub struct TextDatasetItem { /// Text content. pub text: TextData, /// Label for the text. pub label: usize, } /// Raw text dataset item. #[derive(Debug, Clone)] struct TextDatasetItemRaw { /// Text path. text_path: PathBuf, /// Text label. label: String, } impl TextDatasetItemRaw { fn new>(text_path: P, label: String) -> TextDatasetItemRaw { TextDatasetItemRaw { text_path: text_path.as_ref().to_path_buf(), label, } } } struct PathToTextDatasetItem { classes: HashMap, } /// Parse the text content from file with auto-detection of encoding. fn parse_text_content(text_path: &PathBuf) -> String { // Read raw bytes from disk let mut file = fs::File::open(text_path).unwrap(); let mut bytes = Vec::new(); file.read_to_end(&mut bytes).unwrap(); // Try to detect encoding and decode text // First try UTF-8 with BOM if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) && bytes.len() >= 3 { let (result, _, had_errors) = UTF_8.decode(&bytes[3..]); if !had_errors { return result.into_owned(); } } // Try UTF-8 without BOM let (result, _, had_errors) = UTF_8.decode(&bytes); if !had_errors { return result.into_owned(); } // Try UTF-16LE with BOM if bytes.starts_with(&[0xFF, 0xFE]) && bytes.len() >= 2 { let (result, had_errors) = UTF_16LE.decode_with_bom_removal(&bytes[2..]); if !had_errors { return result.into_owned(); } } // Try UTF-16BE with BOM if bytes.starts_with(&[0xFE, 0xFF]) && bytes.len() >= 2 { let (result, had_errors) = UTF_16BE.decode_with_bom_removal(&bytes[2..]); if !had_errors { return result.into_owned(); } } // Try GB18030 encoding let (result, _, had_errors) = GB18030.decode(&bytes); if !had_errors { return result.into_owned(); } // Try GBK encoding let (result, _, had_errors) = GBK.decode(&bytes); if !had_errors { return result.into_owned(); } // Default fallback - use from_utf8_lossy for any remaining cases String::from_utf8_lossy(&bytes).to_string() } impl Mapper for PathToTextDatasetItem { /// Convert a raw text dataset item (path-like) to text content with a target label. fn map(&self, item: &TextDatasetItemRaw) -> TextDatasetItem { let label = *self.classes.get(&item.label).unwrap(); // Load text from disk let text_content = parse_text_content(&item.text_path); let text_data = TextData { text: text_content, text_path: item.text_path.display().to_string(), }; TextDatasetItem { text: text_data, label, } } } /// Error type for [TextFolderDataset](TextFolderDataset). #[derive(Error, Debug)] pub enum TextLoaderError { /// Unknown error. #[error("unknown: `{0}`")] Unknown(String), /// I/O operation error. #[error("I/O error: `{0}`")] IOError(String), /// Invalid file error. #[error("Invalid file extension: `{0}`")] InvalidFileExtensionError(String), /// Encoding error. #[error("Encoding error: `{0}`")] EncodingError(String), } type TextDatasetMapper = MapperDataset, PathToTextDatasetItem, TextDatasetItemRaw>; /// A generic dataset to load texts from disk. pub struct TextFolderDataset { dataset: TextDatasetMapper, } impl Dataset for TextFolderDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } impl TextFolderDataset { /// Create a text classification dataset from the root folder. /// /// # Arguments /// /// * `root` - Dataset root folder. /// /// # Returns /// A new dataset instance. pub fn new_classification>(root: P) -> Result { // New dataset containing any of the supported file types TextFolderDataset::new_classification_with(root, &SUPPORTED_FILES) } /// Create a text classification dataset from the root folder. /// The included texts are filtered based on the provided extensions. /// /// # Arguments /// /// * `root` - Dataset root folder. /// * `extensions` - List of allowed extensions. /// /// # Returns /// A new dataset instance. pub fn new_classification_with(root: P, extensions: &[S]) -> Result where P: AsRef, S: AsRef, { // Glob all texts with extensions let walker = globwalk::GlobWalkerBuilder::from_patterns( root.as_ref(), &[format!( "*.{{{}}}", // "*.{ext1,ext2,ext3} extensions .iter() .map(Self::check_extension) .collect::, _>>()? .join(",") )], ) .follow_links(true) .sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path .build() .map_err(|err| TextLoaderError::Unknown(format!("{err:?}")))? .filter_map(Result::ok); // Get all dataset items let mut items = Vec::new(); let mut classes = HashSet::new(); for text in walker { let text_path = text.path(); // Label name is represented by the parent folder name let label = text_path .parent() .ok_or_else(|| { TextLoaderError::IOError("Could not resolve text parent folder".to_string()) })? .file_name() .ok_or_else(|| { TextLoaderError::IOError( "Could not resolve text parent folder name".to_string(), ) })? .to_string_lossy() .into_owned(); classes.insert(label.clone()); items.push(TextDatasetItemRaw::new(text_path, label)) } // Sort class names let mut classes = classes.into_iter().collect::>(); classes.sort(); Self::with_items(items, &classes) } /// Create a text classification dataset with the specified items. /// /// # Arguments /// /// * `items` - List of dataset items, each item represented by a tuple `(text path, label)`. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. pub fn new_classification_with_items, S: AsRef>( items: Vec<(P, String)>, classes: &[S], ) -> Result { // Parse items and check valid text extension types let items = items .into_iter() .map(|(path, label)| { // Map text path and label let path = path.as_ref(); let label = label; Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; Ok(TextDatasetItemRaw::new(path, label)) }) .collect::, _>>()?; Self::with_items(items, classes) } /// Create a text dataset with the specified items. /// /// # Arguments /// /// * `items` - Raw dataset items. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. fn with_items>( items: Vec, classes: &[S], ) -> Result { // NOTE: right now we don't need to validate the supported text files since // the method is private. We assume it's already validated. let dataset = InMemDataset::new(items); // Class names to index map let classes = classes.iter().map(|c| c.as_ref()).collect::>(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() .map(|(idx, cls)| (cls.to_string(), idx)) .collect(); let mapper = PathToTextDatasetItem { classes: classes_map, }; let dataset = MapperDataset::new(dataset, mapper); Ok(Self { dataset }) } /// Check if extension is supported. fn check_extension>(extension: &S) -> Result { let extension = extension.as_ref(); if !SUPPORTED_FILES.contains(&extension) { Err(TextLoaderError::InvalidFileExtensionError( extension.to_string(), )) } else { Ok(extension.to_string()) } } } #[cfg(test)] mod tests { use super::*; use std::path::Path; const TEXT_ROOT: &str = "tests/data/text_folder"; #[test] fn test_text_folder_dataset() { let dataset = TextFolderDataset::new_classification(TEXT_ROOT).unwrap(); // Dataset should have 4 elements (2 positive + 2 negative) assert_eq!(dataset.len(), 4); assert_eq!(dataset.get(4), None); // Check that we have items from both classes let mut found_positive = false; let mut found_negative = false; for i in 0..dataset.len() { let item = dataset.get(i).unwrap(); if item.label == 0 { found_negative = true; // Check that the text content is loaded correctly assert!(!item.text.text.is_empty()); assert!(item.text.text_path.contains("negative")); } else if item.label == 1 { found_positive = true; // Check that the text content is loaded correctly assert!(!item.text.text.is_empty()); assert!(item.text.text_path.contains("positive")); } } // Verify we found items from both classes assert!(found_positive); assert!(found_negative); } #[test] fn test_text_folder_dataset_with_invalid_extension() { // Try to create a dataset with an unsupported extension let result = TextFolderDataset::new_classification_with(TEXT_ROOT, &["invalid"]); assert!(result.is_err()); } #[test] fn test_text_folder_dataset_with_items() { // Create the dataset let root = Path::new(TEXT_ROOT); let items = vec![ ( root.join("positive").join("sample1.txt"), "positive".to_string(), ), ( root.join("negative").join("sample2.txt"), "negative".to_string(), ), ]; let classes = vec!["positive", "negative"]; let dataset = TextFolderDataset::new_classification_with_items(items, &classes).unwrap(); // Dataset should have 2 elements assert_eq!(dataset.len(), 2); assert_eq!(dataset.get(2), None); // Get items let item0 = dataset.get(0).unwrap(); let item1 = dataset.get(1).unwrap(); // Check item0 assert!(compare_item( &item0, &( "This is a positive text sample for testing the text folder dataset functionality." .to_string(), 0 ) )); // Check item1 assert_eq!(item1.label, 1); assert!(item1.text.text_path.contains("negative")); assert!(compare_item( &item1, &( "另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。".to_string(), 1 ) )); } fn compare_item(item: &TextDatasetItem, target: &(String, usize)) -> bool { item.text.text == target.0 && item.label == target.1 } } ================================================ FILE: crates/burn-dataset/src/source/huggingface/downloader.rs ================================================ use std::fs::{self, create_dir_all}; use std::path::{Path, PathBuf}; use std::process::Command; use crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage}; use sanitize_filename::sanitize; use serde::de::DeserializeOwned; use thiserror::Error; const PYTHON_SOURCE: &str = include_str!("importer.py"); #[cfg(not(target_os = "windows"))] const VENV_BIN_PYTHON: &str = "bin/python3"; #[cfg(target_os = "windows")] const VENV_BIN_PYTHON: &str = "Scripts\\python"; /// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader). #[derive(Error, Debug)] pub enum ImporterError { /// Unknown error. #[error("unknown: `{0}`")] Unknown(String), /// Fail to download python dependencies. #[error("fail to download python dependencies: `{0}`")] FailToDownloadPythonDependencies(String), /// Fail to create sqlite dataset. #[error("sqlite dataset: `{0}`")] SqliteDataset(#[from] SqliteDatasetError), /// python3 is not installed. #[error("python3 is not installed")] PythonNotInstalled, /// venv environment is not initialized. #[error("venv environment is not initialized")] VenvNotInitialized, } /// Load a dataset from [huggingface datasets](https://huggingface.co/datasets). /// /// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)). /// /// # Example /// ```no_run /// use burn_dataset::HuggingfaceDatasetLoader; /// use burn_dataset::SqliteDataset; /// use serde::{Deserialize, Serialize}; /// /// #[derive(Deserialize, Debug, Clone)] /// struct MnistItemRaw { /// pub image_bytes: Vec, /// pub label: usize, /// } /// /// let train_ds:SqliteDataset = HuggingfaceDatasetLoader::new("mnist") /// .dataset("train") /// .unwrap(); /// ``` /// /// # Note /// This loader relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) /// to download datasets. This is a Python library, so you must have an existing Python installation. pub struct HuggingfaceDatasetLoader { name: String, subset: Option, base_dir: Option, huggingface_token: Option, huggingface_cache_dir: Option, huggingface_data_dir: Option, trust_remote_code: bool, use_python_venv: bool, } impl HuggingfaceDatasetLoader { /// Create a huggingface dataset loader. pub fn new(name: &str) -> Self { Self { name: name.to_string(), subset: None, base_dir: None, huggingface_token: None, huggingface_cache_dir: None, huggingface_data_dir: None, trust_remote_code: false, use_python_venv: true, } } /// Create a huggingface dataset loader for a subset of the dataset. /// /// The subset name must be one of the subsets listed in the dataset page. /// /// If no subset names are listed, then do not use this method. pub fn with_subset(mut self, subset: &str) -> Self { self.subset = Some(subset.to_string()); self } /// Specify a base directory to store the dataset. /// /// If not specified, the dataset will be stored in the system cache directory under `burn-dataset`. pub fn with_base_dir(mut self, base_dir: &str) -> Self { self.base_dir = Some(base_dir.into()); self } /// Specify a huggingface token to download datasets behind authentication. /// /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens) pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self { self.huggingface_token = Some(huggingface_token.to_string()); self } /// Specify a huggingface cache directory to store the downloaded datasets. /// /// If not specified, the dataset will be stored in the system cache directory under `huggingface/datasets`. pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self { self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string()); self } /// Specify a relative path to a subset of a dataset. This is used in some datasets for the /// manual steps of dataset download process. /// /// Unless you've encountered a ManualDownloadError /// when loading your dataset you probably don't have to worry about this setting. pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self { self.huggingface_data_dir = Some(huggingface_data_dir.to_string()); self } /// Specify whether or not to trust remote code. /// /// If not specified, trust remote code is set to true. pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self { self.trust_remote_code = trust_remote_code; self } /// Specify whether or not to use the burn-dataset Python /// virtualenv for running the importer script. If false, local /// `python3`'s environment is used. /// /// If not specified, the virtualenv is used. pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self { self.use_python_venv = use_python_venv; self } /// Load the dataset. pub fn dataset( self, split: &str, ) -> Result, ImporterError> { let db_file = self.db_file()?; let dataset = SqliteDataset::from_db_file(db_file, split)?; Ok(dataset) } /// Get the path to the sqlite database file. /// /// If the database file does not exist, it will be downloaded and imported. pub fn db_file(self) -> Result { // determine (and create if needed) the base directory let base_dir = SqliteDatasetStorage::base_dir(self.base_dir); if !base_dir.exists() { create_dir_all(&base_dir).expect("Failed to create base directory"); } //sanitize the name and subset let name = sanitize(self.name.as_str()); // create the db file path let db_file_name = if let Some(subset) = self.subset.clone() { format!("{name}-{}.db", sanitize(subset.as_str())) } else { format!("{name}.db") }; let db_file = base_dir.join(db_file_name); // import the dataset if needed if !Path::new(&db_file).exists() { import( self.name, self.subset, db_file.clone(), base_dir, self.huggingface_token, self.huggingface_cache_dir, self.huggingface_data_dir, self.trust_remote_code, self.use_python_venv, )?; } Ok(db_file) } } /// Import a dataset from huggingface. The transformed dataset is stored as sqlite database. #[allow(clippy::too_many_arguments)] fn import( name: String, subset: Option, base_file: PathBuf, base_dir: PathBuf, huggingface_token: Option, huggingface_cache_dir: Option, huggingface_data_dir: Option, trust_remote_code: bool, use_python_venv: bool, ) -> Result<(), ImporterError> { let python_path = if use_python_venv { install_python_deps(&base_dir)? } else { get_python_name()?.into() }; let mut command = Command::new(python_path); command.arg(importer_script_path(&base_dir)); command.arg("--name"); command.arg(name); command.arg("--file"); command.arg(base_file); if let Some(subset) = subset { command.arg("--subset"); command.arg(subset); } if let Some(huggingface_token) = huggingface_token { command.arg("--token"); command.arg(huggingface_token); } if let Some(huggingface_cache_dir) = huggingface_cache_dir { command.arg("--cache_dir"); command.arg(huggingface_cache_dir); } if let Some(huggingface_data_dir) = huggingface_data_dir { command.arg("--data_dir"); command.arg(huggingface_data_dir); } if trust_remote_code { command.arg("--trust_remote_code"); command.arg("True"); } let mut handle = command.spawn().unwrap(); let exit_status = handle .wait() .map_err(|err| ImporterError::Unknown(format!("{err:?}")))?; if !exit_status.success() { return Err(ImporterError::Unknown(format!("{exit_status}"))); } Ok(()) } /// check python --version output is `Python 3.x.x` fn check_python_version_is_3(python: &str) -> bool { let output = Command::new(python).arg("--version").output(); match output { Ok(output) => { if output.status.success() { let version_string = String::from_utf8_lossy(&output.stdout); if let Some(index) = version_string.find(' ') { let version = &version_string[index + 1..]; version.starts_with("3.") } else { false } } else { false } } Err(_error) => false, } } /// get python3 name `python` `python3` or `py` fn get_python_name() -> Result<&'static str, ImporterError> { let python_name_list = ["python3", "python", "py"]; for python_name in python_name_list.iter() { if check_python_version_is_3(python_name) { return Ok(python_name); } } Err(ImporterError::PythonNotInstalled) } fn importer_script_path(base_dir: &Path) -> PathBuf { let path_file = base_dir.join("importer.py"); fs::write(&path_file, PYTHON_SOURCE).expect("Write python dataset downloader"); path_file } fn install_python_deps(base_dir: &Path) -> Result { let venv_dir = base_dir.join("venv"); let venv_python_path = venv_dir.join(VENV_BIN_PYTHON); // If the venv environment is already initialized, skip the initialization. if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { let python_name = get_python_name()?; let mut command = Command::new(python_name); command.args([ "-m", "venv", venv_dir .as_os_str() .to_str() .expect("Path utf8 conversion should not fail"), ]); // Spawn the venv creation process and wait for it to complete. let mut handle = command.spawn().unwrap(); handle.wait().map_err(|err| { ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")) })?; // Check if the venv environment can be used successfully." if !check_python_version_is_3(venv_python_path.to_str().unwrap()) { return Err(ImporterError::VenvNotInitialized); } } let mut ensurepip_cmd = Command::new(&venv_python_path); ensurepip_cmd.args(["-m", "ensurepip", "--upgrade"]); let status = ensurepip_cmd.status().map_err(|err| { ImporterError::FailToDownloadPythonDependencies(format!("failed to run ensurepip: {err}")) })?; if !status.success() { return Err(ImporterError::FailToDownloadPythonDependencies( "ensurepip failed to initialize pip".to_string(), )); } let mut command = Command::new(&venv_python_path); command.args([ "-m", "pip", "--quiet", "install", "pyarrow", "sqlalchemy", "Pillow", "soundfile", "datasets", ]); // Spawn the pip install process and wait for it to complete. let mut handle = command.spawn().unwrap(); handle .wait() .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(" error: {err}")))?; Ok(venv_python_path) } ================================================ FILE: crates/burn-dataset/src/source/huggingface/importer.py ================================================ import argparse import pyarrow as pa from datasets import Audio, Image, load_dataset from sqlalchemy import Column, Integer, Table, create_engine, event, inspect from sqlalchemy.types import LargeBinary def download_and_export( name: str, subset: str, db_file: str, token: str, cache_dir: str, data_dir: str | None, trust_remote_code: bool, ): """ Download a dataset from using HuggingFace dataset and export it to a sqlite database. """ # TODO For media columns (Image and Audio) sometimes when decode=False, # bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'} # We should handle this case, but unfortunately we did not come across this case yet to test it. print("*" * 80) print("Starting huggingface dataset download and export") print(f"Dataset Name: {name}") print(f"Subset Name: {subset}") print(f"Sqlite database file: {db_file}") print(f"Trust remote code: {trust_remote_code}") if cache_dir is None: print(f"Custom cache dir: {cache_dir}") print("*" * 80) # Load the dataset dataset_all = load_dataset( name, subset, cache_dir=cache_dir, data_dir=data_dir, use_auth_token=token, trust_remote_code=trust_remote_code, ) print(f"Dataset: {dataset_all}") # Create the database connection descriptor (sqlite) engine = create_engine(f"sqlite:///{db_file}") # Set some sqlite pragmas to speed up the database event.listen(engine, "connect", set_sqlite_pragma) # Add an row_id column to each table as primary key (datasets does not have API for this) event.listen(Table, "before_create", add_pk_column) # Export each split in the dataset for key in dataset_all.keys(): dataset = dataset_all[key] # Disable decoding for audio and image fields dataset = disable_decoding(dataset) # Flatten the dataset dataset = dataset.flatten() # Rename columns to remove dots from the names dataset = rename_columns(dataset) print(f"Saving dataset: {name} - {key}") print(f"Dataset features: {dataset.features}") # Save the dataset to a sqlite database dataset.to_sql( key, # table name engine, # don't save the index, use row_id instead (index is not unique) index=False, dtype=blob_columns(dataset), # save binary columns as blob ) # Print the schema of the database so we can reference the columns in the rust code print_table_info(engine) def disable_decoding(dataset): """ Disable decoding for audio and image fields. The fields will be saved as raw file bytes. """ for k, v in dataset.features.items(): if isinstance(v, Audio): dataset = dataset.cast_column(k, Audio(decode=False)) elif isinstance(v, Image): dataset = dataset.cast_column(k, Image(decode=False)) return dataset def rename_columns(dataset): """ Rename columns to remove dots from the names. Dots appear in the column names because of the flattening. Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores. This way there is an easy name mapping between the rust and sql columns. """ for name in dataset.features.keys(): if "." in name: dataset = dataset.rename_column(name, name.replace(".", "_")) return dataset def blob_columns(dataset): """ Make sure all binary columns are blob columns in the database because `to_sql` exports binary values as TEXT instead of BLOB. """ type_mapping = {} for name, value in dataset.features.items(): if value.pa_type is not None and pa.types.is_binary(value.pa_type): type_mapping[name] = LargeBinary return type_mapping def set_sqlite_pragma(dbapi_connection, connection_record): """ Set some sqlite pragmas to speed up the database """ cursor = dbapi_connection.cursor() cursor.execute("PRAGMA synchronous = OFF") cursor.execute("PRAGMA journal_mode = OFF") cursor.close() def add_pk_column(target, connection, **kw): """ Add an id column to each table. """ target.append_column(Column("row_id", Integer, primary_key=True)) def print_table_info(engine): """ Print the schema of the database so we can reference the columns in the rust code """ print(f"Printing table schema for sqlite3 db ({engine})") inspector = inspect(engine) for table_name in inspector.get_table_names(): print(f"Table: {table_name}") for column in inspector.get_columns(table_name): print(f"Column: {column['name']} - {column['type']}") print("") def parse_args(): parser = argparse.ArgumentParser( description="Huggingface datasets downloader to use with burn-dataset" ) parser.add_argument( "--name", type=str, help="Name of the dataset to download", required=True ) parser.add_argument( "--file", type=str, help="Base file name where the data is saved", required=True ) parser.add_argument( "--subset", type=str, help="Subset name", required=False, default=None ) parser.add_argument( "--token", type=str, help="HuggingFace authentication token", required=False, default=None, ) parser.add_argument( "--cache_dir", type=str, help="Cache directory", required=False, default=None ) parser.add_argument( "--data_dir", type=str, help="Relative path to a specific subset of your dataset", required=False, default=None ) parser.add_argument( "--trust_remote_code", type=bool, help="Trust remote code", required=False, default=None, ) return parser.parse_args() def run(): args = parse_args() download_and_export( args.name, args.subset, args.file, args.token, args.data_dir, args.cache_dir, args.trust_remote_code, ) if __name__ == "__main__": run() ================================================ FILE: crates/burn-dataset/src/source/huggingface/mod.rs ================================================ pub(crate) mod downloader; pub use downloader::*; ================================================ FILE: crates/burn-dataset/src/source/mod.rs ================================================ /// Huggingface source #[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))] pub mod huggingface; ================================================ FILE: crates/burn-dataset/src/transform/composed.rs ================================================ use crate::Dataset; /// Compose multiple datasets together to create a bigger one. #[derive(new)] pub struct ComposedDataset { datasets: Vec, } impl Dataset for ComposedDataset where D: Dataset, I: Clone, { fn get(&self, index: usize) -> Option { let mut current_index = 0; for dataset in self.datasets.iter() { if index < dataset.len() + current_index { return dataset.get(index - current_index); } current_index += dataset.len(); } None } fn len(&self) -> usize { let mut total = 0; for dataset in self.datasets.iter() { total += dataset.len(); } total } } #[cfg(test)] mod tests { use super::*; use crate::FakeDataset; #[test] fn test_composed_dataset() { let dataset1 = FakeDataset::::new(10); let dataset2 = FakeDataset::::new(5); let items1 = dataset1.iter().collect::>(); let items2 = dataset2.iter().collect::>(); let composed = ComposedDataset::new(vec![dataset1, dataset2]); assert_eq!(composed.len(), 15); let expected_items: Vec = items1.iter().chain(items2.iter()).cloned().collect(); let items = composed.iter().collect::>(); assert_eq!(items, expected_items); } } ================================================ FILE: crates/burn-dataset/src/transform/mapper.rs ================================================ use crate::Dataset; use std::marker::PhantomData; /// Basic mapper trait to be used with the [mapper dataset](MapperDataset). pub trait Mapper: Send + Sync { /// Maps an item of type I to an item of type O. fn map(&self, item: &I) -> O; } /// Dataset mapping each element in an inner dataset to another element type lazily. #[derive(new)] pub struct MapperDataset { dataset: D, mapper: M, input: PhantomData, } impl Dataset for MapperDataset where D: Dataset, M: Mapper + Send + Sync, I: Send + Sync, O: Send + Sync, { fn get(&self, index: usize) -> Option { let item = self.dataset.get(index); item.map(|item| self.mapper.map(&item)) } fn len(&self) -> usize { self.dataset.len() } } #[cfg(test)] mod tests { use super::*; use crate::{InMemDataset, test_data}; #[test] pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { struct StringToFirstChar; impl Mapper for StringToFirstChar { fn map(&self, item: &String) -> String { let mut item = item.clone(); item.truncate(1); item } } let items_original = test_data::string_items(); let dataset = InMemDataset::new(items_original); let dataset = MapperDataset::new(dataset, StringToFirstChar); let items: Vec = dataset.iter().collect(); assert_eq!(vec!["1", "2", "3", "4"], items); } } ================================================ FILE: crates/burn-dataset/src/transform/mod.rs ================================================ //! # Dataset Transformations //! //! This module provides a collection of [`crate::Dataset`] composition wrappers; //! providing composition, subset selection, sampling, random shuffling, and windowing. //! //! * [`ComposedDataset`] - composes a list of datasets. //! * [`PartialDataset`] - selects a contiguous index range subset of a dataset. //! * [`ShuffledDataset`] - a randomly shuffled / mutably shuffle-able dataset; //! a thin wrapper around [`SelectionDataset`]. //! * [`SamplerDataset`] - samples a dataset; support for with/without replacement, //! and under/oversampling. //! * [`SelectionDataset`] - selects a subset of a dataset via indices; support for shuffling. //! * [`WindowsDataset`] - creates a sliding window over a dataset. mod composed; mod mapper; mod options; mod partial; mod sampler; mod selection; mod shuffle; mod window; pub use composed::*; pub use mapper::*; pub use options::*; pub use partial::*; pub use sampler::*; pub use selection::*; pub use shuffle::*; pub use window::*; ================================================ FILE: crates/burn-dataset/src/transform/options.rs ================================================ use rand::SeedableRng; use rand::prelude::StdRng; use rand::rngs::SysRng; /// Defines a source for a `StdRng`. /// /// # Examples /// /// ```rust,no_run /// use rand::rngs::StdRng; /// use rand::SeedableRng; /// use burn_dataset::transform::RngSource; /// /// // Default via `StdRng::from_os_rng()` (`RngSource::Default`) /// let system: RngSource = RngSource::default(); /// /// // From a fixed seed (`RngSource::Seed`) /// let seeded: RngSource = 42.into(); /// /// // From an existing rng (`RngSource::Rng`) /// let rng = StdRng::seed_from_u64(123); /// let with_rng: RngSource = rng.into(); /// /// // Forks the parent RNG to derive an independent, deterministic child RNG. /// // The original `rng` is modified, and the resulting `RngSource` contains /// // a new RNG starting from a unique state. /// let mut rng = StdRng::seed_from_u64(123); /// let forked: RngSource = (&mut rng).into(); /// ``` #[derive(Debug, Default, PartialEq, Eq)] #[allow(clippy::large_enum_variant)] pub enum RngSource { /// Build a new rng from the system. #[default] Default, /// The rng is passed as a seed. Seed(u64), /// The rng is passed as an option. Rng(StdRng), } impl From for StdRng { fn from(source: RngSource) -> Self { match source { RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(), RngSource::Rng(rng) => rng, RngSource::Seed(seed) => StdRng::seed_from_u64(seed), } } } impl From for RngSource { fn from(seed: u64) -> Self { Self::Seed(seed) } } impl From for RngSource { fn from(rng: StdRng) -> Self { Self::Rng(rng) } } /// Derive an independent RNG from a mutable parent RNG. /// /// This advances the parent RNG and creates a new RNG seeded from its output. /// The derived RNG is *not* a clone of the parent's state, but an independent /// stream (equivalent to `SeedableRng::fork`). impl From<&mut StdRng> for RngSource { fn from(rng: &mut StdRng) -> Self { Self::Rng(rng.fork()) } } /// Helper option to describe the size of a wrapper, relative to a wrapped object. #[derive(Debug, Clone, Copy, Default, PartialEq)] pub enum SizeConfig { /// Use the size of the source dataset. #[default] Default, /// Use the size as a ratio of the source dataset size. /// /// Must be >= 0. Ratio(f64), /// Use a fixed size. Fixed(usize), } impl SizeConfig { /// Construct a source which will have the same size as the source dataset. pub fn source() -> Self { Self::Default } /// Resolve the effective size. /// /// ## Arguments /// /// - `source_size`: the size of the source dataset. /// /// ## Returns /// /// The resolved size of the wrapper dataset. pub fn resolve(self, source_size: usize) -> usize { match self { SizeConfig::Default => source_size, SizeConfig::Ratio(ratio) => { assert!(ratio >= 0.0, "Ratio must be positive: {ratio}"); ((source_size as f64) * ratio) as usize } SizeConfig::Fixed(size) => size, } } } impl From for SizeConfig { fn from(size: usize) -> Self { Self::Fixed(size) } } impl From for SizeConfig { fn from(ratio: f64) -> Self { Self::Ratio(ratio) } } #[cfg(test)] mod tests { use super::*; use rand::SeedableRng; #[test] fn test_rng_source_default() { let rng_source: RngSource = Default::default(); assert_eq!(&rng_source, &RngSource::Default); assert_eq!(&rng_source, &RngSource::default()); // Exercise the from_os_rng() call; but we don't know its seed; let _rng: StdRng = rng_source.into(); } #[test] fn test_rng_source_seed() { let rng_source = RngSource::from(42); assert_eq!(&rng_source, &RngSource::Seed(42)); let rng: StdRng = rng_source.into(); let expected = StdRng::seed_from_u64(42); assert_eq!(rng, expected); } #[test] fn test_rng_source_rng() { // From StdRng (owned). { let original = StdRng::seed_from_u64(42); let rng_source = RngSource::from(original); let rng: StdRng = rng_source.into(); // No longer clone, but from <> into should not have advanced the state let original = StdRng::seed_from_u64(42); assert_eq!(rng, original); } // From &mut StdRng (forks parent) { let mut original = StdRng::seed_from_u64(42); let mut rng = StdRng::seed_from_u64(42); let rng_forked = rng.fork(); let rng_source = RngSource::from(&mut original); // Ensure the original was advanced assert_eq!(original, rng); // Ensure the sourced RNG matches the fork let rng: StdRng = rng_source.into(); assert_eq!(rng, rng_forked); } } #[test] fn test_size_config() { assert_eq!(SizeConfig::default(), SizeConfig::Default); assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42)); assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5)); assert_eq!(SizeConfig::source(), SizeConfig::Default); assert_eq!(SizeConfig::source().resolve(50), 50); } } ================================================ FILE: crates/burn-dataset/src/transform/partial.rs ================================================ use crate::Dataset; use std::{marker::PhantomData, sync::Arc}; /// Only use a fraction of an existing dataset lazily. #[derive(new, Clone)] pub struct PartialDataset { dataset: D, start_index: usize, end_index: usize, input: PhantomData, } impl PartialDataset where D: Dataset, { /// Splits a dataset into multiple partial datasets. pub fn split(dataset: D, num: usize) -> Vec, I>> { let dataset = Arc::new(dataset); // cheap cloning. let mut current = 0; let mut datasets = Vec::with_capacity(num); let batch_size = dataset.len() / num; for i in 0..num { let start = current; let mut end = current + batch_size; if i == (num - 1) { end = dataset.len(); } let dataset = PartialDataset::new(dataset.clone(), start, end); current += batch_size; datasets.push(dataset); } datasets } /// Splits a dataset by distributing complete chunks/batches across multiple partial datasets. pub fn split_chunks( dataset: D, num: usize, batch_size: usize, ) -> Vec, I>> { let dataset = Arc::new(dataset); // cheap cloning. let total_items = dataset.len(); // Total number of complete batches let total_batches = total_items.div_ceil(batch_size); let batches_per_split = total_batches / num; let extra_batches = total_batches % num; let mut datasets = Vec::with_capacity(num); let mut current_batch = 0; for i in 0..num { // Extra batches distributed across first splits let split_batches = if i < extra_batches { batches_per_split + 1 } else { batches_per_split }; let start_batch = current_batch; let end_batch = start_batch + split_batches; let start_index = start_batch * batch_size; let end_index = core::cmp::min(end_batch * batch_size, total_items); if start_index < total_items { datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index)); } current_batch = end_batch; } datasets } } impl Dataset for PartialDataset where D: Dataset, I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { let index = index + self.start_index; if index < self.start_index || index >= self.end_index { return None; } self.dataset.get(index) } fn len(&self) -> usize { usize::min(self.end_index - self.start_index, self.dataset.len()) } } #[cfg(test)] mod tests { use super::*; use crate::FakeDataset; use std::collections::HashSet; #[test] fn test_start_from_beginning() { let dataset_original = FakeDataset::::new(27); let mut items_original_1 = HashSet::new(); let mut items_original_2 = HashSet::new(); let mut items_partial = HashSet::new(); dataset_original.iter().enumerate().for_each(|(i, item)| { match i >= 10 { true => items_original_2.insert(item), false => items_original_1.insert(item), }; }); let dataset_partial = PartialDataset::new(dataset_original, 0, 10); for item in dataset_partial.iter() { items_partial.insert(item); } assert_eq!(dataset_partial.len(), 10); assert_eq!(items_original_1, items_partial); for item in items_original_2 { assert!(!items_partial.contains(&item)); } } #[test] fn test_start_inside() { let dataset_original = FakeDataset::::new(27); let mut items_original_1 = HashSet::new(); let mut items_original_2 = HashSet::new(); let mut items_partial = HashSet::new(); dataset_original.iter().enumerate().for_each(|(i, item)| { match !(10..20).contains(&i) { true => items_original_2.insert(item), false => items_original_1.insert(item), }; }); let dataset_partial = PartialDataset::new(dataset_original, 10, 20); for item in dataset_partial.iter() { items_partial.insert(item); } assert_eq!(dataset_partial.len(), 10); assert_eq!(items_original_1, items_partial); for item in items_original_2 { assert!(!items_partial.contains(&item)); } } #[test] fn test_split_contains_all_items_without_duplicates() { let dataset_original = FakeDataset::::new(27); let mut items_original = Vec::new(); let mut items_partial = Vec::new(); for item in dataset_original.iter() { items_original.push(item); } let dataset_partials = PartialDataset::split(dataset_original, 4); let expected_len = [6, 6, 6, 9]; for (i, dataset) in dataset_partials.iter().enumerate() { assert_eq!(dataset.len(), expected_len[i]); for item in dataset.iter() { items_partial.push(item); } } assert_eq!(items_original, items_partial); } #[test] fn test_split_chunks_contains_all_items_without_duplicates() { let dataset_original = FakeDataset::::new(27); let mut items_original = Vec::new(); let mut items_partial = Vec::new(); for item in dataset_original.iter() { items_original.push(item); } let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5); // [(2 * 5), (2 * 5), 5, 2] -> 5 complete chunks + 1 incomplete with 2 remaining items // OTOH, `split(dataset, 4)` would yield [6, 6, 6, 9] -> 4 incomplete chunks + 4 incomplete with [1, 1, 1, 4] let expected_len = [10, 10, 5, 2]; for (i, dataset) in dataset_partials.iter().enumerate() { assert_eq!(dataset.len(), expected_len[i]); for item in dataset.iter() { items_partial.push(item); } } assert_eq!(items_original, items_partial); } } ================================================ FILE: crates/burn-dataset/src/transform/sampler.rs ================================================ use crate::Dataset; use crate::transform::{RngSource, SizeConfig}; use rand::prelude::SliceRandom; use rand::{RngExt, distr::Uniform, rngs::StdRng, seq::IteratorRandom}; use std::{marker::PhantomData, ops::DerefMut, sync::Mutex}; /// Options to configure a [SamplerDataset]. #[derive(Debug, PartialEq)] pub struct SamplerDatasetOptions { /// The sampling mode. pub replace_samples: bool, /// The size source of the wrapper relative to the dataset. pub size_config: SizeConfig, /// The source of the random number generator. pub rng_source: RngSource, } impl Default for SamplerDatasetOptions { fn default() -> Self { Self { replace_samples: true, size_config: SizeConfig::Default, rng_source: RngSource::Default, } } } impl From> for SamplerDatasetOptions where T: Into, { fn from(option: Option) -> Self { match option { Some(option) => option.into(), None => Self::default(), } } } impl From for SamplerDatasetOptions { fn from(size: usize) -> Self { Self::default().with_replacement().with_fixed_size(size) } } impl SamplerDatasetOptions { /// Set the replacement mode. pub fn with_replace_samples(self, replace_samples: bool) -> Self { Self { replace_samples, ..self } } /// Set the replacement mode to WithReplacement. pub fn with_replacement(self) -> Self { self.with_replace_samples(true) } /// Set the replacement mode to WithoutReplacement. pub fn without_replacement(self) -> Self { self.with_replace_samples(false) } /// Set the size source. pub fn with_size(self, source: S) -> Self where S: Into, { Self { size_config: source.into(), ..self } } /// Set the size to the size of the source. pub fn with_source_size(self) -> Self { self.with_size(SizeConfig::Default) } /// Set the size to a fixed size. pub fn with_fixed_size(self, size: usize) -> Self { self.with_size(size) } /// Set the size to be a multiple of the ration and the source size. pub fn with_size_ratio(self, size_ratio: f64) -> Self { self.with_size(size_ratio) } /// Set the `RngSource`. pub fn with_rng(self, rng: R) -> Self where R: Into, { Self { rng_source: rng.into(), ..self } } /// Use the system rng. pub fn with_system_rng(self) -> Self { self.with_rng(RngSource::Default) } /// Use a rng, built from a seed. pub fn with_seed(self, seed: u64) -> Self { self.with_rng(seed) } } /// Sample items from a dataset. /// /// This is a convenient way of modeling a dataset as a probability distribution of a fixed size. /// You have multiple options to instantiate the dataset sampler. /// /// * With replacement (Default): This is the most efficient way of using the sampler because no state is /// required to keep indices that have been selected. /// /// * Without replacement: This has a similar effect to using a /// [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can /// set the dataset to an arbitrary size. Once every item has been used, a new cycle is /// created with a new random suffle. pub struct SamplerDataset { dataset: D, size: usize, state: Mutex, input: PhantomData, } enum SamplerState { WithReplacement(StdRng), WithoutReplacement(StdRng, Vec), } impl SamplerDataset where D: Dataset, I: Send + Sync, { /// Creates a new sampler dataset with replacement. /// /// When the sample size is less than or equal to the source dataset size, /// data will be sampled without replacement from the source dataset in /// a uniformly shuffled order. /// /// When the sample size is greater than the source dataset size, /// the entire source dataset will be sampled once for every multiple /// of the size ratios; with the remaining samples taken without replacement /// uniformly from the source. All samples will be returned uniformly shuffled. /// /// ## Arguments /// /// * `dataset`: the dataset to wrap. /// * `options`: the options to configure the sampler dataset. /// /// ## Examples /// ```rust,ignore /// use burn_dataset::transform::{ /// SamplerDataset, /// SamplerDatasetOptions, /// }; /// /// // Examples below assuming `dataset.len()` = `10`. /// /// // sample size: 5 /// // WithReplacement /// // rng: StdRng::from_os_rng() /// SamplerDataset::new(dataset, 5); /// /// // sample size: 10 (source) /// // WithReplacement /// // rng: StdRng::from_os_rng() /// SamplerDataset::new(dataset, SamplerDatasetOptions::default()); /// /// // sample size: 15 /// // WithoutReplacement /// // rng: StdRng::seed_from_u64(42) /// SamplerDataset::new( /// dataset, /// SamplerDatasetOptions::default() /// .with_size(1.5) /// .without_replacement() /// .with_rng(42), /// ); /// ``` pub fn new(dataset: D, options: O) -> Self where O: Into, { let options = options.into(); let size = options.size_config.resolve(dataset.len()); let rng = options.rng_source.into(); Self { dataset, size, state: Mutex::new(match options.replace_samples { true => SamplerState::WithReplacement(rng), false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)), }), input: PhantomData, } } /// Creates a new sampler dataset with replacement. /// /// # Arguments /// /// - `dataset`: the dataset to wrap. /// - `size`: the effective size of the sampled dataset. pub fn with_replacement(dataset: D, size: usize) -> Self { Self::new( dataset, SamplerDatasetOptions::default() .with_replacement() .with_fixed_size(size), ) } /// Creates a new sampler dataset without replacement. /// /// When the sample size is less than or equal to the source dataset size, /// data will be sampled without replacement from the source dataset in /// a uniformly shuffled order. /// /// When the sample size is greater than the source dataset size, /// the entire source dataset will be sampled once for every multiple /// of the size ratios; with the remaining samples taken without replacement /// uniformly from the source. All samples will be returned uniformly shuffled. /// /// # Arguments /// - `dataset`: the dataset to wrap. /// - `size`: the effective size of the sampled dataset. pub fn without_replacement(dataset: D, size: usize) -> Self { Self::new( dataset, SamplerDatasetOptions::default() .without_replacement() .with_fixed_size(size), ) } /// Determines if the sampler is using the "with replacement" strategy. /// /// # Returns /// - `true`: If the sampler is configured to sample with replacement. /// - `false`: If the sampler is configured to sample without replacement. pub fn is_with_replacement(&self) -> bool { match self.state.lock().unwrap().deref_mut() { SamplerState::WithReplacement(_) => true, SamplerState::WithoutReplacement(_, _) => false, } } fn index(&self) -> usize { match self.state.lock().unwrap().deref_mut() { SamplerState::WithReplacement(rng) => { rng.sample(Uniform::new(0, self.dataset.len()).unwrap()) } SamplerState::WithoutReplacement(rng, indices) => { if indices.is_empty() { // Refill the state. let idx_range = 0..self.dataset.len(); for _ in 0..(self.size / self.dataset.len()) { // No need to `.choose_multiple` here because we're using // the entire source range; and `.choose_multiple` will // not return a random sample anyway. indices.extend(idx_range.clone()) } // From `choose_multiple` documentation: // > Although the elements are selected randomly, the order of elements in // > the buffer is neither stable nor fully random. If random ordering is // > desired, shuffle the result. indices.extend(idx_range.sample(rng, self.size - indices.len())); // The real shuffling is done here. indices.shuffle(rng); } indices.pop().expect("Indices are refilled when empty.") } } } } impl Dataset for SamplerDataset where D: Dataset, I: Send + Sync, { fn get(&self, index: usize) -> Option { if index >= self.size { return None; } self.dataset.get(self.index()) } fn len(&self) -> usize { self.size } } #[cfg(test)] mod tests { #![allow(clippy::bool_assert_comparison)] use super::*; use crate::FakeDataset; use rand::SeedableRng; use std::collections::HashMap; #[test] fn test_samplerdataset_options() { let options = SamplerDatasetOptions::default(); assert_eq!(options.replace_samples, true); assert_eq!(options.size_config, SizeConfig::Default); assert_eq!(options.rng_source, RngSource::Default); // ReplacementMode let options = options.with_replace_samples(false); assert_eq!(options.replace_samples, false); let options = options.with_replacement(); assert_eq!(options.replace_samples, true); let options = options.without_replacement(); assert_eq!(options.replace_samples, false); // SourceSize let options = options.with_size(SizeConfig::Default); assert_eq!(options.size_config, SizeConfig::Default); let options = options.with_source_size(); assert_eq!(options.size_config, SizeConfig::Default); let options = options.with_fixed_size(10); assert_eq!(options.size_config, SizeConfig::Fixed(10)); let options = options.with_size_ratio(1.5); assert_eq!(options.size_config, SizeConfig::Ratio(1.5)); // RngSource let options = options.with_system_rng(); assert_eq!(options.rng_source, RngSource::Default); let options = options.with_seed(42); assert_eq!(options.rng_source, RngSource::Seed(42)); let rng = StdRng::seed_from_u64(9); let options = options.with_rng(rng); assert!(matches!(options.rng_source, RngSource::Rng(_))); } #[test] fn sampler_dataset_constructors_test() { let ds = SamplerDataset::new(FakeDataset::::new(10), 15); assert_eq!(ds.len(), 15); assert_eq!(ds.dataset.len(), 10); assert!(ds.is_with_replacement()); let ds = SamplerDataset::with_replacement(FakeDataset::::new(10), 15); assert_eq!(ds.len(), 15); assert_eq!(ds.dataset.len(), 10); assert!(ds.is_with_replacement()); let ds = SamplerDataset::without_replacement(FakeDataset::::new(10), 15); assert_eq!(ds.len(), 15); assert_eq!(ds.dataset.len(), 10); assert!(!ds.is_with_replacement()); } #[test] fn sampler_dataset_with_replacement_iter() { let factor = 3; let len_original = 10; let dataset_sampler = SamplerDataset::with_replacement( FakeDataset::::new(len_original), len_original * factor, ); let mut total = 0; for _item in dataset_sampler.iter() { total += 1; } assert_eq!(total, factor * len_original); } #[test] fn sampler_dataset_without_replacement_bucket_test() { let factor = 3; let len_original = 10; let dataset_sampler = SamplerDataset::new( FakeDataset::::new(len_original), SamplerDatasetOptions::default() .without_replacement() .with_size_ratio(factor as f64), ); let mut buckets = HashMap::new(); for item in dataset_sampler.iter() { let count = match buckets.get(&item) { Some(count) => count + 1, None => 1, }; buckets.insert(item, count); } let mut total = 0; for count in buckets.into_values() { assert_eq!(count, factor); total += count; } assert_eq!(total, factor * len_original); } #[test] fn sampler_dataset_without_replacement_uniform_order_test() { // This is a reversion test on the indices.shuffle(rng) call in SamplerDataset::index(). let size = 1000; let dataset_sampler = SamplerDataset::without_replacement(FakeDataset::::new(size), size); let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect(); let mean_delta = indices .windows(2) .map(|pair| pair[1].abs_diff(pair[0])) .sum::() as f64 / (size - 1) as f64; let expected = (size + 2) as f64 / 3.0; assert!( (mean_delta - expected).abs() <= 0.25 * expected, "Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}" ); } } ================================================ FILE: crates/burn-dataset/src/transform/selection.rs ================================================ use crate::Dataset; use crate::transform::RngSource; use rand::prelude::SliceRandom; use rand::rngs::StdRng; use std::marker::PhantomData; use std::sync::Arc; /// Generates a vector of indices from 0 to size - 1. /// /// # Arguments /// /// * `size` - The size of the dataset. /// /// # Returns /// /// A vector containing indices from 0 to size - 1. #[inline(always)] pub fn iota(size: usize) -> Vec { (0..size).collect() } /// Generates a shuffled vector of indices up to a size. /// /// # Arguments /// /// * `size` - The size of the dataset to shuffle. /// /// # Returns /// /// A vector of shuffled indices. #[inline(always)] pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec { let mut indices = iota(size); indices.shuffle(rng); indices } /// A dataset that selects a subset of indices from an existing dataset. /// /// Indices may appear multiple times, but they must be within the bounds of the original dataset. #[derive(Clone)] pub struct SelectionDataset where D: Dataset, I: Clone + Send + Sync, { /// The wrapped dataset from which to select indices. pub wrapped: Arc, /// The indices to select from the wrapped dataset. pub indices: Vec, input: PhantomData, } impl SelectionDataset where D: Dataset, I: Clone + Send + Sync, { /// Creates a new selection dataset with the given dataset and indices. /// /// Checks that all indices are within the bounds of the dataset. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// * `indices` - A slice of indices to select from the dataset. /// These indices must be within the bounds of the dataset. /// /// # Panics /// /// Panics if any index is out of bounds for the dataset. pub fn from_indices_checked(dataset: S, indices: Vec) -> Self where S: Into>, { let dataset = dataset.into(); let size = dataset.len(); if let Some(idx) = indices.iter().find(|&i| *i >= size) { panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}"); } Self::from_indices_unchecked(dataset, indices) } /// Creates a new selection dataset with the given dataset and indices without checking bounds. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// * `indices` - A vector of indices to select from the dataset. /// /// # Safety /// /// This function does not check if the indices are within the bounds of the dataset. pub fn from_indices_unchecked(dataset: S, indices: Vec) -> Self where S: Into>, { Self { wrapped: dataset.into(), indices, input: PhantomData, } } /// Creates a new selection dataset that selects all indices from the dataset. /// /// This allocates a 1-to-1 mapping of indices to the dataset size, /// essentially functioning as a no-op selection. This is only useful /// when the dataset will later be shuffled or transformed in place. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// /// # Returns /// /// A new `SelectionDataset` that selects all indices from the dataset. pub fn new_select_all(dataset: S) -> Self where S: Into>, { let dataset = dataset.into(); let size = dataset.len(); Self::from_indices_unchecked(dataset, iota(size)) } /// Creates a new selection dataset with shuffled indices. /// /// Selects every index of the dataset and shuffles them /// with randomness from the provided random number generator. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// * `rng` - A mutable reference to a random number generator. /// /// # Returns /// /// A new `SelectionDataset` with shuffled indices. pub fn new_shuffled(dataset: S, rng_source: R) -> Self where S: Into>, R: Into, { let mut this = Self::new_select_all(dataset); this.shuffle(rng_source); this } /// Shuffles the indices of the dataset using a mutable random number generator. /// /// This method modifies the dataset in place, shuffling the indices. /// /// # Arguments /// /// * `rng` - A mutable reference to a random number generator. pub fn shuffle(&mut self, rng_source: R) where R: Into, { let mut rng: StdRng = rng_source.into().into(); self.indices.shuffle(&mut rng) } /// Creates a new dataset that is a slice of the current selection dataset. /// /// Slices the *selection indices* from ``[start..end]``. /// /// Independent of future shuffles on the parent, but shares the same wrapped dataset. /// /// /// # Arguments /// /// * `start` - The start of the range. /// * `end` - The end of the range (exclusive). // TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg. pub fn slice(&self, start: usize, end: usize) -> Self { Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec()) } /// Split into `num` datasets by slicing the selection indices evenly. /// /// Split is done via `slice`, so the datasets share the same wrapped dataset. /// /// Independent of future shuffles on the parent, but shares the same wrapped dataset. /// /// # Arguments /// /// * `num` - The number of datasets to split into. /// /// # Returns /// /// A vector of `SelectionDataset` instances, each containing a subset of the indices. pub fn split(&self, num: usize) -> Vec { let n = self.indices.len(); let mut current = 0; let mut datasets = Vec::with_capacity(num); let batch_size = n / num; for i in 0..num { let start = current; let mut end = current + batch_size; if i == (num - 1) { end = n; } let dataset = self.slice(start, end); current += batch_size; datasets.push(dataset); } datasets } } impl Dataset for SelectionDataset where D: Dataset, I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { let index = self.indices.get(index)?; self.wrapped.get(*index) } fn len(&self) -> usize { self.indices.len() } } #[cfg(test)] mod tests { use super::*; use crate::FakeDataset; use rand::SeedableRng; #[test] fn test_iota() { let size = 10; let indices = iota(size); assert_eq!(indices.len(), size); assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); } #[test] fn test_shuffled_indices_same_seed_is_deterministic() { let size = 10; let mut rng1 = StdRng::seed_from_u64(10); // `StdRng` is no longer `Clone`, so its internal state cannot be duplicated. // To test determinism, we must explicitly create a second RNG from the same seed. let mut rng2 = StdRng::seed_from_u64(10); let mut expected = iota(size); expected.shuffle(&mut rng1); let indices = shuffled_indices(size, &mut rng2); assert_eq!(indices, expected); } #[test] fn test_shuffled_indices_forked_rngs_differ() { let size = 10; let mut rng1 = StdRng::seed_from_u64(10); let mut rng2 = rng1.fork(); let mut a = iota(size); let mut b = iota(size); a.shuffle(&mut rng1); b.shuffle(&mut rng2); assert_ne!(a, b); } #[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")] #[test] fn test_from_indices_checked_panics() { let source_dataset = FakeDataset::::new(27); let indices: Vec = vec![15, 1, 12, 300]; SelectionDataset::from_indices_checked(source_dataset, indices); } #[test] fn test_checked_selection_dataset() { let source_dataset = FakeDataset::::new(27); let indices: Vec = vec![15, 1, 12, 12]; let expected: Vec = indices .iter() .map(|i| source_dataset.get(*i).unwrap()) .collect(); let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone()); assert_eq!(&selection.indices, &indices); let items = selection.iter().collect::>(); assert_eq!(items, expected); } #[test] fn test_shuffled_dataset() { let dataset = FakeDataset::::new(27); let source_items = dataset.iter().collect::>(); let selection = SelectionDataset::new_shuffled(dataset, 42); let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42)); assert_eq!(&selection.indices, &indices); assert_eq!(selection.len(), source_items.len()); let expected_items: Vec<_> = indices .iter() .map(|&i| source_items[i].to_string()) .collect(); assert_eq!(&selection.iter().collect::>(), &expected_items); } #[test] fn test_slice() { let dataset = FakeDataset::::new(27); let source_items = dataset.iter().collect::>(); let selection = SelectionDataset::new_select_all(dataset); let start = 5; let end = 15; let sliced_selection = selection.slice(start, end); assert_eq!(sliced_selection.len(), end - start); #[allow(clippy::needless_range_loop)] for i in start..end { assert_eq!( sliced_selection.get(i - start), Some(source_items[i].to_string()) ); } } #[test] fn test_split() { let dataset = FakeDataset::::new(28); let source_items = dataset.iter().collect::>(); let selection = SelectionDataset::new_select_all(dataset); let split_contents: Vec> = selection .split(3) .iter() .map(|d| d.iter().collect::>()) .collect(); assert_eq!( split_contents, vec![ source_items[0..9].to_vec(), source_items[9..18].to_vec(), source_items[18..28].to_vec(), ] ); } } ================================================ FILE: crates/burn-dataset/src/transform/shuffle.rs ================================================ use crate::Dataset; use crate::transform::{RngSource, SelectionDataset}; /// A Shuffled a dataset. /// /// This is a thin wrapper around a [SelectionDataset] which selects and shuffles /// the full indices of the original dataset. /// /// Consider using [SelectionDataset] if you are only interested in /// shuffling mechanisms. /// /// Consider using [sampler dataset](crate::transform::SamplerDataset) if you /// want a probability distribution which is computed lazily. pub struct ShuffledDataset where D: Dataset, I: Clone + Send + Sync, { wrapped: SelectionDataset, } impl ShuffledDataset where D: Dataset, I: Clone + Send + Sync, { /// Creates a new selection dataset with shuffled indices. /// /// This is a thin wrapper around `SelectionDataset::new_shuffled`. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// * `rng_source` - The source of the random number generator. /// /// # Returns /// /// A new `ShuffledDataset`. pub fn new(dataset: D, rng_source: R) -> Self where R: Into, { Self { wrapped: SelectionDataset::new_shuffled(dataset, rng_source), } } /// Creates a new selection dataset with shuffled indices using a fixed seed. /// /// This is a thin wrapper around `SelectionDataset::new_shuffled_with_seed`. /// /// # Arguments /// /// * `dataset` - The original dataset to select from. /// * `seed` - A fixed seed for the random number generator. /// /// # Returns /// /// A new `ShuffledDataset`. #[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")] pub fn with_seed(dataset: D, seed: u64) -> Self { Self::new(dataset, seed) } } impl Dataset for ShuffledDataset where D: Dataset, I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { self.wrapped.get(index) } fn len(&self) -> usize { self.wrapped.len() } } #[cfg(test)] mod tests { use super::*; use crate::FakeDataset; use crate::transform::selection::shuffled_indices; use rand::SeedableRng; use rand::prelude::StdRng; #[test] fn test_shuffled_dataset() { let dataset = FakeDataset::::new(27); let source_items = dataset.iter().collect::>(); let seed = 42; #[allow(deprecated)] let shuffled = ShuffledDataset::with_seed(dataset, seed); let mut rng = StdRng::seed_from_u64(seed); let indices = shuffled_indices(source_items.len(), &mut rng); assert_eq!(shuffled.len(), source_items.len()); let expected_items: Vec<_> = indices .iter() .map(|&i| source_items[i].to_string()) .collect(); assert_eq!(&shuffled.iter().collect::>(), &expected_items); } } ================================================ FILE: crates/burn-dataset/src/transform/window.rs ================================================ use std::{cmp::max, marker::PhantomData, num::NonZeroUsize}; use crate::Dataset; /// Functionality to create a window. pub trait Window { /// Creates a window of a collection. /// /// # Returns /// /// A `Vec` representing the window. fn window(&self, current: usize, size: NonZeroUsize) -> Option>; } impl + ?Sized> Window for T { fn window(&self, current: usize, size: NonZeroUsize) -> Option> { (current..current + size.get()) .map(|x| self.get(x)) .collect() } } /// Functionality to create a `WindowsIterator`. pub trait Windows { /// Creates and returns an iterator over all the windows of length `size`. fn windows(&self, size: usize) -> WindowsIterator<'_, I>; } impl> Windows for T { /// Is empty if the `Dataset` is shorter than `size`. /// /// # Panics /// /// Panics if `size` is 0. /// /// # Examples /// /// ``` /// use crate::burn_dataset::{ /// transform::{Windows, WindowsDataset}, /// Dataset, InMemDataset, /// }; /// /// let items = [1, 2, 3, 4].to_vec(); /// let dataset = InMemDataset::new(items.clone()); /// /// for window in dataset.windows(2) { /// // do sth with window /// } /// ``` fn windows(&self, size: usize) -> WindowsIterator<'_, I> { let size = NonZeroUsize::new(size).expect("window size must be non-zero"); WindowsIterator::new(self, size) } } /// Overlapping windows iterator. pub struct WindowsIterator<'a, I> { /// The size of the windows. pub size: NonZeroUsize, current: usize, dataset: &'a dyn Dataset, } impl<'a, I> WindowsIterator<'a, I> { /// Creates a new `WindowsIterator` instance. The windows overlap. /// Is empty if the input `Dataset` is shorter than `size`. /// /// # Parameters /// /// - `dataset`: The dataset over which windows will be created. /// - `size`: The size of the windows. pub fn new(dataset: &'a dyn Dataset, size: NonZeroUsize) -> Self { WindowsIterator { current: 0, dataset, size, } } } impl Iterator for WindowsIterator<'_, I> { type Item = Vec; fn next(&mut self) -> Option> { self.current += 1; self.dataset.window(self.current - 1, self.size) } } impl Clone for WindowsIterator<'_, I> { fn clone(&self) -> Self { WindowsIterator { size: self.size, dataset: self.dataset, current: self.current, } } } /// Dataset designed to work with overlapping windows of data. pub struct WindowsDataset { /// The size of the windows. pub size: NonZeroUsize, dataset: D, input: PhantomData, } impl WindowsDataset where D: Dataset, { /// Creates a new `WindowsDataset` instance. The windows overlap. /// Is empty if the input `Dataset` is shorter than `size`. /// /// # Parameters /// /// - `dataset`: The dataset over which windows will be created. /// - `size`: The size of the windows. pub fn new(dataset: D, size: usize) -> Self where D:, { let size = NonZeroUsize::new(size).expect("window size must be non-zero"); WindowsDataset:: { size, dataset, input: PhantomData, } } } impl Dataset> for WindowsDataset where D: Dataset, I: Send + Sync, { /// Retrieves a window of items from the dataset. /// /// # Parameters /// /// - `index`: The index of the window. /// /// # Returns /// /// A vector representing the window. fn get(&self, index: usize) -> Option> { self.dataset.window(index, self.size) } /// Retrieves the number of windows in the dataset. /// /// # Returns /// /// A size representing the number of windows. fn len(&self) -> usize { let len = self.dataset.len() as isize - self.size.get() as isize + 1; max(len, 0) as usize } } #[cfg(test)] mod tests { use rstest::rstest; use crate::{ Dataset, InMemDataset, transform::{Windows, WindowsDataset}, }; #[rstest] pub fn windows_should_be_equal_to_vec_windows() { let items = [1, 2, 3, 4, 5].to_vec(); let dataset = InMemDataset::new(items.clone()); let expected = items .windows(3) .map(|x| x.to_vec()) .collect::>>(); let result = dataset.windows(3).collect::>>(); assert_eq!(result, expected); } #[rstest] pub fn windows_dataset_should_be_equal_to_vec_windows() { let items = [1, 2, 3, 4, 5].to_vec(); let dataset = InMemDataset::new(items.clone()); let expected = items .windows(3) .map(|x| x.to_vec()) .collect::>>(); let result = WindowsDataset::new(dataset, 3) .iter() .collect::>>(); assert_eq!(result, expected); } #[rstest] pub fn cloned_iterator_should_be_equal() { let items = [1, 2, 3, 4, 5].to_vec(); let dataset = InMemDataset::new(items.clone()); let original = dataset.windows(4); let cloned = original.clone(); assert!(std::ptr::eq(cloned.dataset, original.dataset)); assert_eq!(cloned.size, original.size); assert_eq!(cloned.current, original.current); } #[rstest] pub fn cloned_iterator_should_be_unaffected() { let items = [1, 2, 3, 4, 5].to_vec(); let dataset = InMemDataset::new(items.clone()); let mut original = dataset.windows(4); let cloned = original.clone(); original.current = 2; assert_ne!(cloned.current, original.current); } #[rstest] #[should_panic(expected = "window size must be non-zero")] pub fn windows_should_panic() { let items = [1, 2].to_vec(); let dataset = InMemDataset::new(items.clone()); dataset.windows(0); } #[rstest] #[should_panic(expected = "window size must be non-zero")] pub fn new_window_dataset_should_panic() { let items = [1, 2].to_vec(); let dataset = InMemDataset::new(items.clone()); WindowsDataset::new(dataset, 0); } #[rstest] pub fn window_dataset_len_should_be_equal() { let dataset = InMemDataset::new([1, 2, 3, 4].to_vec()); let result = WindowsDataset::new(dataset, 2).len(); assert_eq!(result, 3); } #[rstest] pub fn window_iterator_should_be_empty() { let dataset = InMemDataset::new([1, 2].to_vec()); let mut peekable = dataset.windows(4).peekable(); let result = peekable.peek(); assert_eq!(result, None); } #[rstest] pub fn window_dataset_len_should_be_zero() { let dataset = InMemDataset::new([1, 2].to_vec()); let result = WindowsDataset::new(dataset, 4).len(); assert_eq!(result, 0); } #[rstest] pub fn window_dataset_get_should_be_equal() { let dataset = InMemDataset::new([1, 2, 3, 4].to_vec()); let expected = Some([1, 2, 3].to_vec()); let result = WindowsDataset::new(dataset, 3).get(0); assert_eq!(result, expected); } #[rstest] pub fn window_dataset_get_should_be_none() { let dataset = InMemDataset::new([1, 2].to_vec()); let result = WindowsDataset::new(dataset, 4).get(0); assert_eq!(result, None); } } ================================================ FILE: crates/burn-dataset/src/vision/cifar.rs ================================================ //! CIFAR Dataset Module //! //! This module provides functionality for loading the CIFAR-10 and CIFAR-100 image classification datasets. //! CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks in computer vision, //! consisting of 32×32 pixel color images split into training (50,000 images) and test (10,000 images) sets. //! //! ## Dataset Variants //! - **CIFAR-10**: Contains 10 distinct classes (e.g., airplane, automobile, bird, cat) //! - CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44). //! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). //! - **CIFAR-100**: Contains 100 fine-grained classes (e.g., beaver, dolphin, oak tree) //! - CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75). //! - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). //! //! ## Usage Example //! ```rust //! use burn_dataset::vision::CifarDataset; //! use burn_dataset::vision::CifarType; //! //! // Create a CIFAR-10 dataset accessor //! let dataset = CifarDataset::new(CifarType::Cifar10); //! //! // Access training and test sets //! let train_dataset = dataset.train(); //! let test_dataset = dataset.test(); //! ``` //! ```rust //! use burn_dataset::vision::CifarDataset; //! use burn_dataset::vision::CifarType; //! //! // Create a CIFAR-100 dataset accessor //! let dataset = CifarDataset::new(CifarType::Cifar100); //! //! // Access training and test sets //! let train_dataset = dataset.train(); //! let test_dataset = dataset.test(); //! ``` use std::{path::PathBuf, sync::Mutex}; use flate2::read::GzDecoder; use tar::Archive; use crate::network::downloader; use crate::vision::ImageFolderDataset; /// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44). /// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). const CIFAR10_URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz"; /// CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75). /// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). const CIFAR100_URL: &str = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz"; /// Enum representing the types of CIFAR datasets available. /// /// CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks for image classification. /// This enum provides support for the two main CIFAR datasets. #[derive(Debug, Clone, Copy)] #[allow(dead_code)] pub enum CifarType { /// CIFAR-10 dataset containing 10 classes with 60,000 images in total. Cifar10, /// CIFAR-100 dataset containing 100 classes with 60,000 images in total. Cifar100, } /// CIFAR dataset accessor. /// /// This struct provides convenient access to the CIFAR-10 and CIFAR-100 image classification datasets. /// It automatically downloads (if not already downloaded), extracts, and loads the datasets. /// /// All images in CIFAR datasets are 32×32 pixel color images, with 50,000 images in the training set /// and 10,000 images in the test set. /// /// ## Differences between datasets /// - **CIFAR-10**: Contains 10 mutually exclusive classes such as airplane, automobile, bird, cat, etc. /// - **CIFAR-100**: Contains 100 fine-grained classes such as beaver, dolphin, etc. pub struct CifarDataset { cifar_dir: PathBuf, } impl CifarDataset { /// Creates a new CIFAR dataset accessor. /// /// # Arguments /// * `cifar_type` - Specifies whether to use CIFAR-10 or CIFAR-100 dataset pub fn new(cifar_type: CifarType) -> Self { Self { cifar_dir: download(&cifar_type), } } /// Gets the training dataset. /// /// # Returns /// An `ImageFolderDataset` instance containing 50,000 training images pub fn train(&self) -> ImageFolderDataset { ImageFolderDataset::new_classification(self.cifar_dir.join("train")).unwrap() } /// Gets the test dataset. /// /// # Returns /// An `ImageFolderDataset` instance containing 10,000 test images pub fn test(&self) -> ImageFolderDataset { ImageFolderDataset::new_classification(self.cifar_dir.join("test")).unwrap() } } /// CIFAR dataset download lock. /// /// This lock ensures that only one thread downloads the CIFAR dataset at a time. static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(()); fn download(cifar_type: &CifarType) -> PathBuf { // Acquire the lock. This will block if another thread already holds the lock. let _lock = DOWNLOAD_LOCK.lock().unwrap(); // Dataset files are stored in the burn-dataset cache directory let cache_dir = dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset"); // Cifar store directory let cifar_dir = match cifar_type { CifarType::Cifar10 => cache_dir.join("cifar10"), CifarType::Cifar100 => cache_dir.join("cifar100"), }; // Cifar dataset url let url = match cifar_type { CifarType::Cifar10 => CIFAR10_URL, CifarType::Cifar100 => CIFAR100_URL, }; // Cifar dataset archive filename let filename = match cifar_type { CifarType::Cifar10 => "cifar10.tgz", CifarType::Cifar100 => "cifar100.tgz", }; // Check for already downloaded content if !cifar_dir.exists() { // Download gzip file let bytes = downloader::download_file_as_bytes(url, filename); // Decode gzip file content and unpack archive let gz_buffer = GzDecoder::new(&bytes[..]); let mut archive = Archive::new(gz_buffer); archive.unpack(cache_dir).unwrap(); } cifar_dir } #[cfg(test)] mod tests { use super::*; use crate::{Dataset, vision::Annotation}; /// CIFAR dataset length const TRAINDATASET_LEN: usize = 50000; const TESTDATASET_LEN: usize = 10000; /// CIFAR-10 label range const CIFAR10_LABEL_MIN: usize = 0; const CIFAR10_LABEL_MAX: usize = 9; /// CIFAR-100 label range const CIFAR100_LABEL_MIN: usize = 0; const CIFAR100_LABEL_MAX: usize = 99; #[test] fn test_cifar10_download() { let cifar_dir = download(&CifarType::Cifar10); assert!(cifar_dir.exists()); } #[test] fn test_cifar100_download() { let cifar_dir = download(&CifarType::Cifar100); assert!(cifar_dir.exists()); } #[test] fn test_cifar10_len() { let dataset = CifarDataset::new(CifarType::Cifar10); let train_dataset = dataset.train(); let test_dataset = dataset.test(); assert_eq!(train_dataset.len(), TRAINDATASET_LEN); assert_eq!(test_dataset.len(), TESTDATASET_LEN); } #[test] fn test_cifar100_len() { let dataset = CifarDataset::new(CifarType::Cifar100); let train_dataset = dataset.train(); let test_dataset = dataset.test(); assert_eq!(train_dataset.len(), TRAINDATASET_LEN); assert_eq!(test_dataset.len(), TESTDATASET_LEN); } #[test] fn test_cifar10_label_range() { let dataset = CifarDataset::new(CifarType::Cifar10); let test_dataset = dataset.test(); let (min, max) = get_label_range(&test_dataset); assert_eq!(min, CIFAR10_LABEL_MIN); assert_eq!(max, CIFAR10_LABEL_MAX); } #[test] fn test_cifar100_label_range() { let dataset = CifarDataset::new(CifarType::Cifar100); let test_dataset = dataset.test(); let (min, max) = get_label_range(&test_dataset); assert_eq!(min, CIFAR100_LABEL_MIN); assert_eq!(max, CIFAR100_LABEL_MAX); } fn get_label_range(dataset: &ImageFolderDataset) -> (usize, usize) { let labels: Vec<_> = dataset.iter().map(|item| item.annotation).collect(); let mut min = 128; let mut max = 0; for label in labels { let index = match label { Annotation::Label(index) => index, _ => 0, }; if index < min { min = index; } if index > max { max = index; } } (min, max) } } ================================================ FILE: crates/burn-dataset/src/vision/image_folder.rs ================================================ use crate::transform::{Mapper, MapperDataset}; use crate::{Dataset, InMemDataset}; use globwalk::{self, DirEntry}; use image::{self, ColorType}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::{HashMap, HashSet}; use std::fs; use std::path::{Path, PathBuf}; use thiserror::Error; const SUPPORTED_FILES: [&str; 4] = ["bmp", "jpg", "jpeg", "png"]; const BBOX_MIN_NUM_VALUES: usize = 4; /// Image data type. #[derive(Debug, Copy, Clone, PartialEq)] pub enum PixelDepth { /// 8-bit unsigned. U8(u8), /// 16-bit unsigned. U16(u16), /// 32-bit floating point. F32(f32), } impl TryFrom for u8 { type Error = &'static str; fn try_from(value: PixelDepth) -> Result { if let PixelDepth::U8(v) = value { Ok(v) } else { Err("Value is not u8") } } } impl TryFrom for u16 { type Error = &'static str; fn try_from(value: PixelDepth) -> Result { if let PixelDepth::U16(v) = value { Ok(v) } else { Err("Value is not u16") } } } impl TryFrom for f32 { type Error = &'static str; fn try_from(value: PixelDepth) -> Result { if let PixelDepth::F32(v) = value { Ok(v) } else { Err("Value is not f32") } } } /// Annotation type for different tasks. #[derive(Debug, Clone, PartialEq)] pub enum Annotation { /// Image-level label. Label(usize), /// Multiple image-level labels. MultiLabel(Vec), /// Object bounding boxes. BoundingBoxes(Vec), /// Segmentation mask. SegmentationMask(SegmentationMask), } /// Segmentation mask annotation. /// For semantic segmentation, a mask has a single channel (C = 1). /// For instance segmentation, there may be multiple masks per image (C >= 1). #[derive(Debug, Clone, PartialEq)] pub struct SegmentationMask { /// Segmentation mask. pub mask: Vec, } /// Object detection bounding box annotation. #[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] pub struct BoundingBox { /// Coordinates in [x_min, y_min, width, height] format. pub coords: [f32; 4], /// Box class label. pub label: usize, } /// Image dataset item. #[derive(Debug, Clone, PartialEq)] pub struct ImageDatasetItem { /// Image as a vector with a valid image type. pub image: Vec, /// Original source image width. pub image_width: usize, /// Original source image height. pub image_height: usize, /// Annotation for the image. pub annotation: Annotation, /// Original image source. pub image_path: String, } /// Raw annotation types. #[derive(Deserialize, Serialize, Debug, Clone)] enum AnnotationRaw { Label(String), MultiLabel(Vec), BoundingBoxes(Vec), SegmentationMask(PathBuf), } #[derive(Deserialize, Serialize, Debug, Clone)] struct ImageDatasetItemRaw { /// Image path. image_path: PathBuf, /// Image annotation. annotation: AnnotationRaw, } impl ImageDatasetItemRaw { fn new>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw { ImageDatasetItemRaw { image_path: image_path.as_ref().to_path_buf(), annotation, } } } struct PathToImageDatasetItem { classes: HashMap, } fn segmentation_mask_to_vec_usize(mask_path: &PathBuf) -> Vec { // Load image from disk let image = image::open(mask_path).unwrap(); // Image as Vec // if rgb8 or rgb16, keep only the first channel assuming all channels are the same match image.color() { ColorType::L8 => image.into_luma8().iter().map(|&x| x as usize).collect(), ColorType::L16 => image.into_luma16().iter().map(|&x| x as usize).collect(), ColorType::Rgb8 => image .into_rgb8() .iter() .step_by(3) .map(|&x| x as usize) .collect(), ColorType::Rgb16 => image .into_rgb16() .iter() .step_by(3) .map(|&x| x as usize) .collect(), _ => panic!("Unrecognized image color type"), } } /// Parse the image annotation to the corresponding type. fn parse_image_annotation( annotation: &AnnotationRaw, classes: &HashMap, ) -> Annotation { // TODO: add support for other annotations // - [ ] Object bounding boxes // - [x] Segmentation mask // For now, only image classification labels and segmentation are supported. // Map class string to label id match annotation { AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()), AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel( names .iter() .map(|name| *classes.get(name).unwrap()) .collect(), ), AnnotationRaw::SegmentationMask(mask_path) => { Annotation::SegmentationMask(SegmentationMask { mask: segmentation_mask_to_vec_usize(mask_path), }) } AnnotationRaw::BoundingBoxes(v) => Annotation::BoundingBoxes(v.clone()), } } /// Retrieve all available classes from the COCO JSON fn parse_coco_classes( json: &serde_json::Value, ) -> Result, ImageLoaderError> { let mut classes = HashMap::new(); if let Some(json_classes) = json["categories"].as_array() { for class in json_classes { let id = class["id"] .as_u64() .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class ID".to_string())) .and_then(|v| { usize::try_from(v).map_err(|_| { ImageLoaderError::ParsingError("Class ID out of usize range".to_string()) }) })?; let name = class["name"] .as_str() .filter(|&s| !s.is_empty()) .ok_or_else(|| ImageLoaderError::ParsingError("Invalid class name".to_string()))? .to_string(); classes.insert(name, id); } } if classes.is_empty() { return Err(ImageLoaderError::ParsingError( "No classes found in annotations".to_string(), )); } Ok(classes) } /// Retrieve annotations from COCO JSON fn parse_coco_bbox_annotations( json: &serde_json::Value, ) -> Result, ImageLoaderError> { let mut annotations = HashMap::new(); if let Some(json_annotations) = json["annotations"].as_array() { for annotation in json_annotations { let image_id = annotation["image_id"].as_u64().ok_or_else(|| { ImageLoaderError::ParsingError("Invalid image ID in annotation".into()) })?; let class_id = annotation["category_id"] .as_u64() .ok_or_else(|| { ImageLoaderError::ParsingError("Invalid class ID in annotations".to_string()) }) .and_then(|v| { usize::try_from(v).map_err(|_| { ImageLoaderError::ParsingError( "Class ID in annotations out of usize range".to_string(), ) }) })?; let bbox_coords = annotation["bbox"] .as_array() .ok_or_else(|| ImageLoaderError::ParsingError("missing bbox array".to_string()))? .iter() .map(|v| { v.as_f64() .ok_or_else(|| { ImageLoaderError::ParsingError("invalid bbox value".to_string()) }) .map(|val| val as f32) }) .collect::, _>>()?; if bbox_coords.len() < BBOX_MIN_NUM_VALUES { return Err(ImageLoaderError::ParsingError(format!( "not enough bounding box coordinates in annotation for image {image_id}", ))); } let bbox = BoundingBox { coords: [ bbox_coords[0], bbox_coords[1], bbox_coords[2], bbox_coords[3], ], label: class_id, }; annotations .entry(image_id) .and_modify(|entry| { if let AnnotationRaw::BoundingBoxes(bboxes) = entry { bboxes.push(bbox.clone()); } }) .or_insert_with(|| AnnotationRaw::BoundingBoxes(vec![bbox])); } } if annotations.is_empty() { return Err(ImageLoaderError::ParsingError( "no annotations found".to_string(), )); } Ok(annotations) } /// Retrieve all available images from the COCO JSON fn parse_coco_images>( images_path: &P, mut annotations: HashMap, json: &serde_json::Value, ) -> Result, ImageLoaderError> { let mut images = Vec::new(); if let Some(json_images) = json["images"].as_array() { for image in json_images { let image_id = image["id"].as_u64().ok_or_else(|| { ImageLoaderError::ParsingError("Invalid image ID in image list".to_string()) })?; let file_name = image["file_name"] .as_str() .ok_or_else(|| ImageLoaderError::ParsingError("Invalid image ID".to_string()))? .to_string(); let mut image_path = images_path.as_ref().to_path_buf(); image_path.push(file_name); if !image_path.exists() { return Err(ImageLoaderError::IOError(format!( "Image {} not found", image_path.display() ))); } let annotation = annotations .remove(&image_id) .unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new())); images.push(ImageDatasetItemRaw { annotation, image_path, }); } } if images.is_empty() { return Err(ImageLoaderError::ParsingError( "No images found in annotations".to_string(), )); } Ok(images) } impl Mapper for PathToImageDatasetItem { /// Convert a raw image dataset item (path-like) to a 3D image array with a target label. fn map(&self, item: &ImageDatasetItemRaw) -> ImageDatasetItem { let annotation = parse_image_annotation(&item.annotation, &self.classes); // Load image from disk let image = image::open(&item.image_path).unwrap(); // Save image dimensions for manipulation let img_width = image.width() as usize; let img_height = image.height() as usize; // Image as Vec let img_vec = match image.color() { ColorType::L8 => image .into_luma8() .iter() .map(|&x| PixelDepth::U8(x)) .collect(), ColorType::La8 => image .into_luma_alpha8() .iter() .map(|&x| PixelDepth::U8(x)) .collect(), ColorType::L16 => image .into_luma16() .iter() .map(|&x| PixelDepth::U16(x)) .collect(), ColorType::La16 => image .into_luma_alpha16() .iter() .map(|&x| PixelDepth::U16(x)) .collect(), ColorType::Rgb8 => image .into_rgb8() .iter() .map(|&x| PixelDepth::U8(x)) .collect(), ColorType::Rgba8 => image .into_rgba8() .iter() .map(|&x| PixelDepth::U8(x)) .collect(), ColorType::Rgb16 => image .into_rgb16() .iter() .map(|&x| PixelDepth::U16(x)) .collect(), ColorType::Rgba16 => image .into_rgba16() .iter() .map(|&x| PixelDepth::U16(x)) .collect(), ColorType::Rgb32F => image .into_rgb32f() .iter() .map(|&x| PixelDepth::F32(x)) .collect(), ColorType::Rgba32F => image .into_rgba32f() .iter() .map(|&x| PixelDepth::F32(x)) .collect(), _ => panic!("Unrecognized image color type"), }; ImageDatasetItem { image: img_vec, image_width: img_width, image_height: img_height, annotation, image_path: item.image_path.display().to_string(), } } } /// Error type for [ImageFolderDataset](ImageFolderDataset). #[derive(Error, Debug)] pub enum ImageLoaderError { /// Unknown error. #[error("unknown: `{0}`")] Unknown(String), /// I/O operation error. #[error("I/O error: `{0}`")] IOError(String), /// Invalid file error. #[error("Invalid file extension: `{0}`")] InvalidFileExtensionError(String), /// Parsing error. #[error("Parsing error: `{0}`")] ParsingError(String), } type ImageDatasetMapper = MapperDataset, PathToImageDatasetItem, ImageDatasetItemRaw>; /// A generic dataset to load images from disk. pub struct ImageFolderDataset { dataset: ImageDatasetMapper, } impl Dataset for ImageFolderDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } impl ImageFolderDataset { /// Create an image classification dataset from the root folder. /// /// # Arguments /// /// * `root` - Dataset root folder. /// /// # Returns /// A new dataset instance. pub fn new_classification>(root: P) -> Result { // New dataset containing any of the supported file types ImageFolderDataset::new_classification_with(root, &SUPPORTED_FILES) } /// Create an image classification dataset from the root folder. /// The included images are filtered based on the provided extensions. /// /// # Arguments /// /// * `root` - Dataset root folder. /// * `extensions` - List of allowed extensions. /// /// # Returns /// A new dataset instance. pub fn new_classification_with( root: P, extensions: &[S], ) -> Result where P: AsRef, S: AsRef, { // Glob all images with extensions let walker = globwalk::GlobWalkerBuilder::from_patterns( root.as_ref(), &[format!( "*.{{{}}}", // "*.{ext1,ext2,ext3} extensions .iter() .map(Self::check_extension) .collect::, _>>()? .join(",") )], ) .follow_links(true) .sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path .build() .map_err(|err| ImageLoaderError::Unknown(format!("{err:?}")))? .filter_map(Result::ok); // Get all dataset items let mut items = Vec::new(); let mut classes = HashSet::new(); for img in walker { let image_path = img.path(); // Label name is represented by the parent folder name let label = image_path .parent() .ok_or_else(|| { ImageLoaderError::IOError("Could not resolve image parent folder".to_string()) })? .file_name() .ok_or_else(|| { ImageLoaderError::IOError( "Could not resolve image parent folder name".to_string(), ) })? .to_string_lossy() .into_owned(); classes.insert(label.clone()); items.push(ImageDatasetItemRaw::new( image_path, AnnotationRaw::Label(label), )) } // Sort class names let mut classes = classes.into_iter().collect::>(); classes.sort(); Self::with_items(items, &classes) } /// Create an image classification dataset with the specified items. /// /// # Arguments /// /// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. pub fn new_classification_with_items, S: AsRef>( items: Vec<(P, String)>, classes: &[S], ) -> Result { // Parse items and check valid image extension types let items = items .into_iter() .map(|(path, label)| { // Map image path and label let path = path.as_ref(); let label = AnnotationRaw::Label(label); Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; Ok(ImageDatasetItemRaw::new(path, label)) }) .collect::, _>>()?; Self::with_items(items, classes) } /// Create a multi-label image classification dataset with the specified items. /// /// # Arguments /// /// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. pub fn new_multilabel_classification_with_items, S: AsRef>( items: Vec<(P, Vec)>, classes: &[S], ) -> Result { // Parse items and check valid image extension types let items = items .into_iter() .map(|(path, labels)| { // Map image path and multi-label let path = path.as_ref(); let labels = AnnotationRaw::MultiLabel(labels); Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; Ok(ImageDatasetItemRaw::new(path, labels)) }) .collect::, _>>()?; Self::with_items(items, classes) } /// Create an image segmentation dataset with the specified items. /// /// # Arguments /// /// * `items` - List of dataset items, each item represented by a tuple `(image path, annotation path)`. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. pub fn new_segmentation_with_items, S: AsRef>( items: Vec<(P, P)>, classes: &[S], ) -> Result { // Parse items and check valid image extension types let items = items .into_iter() .map(|(image_path, mask_path)| { // Map image path and segmentation mask path let image_path = image_path.as_ref(); let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf()); Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?; Ok(ImageDatasetItemRaw::new(image_path, annotation)) }) .collect::, _>>()?; Self::with_items(items, classes) } /// Create a COCO detection dataset based on the annotations JSON and image directory. /// /// # Arguments /// /// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for /// example instances_train2017.json). /// /// * `images_path` - Path containing the images matching the annotations JSON. /// /// # Returns /// A new dataset instance. pub fn new_coco_detection, I: AsRef>( annotations_json: A, images_path: I, ) -> Result { let file = fs::File::open(annotations_json) .map_err(|e| ImageLoaderError::IOError(format!("Failed to open annotations: {e}")))?; let json: Value = serde_json::from_reader(file).map_err(|e| { ImageLoaderError::ParsingError(format!("Failed to parse annotations: {e}")) })?; let classes = parse_coco_classes(&json)?; let annotations = parse_coco_bbox_annotations(&json)?; let items = parse_coco_images(&images_path, annotations, &json)?; let dataset = InMemDataset::new(items); let mapper = PathToImageDatasetItem { classes }; let dataset = MapperDataset::new(dataset, mapper); Ok(Self { dataset }) } /// Create an image dataset with the specified items. /// /// # Arguments /// /// * `items` - Raw dataset items. /// * `classes` - Dataset class names. /// /// # Returns /// A new dataset instance. fn with_items>( items: Vec, classes: &[S], ) -> Result { // NOTE: right now we don't need to validate the supported image files since // the method is private. We assume it's already validated. let dataset = InMemDataset::new(items); // Class names to index map let classes = classes.iter().map(|c| c.as_ref()).collect::>(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() .map(|(idx, cls)| (cls.to_string(), idx)) .collect(); let mapper = PathToImageDatasetItem { classes: classes_map, }; let dataset = MapperDataset::new(dataset, mapper); Ok(Self { dataset }) } /// Check if extension is supported. fn check_extension>(extension: &S) -> Result { let extension = extension.as_ref(); if !SUPPORTED_FILES.contains(&extension) { Err(ImageLoaderError::InvalidFileExtensionError( extension.to_string(), )) } else { Ok(extension.to_string()) } } } #[cfg(test)] mod tests { use super::*; const DATASET_ROOT: &str = "tests/data/image_folder"; const SEGMASK_ROOT: &str = "tests/data/segmask_folder"; const COCO_JSON: &str = "tests/data/dataset_coco.json"; const COCO_IMAGES: &str = "tests/data/image_folder_coco"; #[test] pub fn image_folder_dataset() { let dataset = ImageFolderDataset::new_classification(DATASET_ROOT).unwrap(); // Dataset has 3 elements assert_eq!(dataset.len(), 3); assert_eq!(dataset.get(3), None); // Dataset elements should be: orange (0), red (1), red (1) assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0)); assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1)); } #[test] pub fn image_folder_dataset_filtered() { let dataset = ImageFolderDataset::new_classification_with(DATASET_ROOT, &["jpg"]).unwrap(); // Filtered dataset has 2 elements assert_eq!(dataset.len(), 2); assert_eq!(dataset.get(2), None); // Dataset elements should be: orange (0), red (1) assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0)); assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); } #[test] pub fn image_folder_dataset_with_items_sizes() { let root = Path::new(DATASET_ROOT); let items = vec![ (root.join("orange").join("dot.jpg"), "orange".to_string()), (root.join("red").join("dot.jpg"), "red".to_string()), (root.join("red").join("dot.png"), "red".to_string()), ]; let dataset = ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap(); // Dataset has 3 elements assert_eq!(dataset.len(), 3); assert_eq!(dataset.get(3), None); // Test item sizes assert_eq!( ( dataset.get(0).unwrap().image_width, dataset.get(0).unwrap().image_height ), (1, 1) ); assert_eq!( ( dataset.get(1).unwrap().image_width, dataset.get(1).unwrap().image_height ), (1, 1) ); assert_eq!( ( dataset.get(2).unwrap().image_width, dataset.get(2).unwrap().image_height ), (1, 1) ); } #[test] pub fn image_folder_dataset_with_items() { let root = Path::new(DATASET_ROOT); let items = vec![ (root.join("orange").join("dot.jpg"), "orange".to_string()), (root.join("red").join("dot.jpg"), "red".to_string()), (root.join("red").join("dot.png"), "red".to_string()), ]; let dataset = ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap(); // Dataset has 3 elements assert_eq!(dataset.len(), 3); assert_eq!(dataset.get(3), None); // Dataset elements should be: orange (0), red (1), red (1) assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0)); assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1)); } #[test] pub fn image_folder_dataset_multilabel() { let root = Path::new(DATASET_ROOT); let items = vec![ ( root.join("orange").join("dot.jpg"), vec!["dot".to_string(), "orange".to_string()], ), ( root.join("red").join("dot.jpg"), vec!["dot".to_string(), "red".to_string()], ), ( root.join("red").join("dot.png"), vec!["dot".to_string(), "red".to_string()], ), ]; let dataset = ImageFolderDataset::new_multilabel_classification_with_items( items, &["dot", "orange", "red"], ) .unwrap(); // Dataset has 3 elements assert_eq!(dataset.len(), 3); assert_eq!(dataset.get(3), None); // Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2) assert_eq!( dataset.get(0).unwrap().annotation, Annotation::MultiLabel(vec![0, 1]) ); assert_eq!( dataset.get(1).unwrap().annotation, Annotation::MultiLabel(vec![0, 2]) ); assert_eq!( dataset.get(2).unwrap().annotation, Annotation::MultiLabel(vec![0, 2]) ); } #[test] #[should_panic] pub fn image_folder_dataset_invalid_extension() { // Some invalid file extension let _ = ImageFolderDataset::new_classification_with(DATASET_ROOT, &["ico"]).unwrap(); } #[test] pub fn pixel_depth_try_into_u8() { let val = u8::MAX; let pix: u8 = PixelDepth::U8(val).try_into().unwrap(); assert_eq!(pix, val); } #[test] #[should_panic] pub fn pixel_depth_try_into_u8_invalid() { let _: u8 = PixelDepth::U16(u8::MAX as u16 + 1).try_into().unwrap(); } #[test] pub fn pixel_depth_try_into_u16() { let val = u16::MAX; let pix: u16 = PixelDepth::U16(val).try_into().unwrap(); assert_eq!(pix, val); } #[test] #[should_panic] pub fn pixel_depth_try_into_u16_invalid() { let _: u16 = PixelDepth::F32(u16::MAX as f32).try_into().unwrap(); } #[test] pub fn pixel_depth_try_into_f32() { let val = f32::MAX; let pix: f32 = PixelDepth::F32(val).try_into().unwrap(); assert_eq!(pix, val); } #[test] #[should_panic] pub fn pixel_depth_try_into_f32_invalid() { let _: f32 = PixelDepth::U16(u16::MAX).try_into().unwrap(); } #[test] pub fn parse_image_annotation_label_string() { let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]); let anno = AnnotationRaw::Label("0".to_string()); assert_eq!( parse_image_annotation(&anno, &classes), Annotation::Label(0) ); } #[test] pub fn parse_image_annotation_multilabel_string() { let classes = HashMap::from([ ("0".to_string(), 0_usize), ("1".to_string(), 1_usize), ("2".to_string(), 2_usize), ]); let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]); assert_eq!( parse_image_annotation(&anno, &classes), Annotation::MultiLabel(vec![0, 2]) ); } #[test] pub fn segmask_image_path_to_vec_usize() { let root = Path::new(SEGMASK_ROOT); // checkerboard mask const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, ]; assert_eq!( TEST_CHECKERBOARD_MASK_PATTERN .iter() .map(|&x| x as usize) .collect::>(), segmentation_mask_to_vec_usize(&root.join("annotations").join("mask_checkerboard.png")), ); // random 2 colors mask const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, ]; assert_eq!( TEST_RANDOM2COLORS_MASK_PATTERN .iter() .map(|&x| x as usize) .collect::>(), segmentation_mask_to_vec_usize( &root.join("annotations").join("mask_random_2colors.png") ), ); // random 3 colors mask const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, 3, 3, 2, 1, 2, 2, ]; assert_eq!( TEST_RANDOM3COLORS_MASK_PATTERN .iter() .map(|&x| x as usize) .collect::>(), segmentation_mask_to_vec_usize( &root.join("annotations").join("mask_random_3colors.png") ), ); } #[test] pub fn segmask_folder_dataset() { let root = Path::new(SEGMASK_ROOT); let items = vec![ ( root.join("images").join("image_checkerboard.png"), root.join("annotations").join("mask_checkerboard.png"), ), ( root.join("images").join("image_random_2colors.png"), root.join("annotations").join("mask_random_2colors.png"), ), ( root.join("images").join("image_random_3colors.png"), root.join("annotations").join("mask_random_3colors.png"), ), ]; let dataset = ImageFolderDataset::new_segmentation_with_items( items, &[ "foo", // 0 "bar", // 1 "baz", // 2 "qux", // 3 ], ) .unwrap(); // Dataset has 3 elements; each (image, annotation) is a single item assert_eq!(dataset.len(), 3); assert_eq!(dataset.get(3), None); // checkerboard mask const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [ 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, ]; assert_eq!( dataset.get(0).unwrap().annotation, Annotation::SegmentationMask(SegmentationMask { mask: TEST_CHECKERBOARD_MASK_PATTERN .iter() .map(|&x| x as usize) .collect() }) ); // random 2 colors mask const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [ 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, ]; assert_eq!( dataset.get(1).unwrap().annotation, Annotation::SegmentationMask(SegmentationMask { mask: TEST_RANDOM2COLORS_MASK_PATTERN .iter() .map(|&x| x as usize) .collect() }) ); // random 3 colors mask const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [ 3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3, 3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1, 3, 3, 2, 1, 2, 2, ]; assert_eq!( dataset.get(2).unwrap().annotation, Annotation::SegmentationMask(SegmentationMask { mask: TEST_RANDOM3COLORS_MASK_PATTERN .iter() .map(|&x| x as usize) .collect() }) ); } #[test] pub fn coco_detection_dataset() { let dataset = ImageFolderDataset::new_coco_detection(COCO_JSON, COCO_IMAGES).unwrap(); assert_eq!(dataset.len(), 3); // we have only three images defined assert_eq!(dataset.get(3), None); const TWO_DOTS_AND_TRIANGLE_B1: BoundingBox = BoundingBox { coords: [3.125_172, 18.090_784, 10.960_11, 10.740_027], label: 0, }; const TWO_DOTS_AND_TRIANGLE_B2: BoundingBox = BoundingBox { coords: [3.257_221_5, 3.037_139, 10.563_961, 10.828_06], label: 0, }; const TWO_DOTS_AND_TRIANGLE_B3: BoundingBox = BoundingBox { coords: [15.097_662, 3.389_271, 12.632_737, 11.180_193], label: 1, }; const DOTS_TRIANGLE_B1: BoundingBox = BoundingBox { coords: [3.125_172, 17.914_719, 10.828_06, 11.004_127], label: 0, }; const DOTS_TRIANGLE_B2: BoundingBox = BoundingBox { coords: [15.273_727, 3.301_238, 12.192_573, 11.708_39], label: 1, }; const ONE_DOT_B1: BoundingBox = BoundingBox { coords: [10.079_78, 9.595_598, 10.960_11, 11.356_258], label: 0, }; for item in dataset.iter() { let file_name = Path::new(&item.image_path).file_name().unwrap(); match item.annotation { // check if the number of bounding boxes is correct Annotation::BoundingBoxes(v) => { if file_name == "two_dots_and_triangle.jpg" { assert_eq!(v.len(), 3); assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B1)); assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B2)); assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B3)); } else if file_name == "dot_triangle.jpg" { assert_eq!(v.len(), 2); assert!(v.contains(&DOTS_TRIANGLE_B1)); assert!(v.contains(&DOTS_TRIANGLE_B2)); } else if file_name == "one_dot.jpg" { assert_eq!(v.len(), 1); assert!(v.contains(&ONE_DOT_B1)); } else { panic!("{}", format!("unexpected image name: {}", item.image_path)); } } _ => panic!("unexpected annotation"), } } } } ================================================ FILE: crates/burn-dataset/src/vision/mnist.rs ================================================ use std::fs::{File, create_dir_all}; use std::io::{Read, Seek, SeekFrom}; use std::path::{Path, PathBuf}; use flate2::read::GzDecoder; use serde::{Deserialize, Serialize}; use crate::{ Dataset, InMemDataset, transform::{Mapper, MapperDataset}, }; use crate::network::downloader::download_file_as_bytes; // CVDF mirror of http://yann.lecun.com/exdb/mnist/ const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/"; const TRAIN_IMAGES: &str = "train-images-idx3-ubyte"; const TRAIN_LABELS: &str = "train-labels-idx1-ubyte"; const TEST_IMAGES: &str = "t10k-images-idx3-ubyte"; const TEST_LABELS: &str = "t10k-labels-idx1-ubyte"; const WIDTH: usize = 28; const HEIGHT: usize = 28; /// MNIST item. #[derive(Deserialize, Serialize, Debug, Clone)] pub struct MnistItem { /// Image as a 2D array of floats. pub image: [[f32; WIDTH]; HEIGHT], /// Label of the image. pub label: u8, } #[derive(Deserialize, Debug, Clone)] struct MnistItemRaw { pub image_bytes: Vec, pub label: u8, } struct BytesToImage; impl Mapper for BytesToImage { /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). fn map(&self, item: &MnistItemRaw) -> MnistItem { // Ensure the image dimensions are correct. debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT); // Convert the image to a 2D array of floats. let mut image_array = [[0f32; WIDTH]; HEIGHT]; for (i, pixel) in item.image_bytes.iter().enumerate() { let x = i % WIDTH; let y = i / HEIGHT; image_array[y][x] = *pixel as f32; } MnistItem { image: image_array, label: item.label, } } } type MappedDataset = MapperDataset, BytesToImage, MnistItemRaw>; /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000 /// images per class. There are 60,000 training images and 10,000 test images. /// /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist). pub struct MnistDataset { dataset: MappedDataset, } impl Dataset for MnistDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } impl MnistDataset { /// Creates a new train dataset. pub fn train() -> Self { Self::new("train") } /// Creates a new test dataset. pub fn test() -> Self { Self::new("test") } fn new(split: &str) -> Self { // Download dataset let root = MnistDataset::download(split); // MNIST is tiny so we can load it in-memory // Train images (u8): 28 * 28 * 60000 = 47.04Mb // Test images (u8): 28 * 28 * 10000 = 7.84Mb let images = MnistDataset::read_images(&root, split); let labels = MnistDataset::read_labels(&root, split); // Collect as vector of MnistItemRaw let items: Vec<_> = images .into_iter() .zip(labels) .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label }) .collect(); let dataset = InMemDataset::new(items); let dataset = MapperDataset::new(dataset, BytesToImage); Self { dataset } } /// Download the MNIST dataset files from the web. /// Panics if the download cannot be completed or the content of the file cannot be written to disk. fn download(split: &str) -> PathBuf { // Dataset files are stored in the burn-dataset cache directory let cache_dir = dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset"); let split_dir = cache_dir.join("mnist").join(split); if !split_dir.exists() { create_dir_all(&split_dir).expect("Failed to create base directory"); } // Download split files match split { "train" => { MnistDataset::download_file(TRAIN_IMAGES, &split_dir); MnistDataset::download_file(TRAIN_LABELS, &split_dir); } "test" => { MnistDataset::download_file(TEST_IMAGES, &split_dir); MnistDataset::download_file(TEST_LABELS, &split_dir); } _ => panic!("Invalid split specified {split}"), }; split_dir } /// Download a file from the MNIST dataset URL to the destination directory. /// File download progress is reported with the help of a [progress bar](indicatif). fn download_file>(name: &str, dest_dir: &P) -> PathBuf { // Output file name let file_name = dest_dir.as_ref().join(name); if !file_name.exists() { // Download gzip file let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name); // Create file to write the downloaded content to let mut output_file = File::create(&file_name).unwrap(); // Decode gzip file content and write to disk let mut gz_buffer = GzDecoder::new(&bytes[..]); std::io::copy(&mut gz_buffer, &mut output_file).unwrap(); } file_name } /// Read images at the provided path for the specified split. /// Each image is a vector of bytes. fn read_images>(root: &P, split: &str) -> Vec> { let file_name = if split == "train" { TRAIN_IMAGES } else { TEST_IMAGES }; let file_name = root.as_ref().join(file_name); // Read number of images from 16-byte header metadata let mut f = File::open(file_name).unwrap(); let mut buf = [0u8; 4]; let _ = f.seek(SeekFrom::Start(4)).unwrap(); f.read_exact(&mut buf) .expect("Should be able to read image file header"); let size = u32::from_be_bytes(buf); let mut buf_images: Vec = vec![0u8; WIDTH * HEIGHT * (size as usize)]; let _ = f.seek(SeekFrom::Start(16)).unwrap(); f.read_exact(&mut buf_images) .expect("Should be able to read image file header"); buf_images .chunks(WIDTH * HEIGHT) .map(|chunk| chunk.to_vec()) .collect() } /// Read labels at the provided path for the specified split. fn read_labels>(root: &P, split: &str) -> Vec { let file_name = if split == "train" { TRAIN_LABELS } else { TEST_LABELS }; let file_name = root.as_ref().join(file_name); // Read number of labels from 8-byte header metadata let mut f = File::open(file_name).unwrap(); let mut buf = [0u8; 4]; let _ = f.seek(SeekFrom::Start(4)).unwrap(); f.read_exact(&mut buf) .expect("Should be able to read label file header"); let size = u32::from_be_bytes(buf); let mut buf_labels: Vec = vec![0u8; size as usize]; let _ = f.seek(SeekFrom::Start(8)).unwrap(); f.read_exact(&mut buf_labels) .expect("Should be able to read labels from file"); buf_labels } } ================================================ FILE: crates/burn-dataset/src/vision/mod.rs ================================================ #[cfg(feature = "builtin-sources")] mod cifar; mod image_folder; mod mnist; #[cfg(feature = "builtin-sources")] pub use cifar::*; pub use image_folder::*; pub use mnist::*; ================================================ FILE: crates/burn-dataset/tests/data/dataset-fmt.csv ================================================ HI1 1 true 1.0 HI2 1 false 1.0 ================================================ FILE: crates/burn-dataset/tests/data/dataset.csv ================================================ column_str,column_int,column_bool,column_float HI1,1,true,1.0 HI2,1,false,1.0 ================================================ FILE: crates/burn-dataset/tests/data/dataset.json ================================================ {"column_str":"HI1","column_bytes":[1,2,3,3],"column_int":1,"column_bool":true,"column_float":1.0} {"column_str":"HI2","column_bytes":[1,2,3,3],"column_int":1,"column_bool":false,"column_float":1.0} ================================================ FILE: crates/burn-dataset/tests/data/dataset_coco.json ================================================ { "images": [ { "width": 32, "height": 32, "id": 0, "file_name": "two_dots_and_triangle.jpg" }, { "width": 32, "height": 32, "id": 1, "file_name": "dot_triangle.jpg" }, { "width": 32, "height": 32, "id": 2, "file_name": "one_dot.jpg" } ], "categories": [ { "id": 0, "name": "dot" }, { "id": 1, "name": "triangle" } ], "annotations": [ { "id": 0, "image_id": 0, "category_id": 0, "segmentation": [], "bbox": [ 3.1251719394773056, 18.0907840440165, 10.96011004126548, 10.740027510316379 ], "ignore": 0, "iscrowd": 0, "area": 117.71188335928603 }, { "id": 1, "image_id": 0, "category_id": 0, "segmentation": [], "bbox": [ 3.2572214580467658, 3.0371389270976605, 10.563961485557085, 10.828060522696012 ], "ignore": 0, "iscrowd": 0, "area": 114.38721432504178 }, { "id": 2, "image_id": 0, "category_id": 1, "segmentation": [], "bbox": [ 15.097661623108666, 3.3892709766162312, 12.632737276478679, 11.18019257221458 ], "ignore": 0, "iscrowd": 0, "area": 141.23643546522516 }, { "id": 3, "image_id": 1, "category_id": 0, "segmentation": [], "bbox": [ 3.125171939477304, 17.914718019257222, 10.82806052269601, 11.004126547455297 ], "ignore": 0, "iscrowd": 0, "area": 119.15334825525184 }, { "id": 4, "image_id": 1, "category_id": 1, "segmentation": [], "bbox": [ 15.27372764786794, 3.301237964236589, 12.192572214580478, 11.708390646492433 ], "ignore": 0, "iscrowd": 0, "area": 142.7553984738776 }, { "id": 5, "image_id": 2, "category_id": 0, "segmentation": [], "bbox": [ 10.07977991746905, 9.59559834938102, 10.960110041265464, 11.356258596973863 ], "ignore": 0, "iscrowd": 0, "area": 124.46584387990049 } ], "info": { "year": 2024, "version": "1.0", "description": "", "contributor": "", "url": "", "date_created": "2024-12-11 22:16:31.823494" } } ================================================ FILE: crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt ================================================ 1 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 1 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 1 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 1 2 1 2 1 2 1 2 2 1 2 1 2 1 2 1 ================================================ FILE: crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt ================================================ 1 2 1 1 1 2 1 1 1 2 1 1 1 1 2 1 2 2 2 1 2 1 2 2 2 2 2 2 2 2 1 1 2 2 2 1 2 1 1 1 1 1 2 2 2 2 2 1 2 2 1 2 1 2 1 2 2 1 1 1 1 1 1 1 ================================================ FILE: crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt ================================================ 3 1 3 3 1 1 3 2 3 3 3 3 1 3 2 1 2 2 2 2 1 1 2 2 1 1 1 3 3 3 2 3 2 2 3 2 3 3 1 3 1 3 3 1 1 3 2 1 2 2 2 1 2 1 2 3 3 1 3 3 2 1 2 2 ================================================ FILE: crates/burn-dataset/tests/data/text_folder/negative/sample1.txt ================================================ This is a negative text sample for testing the text folder dataset functionality. ================================================ FILE: crates/burn-dataset/tests/data/text_folder/negative/sample2.txt ================================================ 另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。 ================================================ FILE: crates/burn-dataset/tests/data/text_folder/positive/sample1.txt ================================================ This is a positive text sample for testing the text folder dataset functionality. ================================================ FILE: crates/burn-dataset/tests/data/text_folder/positive/sample2.txt ================================================ 另一个正面文本样本,以确保数据集能够处理同一类别中的多个文件。 ================================================ FILE: crates/burn-derive/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Derive crate for the Burn framework" edition.workspace = true keywords = [] license.workspace = true name = "burn-derive" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-derive" version.workspace = true [lints] workspace = true [lib] proc-macro = true [dependencies] proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } derive-new = { workspace = true } ================================================ FILE: crates/burn-derive/README.md ================================================ # Burn Derive This crate should only be used with [burn](https://github.com/tracel-ai/burn). [![Current Crates.io Version](https://img.shields.io/crates/v/burn-derive.svg)](https://crates.io/crates/burn-derive) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-derive/blob/master/README.md) ================================================ FILE: crates/burn-derive/src/config/analyzer.rs ================================================ use super::ConfigEnumAnalyzer; use crate::config::ConfigStructAnalyzer; use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer}; use proc_macro2::TokenStream; use quote::quote; use syn::{Field, Ident}; pub struct ConfigAnalyzerFactory {} pub trait ConfigAnalyzer { fn gen_new_fn(&self) -> TokenStream { quote! {} } fn gen_builder_fns(&self) -> TokenStream { quote! {} } fn gen_serde_impl(&self) -> TokenStream; fn gen_clone_impl(&self) -> TokenStream; fn gen_display_impl(&self) -> TokenStream; fn gen_config_impl(&self) -> TokenStream; } impl ConfigAnalyzerFactory { pub fn new() -> Self { Self {} } pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box { let name = item.ident.clone(); let config_type = parse_asm(item); match config_type { ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)), ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)), } } fn create_struct_analyzer(&self, name: Ident, fields: Vec) -> ConfigStructAnalyzer { let fields = fields.into_iter().map(FieldTypeAnalyzer::new); let mut fields_required = Vec::new(); let mut fields_option = Vec::new(); let mut fields_default = Vec::new(); for field in fields { let attributes: Vec = field .attributes() .filter(|attr| attr.has_name("config")) .map(|attr| attr.item()) .collect(); if !attributes.is_empty() { let item = attributes.first().unwrap().clone(); fields_default.push((field.clone(), item)); continue; } if field.is_of_type(&["Option"]) { fields_option.push(field.clone()); continue; } fields_required.push(field.clone()); } ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default) } fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer { ConfigEnumAnalyzer::new(name, data) } } enum ConfigType { Struct(Vec), Enum(syn::DataEnum), } fn parse_asm(ast: &syn::DeriveInput) -> ConfigType { match &ast.data { syn::Data::Struct(struct_data) => { ConfigType::Struct(struct_data.fields.clone().into_iter().collect()) } syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()), syn::Data::Union(_) => panic!("Only struct and enum can be derived"), } } ================================================ FILE: crates/burn-derive/src/config/analyzer_enum.rs ================================================ use crate::shared::enum_variant::map_enum_variant; use super::ConfigAnalyzer; use proc_macro2::{Ident, TokenStream}; use quote::quote; pub struct ConfigEnumAnalyzer { name: Ident, data: syn::DataEnum, } impl ConfigEnumAnalyzer { pub fn new(name: Ident, data: syn::DataEnum) -> Self { Self { name, data } } fn serde_enum_ident(&self) -> Ident { Ident::new(&format!("{}Serde", self.name), self.name.span()) } fn gen_serde_enum(&self) -> TokenStream { let enum_name = self.serde_enum_ident(); let data = &self.data.variants; quote! { #[derive(burn::serde::Serialize, burn::serde::Deserialize)] #[serde(crate = "burn::serde")] enum #enum_name { #data } } } fn gen_serialize_fn(&self) -> TokenStream { let enum_name = self.serde_enum_ident(); let variants = self.data.variants.iter().map(|variant| { let variant_name = &variant.ident; let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() }); quote! { Self::#variant_name #inputs => #enum_name::#variant_name #outputs } }); let name = &self.name; quote! { impl burn::serde::Serialize for #name { fn serialize(&self, serializer: S) -> Result where S: burn::serde::Serializer { let serde_state = match self { #(#variants),* }; serde_state.serialize(serializer) } } } } fn gen_deserialize_fn(&self) -> TokenStream { let enum_name = self.serde_enum_ident(); let variants = self.data.variants.iter().map(|variant| { let variant_name = &variant.ident; let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() }); quote! { #enum_name::#variant_name #inputs => Self::#variant_name #outputs } }); let name = &self.name; quote! { impl<'de> burn::serde::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> Result where D: burn::serde::Deserializer<'de> { let serde_state = #enum_name::deserialize(deserializer)?; Ok(match serde_state { #(#variants),* }) } } } } } impl ConfigAnalyzer for ConfigEnumAnalyzer { fn gen_serde_impl(&self) -> TokenStream { let struct_gen = self.gen_serde_enum(); let serialize_gen = self.gen_serialize_fn(); let deserialize_gen = self.gen_deserialize_fn(); quote! { #struct_gen #serialize_gen #deserialize_gen } } fn gen_clone_impl(&self) -> TokenStream { let variants = self.data.variants.iter().map(|variant| { let variant_name = &variant.ident; let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() }); quote! { Self::#variant_name #inputs => Self::#variant_name #outputs } }); let name = &self.name; quote! { impl Clone for #name { fn clone(&self) -> Self { match self { #(#variants),* } } } } } fn gen_display_impl(&self) -> TokenStream { let name = &self.name; quote! { impl core::fmt::Display for #name { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(&burn::config::config_to_json(self)) } } } } fn gen_config_impl(&self) -> TokenStream { let name = &self.name; quote! { impl burn::config::Config for #name { } } } } ================================================ FILE: crates/burn-derive/src/config/analyzer_struct.rs ================================================ use super::ConfigAnalyzer; use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer}; use proc_macro2::{Ident, TokenStream}; use quote::quote; pub struct ConfigStructAnalyzer { name: Ident, fields_required: Vec, fields_option: Vec, fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, } impl ConfigStructAnalyzer { pub fn new( name: Ident, fields_required: Vec, fields_option: Vec, fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>, ) -> Self { Self { name, fields_required, fields_option, fields_default, } } fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream { let name = &self.name; quote! { impl #name { #tokens } } } fn names(&self) -> Vec { let mut names = Vec::new(); for field in self.fields_required.iter() { names.push(field.clone()); } for field in self.fields_option.iter() { names.push(field.clone()); } for (field, _) in self.fields_default.iter() { names.push(field.clone()); } names } fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec { let mut name_types = Vec::new(); for field in names.iter() { let name = field.ident(); let ty = &field.field.ty; name_types.push(quote! { #name: #ty }); } name_types } fn serde_struct_ident(&self) -> Ident { Ident::new(&format!("{}Serde", self.name), self.name.span()) } fn gen_serialize_fn( &self, struct_name: &Ident, struct_gen: &TokenStream, names: &[FieldTypeAnalyzer], ) -> TokenStream { let name = &self.name; let names = names.iter().map(|name| { let name = name.ident(); quote! { #name: self.#name.clone() } }); quote! { impl burn::serde::Serialize for #name { fn serialize(&self, serializer: S) -> Result where S: burn::serde::Serializer { #[derive(burn::serde::Serialize)] #[serde(crate = "burn::serde")] #struct_gen let serde_state = #struct_name { #(#names),* }; serde_state.serialize(serializer) } } } } fn gen_deserialize_fn( &self, struct_name: &Ident, struct_gen: &TokenStream, names: &[FieldTypeAnalyzer], ) -> TokenStream { let name = &self.name; let names = names.iter().map(|name| { let name = name.ident(); quote! { #name: serde_state.#name } }); quote! { impl<'de> burn::serde::Deserialize<'de> for #name { fn deserialize(deserializer: D) -> Result where D: burn::serde::Deserializer<'de> { #[derive(burn::serde::Deserialize)] #[serde(crate = "burn::serde")] #struct_gen let serde_state = #struct_name::deserialize(deserializer)?; Ok(#name { #(#names),* }) } } } } fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream { let struct_name = self.serde_struct_ident(); quote! { struct #struct_name { #(#names),* } } } } impl ConfigAnalyzer for ConfigStructAnalyzer { fn gen_new_fn(&self) -> TokenStream { let mut body = quote! {}; let mut args = Vec::new(); let mut fn_docs = quote! {}; let mut has_field_docs = false; let mut has_required_docs = false; let mut has_option_docs = false; let mut has_default_docs = false; let mut docs_header = |fn_docs: &mut TokenStream, required_docs: bool, option_docs: bool, default_docs: bool| { if !has_field_docs { has_field_docs = true; fn_docs.extend(quote! { #[doc = "# Arguments"] }); } if !has_required_docs && required_docs { fn_docs.extend(quote! { #[doc = "###### Required Arguments"] }); has_required_docs = true; } if !has_option_docs && option_docs { fn_docs.extend(quote! { #[doc = "###### Optional Arguments"] }); has_option_docs = true; } if !has_default_docs && default_docs { fn_docs.extend(quote! { #[doc = "###### Default Arguments"] }); has_default_docs = true; } }; for field in self.fields_required.iter() { let name = field.ident(); let ty = &field.field.ty; let docs = field.docs(); body.extend(quote! { #name: #name, }); args.push(quote! { #name: #ty }); docs_header(&mut fn_docs, true, false, false); let doc_str = format!("###### `{}`\n\n", quote!(#name)); fn_docs.extend(quote! { #[doc = #doc_str] #(#docs)* }); } for field in self.fields_option.iter() { let name = field.ident(); let docs = field.docs(); body.extend(quote! { #name: None, }); docs_header(&mut fn_docs, false, true, false); let default_doc = "- Defaults to `None`"; let doc_str = format!("###### `{}`\n", quote!(#name)); fn_docs.extend(quote! { #[doc = #doc_str] #(#docs)* #[doc = #default_doc] }); } for (field, attribute) in self.fields_default.iter() { let name = field.ident(); let value = &attribute.value; let docs = field.docs(); match value { syn::Lit::Str(value) => { let stream: proc_macro2::TokenStream = value.value().parse().unwrap(); body.extend(quote! { #name: #stream, }); } _ => { body.extend(quote! { #name: #value, }); } }; docs_header(&mut fn_docs, false, false, true); let default_doc = format!("- Defaults to `{}`", quote!(#value)); let doc_str = format!("###### `{}`\n", quote!(#name)); fn_docs.extend(quote! { #[doc = #doc_str] #(#docs)* #[doc = #default_doc] }); } let body = quote! { #[doc = "Create a new instance of the config."] #fn_docs #[allow(clippy::too_many_arguments)] pub fn new( #(#args),* ) -> Self { Self { #body } } }; self.wrap_impl_block(body) } fn gen_builder_fns(&self) -> TokenStream { let mut body = quote! {}; for (field, attribute) in self.fields_default.iter() { let name = field.ident(); let ty = &field.field.ty; let value = &attribute.value; let docs = field.docs(); let default_doc = format!("- Defaults to `{}`", quote!(#value)); let doc_str = format!( "Sets the value for the field [`{}`](Self::{0}).\n\n", quote!(#name) ); let fn_docs = quote! { #[doc = #doc_str] #(#docs)* #[doc = #default_doc] }; let fn_name = Ident::new(&format!("with_{name}"), name.span()); body.extend(quote! { #fn_docs pub fn #fn_name(mut self, #name: #ty) -> Self { self.#name = #name; self } }); } for field in self.fields_option.iter() { let name = field.ident(); let ty = &field.field.ty; let docs = field.docs(); let default_doc = "- Defaults to `None`"; let doc_str = format!( "Sets the value for the field [`{}`](Self::{0}).\n\n", quote!(#name) ); let fn_docs = quote! { #[doc = #doc_str] #(#docs)* #[doc = #default_doc] }; let fn_name = Ident::new(&format!("with_{name}"), name.span()); body.extend(quote! { #fn_docs pub fn #fn_name(mut self, #name: #ty) -> Self { self.#name = #name; self } }); } self.wrap_impl_block(body) } fn gen_serde_impl(&self) -> TokenStream { let names = self.names(); let struct_name = self.serde_struct_ident(); let name_types = self.name_types(&names); let struct_gen = self.gen_serde_struct(&name_types); let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names); let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names); quote! { #serialize_gen #deserialize_gen } } fn gen_clone_impl(&self) -> TokenStream { let name = &self.name; let names = self.names().into_iter().map(|name| { let name = name.ident(); quote! { #name: self.#name.clone() } }); quote! { impl Clone for #name { fn clone(&self) -> Self { Self { #(#names),* } } } } } fn gen_display_impl(&self) -> TokenStream { let name = &self.name; quote! { impl core::fmt::Display for #name { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(&burn::config::config_to_json(self)) } } } } fn gen_config_impl(&self) -> TokenStream { let name = &self.name; quote! { impl burn::config::Config for #name { } } } } ================================================ FILE: crates/burn-derive/src/config/base.rs ================================================ use super::ConfigAnalyzerFactory; use quote::quote; pub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream { let factory = ConfigAnalyzerFactory::new(); let analyzer = factory.create_analyzer(item); let constructor = analyzer.gen_new_fn(); let builders = analyzer.gen_builder_fns(); let serde = analyzer.gen_serde_impl(); let clone = analyzer.gen_clone_impl(); let display = analyzer.gen_display_impl(); let config_impl = analyzer.gen_config_impl(); quote! { #config_impl #constructor #builders #serde #clone #display } .into() } ================================================ FILE: crates/burn-derive/src/config/mod.rs ================================================ mod analyzer; mod analyzer_enum; mod analyzer_struct; mod base; pub(crate) use analyzer::*; pub(crate) use analyzer_enum::*; pub(crate) use analyzer_struct::*; pub(crate) use base::*; ================================================ FILE: crates/burn-derive/src/lib.rs ================================================ #![warn(missing_docs)] //! The derive crate of Burn. #[macro_use] extern crate derive_new; use proc_macro::TokenStream; pub(crate) mod config; pub(crate) mod module; pub(crate) mod record; pub(crate) mod shared; /// Derive macro for the `Module` trait. /// /// # Sub-modules /// /// By default, the macro automatically detects sub-modules and parameters as module types. /// /// Any field not recognized as a module type is assumed to be a non-module /// and is skipped by the module system (not persistent, not visited). /// /// ## Generics /// /// Generic type parameters (e.g., `field: M`) are assumed to be sub-modules by default. /// If a generic field represents some other runtime state or configuration, you can use /// the `#[module(skip)]` attribute to provide a hint. /// /// # Field Attributes /// /// ## `#[module(skip)]` /// /// Explicitly marks a field to be ignored by the module derive. /// /// Skipped fields are not parameters, not modules, and are not persistent. /// This is equivalent to the deprecated `Ignored` wrapper. /// /// ### Requirements /// /// The field must implement: `Debug + Clone + Send`. /// /// # Example /// /// ```ignore /// #[derive(Module, Debug)] /// pub struct MyModule { /// /// A normal parameter. /// weights: Param>, /// /// A field configured at runtime. /// dropout_prob: f64, /// /// A field that is recomputed at runtime. /// cached_mask: Option>, /// /// A field that contains some debug state. /// debug_state: String, /// /// Treated as a module (default for generics). /// inner: M, /// /// Hint required: this generic is NOT a module. /// #[module(skip)] /// other: N, /// } /// ``` #[proc_macro_derive(Module, attributes(module))] pub fn module_derive(input: TokenStream) -> TokenStream { let input = syn::parse(input).unwrap(); module::derive_impl(&input) } /// Derive macro for the record. #[proc_macro_derive(Record)] pub fn record_derive(input: TokenStream) -> TokenStream { let input = syn::parse(input).unwrap(); record::derive_impl(&input) } /// Derive macro for the config. #[proc_macro_derive(Config, attributes(config))] pub fn config_derive(input: TokenStream) -> TokenStream { let item = syn::parse(input).unwrap(); config::derive_impl(&item) } ================================================ FILE: crates/burn-derive/src/module/base.rs ================================================ use super::{ codegen::{generate_module_const, generate_module_standard}, codegen_enum::EnumModuleCodegen, codegen_struct::StructModuleCodegen, }; use proc_macro::TokenStream; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream { let has_backend = ast .generics .type_params() .map(|param| param.ident == "B") .reduce(|accum, is_backend| is_backend || accum) .unwrap_or(false); match &ast.data { syn::Data::Struct(_) => match StructModuleCodegen::from_ast(ast) { Ok(struct_codegen) => { if has_backend { generate_module_standard(ast, struct_codegen) } else { generate_module_const(ast) } } Err(err) => err.to_compile_error(), }, syn::Data::Enum(_data) => match EnumModuleCodegen::from_ast(ast) { Ok(enum_codegen) => { if has_backend { generate_module_standard(ast, enum_codegen) } else { generate_module_const(ast) } } Err(err) => err.to_compile_error(), }, syn::Data::Union(_) => { syn::Error::new_spanned(ast, "Union modules aren't supported").to_compile_error() } } .into() } ================================================ FILE: crates/burn-derive/src/module/codegen.rs ================================================ use super::{display, record::ModuleRecordCodegen}; use crate::{ module::generics::{GenericKind, ModuleGenerics}, shared::generics::GenericsHelper, }; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Attribute, Generics, parse_quote}; /// Basic trait to be implemented for Module generation. pub(crate) trait ModuleCodegen { type RecordCodegen: ModuleRecordCodegen; fn gen_num_params(&self) -> TokenStream; fn gen_visit(&self) -> TokenStream; fn gen_collect_devices(&self) -> TokenStream; fn gen_to_device(&self) -> TokenStream; fn gen_fork(&self) -> TokenStream; fn gen_map(&self) -> TokenStream; fn gen_valid(&self) -> TokenStream; fn gen_from_inner(&self) -> TokenStream; fn gen_into_record(&self) -> TokenStream; fn gen_load_record(&self) -> TokenStream; fn gen_clone(&self) -> TokenStream; fn record_codegen(self) -> Self::RecordCodegen; fn gen_display(&self) -> TokenStream; fn module_generics(&self) -> &ModuleGenerics; } pub(crate) fn generate_module_standard( ast: &syn::DeriveInput, codegen: Codegen, ) -> TokenStream { let name = &ast.ident; let generics = GenericsParser::from_ast(&ast.generics, codegen.module_generics()); let display_fn = display::display_fn(ast); let attributes_fn = codegen.gen_display(); let num_params_fn = codegen.gen_num_params(); let visit = codegen.gen_visit(); let map_mut = codegen.gen_map(); let collect_devices = codegen.gen_collect_devices(); let to_device = codegen.gen_to_device(); let fork = codegen.gen_fork(); let valid_fn = codegen.gen_valid(); let from_inner_fn = codegen.gen_from_inner(); let into_record_fn = codegen.gen_into_record(); let load_record_fn = codegen.gen_load_record(); let clone_fn = codegen.gen_clone(); let record = codegen.record_codegen(); let record_name = Ident::new(format!("{name}Record").as_str(), name.span()); let (record_type, record_generics) = record.gen_record_type(&record_name, &generics.module); let (generics_module, generics_ty_module, generics_where_module) = generics.module.split_for_impl(); let (generics_module_autodiff, generics_ty_module_autodiff, generics_where_module_autodiff) = generics.module_autodiff.split_for_impl(); let (generics_module_has_autodiff, _generics_ty, generics_where_module_has_autodiff) = generics.module_has_autodiff.split_for_impl(); let (_, generics_ty_record, _) = record_generics.split_for_impl(); let generics_ty_inner_module = generics.inner_module_ty; let generics_ty_train_module = generics.train_module_ty; let generics_ty_train_inner_module = generics.train_inner_ty; let mut codegen = quote! { impl #generics_module burn::module::Module for #name #generics_ty_module #generics_where_module { type Record = #record_name #generics_ty_record; #load_record_fn #into_record_fn #num_params_fn #visit #map_mut #collect_devices #to_device #fork } impl #generics_module_autodiff burn::module::AutodiffModule for #name #generics_ty_module_autodiff #generics_where_module_autodiff { type InnerModule=#name; #valid_fn #from_inner_fn } impl #generics_module_has_autodiff burn::module::HasAutodiffModule for #name #generics_where_module_has_autodiff { type TrainModule=#name; } impl #generics_module core::fmt::Display for #name #generics_ty_module #generics_where_module { #display_fn } impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module { #attributes_fn fn num_params(&self) -> usize { burn::module::Module::num_params(self) } } impl #generics_module Clone for #name #generics_ty_module #generics_where_module { #clone_fn } #record_type }; if !has_custom_display(&ast.attrs) { codegen.extend(quote! { impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module { } }); } codegen } // TODO: wait that means nothing is persistent... (empty!) // When there is no backend in the generic parameter, the type is considered as a constant. pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); let backend: syn::Generics = parse_quote! { }; let backend_ad: syn::Generics = parse_quote! { }; let mut generics_module = ast.generics.clone(); let mut generics_module_autodiff = ast.generics.clone(); for param in backend.params.into_iter() { generics_module.params.push(param); } for param in backend_ad.params.into_iter() { generics_module_autodiff.params.push(param); } let (generics_module, _, _) = generics_module.split_for_impl(); let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl(); let display_fn = display::display_fn(ast); let attributes_fn = display::attributes_fn(ast); let mut codegen = quote! { impl #generics_module burn::module::Module for #name #generics_ty #generics_where { burn::empty!(module); } impl #generics_module_ad burn::module::AutodiffModule for #name #generics_ty #generics_where { burn::empty!(ad_module, #name #generics_ty); } impl #generics core::fmt::Display for #name #generics_ty #generics_where { #display_fn } impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where { #attributes_fn } }; if !has_custom_display(&ast.attrs) { codegen.extend(quote! { impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where { } }); } codegen } struct GenericsParser { module: Generics, module_autodiff: Generics, module_has_autodiff: Generics, inner_module_ty: TokenStream, train_module_ty: TokenStream, train_inner_ty: TokenStream, } impl GenericsParser { fn from_ast(generics: &Generics, module_generics: &ModuleGenerics) -> Self { let mut module = GenericsHelper::new(generics.clone()); let mut module_autodiff = GenericsHelper::new(generics.clone()); let mut module_has_autodiff = GenericsHelper::new(generics.clone()); let backend_trait = module.fetch_backend_trait(); module_autodiff.add_predicate(parse_quote! { B: burn::tensor::backend::AutodiffBackend }); module_autodiff.add_predicate(parse_quote! { ::InnerBackend: #backend_trait }); module_has_autodiff.add_predicate(parse_quote! { B: burn::tensor::backend::AutodiffBackend }); module_has_autodiff.add_predicate(parse_quote! { ::InnerBackend: #backend_trait }); let mut generics_names_except_backend = quote! {}; let mut train_generics_names_except_backend = quote! {}; let mut train_inner_generics_names_except_backend = quote! {}; module .types() .into_iter() .filter(|ident| ident != "B") .for_each(|ident| { // By default, require module bound let mut requires_module_bound = true; let mut generic_kind = None; if !module_generics.is_empty() { generic_kind = module_generics.get_generic_kind(&ident); let has_module_bound = matches!(generic_kind, Some(GenericKind::Module)); let is_unbounded = matches!(generic_kind, Some(GenericKind::Plain)); requires_module_bound = has_module_bound || is_unbounded; } if requires_module_bound { module.add_predicate( parse_quote! { #ident: burn::module::Module } ); module.add_predicate( parse_quote! { #ident: burn::module::ModuleDisplay } ); module_autodiff.add_predicate( parse_quote! { #ident: burn::module::AutodiffModule } ); module_autodiff.add_predicate( parse_quote! { <#ident as burn::module::AutodiffModule>::InnerModule: burn::module::Module } ); module_autodiff.add_predicate( parse_quote! { <#ident as burn::module::AutodiffModule>::InnerModule: burn::module::ModuleDisplay } ); generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule>::InnerModule, }); module_autodiff.add_predicate( parse_quote! { #ident: burn::module::ModuleDisplay } ); module_has_autodiff.add_predicate( parse_quote! { #ident: burn::module::Module } ); module_has_autodiff.add_predicate( parse_quote! { #ident: burn::module::ModuleDisplay } ); module_has_autodiff.add_predicate( parse_quote! { #ident: burn::module::HasAutodiffModule } ); module_has_autodiff.add_predicate( parse_quote! { #ident::TrainModule: burn::module::ModuleDisplay } ); train_generics_names_except_backend.extend(quote! { #ident, }); train_inner_generics_names_except_backend.extend(quote! { #ident::TrainModule, }); } else { // Add required bounds to impl if let Some(GenericKind::Skip) = generic_kind { module.add_predicate( parse_quote! { #ident: Clone + core::fmt::Debug + Send } ); module_autodiff.add_predicate( parse_quote! { #ident: Clone + core::fmt::Debug + Send } ); module_has_autodiff.add_predicate( parse_quote! { #ident: Clone + core::fmt::Debug + Send } ); } // Pass through generics_names_except_backend.extend(quote! { #ident, }); train_generics_names_except_backend.extend(quote! { #ident, }); train_inner_generics_names_except_backend.extend(quote! { #ident, }); } }); module.consts().into_iter().for_each(|ident| { generics_names_except_backend.extend(quote! { #ident, }); train_generics_names_except_backend.extend(quote! { #ident, }); train_inner_generics_names_except_backend.extend(quote! { #ident, }); }); Self { module: module.generics, module_autodiff: module_autodiff.generics, module_has_autodiff: module_has_autodiff.generics, inner_module_ty: generics_names_except_backend, train_module_ty: train_generics_names_except_backend, train_inner_ty: train_inner_generics_names_except_backend, } } } fn has_custom_display(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| { attr.path().is_ident("module") && attr .parse_nested_meta(|meta| { if meta.path.is_ident("custom_display") { Ok(()) } else { Err(meta.error("unsupported attribute")) } }) .is_ok() }) } ================================================ FILE: crates/burn-derive/src/module/codegen_enum.rs ================================================ use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen}; use crate::{ module::generics::{ModuleGenerics, parse_module_generics}, shared::enum_variant::{EnumVariant, parse_variants}, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::Visibility; pub(crate) struct EnumModuleCodegen { pub name: Ident, pub variants: Vec, pub vis: Visibility, pub generics: ModuleGenerics, } impl ModuleCodegen for EnumModuleCodegen { type RecordCodegen = EnumModuleRecordCodegen; fn gen_num_params(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|_| { quote! { burn::module::Module::::num_params(module) } }); quote! { fn num_params(&self) -> usize { #match_body } } } fn gen_visit(&self) -> TokenStream { let enum_name = self.name.to_string(); let container_type = format!("Enum:{}", enum_name); let match_body = self.gen_variants_match_fn(|variant_name| { let variant_str = variant_name.to_string(); quote! { { visitor.enter_module(#variant_str, #container_type); burn::module::Module::visit(module, visitor); visitor.exit_module(#variant_str, #container_type); } } }); quote! { fn visit>(&self, visitor: &mut Visitor) { #match_body } } } fn gen_collect_devices(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|_| { quote! { burn::module::Module::::collect_devices(module, devices) } }); quote! { fn collect_devices( &self, devices: burn::module::Devices ) -> burn::module::Devices { #match_body } } } fn gen_to_device(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { Self::#variant(burn::module::Module::::to_device(module, device)) } }); quote! { fn to_device(self, device: &B::Device) -> Self { #match_body } } } fn gen_fork(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { Self::#variant(burn::module::Module::::fork(module, device)) } }); quote! { fn fork(self, device: &B::Device) -> Self { #match_body } } } fn gen_map(&self) -> TokenStream { let enum_name = self.name.to_string(); let container_type = format!("Enum:{}", enum_name); let match_body = self.gen_variants_match_fn(|variant| { let variant_str = variant.to_string(); quote! { { mapper.enter_module(#variant_str, #container_type); let result = burn::module::Module::::map(module, mapper); mapper.exit_module(#variant_str, #container_type); Self::#variant(result) } } }); quote! { fn map>(self, mapper: &mut Mapper) -> Self { #match_body } } } fn gen_valid(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { Self::InnerModule::#variant(burn::module::AutodiffModule::::valid(module)) } }); quote! { fn valid(&self) -> Self::InnerModule { #match_body } } } fn gen_from_inner(&self) -> TokenStream { let match_body = self.gen_variants_match_fn_param("module", "Self::InnerModule::", |variant| { quote! { Self::#variant(burn::module::AutodiffModule::::from_inner(module)) } }); quote! { fn from_inner(module: Self::InnerModule) -> Self { #match_body } } } fn gen_into_record(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { Self::Record::#variant(burn::module::Module::::into_record(module)) } }); quote! { fn into_record(self) -> Self::Record { #match_body } } } fn gen_load_record(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { { let Self::Record::#variant(r) = record else {panic!("Can't parse record from a different variant");}; Self::#variant(burn::module::Module::::load_record(module, r)) } } }); quote! { fn load_record(self, record: Self::Record) -> Self { #match_body } } } fn gen_clone(&self) -> TokenStream { let match_body = self.gen_variants_match_fn(|variant| { quote! { Self::#variant(module.clone()) } }); quote! { fn clone(&self) -> Self { #match_body } } } fn record_codegen(self) -> Self::RecordCodegen { EnumModuleRecordCodegen::new(self.variants, self.vis) } fn module_generics(&self) -> &ModuleGenerics { &self.generics } fn gen_display(&self) -> TokenStream { // Only tuple enum variants with exactly one field are supported let variant_prints = self.variants.iter().map(|variant| { let variant_name = &variant.ident; let field_names = (0..1).map(|i| syn::Ident::new(&format!("_{i}"), proc_macro2::Span::call_site())); let field_prints = field_names.clone().map(|field_name| { quote! { .add(stringify!(#field_name), #field_name) } }); quote! { Self::#variant_name(#(#field_names),*) => { content.set_top_level_type(&stringify!(#variant_name)) #(#field_prints)* .optional() } } }); quote! { fn content(&self, mut content: burn::module::Content) -> Option { match self { #(#variant_prints)* } } } } } impl EnumModuleCodegen { pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result { Ok(Self { name: ast.ident.clone(), variants: parse_variants(ast)?, vis: ast.vis.clone(), generics: parse_module_generics(&ast.generics), }) } /// Generate the enum variants' match arms with the provided function fn gen_variants_match_fn(&self, func: F) -> TokenStream where F: Fn(Ident) -> TokenStream, { self.gen_variants_match_fn_param("self", "Self::", func) } /// Generate a match expression over the given argument (e.g., `self`) /// and using the provided prefix for variants (e.g., `Self::` or `Self::InnerModule::`) fn gen_variants_match_fn_param(&self, arg: &str, prefix: &str, func: F) -> TokenStream where F: Fn(Ident) -> TokenStream, { let match_arms = self.variants.iter().map(|variant| { let name = &variant.ident; let full_variant = syn::parse_str::(&format!("{prefix}{name}")).unwrap(); let arm_pattern = quote! { #full_variant(module) }; let arm_code = func(name.clone()); quote! { #arm_pattern => #arm_code, } }); let arg = Ident::new(arg, Span::call_site()); quote! { match #arg { #(#match_arms)* } } } } ================================================ FILE: crates/burn-derive/src/module/codegen_struct.rs ================================================ use std::collections::HashSet; use crate::module::generics::{ GenericKind, ModuleGenerics, parse_module_generics, parse_ty_generics, }; use super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen}; use proc_macro2::{Ident, TokenStream}; use quote::{ToTokens, quote}; use syn::{Field, Visibility}; pub(crate) struct StructModuleCodegen { pub name: Ident, pub fields: Vec, pub vis: Visibility, pub generics: ModuleGenerics, } impl ModuleCodegen for StructModuleCodegen { type RecordCodegen = StructModuleRecordCodegen; fn gen_num_params(&self) -> TokenStream { let body = self.gen_fields_fn(|name, field_type| { if field_type.is_parameter_module() || field_type.maybe_generic_module() { quote! { num_params += burn::module::Module::::num_params(&self.#name); } } else { quote! {} // other fields have 0 params } }); quote! { fn num_params(&self) -> usize { let mut num_params = 0; #body num_params } } } fn gen_visit(&self) -> TokenStream { let struct_name = self.name.to_string(); let container_type = format!("Struct:{}", struct_name); let body = self.gen_fields_fn(|name, field_type| { if field_type.is_parameter_module() || field_type.maybe_generic_module() { let name_str = name.to_string(); quote! { visitor.enter_module(#name_str, #container_type); burn::module::Module::visit(&self.#name, visitor); visitor.exit_module(#name_str, #container_type); } } else { quote! {} } }); quote! { fn visit>(&self, visitor: &mut Visitor) { #body } } } fn gen_collect_devices(&self) -> TokenStream { let body = self.gen_fields_fn(|name, field_type| { if field_type.is_module || field_type.maybe_generic_module() { quote! { let devices = burn::module::Module::::collect_devices(&self.#name, devices); } } else { quote! {} } }); quote! { fn collect_devices( &self, devices: burn::module::Devices ) -> burn::module::Devices { #body devices } } } fn gen_to_device(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name, field_type| { if field_type.is_module || field_type.maybe_generic_module() { quote! { let #name = burn::module::Module::::to_device(self.#name, device); } } else { quote! { let #name = self.#name; } } }); quote! { fn to_device(self, device: &B::Device) -> Self { #body Self { #(#names),* } } } } fn gen_fork(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name, field_type| { if field_type.is_module || field_type.maybe_generic_module() { quote! { let #name = burn::module::Module::::fork(self.#name, device); } } else { quote! { let #name = self.#name; } } }); quote! { fn fork(self, device: &B::Device) -> Self { #body Self { #(#names),* } } } } fn gen_map(&self) -> TokenStream { let struct_name = self.name.to_string(); let container_type = format!("Struct:{}", struct_name); let (names, body) = self.gen_fields_fn_names(|name, field_type| { if field_type.is_parameter_module() || field_type.maybe_generic_module() { let name_str = name.to_string(); quote! { mapper.enter_module(#name_str, #container_type); let #name = burn::module::Module::::map(self.#name, mapper); mapper.exit_module(#name_str, #container_type); } } else { quote! { let #name = self.#name; } } }); quote! { fn map>(self, mapper: &mut Mapper) -> Self { #body Self { #(#names),* } } } } fn gen_valid(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name, field_type| { if field_type.is_module || field_type.maybe_generic_module() { quote! { let #name = burn::module::AutodiffModule::::valid(&self.#name); } } else { quote! { let #name = self.#name.clone(); } } }); quote! { fn valid(&self) -> Self::InnerModule { #body Self::InnerModule { #(#names),* } } } } fn gen_from_inner(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name, field_type| { if field_type.is_module || field_type.maybe_generic_module() { quote! { let #name = burn::module::AutodiffModule::::from_inner(#name); } } else { quote! { let #name = #name; } } }); let destructure = quote! { let Self::InnerModule { #(#names),* } = module; }; quote! { fn from_inner(module: Self::InnerModule) -> Self { #destructure #body Self { #(#names),* } } } } fn gen_into_record(&self) -> TokenStream { let body = self.gen_fields_fn(|name, field_type| { if field_type.is_persistent_module() || field_type.maybe_generic_module() { quote! { #name: burn::module::Module::::into_record(self.#name), } } else { match field_type.attr { // Default (None) gets skipped None | Some(ModuleFieldAttribute::Skip) => { quote! { #name: burn::module::EmptyRecord::new(), } } } } }); quote! { fn into_record(self) -> Self::Record { Self::Record { #body } } } } fn gen_load_record(&self) -> TokenStream { let body = self.gen_fields_fn(|name, field_type| { if field_type.is_persistent_module() || field_type.maybe_generic_module() { quote! { #name: burn::module::Module::::load_record(self.#name, record.#name), } } else { match field_type.attr { // Default (None) gets skipped None | Some(ModuleFieldAttribute::Skip) => { quote! { #name: self.#name, } } } } }); quote! { fn load_record(self, record: Self::Record) -> Self { Self { #body } } } } fn gen_clone(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name, _field_type| { quote! { let #name = self.#name.clone(); } }); quote! { fn clone(&self) -> Self { #body Self { #(#names),* } } } } fn record_codegen(self) -> Self::RecordCodegen { StructModuleRecordCodegen::new(self.fields, self.vis) } fn module_generics(&self) -> &ModuleGenerics { &self.generics } fn gen_display(&self) -> TokenStream { let struct_name = self.name.to_string(); let field_prints = self.fields.iter().map(|field| { let field_name = field.ident(); if field.field_type.is_module || field.field_type.maybe_generic_module() { // Standard module type, use underlying `ModuleDisplay` impl quote! { .add(stringify!(#field_name), &self.#field_name) } } else { // Not a module, use the debug implementation quote! { .add_debug_attribute(stringify!(#field_name), &self.#field_name) } } }); quote! { fn content(&self, mut content: burn::module::Content) -> Option { content .set_top_level_type(&stringify!(#struct_name)) #(#field_prints)* .optional() } } } } impl StructModuleCodegen { pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result { let mut generics = parse_module_generics(&ast.generics); Ok(Self { name: ast.ident.clone(), fields: parse_module_fields(ast, &mut generics)?, vis: ast.vis.clone(), generics, }) } fn gen_fields_fn_names(&self, func: F) -> (Vec, TokenStream) where F: Fn(Ident, &ModuleFieldType) -> TokenStream, { let mut body = quote! {}; let mut names = Vec::new(); for field in self.fields.iter() { let name = field.ident(); names.push(name.clone()); body.extend(func(name, &field.field_type)); } (names, body) } fn gen_fields_fn(&self, func: F) -> TokenStream where F: Fn(Ident, &ModuleFieldType) -> TokenStream, { let mut body = quote! {}; for field in self.fields.iter() { body.extend(func(field.ident(), &field.field_type)); } body } } #[derive(new)] pub struct ModuleField { pub field: Field, pub field_type: ModuleFieldType, } impl ModuleField { pub fn ident(&self) -> Ident { self.field.ident.clone().unwrap() } } #[derive(Debug)] pub enum ModuleFieldAttribute { Skip, } #[derive(Default, Debug)] pub struct ModuleFieldType { pub is_module: bool, pub attr: Option, pub generic_idents: HashSet, } impl ModuleFieldType { /// Returns true if the field is a module with parameters /// (i.e., a real module that is neither skipped nor constant). pub fn is_parameter_module(&self) -> bool { self.is_module && self.attr.is_none() } /// Returns true for modules that should be persisted, including constants. pub fn is_persistent_module(&self) -> bool { self.is_module && !matches!(self.attr, Some(ModuleFieldAttribute::Skip)) } /// Returns true for generic fields that are assumed to be modules. pub fn maybe_generic_module(&self) -> bool { // We assumed it might be a module generic if the field is not marked // by any attributes (skip or constant) !self.generic_idents.is_empty() && self.attr.is_none() } } pub(crate) fn parse_module_fields( ast: &syn::DeriveInput, generics: &mut ModuleGenerics, ) -> syn::Result> { let mut fields = Vec::new(); match &ast.data { syn::Data::Struct(struct_data) => { for field in struct_data.fields.iter() { let field_type = parse_module_field_type(field, generics)?; fields.push(ModuleField::new(field.clone(), field_type)); } } syn::Data::Enum(_) => panic!("Only struct can be derived"), syn::Data::Union(_) => panic!("Only struct can be derived"), }; Ok(fields) } pub(crate) fn parse_module_field_type( field: &Field, generics: &mut ModuleGenerics, ) -> syn::Result { let mut field_type = ModuleFieldType::default(); // Check for generics let mut has_backend = false; let mut has_module_bound = false; let field_generics = parse_ty_generics(&field.ty, generics) .into_iter() .filter_map(|ident| { if ident == "B" { has_backend = true; None } else { has_module_bound = generics.is_bounded_module(&ident); Some(ident) } }) .collect::>(); // Infer if a field is a module let is_primitive = is_primitive_type(&field.ty); let is_param = is_param_type(&field.ty); let is_tensor = is_tensor_type(&field.ty); let is_module = !is_primitive && (has_module_bound || is_param || is_tensor || has_backend); for attr in &field.attrs { if attr.path().is_ident("module") { attr.parse_nested_meta(|meta| { if meta.path.is_ident("skip") { // Mark field attribute and generic field_type.attr = Some(ModuleFieldAttribute::Skip); for ty in &field_generics { generics.update(ty, GenericKind::Skip); } Ok(()) } else { let path = meta.path.to_token_stream().to_string(); Err(meta.error(format!("Unsupported module attribute: {}", path))) }?; if is_param && field_type.attr.is_some() { Err(meta.error("Fields of type 'Param' should not be marked as 'skip'. Use a 'Tensor' instead.")) } else { Ok(()) } })?; } } field_type.is_module = is_module; field_type.generic_idents = field_generics; Ok(field_type) } fn type_matches_ident(ty: &syn::Type, idents: &[&str]) -> bool { if let syn::Type::Path(type_path) = ty { // Look at the last segment of the path (e.g., 'Param' in 'burn::module::Param') if let Some(segment) = type_path.path.segments.last() { return idents.contains(&segment.ident.to_string().as_str()); } } false } fn is_primitive_type(ty: &syn::Type) -> bool { type_matches_ident( ty, &[ "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize", "f32", "f64", "String", ], ) } fn is_tensor_type(ty: &syn::Type) -> bool { type_matches_ident(ty, &["Tensor"]) } fn is_param_type(ty: &syn::Type) -> bool { type_matches_ident(ty, &["Param"]) } ================================================ FILE: crates/burn-derive/src/module/display.rs ================================================ use quote::quote; use crate::module::{codegen_struct::parse_module_field_type, generics::parse_module_generics}; // Only used for "const" modules pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { let mut generics = parse_module_generics(&ast.generics); match &ast.data { syn::Data::Struct(data_struct) => { let fields = match &data_struct.fields { syn::Fields::Named(named_fields) => named_fields.named.iter().collect::>(), syn::Fields::Unit => Vec::new(), _ => panic!("attributes_fn only supports structs with named or unit fields"), }; let field_prints = fields.iter().map(|field| { let field_name = &field.ident; let field_type = parse_module_field_type(field, &mut generics).unwrap(); if field_type.is_module || field_type.maybe_generic_module() { // Standard module type, use underlying `ModuleDisplay` impl quote! { .add(stringify!(#field_name), &self.#field_name) } } else { // Not a module, use the debug implementation quote! { .add_debug_attribute(stringify!(#field_name), &self.#field_name) } } }); let struct_name = &ast.ident; quote! { fn content(&self, mut content: burn::module::Content) -> Option { content .set_top_level_type(&stringify!(#struct_name)) #(#field_prints)* .optional() } } } syn::Data::Enum(data_enum) => { let variant_prints = data_enum.variants.iter().map(|variant| { let variant_name = &variant.ident; match &variant.fields { syn::Fields::Unit => { quote! { Self::#variant_name => { content.add_formatted(&stringify!(#variant_name).to_string()) .optional() } } } syn::Fields::Named(named_fields) => { let field_prints = named_fields.named.iter().map(|field| { let field_name = &field.ident; quote! { .add(stringify!(#field_name), &self.#field_name) } }); let field_names = named_fields.named.iter().map(|field| { let field_name = &field.ident; quote! { #field_name } }); quote! { Self::#variant_name { #(#field_names),* } => { content.set_top_level_type(&stringify!(#variant_name)) #(#field_prints)* .optional() } } } syn::Fields::Unnamed(unnamed_fields) => { let field_names = (0..unnamed_fields.unnamed.len()).map(|i| { syn::Ident::new(&format!("_{i}"), proc_macro2::Span::call_site()) }); let field_prints = field_names.clone().map(|field_name| { quote! { .add(stringify!(#field_name), #field_name) } }); quote! { Self::#variant_name(#(#field_names),*) => { content.set_top_level_type(&stringify!(#variant_name)) #(#field_prints)* .optional() } } } } }); quote! { fn content(&self, mut content: burn::module::Content) -> Option { match self { #(#variant_prints)* } } } } _ => panic!("attributes_fn only supports structs and enums"), } } pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream { quote! { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let formatted = burn::module::ModuleDisplay::format(self, Default::default()); write!(f, "{}", formatted) } } } ================================================ FILE: crates/burn-derive/src/module/generics.rs ================================================ use std::collections::{HashMap, HashSet}; use proc_macro2::Ident; use syn::{GenericParam, Generics, Type, TypeParamBound, WherePredicate, visit::Visit}; #[derive(Debug)] pub enum GenericKind { /// A generic with `Module` bound. Module, /// A generic used in a field marked by `#[module(skip)]`. Skip, /// A plain generic that does not fit any of the above conditions. Plain, } #[derive(Debug)] pub struct ModuleGenerics { kinds: HashMap, } impl ModuleGenerics { pub fn is_empty(&self) -> bool { self.kinds.is_empty() } pub fn get_generic_kind(&self, ident: &Ident) -> Option<&GenericKind> { self.kinds.get(ident) } pub fn is_bounded_module(&self, ident: &Ident) -> bool { self.kinds .get(ident) .map(|kind| matches!(kind, GenericKind::Module)) .unwrap_or(false) } pub fn update(&mut self, ident: &Ident, kind: GenericKind) { self.kinds.insert(ident.clone(), kind); } pub fn contains(&self, ident: &Ident) -> bool { self.kinds.contains_key(ident) } } pub fn parse_module_generics(generics: &Generics) -> ModuleGenerics { let mut kinds = HashMap::new(); // Check inline bounds e.g. `M: Module` for param in &generics.params { if let GenericParam::Type(type_param) = param { let ident = &type_param.ident; if ident != "B" { if has_module_bound(&type_param.bounds) { kinds.insert(ident.clone(), GenericKind::Module); } else { kinds.insert(ident.clone(), GenericKind::Plain); } } } } // Check `where` clauses if let Some(where_clause) = &generics.where_clause { for predicate in &where_clause.predicates { if let WherePredicate::Type(pt) = predicate { // We only care if the bounded type is a simple identifier (like 'M') if let Type::Path(p) = &pt.bounded_ty && let Some(ident) = p.path.get_ident() && ident != "B" { if has_module_bound(&pt.bounds) { kinds.insert(ident.clone(), GenericKind::Module); } else { kinds.insert(ident.clone(), GenericKind::Plain); } } } } } ModuleGenerics { kinds } } // TODO: remove special cases for `ident == "B"`, this could be used to check for `Backend` bound. /// Helper to check if a list of bounds contains "Module". fn has_module_bound( bounds: &syn::punctuated::Punctuated, ) -> bool { has_bound(bounds, "Module") } /// Helper to check if a list of bounds contains the specified bound. fn has_bound( bounds: &syn::punctuated::Punctuated, ident: &str, ) -> bool { bounds.iter().any(|bound| { if let TypeParamBound::Trait(trait_bound) = bound && let Some(segment) = trait_bound.path.segments.last() { return segment.ident == ident; } false }) } pub fn parse_ty_generics(ty: &Type, declared: &ModuleGenerics) -> HashSet { struct Collector<'a> { generics: HashSet, declared: &'a ModuleGenerics, } impl<'ast, 'a> Visit<'ast> for Collector<'a> { fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) { if type_path.qself.is_none() && let Some(ident) = type_path.path.get_ident() && (self.declared.contains(ident) || ident == "B") { self.generics.insert(ident.clone()); } syn::visit::visit_type_path(self, type_path); } } let mut collector = Collector { generics: HashSet::new(), declared, }; collector.visit_type(ty); collector.generics } ================================================ FILE: crates/burn-derive/src/module/mod.rs ================================================ pub(crate) mod codegen; pub(crate) mod codegen_enum; pub(crate) mod codegen_struct; pub(crate) mod display; pub(crate) mod generics; pub(crate) mod record; pub(crate) mod record_enum; pub(crate) mod record_struct; mod base; pub(crate) use base::*; ================================================ FILE: crates/burn-derive/src/module/record.rs ================================================ use proc_macro2::{Ident, TokenStream}; use syn::Generics; /// Basic trait to generate a record type based on the Module struct. pub(crate) trait ModuleRecordCodegen { /// Generate the record type (i.e a struct) fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics); } ================================================ FILE: crates/burn-derive/src/module/record_enum.rs ================================================ use crate::shared::enum_variant::EnumVariant; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Generics, Visibility}; use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct EnumModuleRecordCodegen { variants: Vec, vis: Visibility, } impl ModuleRecordCodegen for EnumModuleRecordCodegen { fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics) { let mut variants = quote! {}; let vis = &self.vis; // Capture the Record enum variant types for variant in self.variants.iter() { let ty = &variant.ty; let name = &variant.ident; variants.extend(quote! { /// The module record associative type. #name(<#ty as burn::module::Module>::Record), }); } let (impl_generics, _generics_ty, generics_where) = generics.split_for_impl(); ( quote! { /// The record type for the module. #[derive(burn::record::Record)] #vis enum #record_name #impl_generics #generics_where { #variants } }, generics.clone(), ) } } ================================================ FILE: crates/burn-derive/src/module/record_struct.rs ================================================ use std::collections::HashSet; use crate::module::codegen_struct::{ModuleField, ModuleFieldAttribute}; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Generics, Visibility}; use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct StructModuleRecordCodegen { fields: Vec, vis: Visibility, } impl ModuleRecordCodegen for StructModuleRecordCodegen { fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics) { let mut fields = quote! {}; let vis = &self.vis; let mut used_generics = HashSet::new(); for field in self.fields.iter() { let ty = &field.field.ty; let name = &field.field.ident; if field.field_type.is_persistent_module() || field.field_type.maybe_generic_module() { fields.extend(quote! { /// The module record associative type. #vis #name: <#ty as burn::module::Module>::Record, }); used_generics.extend(&field.field_type.generic_idents); } else { match field.field_type.attr { // Default (None) gets skipped None | Some(ModuleFieldAttribute::Skip) => { fields.extend(quote! { #[allow(missing_docs)] #vis #name: burn::module::EmptyRecord, }); // Do not capture generics from this field since it produces an empty record } } } } let mut filtered_generics = generics.clone(); filtered_generics.params = generics .params .iter() .filter(|param| match param { syn::GenericParam::Type(ty) if ty.ident == "B" => true, syn::GenericParam::Type(ty) => used_generics.contains(&ty.ident), _ => true, }) .cloned() .collect(); if let Some(where_clause) = &mut filtered_generics.where_clause { where_clause.predicates = where_clause .predicates .iter() .filter(|pred| { match pred { syn::WherePredicate::Type(ty) => { // Check if the bounded type is one of our remaining generics if let syn::Type::Path(p) = &ty.bounded_ty && let Some(ident) = p.path.get_ident() { return ident == "B" || used_generics.contains(ident); } true } _ => true, } }) .cloned() .collect(); // Remove the where clause entirely if where_clause.predicates.is_empty() { filtered_generics.where_clause = None; } } let (impl_generics, _generics_ty, generics_where) = filtered_generics.split_for_impl(); ( quote! { /// The record type for the module. #[derive(burn::record::Record)] #vis struct #record_name #impl_generics #generics_where { #fields } }, filtered_generics, ) } } ================================================ FILE: crates/burn-derive/src/record/base.rs ================================================ use super::{ codegen::generate_record, item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen}, }; pub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream { match &ast.data { syn::Data::Struct(_) => generate_record::(ast), syn::Data::Enum(_) => generate_record::(ast), syn::Data::Union(_) => panic!("Union modules aren't supported yet."), } .into() } ================================================ FILE: crates/burn-derive/src/record/codegen.rs ================================================ use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Generics, parse_quote}; use crate::record::item::codegen::RecordItemCodegen; pub(crate) fn generate_record(ast: &syn::DeriveInput) -> TokenStream { let record_gen: syn::Result> = RecordCodegen::from_ast(ast); match record_gen { Ok(record_gen) => { let item_type = record_gen.gen_record_type(); let record_impl = record_gen.gen_impl_record(); quote! { #item_type #record_impl } } Err(err) => err.to_compile_error(), } } pub(crate) struct RecordCodegen { /// Record type info. ty: RecordType, /// Record item code gen. codegen: G, } impl RecordCodegen { /// Generate the record type with the correct generics. pub(crate) fn gen_record_type(&self) -> TokenStream { // Add precision settings type bound let param: syn::Generics = parse_quote! { }; let mut generics = self.ty.generics.clone(); for param in param.params.into_iter() { generics.params.push(param); } // Generate the record item definition self.codegen .gen_item_type(&self.ty.item, &generics, self.ty.has_backend) } /// Generate the implementation for the Record trait. pub(crate) fn gen_impl_record(&self) -> TokenStream { // Capture the record type's generics and bounds in where clauses let item_generics = self.record_item_generics(); let (_, ty_generics_item, _) = item_generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl(); let impl_generics = if let Some(impl_generic) = self.impl_generics() { impl_generic } else { quote! { #impl_generics } }; let name_item = &self.ty.item; let into_item_fn = self.codegen.gen_into_item(name_item); let from_item_fn = self.codegen.gen_from_item(); // Return the generated stream of token trees (i.e., code to be generated) let name = &self.ty.name; quote! { impl #impl_generics burn::record::Record for #name #ty_generics #where_clause { type Item = #name_item #ty_generics_item; #into_item_fn #from_item_fn } } } /// Add backend generic type to the implementation block. fn impl_generics(&self) -> Option { if self.ty.has_backend { return None; } let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; let mut generics = self.ty.generics.clone(); generics.params.push(syn::GenericParam::Type(param)); let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl(); Some(quote! {#impl_generics}) } /// Get the generics attached to the record item type. fn record_item_generics(&self) -> Generics { let param: syn::Generics = parse_quote! { }; let mut generics = self.ty.generics.clone(); for param in param.params.into_iter() { generics.params.push(param); } if !self.ty.has_backend { let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; generics.params.push(syn::GenericParam::Type(param)); } generics } pub(crate) fn from_ast(ast: &syn::DeriveInput) -> syn::Result { Ok(Self { ty: RecordType::from_ast(ast), codegen: G::from_ast(ast)?, }) } } /// Information about a record type. struct RecordType { /// Record type name. name: Ident, /// Record item type name. item: Ident, /// Lifetimes and type parameters attached to the record type declaration. generics: Generics, /// Whether or not the record type should specify a backend generic. has_backend: bool, } impl RecordType { fn from_ast(ast: &syn::DeriveInput) -> Self { let name = ast.ident.clone(); let item = Ident::new(format!("{name}Item").as_str(), name.span()); let has_backend = ast .generics .type_params() .map(|param| param.ident == "B") .reduce(|accum, is_backend| is_backend || accum) .unwrap_or(false); Self { name, item, generics: ast.generics.clone(), has_backend, } } } ================================================ FILE: crates/burn-derive/src/record/item/codegen.rs ================================================ use proc_macro2::{Ident, TokenStream}; use syn::Generics; /// Basic trait to be implemented for record generation. pub(crate) trait RecordItemCodegen { /// Initialize the record item. fn from_ast(ast: &syn::DeriveInput) -> syn::Result where Self: Sized; /// Generate the record item type. fn gen_item_type( &self, item_name: &Ident, generics: &Generics, has_backend: bool, ) -> TokenStream; /// Generate the into_item function. fn gen_into_item(&self, item_name: &Ident) -> TokenStream; /// Generate the from item function. fn gen_from_item(&self) -> TokenStream; } ================================================ FILE: crates/burn-derive/src/record/item/codegen_enum.rs ================================================ use crate::shared::enum_variant::{EnumVariant, parse_variants}; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Generics, Visibility, parse_quote}; use super::codegen::RecordItemCodegen; pub(crate) struct EnumRecordItemCodegen { /// Enum variants. variants: Vec, vis: Visibility, } impl RecordItemCodegen for EnumRecordItemCodegen { fn from_ast(ast: &syn::DeriveInput) -> syn::Result { Ok(Self { variants: parse_variants(ast)?, vis: ast.vis.clone(), }) } fn gen_item_type( &self, item_name: &Ident, generics: &Generics, has_backend: bool, ) -> TokenStream { let mut variants = quote! {}; let mut serde_bounds = quote! {}; let mut clone_bounds = vec![]; let mut clone_match_arms = quote! {}; let vis = &self.vis; // Capture the Record enum variant types and names to transpose them in RecordItem for variant in self.variants.iter() { let ty = &variant.ty; let name = &variant.ident; variants.extend(quote! { /// Variant to be serialized. #name(<#ty as burn::record::Record>::Item), }); // Item types must implement serialization/deserialization serde_bounds.extend(quote! { <#ty as burn::record::Record>::Item: burn::serde::Serialize + burn::serde::de::DeserializeOwned, }); clone_bounds.push(parse_quote! { <#ty as burn::record::Record>::Item: Clone }); clone_match_arms.extend(quote! { Self::#name(inner) => Self::#name(inner.clone()), }); } let serde_bound = serde_bounds.to_string(); // Capture the type's generics and bounds in where clauses let mut generics = generics.clone(); if !has_backend { let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; generics.params.push(syn::GenericParam::Type(param)); } let (generics, type_generics, generics_where) = generics.split_for_impl(); let clone_bounds = generics_where.cloned().map(|mut where_clause| { for predicate in clone_bounds { where_clause.predicates.push(predicate); } where_clause }); let clone_impl = quote! { impl #generics Clone for #item_name #type_generics #clone_bounds { fn clone(&self) -> Self { match self { #clone_match_arms } } } }; // Return the generated stream of token trees (i.e., code to be generated) quote! { /// The record item type for the module. #[derive(burn::serde::Serialize, burn::serde::Deserialize)] #[serde(crate = "burn::serde")] #[serde(bound = #serde_bound)] #vis enum #item_name #generics #generics_where { #variants } #clone_impl } } fn gen_into_item(&self, _item_name: &Ident) -> TokenStream { let mut into_item_match_arms = quote! {}; for variant in self.variants.iter() { let name = &variant.ident; into_item_match_arms.extend(quote! { Self::#name(record) => Self::Item::#name(burn::record::Record::::into_item::(record)), }); } quote! { fn into_item(self) -> Self::Item { match self { #into_item_match_arms } } } } fn gen_from_item(&self) -> TokenStream { let mut from_item_match_arms = quote! {}; for variant in self.variants.iter() { let name = &variant.ident; from_item_match_arms.extend(quote! { Self::Item::#name(item) => Self::#name(burn::record::Record::::from_item::(item, device)), }); } quote! { fn from_item(item: Self::Item, device: &B::Device) -> Self { match item { #from_item_match_arms } } } } } ================================================ FILE: crates/burn-derive/src/record/item/codegen_struct.rs ================================================ use crate::shared::field::{FieldTypeAnalyzer, parse_fields}; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::{Generics, Visibility, parse_quote}; use super::codegen::RecordItemCodegen; pub(crate) struct StructRecordItemCodegen { fields: Vec, vis: Visibility, } impl RecordItemCodegen for StructRecordItemCodegen { fn from_ast(ast: &syn::DeriveInput) -> syn::Result { Ok(Self { fields: parse_fields(ast) .into_iter() .map(FieldTypeAnalyzer::new) .collect(), vis: ast.vis.clone(), }) } fn gen_item_type( &self, item_name: &Ident, generics: &Generics, has_backend: bool, ) -> TokenStream { let mut fields = quote! {}; let mut serde_bounds = quote! {}; let mut clone_bounds = vec![]; let mut clone_delegate = quote! {}; let vis = &self.vis; for field in self.fields.iter() { let ty = &field.field.ty; let name = &field.field.ident; fields.extend(quote! { /// Field to be serialized. pub #name: <#ty as burn::record::Record>::Item, }); serde_bounds.extend(quote! { <#ty as burn::record::Record>::Item: burn::serde::Serialize + burn::serde::de::DeserializeOwned, }); clone_bounds.push(parse_quote! { <#ty as burn::record::Record>::Item: Clone }); clone_delegate.extend(quote! { #name: self.#name.clone(), }); } let serde_bound = serde_bounds.to_string(); let mut generics = generics.clone(); if !has_backend { let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend }; generics.params.push(syn::GenericParam::Type(param)); } let (generics, type_generics, generics_where) = generics.split_for_impl(); let clone_bounds = generics_where.cloned().map(|mut where_clause| { for predicate in clone_bounds { where_clause.predicates.push(predicate); } where_clause }); let clone_impl = quote! { impl #generics Clone for #item_name #type_generics #clone_bounds { fn clone(&self) -> Self { Self { #clone_delegate } } } }; quote! { /// The record item type for the module. #[derive(burn::serde::Serialize, burn::serde::Deserialize)] #[serde(crate = "burn::serde")] #[serde(bound = #serde_bound)] #vis struct #item_name #generics #generics_where { #fields } #clone_impl } } fn gen_into_item(&self, item_name: &Ident) -> TokenStream { let mut body_into_item = quote! {}; for field in self.fields.iter() { let name = &field.field.ident; body_into_item.extend(quote! { #name: burn::record::Record::::into_item::(self.#name), }); } quote! { fn into_item(self) -> Self::Item { #item_name { #body_into_item } } } } fn gen_from_item(&self) -> TokenStream { let mut body_from_item = quote! {}; for field in self.fields.iter() { let name = &field.field.ident; body_from_item.extend(quote! { #name: burn::record::Record::::from_item::(item.#name, device), }); } quote! { fn from_item(item: Self::Item, device: &B::Device) -> Self { Self { #body_from_item } } } } } ================================================ FILE: crates/burn-derive/src/record/item/mod.rs ================================================ pub(crate) mod codegen; pub(crate) mod codegen_enum; pub(crate) mod codegen_struct; ================================================ FILE: crates/burn-derive/src/record/mod.rs ================================================ pub(crate) mod codegen; pub(crate) mod item; mod base; pub(crate) use base::*; ================================================ FILE: crates/burn-derive/src/shared/attribute.rs ================================================ use syn::{Attribute, Meta}; pub struct AttributeAnalyzer { attr: Attribute, } #[derive(Clone)] pub struct AttributeItem { pub value: syn::Lit, } impl AttributeAnalyzer { pub fn new(attr: Attribute) -> Self { Self { attr } } pub fn item(&self) -> AttributeItem { let value = match &self.attr.meta { Meta::List(val) => val.parse_args::().unwrap(), Meta::NameValue(meta) => meta.clone(), Meta::Path(_) => panic!("Path meta unsupported"), }; let lit = match value.value { syn::Expr::Lit(lit) => lit.lit, _ => panic!("Only literal is supported"), }; AttributeItem { value: lit } } pub fn has_name(&self, name: &str) -> bool { Self::path_syn_name(self.attr.path()) == name } fn path_syn_name(path: &syn::Path) -> String { let length = path.segments.len(); let mut name = String::new(); for (i, segment) in path.segments.iter().enumerate() { if i == length - 1 { name += segment.ident.to_string().as_str(); } else { let tmp = segment.ident.to_string() + "::"; name += tmp.as_str(); } } name } } ================================================ FILE: crates/burn-derive/src/shared/enum_variant.rs ================================================ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::{FieldsNamed, Variant}; /// Process a variant of an enum where the output is the result of the given mapper. pub(crate) fn map_enum_variant( variant: &Variant, mapper: Mapper, ) -> (TokenStream, TokenStream) where Mapper: Fn(&Ident) -> TokenStream, { let gen_fields_unnamed = |num: usize| { let mut inputs = Vec::new(); let mut outputs = Vec::new(); for i in 0..num { let arg_name = Ident::new(&format!("arg_{i}"), Span::call_site()); let input = quote! { #arg_name }; let output = mapper(&arg_name); inputs.push(input); outputs.push(output); } (quote! (( #(#inputs),* )), quote! (( #(#outputs),* ))) }; let gen_fields_named = |fields: &FieldsNamed| { let mut inputs = Vec::new(); let mut outputs = Vec::new(); fields.named.iter().for_each(|field| { let ident = field.ident.as_ref().expect("Named field to have a name."); let input = quote! { #ident }; let output = mapper(ident); inputs.push(input); outputs.push(quote! { #ident: #output }); }); (quote! {{ #(#inputs),* }}, quote! {{ #(#outputs),* }}) }; match &variant.fields { syn::Fields::Named(fields) => gen_fields_named(fields), syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()), syn::Fields::Unit => (quote! {}, quote! {}), } } /// An enum variant (simplified). pub(crate) struct EnumVariant { pub ident: syn::Ident, pub ty: syn::Type, } pub(crate) fn parse_variants(ast: &syn::DeriveInput) -> syn::Result> { let enum_data = match &ast.data { syn::Data::Enum(data) => data, _ => { return Err(syn::Error::new_spanned( ast, "Module can only be derived for enums.", )); } }; let mut variants = Vec::new(); for variant in enum_data.variants.iter() { for attr in &variant.attrs { if attr.path().is_ident("module") { Err(syn::Error::new_spanned( variant, "Module attributes are not supported for enum variants.", ))?; } } match &variant.fields { syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { let field = &fields.unnamed[0]; variants.push(EnumVariant { ident: variant.ident.clone(), ty: field.ty.clone(), }); } syn::Fields::Unnamed(_) => { return Err(syn::Error::new_spanned( variant, "Module derive only supports tuple enum variants with exactly one field.", )); } syn::Fields::Named(_) => { return Err(syn::Error::new_spanned( variant, "Module derive does not support struct enum variants.", )); } syn::Fields::Unit => { return Err(syn::Error::new_spanned( variant, "Module derive does not support unit enum variants.", )); } } } Ok(variants) } ================================================ FILE: crates/burn-derive/src/shared/field.rs ================================================ use super::attribute::AttributeAnalyzer; use proc_macro2::Ident; use syn::{Field, Type, TypePath}; #[derive(Clone)] pub struct FieldTypeAnalyzer { pub field: Field, } impl FieldTypeAnalyzer { pub fn new(field: Field) -> Self { FieldTypeAnalyzer { field } } pub fn ident(&self) -> Ident { self.field.ident.clone().unwrap() } pub fn is_of_type(&self, paths: &[&str]) -> bool { match &self.field.ty { syn::Type::Path(path) => { let name = Self::path_name(path); paths.contains(&name.as_str()) } _ => false, } } #[allow(dead_code)] pub fn first_generic_field(&self) -> TypePath { let err = || panic!("Field {} as no generic", self.field.ident.clone().unwrap()); match &self.field.ty { syn::Type::Path(path) => Self::path_generic_argument(path), _ => err(), } } pub fn path_generic_argument(path: &TypePath) -> TypePath { let segment = path.path.segments.last().unwrap(); let err = || panic!("Path segment {} has no generic", segment.ident.clone(),); match &segment.arguments { syn::PathArguments::None => err(), syn::PathArguments::AngleBracketed(param) => { let first_param = param.args.first().unwrap(); if let syn::GenericArgument::Type(Type::Path(path)) = first_param { path.clone() } else { err() } } syn::PathArguments::Parenthesized(_) => err(), } } fn path_name(path: &TypePath) -> String { let length = path.path.segments.len(); let mut name = String::new(); for (i, segment) in path.path.segments.iter().enumerate() { if i == length - 1 { name += segment.ident.to_string().as_str(); } else { let tmp = segment.ident.to_string() + "::"; name += tmp.as_str(); } } name } /// Returns the docs of the field. pub fn docs(&self) -> impl Iterator { self.field .attrs .iter() .filter(|attr| attr.path().is_ident("doc")) } pub fn attributes(&self) -> impl Iterator { self.field .attrs .clone() .into_iter() .map(AttributeAnalyzer::new) } } pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec { let mut fields = Vec::new(); match &ast.data { syn::Data::Struct(struct_data) => { for field in struct_data.fields.iter() { fields.push(field.clone()); } } syn::Data::Enum(_) => panic!("Only struct can be derived"), syn::Data::Union(_) => panic!("Only struct can be derived"), }; fields } ================================================ FILE: crates/burn-derive/src/shared/generics.rs ================================================ use proc_macro2::Ident; use quote::quote; use syn::{Generics, WhereClause, WherePredicate, parse_quote}; #[derive(new)] pub struct GenericsHelper { pub(crate) generics: Generics, } impl GenericsHelper { pub fn add_predicate(&mut self, predicate: WherePredicate) { let where_clause: WhereClause = match &self.generics.where_clause { Some(val) => parse_quote! { #val #predicate, }, None => parse_quote! { where #predicate, }, }; self.generics.where_clause = Some(where_clause); } pub fn consts(&self) -> Vec { self.generics .const_params() .map(|c| c.ident.clone()) .collect() } pub fn types(&self) -> Vec { self.generics .type_params() .map(|tp| tp.ident.clone()) .collect() } pub fn fetch_backend_trait(&self) -> proc_macro2::TokenStream { static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str = "Modules should be generic over a backend. - The generic argument named `B` should have its first trait bound being a backend trait. - The default backend trait is `burn::tensor::backend::Backend`. - Any backend trait is supported."; for param in self.generics.params.iter() { if let syn::GenericParam::Type(ty) = ¶m && ty.ident == "B" { let bound = ty .bounds .first() .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG); return quote! { #bound }; } } panic!("{BACKEND_TRAIT_COMPILATION_ERROR_MSG}"); } } ================================================ FILE: crates/burn-derive/src/shared/mod.rs ================================================ pub(crate) mod attribute; pub(crate) mod enum_variant; pub(crate) mod field; pub(crate) mod generics; ================================================ FILE: crates/burn-dispatch/Cargo.toml ================================================ [package] authors = [ "laggui ", "nathanielsimard ", ] categories = ["science"] description = "Backend dispatch for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-dispatch" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-dispatch" documentation = "https://docs.rs/burn-dispatch" version.workspace = true [lints] workspace = true [features] default = [ "std", "ndarray", "burn-autodiff?/default", "burn-cpu?/default", "burn-cuda?/default", "burn-ndarray?/default", "burn-rocm?/default", "burn-tch?/default", "burn-wgpu?/default", ] doc = ["default"] std = [ "burn-backend/std", "burn-std/std", "burn-autodiff?/std", "burn-cpu?/std", "burn-cuda?/std", "burn-ndarray?/std", "burn-rocm?/std", "burn-tch?/std", "burn-wgpu?/std", ] tracing = [ "burn-autodiff?/tracing", "burn-cpu?/tracing", "burn-cuda?/tracing", "burn-ndarray?/tracing", "burn-rocm?/tracing", "burn-tch?/tracing", "burn-wgpu?/tracing", ] # Backends cuda = ["burn-cuda"] rocm = ["burn-rocm"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] vulkan = ["wgpu", "burn-wgpu/vulkan"] webgpu = ["wgpu", "burn-wgpu/webgpu"] metal = ["wgpu", "burn-wgpu/metal"] wgpu = ["burn-wgpu"] cpu = ["burn-cpu"] autodiff = ["burn-autodiff"] # Backend features autotune = [ "burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-rocm?/autotune", "burn-cpu?/autotune", ] autotune-checks = [ "burn-wgpu?/autotune-checks", "burn-cuda?/autotune-checks", "burn-rocm?/autotune-checks", "burn-cpu?/autotune-checks", ] fusion = [ "burn-wgpu?/fusion", "burn-cuda?/fusion", "burn-rocm?/fusion", "burn-cpu?/fusion", ] [dependencies] burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } # Backends burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-cpu = { path = "../burn-cpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } # Op macros with `.as_$inner_kind()` paste = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-dispatch/README.md ================================================ # Burn Backend Dispatch A multi-backend dispatch that forwards the tensor operations to the appropriate backend. ================================================ FILE: crates/burn-dispatch/build.rs ================================================ fn main() { println!("cargo::rustc-check-cfg=cfg(wgpu_metal)"); println!("cargo::rustc-check-cfg=cfg(wgpu_vulkan)"); println!("cargo::rustc-check-cfg=cfg(wgpu_webgpu)"); // Detect which single wgpu backend is enabled let metal = cfg!(feature = "metal"); let vulkan = cfg!(feature = "vulkan"); let webgpu = cfg!(feature = "webgpu"); let enabled = [(metal, "metal"), (vulkan, "vulkan"), (webgpu, "webgpu")] .iter() .filter(|x| x.0) .map(|x| x.1) .collect::>(); // WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error. // In workspace builds with multiple features, we emit a warning and disable all WGPU backends. if enabled.len() > 1 { println!( "cargo:warning=Only one WGPU backend can be enabled at once. Detected: [{}]. No WGPU backend will be available in this build. This is expected in workspace builds. For production, enable only one of: metal, vulkan, or webgpu.", enabled.join(", ") ); return; } if metal { println!("cargo:rustc-cfg=wgpu_metal"); } if vulkan { println!("cargo:rustc-cfg=wgpu_vulkan"); } if webgpu { println!("cargo:rustc-cfg=wgpu_webgpu"); } } ================================================ FILE: crates/burn-dispatch/src/backend.rs ================================================ use alloc::format; use alloc::string::String; use burn_backend::Backend; use burn_backend::ExecutionError; use burn_std::DType; #[cfg(feature = "autodiff")] use burn_autodiff::grads::Gradients; #[cfg(feature = "autodiff")] use burn_backend::AutodiffBackend; use crate::DispatchTensorKind; use crate::backends::*; use crate::{DispatchDevice, DispatchTensor}; /// The main execution backend in Burn. /// /// [`Dispatch`] acts as a global backend that can manage multiple underlying /// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.). /// It is responsible for: /// - Dispatching tensor operations to the appropriate backend. /// - Managing cross-backend tensor transfers. /// /// Essentially, [`Dispatch`] is the single entry point for executing tensor operations /// in a backend-agnostic way. It allows Burn to provide a unified, global backend /// for users while still leveraging multiple specialized backends under the hood. /// /// # Example /// /// ```ignore /// use burn::Dispatch; /// use burn::DispatchDevice; /// /// // Select the device to execute operations on /// let device = DispatchDevice::Cuda(Default::default()); /// /// // Create a tensor using the global backend /// let t = Tensor::::zeros([128, 128], &device); /// ``` #[derive(Debug, Default, Clone)] pub struct Dispatch; impl Backend for Dispatch { type Device = DispatchDevice; type FloatTensorPrimitive = DispatchTensor; // TODO: either allow default dtype generic or remove associated types entirely? type FloatElem = f32; type IntTensorPrimitive = DispatchTensor; type IntElem = i32; type BoolTensorPrimitive = DispatchTensor; type BoolElem = u8; type QuantizedTensorPrimitive = DispatchTensor; fn name(device: &Self::Device) -> String { let inner = dispatch_device!(device, |device| B::name(device)); format!("dispatch<{inner}>") } fn seed(device: &Self::Device, seed: u64) { dispatch_device!(device, |device| B::seed(device, seed)) } fn sync(device: &Self::Device) -> Result<(), ExecutionError> { dispatch_device!(device, |device| B::sync(device)) } fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { dispatch_device!(device, |device| B::dtype_usage(device, dtype)) } fn ad_enabled(device: &Self::Device) -> bool { match device { #[cfg(feature = "autodiff")] DispatchDevice::Autodiff(_) => true, _ => false, } } } #[cfg(feature = "autodiff")] impl AutodiffBackend for Dispatch { type InnerBackend = Dispatch; type Gradients = Gradients; fn backward(tensor: DispatchTensor) -> Self::Gradients { let DispatchTensor { kind, .. } = tensor; match kind { #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(tensor) => match *tensor { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => tensor.autodiff().backward(), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(), DispatchTensorKind::Autodiff(_) => { panic!("Autodiff should not wrap an autodiff tensor.") } }, _ => panic!("Requires autodiff tensor."), } } fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option { let DispatchTensor { kind, checkpointing, } = tensor; let grad = match &kind { #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => tensor .as_autodiff() .grad(grads) .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))), DispatchTensorKind::Autodiff(_) => { panic!("Autodiff should not wrap an autodiff tensor.") } }, _ => panic!("Requires autodiff tensor."), }; grad.map(|kind| DispatchTensor { kind, checkpointing: *checkpointing, }) } fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option { let DispatchTensor { kind, checkpointing, } = tensor; let grad = match &kind { #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => tensor .as_autodiff() .grad_remove(grads) .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))), DispatchTensorKind::Autodiff(_) => { panic!("Autodiff should not wrap an autodiff tensor.") } }, _ => panic!("Requires autodiff tensor."), }; grad.map(|kind| DispatchTensor { kind, checkpointing: *checkpointing, }) } fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) { let DispatchTensor { kind, checkpointing, } = tensor; let DispatchTensor { kind: grad, checkpointing: grad_ckp, } = grad; debug_assert_eq!(checkpointing, &grad_ckp); match &kind { #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) { #[cfg(feature = "cpu")] (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(feature = "cuda")] (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(wgpu_metal)] (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(feature = "rocm")] (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(wgpu_vulkan)] (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(wgpu_webgpu)] (DispatchTensorKind::WebGpu(tensor), DispatchTensorKind::WebGpu(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } #[cfg(feature = "ndarray")] (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => { tensor.as_autodiff().grad_replace(grads, grad.float()) } (DispatchTensorKind::Autodiff(_), _) => { panic!("Autodiff should not wrap an autodiff tensor.") } (t, g) => panic!( "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}." ), }, _ => panic!("Requires autodiff tensor."), } } fn inner(tensor: DispatchTensor) -> DispatchTensor { let DispatchTensor { kind, checkpointing, } = tensor; let kind = match kind { #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::WebGpu( crate::BackendTensor::Float(tensor.autodiff().primitive), ), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray( crate::BackendTensor::Float(tensor.autodiff().primitive), ), DispatchTensorKind::Autodiff(_) => { panic!("Autodiff should not wrap an autodiff tensor.") } }, _ => panic!("Requires autodiff tensor."), }; DispatchTensor { kind, checkpointing, } } fn int_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } fn bool_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } fn q_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } fn from_inner(tensor: DispatchTensor) -> DispatchTensor { let DispatchTensor { kind, checkpointing, } = tensor; let kind = match kind { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::Metal(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::WebGpu(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new( DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff( Autodiff::>::from_inner(tensor.float()), )), )), DispatchTensorKind::Autodiff(_) => { panic!("Autodiff should not wrap an autodiff tensor.") } }; DispatchTensor { kind, checkpointing, } } fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor { tensor } } impl DispatchTensorKind { pub(crate) fn device(&self) -> DispatchDevice { match self { #[cfg(feature = "cpu")] DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()), #[cfg(feature = "cuda")] DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()), #[cfg(wgpu_metal)] DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()), #[cfg(feature = "rocm")] DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()), #[cfg(wgpu_vulkan)] DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()), #[cfg(wgpu_webgpu)] DispatchTensorKind::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()), #[cfg(feature = "ndarray")] DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()), #[cfg(feature = "tch")] DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()), #[cfg(feature = "autodiff")] DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()), } } } impl DispatchTensor { pub(crate) fn device(&self) -> DispatchDevice { self.kind.device() } } ================================================ FILE: crates/burn-dispatch/src/device.rs ================================================ use burn_backend::{DeviceId, DeviceOps}; use crate::backends::*; /// Represents a device for the [`Dispatch`](crate::Dispatch). /// /// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to. /// /// # Example /// /// ```ignore /// use burn::DispatchDevice; /// /// #[cfg(feature = "cpu")] /// let cpu_device = DispatchDevice::Cpu(Default::default()); /// /// #[cfg(feature = "cuda")] /// let cuda_device = DispatchDevice::Cuda(Default::default()); /// ``` #[derive(Clone, Eq)] pub enum DispatchDevice { /// The [CPU backend](Cpu) device. #[cfg(feature = "cpu")] Cpu(CpuDevice), /// The [CUDA backend](Cuda) device. #[cfg(feature = "cuda")] Cuda(CudaDevice), /// The [Metal backend](Metal) device (via WGPU runtime). #[cfg(wgpu_metal)] Metal(WgpuDevice), /// The [ROCm backend](Rocm) device. #[cfg(feature = "rocm")] Rocm(RocmDevice), /// The [Vulkan backend](Vulkan) device. #[cfg(wgpu_vulkan)] Vulkan(WgpuDevice), /// The [WebGPU backend](WebGpu) device (via WGPU runtime). #[cfg(wgpu_webgpu)] WebGpu(WgpuDevice), /// The [NdArray backend](NdArray) device (CPU-only). #[cfg(feature = "ndarray")] NdArray(NdArrayDevice), /// The [LibTorch backend](LibTorch) device. #[cfg(feature = "tch")] LibTorch(LibTorchDevice), /// The [autodiff enabled backend](Autodiff) device. #[cfg(feature = "autodiff")] Autodiff(AutodiffDevice), } #[cfg(feature = "autodiff")] // This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices. /// A wrapper that enables automatic differentiation for a [`DispatchDevice`]. /// /// Use [`DispatchDevice::autodiff`] to construct this type. #[derive(Debug, Clone, PartialEq, Eq)] pub struct AutodiffDevice { pub(crate) inner: Box, pub(crate) checkpointing: CheckpointingStrategy, } #[cfg(feature = "autodiff")] impl AutodiffDevice { pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self { Self { inner: Box::new(device), checkpointing, } } } #[cfg(feature = "autodiff")] // Useful for match in dispatch macros impl core::ops::Deref for AutodiffDevice { type Target = DispatchDevice; fn deref(&self) -> &Self::Target { &self.inner } } #[cfg(feature = "autodiff")] #[allow(missing_docs)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] /// Checkpointing strategy for autodiff. #[repr(u8)] pub enum CheckpointingStrategy { Balanced, #[default] None, } #[cfg(feature = "autodiff")] pub(crate) fn validate_checkpointing( lhs: crate::CheckpointingStrategy, rhs: crate::CheckpointingStrategy, ) -> crate::CheckpointingStrategy { assert_eq!( lhs, rhs, "Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy." ); lhs } impl core::fmt::Debug for DispatchDevice { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { #[cfg(feature = "cpu")] Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(), #[cfg(feature = "cuda")] Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(), #[cfg(wgpu_metal)] Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(), #[cfg(feature = "rocm")] Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(), #[cfg(wgpu_vulkan)] Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(), #[cfg(wgpu_webgpu)] Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(), #[cfg(feature = "ndarray")] Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(), #[cfg(feature = "tch")] Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(), #[cfg(feature = "autodiff")] // Format without `AutodiffDevice` wrapper Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(), } } } impl Default for DispatchDevice { #[allow(unreachable_code)] fn default() -> Self { // TODO: which priority? #[cfg(feature = "cpu")] return Self::Cpu(CpuDevice); #[cfg(feature = "cuda")] return Self::Cuda(CudaDevice::default()); #[cfg(wgpu_metal)] return Self::Metal(burn_wgpu::WgpuDevice::default()); #[cfg(feature = "rocm")] return Self::Rocm(RocmDevice::default()); #[cfg(wgpu_vulkan)] return Self::Vulkan(burn_wgpu::WgpuDevice::default()); #[cfg(wgpu_webgpu)] return Self::WebGpu(burn_wgpu::WgpuDevice::default()); #[cfg(feature = "ndarray")] return Self::NdArray(NdArrayDevice::default()); #[cfg(feature = "tch")] return Self::LibTorch(LibTorchDevice::default()); } } impl PartialEq for DispatchDevice { fn eq(&self, other: &Self) -> bool { match (self, other) { // If both are Autodiff, compare the inner devices #[cfg(feature = "autodiff")] (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b, // If one is Autodiff, compare it to the raw device #[cfg(feature = "autodiff")] (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b, #[cfg(feature = "autodiff")] (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(), #[cfg(feature = "cpu")] (Self::Cpu(a), Self::Cpu(b)) => a == b, #[cfg(feature = "cuda")] (Self::Cuda(a), Self::Cuda(b)) => a == b, #[cfg(wgpu_metal)] (Self::Metal(a), Self::Metal(b)) => a == b, #[cfg(feature = "rocm")] (Self::Rocm(a), Self::Rocm(b)) => a == b, #[cfg(wgpu_vulkan)] (Self::Vulkan(a), Self::Vulkan(b)) => a == b, #[cfg(wgpu_webgpu)] (Self::WebGpu(a), Self::WebGpu(b)) => a == b, #[cfg(feature = "ndarray")] (Self::NdArray(a), Self::NdArray(b)) => a == b, #[cfg(feature = "tch")] (Self::LibTorch(a), Self::LibTorch(b)) => a == b, #[allow(unreachable_patterns)] (_, _) => false, } } } /// Base multiplier to avoid type_id clashes between backends. /// Limits the number of device types per backend, but this is a sensible limit. const TYPE_ID_BASE: u16 = 10; impl DispatchDevice { #[cfg(feature = "autodiff")] /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled. pub fn autodiff(device: impl Into) -> DispatchDevice { Self::autodiff_checkpointed(device, CheckpointingStrategy::None) } #[cfg(feature = "autodiff")] /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled. pub fn autodiff_checkpointed( device: impl Into, checkpointing: CheckpointingStrategy, ) -> DispatchDevice { let device = device.into(); DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing)) } /// Returns a unique number per variant to encode into type_id. fn backend_id(&self) -> BackendId { match self { #[cfg(feature = "cpu")] Self::Cpu(_) => BackendId::Cpu, #[cfg(feature = "cuda")] Self::Cuda(_) => BackendId::Cuda, #[cfg(wgpu_metal)] Self::Metal(_) => BackendId::Metal, #[cfg(feature = "rocm")] Self::Rocm(_) => BackendId::Rocm, #[cfg(wgpu_vulkan)] Self::Vulkan(_) => BackendId::Vulkan, #[cfg(wgpu_webgpu)] Self::WebGpu(_) => BackendId::WebGpu, #[cfg(feature = "ndarray")] Self::NdArray(_) => BackendId::NdArray, #[cfg(feature = "tch")] Self::LibTorch(_) => BackendId::LibTorch, #[cfg(feature = "autodiff")] Self::Autodiff(device) => device.inner.backend_id(), } } /// Encode variant ID and backend type ID into a unique `type_id`. fn encode_type_id(&self, backend_type_id: u16) -> u16 { u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id } /// Decode an encoded `type_id` into variant ID and backend type ID. fn decode_type_id(type_id: u16) -> (BackendId, u16) { let variant = type_id / TYPE_ID_BASE; let backend_type_id = type_id % TYPE_ID_BASE; ( BackendId::try_from(variant).expect("Unknown DispatchDevice variant"), backend_type_id, ) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u16)] enum BackendId { #[cfg(feature = "cpu")] Cpu = 0, #[cfg(feature = "cuda")] Cuda = 1, #[cfg(wgpu_metal)] Metal = 2, #[cfg(feature = "rocm")] Rocm = 3, #[cfg(wgpu_vulkan)] Vulkan = 4, #[cfg(wgpu_webgpu)] WebGpu = 5, #[cfg(feature = "ndarray")] NdArray = 6, #[cfg(feature = "tch")] LibTorch = 7, } impl From for u16 { fn from(variant: BackendId) -> Self { variant as u16 } } impl TryFrom for BackendId { type Error = (); fn try_from(value: u16) -> Result { match value { #[cfg(feature = "cpu")] 0 => Ok(Self::Cpu), #[cfg(feature = "cuda")] 1 => Ok(Self::Cuda), #[cfg(wgpu_metal)] 2 => Ok(Self::Metal), #[cfg(feature = "rocm")] 3 => Ok(Self::Rocm), #[cfg(wgpu_vulkan)] 4 => Ok(Self::Vulkan), #[cfg(wgpu_webgpu)] 5 => Ok(Self::WebGpu), #[cfg(feature = "ndarray")] 6 => Ok(Self::NdArray), #[cfg(feature = "tch")] 7 => Ok(Self::LibTorch), _ => Err(()), } } } impl DeviceOps for DispatchDevice { fn inner(&self) -> &Self { match self { #[cfg(feature = "autodiff")] DispatchDevice::Autodiff(device) => &device.inner, device => device, } } } impl burn_std::device::Device for DispatchDevice { fn from_id(mut device_id: DeviceId) -> Self { let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id); device_id.type_id = backend_type_id; match dispatch_id { #[cfg(feature = "cpu")] BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)), #[cfg(feature = "cuda")] BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)), #[cfg(wgpu_metal)] BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)), #[cfg(feature = "rocm")] BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)), #[cfg(wgpu_vulkan)] BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)), #[cfg(wgpu_webgpu)] BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)), #[cfg(feature = "ndarray")] BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)), #[cfg(feature = "tch")] BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)), } } fn to_id(&self) -> DeviceId { let mut device_id = match self { #[cfg(feature = "cpu")] Self::Cpu(device) => device.to_id(), #[cfg(feature = "cuda")] Self::Cuda(device) => device.to_id(), #[cfg(wgpu_metal)] Self::Metal(device) => device.to_id(), #[cfg(feature = "rocm")] Self::Rocm(device) => device.to_id(), #[cfg(wgpu_vulkan)] Self::Vulkan(device) => device.to_id(), #[cfg(wgpu_webgpu)] Self::WebGpu(device) => device.to_id(), #[cfg(feature = "ndarray")] Self::NdArray(device) => device.to_id(), #[cfg(feature = "tch")] Self::LibTorch(device) => device.to_id(), #[cfg(feature = "autodiff")] Self::Autodiff(device) => device.inner.to_id(), }; device_id.type_id = self.encode_type_id(device_id.type_id); device_id } fn device_count(type_id: u16) -> usize { let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id); match dispatch_id { #[cfg(feature = "cpu")] BackendId::Cpu => CpuDevice::device_count(backend_type_id), #[cfg(feature = "cuda")] BackendId::Cuda => CudaDevice::device_count(backend_type_id), #[cfg(wgpu_metal)] BackendId::Metal => WgpuDevice::device_count(backend_type_id), #[cfg(feature = "rocm")] BackendId::Rocm => RocmDevice::device_count(backend_type_id), #[cfg(wgpu_vulkan)] BackendId::Vulkan => WgpuDevice::device_count(backend_type_id), #[cfg(wgpu_webgpu)] BackendId::WebGpu => WgpuDevice::device_count(backend_type_id), #[cfg(feature = "ndarray")] BackendId::NdArray => NdArrayDevice::device_count(backend_type_id), #[cfg(feature = "tch")] BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id), } } } #[cfg(feature = "cpu")] impl From for DispatchDevice { fn from(device: CpuDevice) -> Self { DispatchDevice::Cpu(device) } } #[cfg(feature = "cuda")] impl From for DispatchDevice { fn from(device: CudaDevice) -> Self { DispatchDevice::Cuda(device) } } #[cfg(wgpu_metal)] impl From for DispatchDevice { fn from(device: WgpuDevice) -> Self { DispatchDevice::Metal(device) } } #[cfg(feature = "rocm")] impl From for DispatchDevice { fn from(device: RocmDevice) -> Self { DispatchDevice::Rocm(device) } } #[cfg(wgpu_vulkan)] impl From for DispatchDevice { fn from(device: WgpuDevice) -> Self { DispatchDevice::Vulkan(device) } } #[cfg(wgpu_webgpu)] impl From for DispatchDevice { fn from(device: WgpuDevice) -> Self { DispatchDevice::WebGpu(device) } } #[cfg(feature = "ndarray")] impl From for DispatchDevice { fn from(device: NdArrayDevice) -> Self { DispatchDevice::NdArray(device) } } #[cfg(feature = "tch")] impl From for DispatchDevice { fn from(device: LibTorchDevice) -> Self { DispatchDevice::LibTorch(device) } } #[cfg(feature = "tch")] impl From for DispatchDevice { fn from(device: LibTorchDevice) -> Self { DispatchDevice::LibTorch(device) } } ================================================ FILE: crates/burn-dispatch/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![recursion_limit = "138"] //! Burn multi-backend dispatch. //! //! # Available Backends //! //! The dispatch backend supports the following variants, each enabled via cargo features: //! //! | Backend | Feature | Description | //! |------------|------------|-------------| //! | `Cpu` | `cpu` | Rust CPU backend (MLIR + LLVM) | //! | `Cuda` | `cuda` | NVIDIA CUDA backend | //! | `Metal` | `metal` | Apple Metal backend via `wgpu` (MSL) | //! | `Rocm` | `rocm` | AMD ROCm backend | //! | `Vulkan` | `vulkan` | Vulkan backend via `wgpu` (SPIR-V) | //! | `WebGpu` | `webgpu` | WebGPU backend via `wgpu` (WGSL) | //! | `NdArray` | `ndarray` | Pure Rust CPU backend using `ndarray` | //! | `LibTorch` | `tch` | Libtorch backend via `tch` | //! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) | //! //! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive. //! All other backends can be combined freely. //! //! ## WGPU Backend Exclusivity //! //! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to //! the current automatic compile, which can only select one target at a time. //! //! Enable only **one** of these features in your `Cargo.toml`: //! - `metal` //! - `vulkan` //! - `webgpu` //! //! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU //! backends** to prevent unintended behavior. #[cfg(not(any( feature = "cpu", feature = "cuda", wgpu_metal, feature = "rocm", wgpu_vulkan, wgpu_webgpu, feature = "ndarray", feature = "tch", )))] compile_error!("At least one backend feature must be enabled."); #[macro_use] mod macros; mod backend; mod device; mod ops; mod tensor; pub use backend::*; pub use device::*; pub use tensor::*; extern crate alloc; /// Backends and devices used. pub(crate) mod backends { #[cfg(feature = "autodiff")] pub use burn_autodiff::Autodiff; #[cfg(feature = "cpu")] pub use burn_cpu::{Cpu, CpuDevice}; #[cfg(feature = "cuda")] pub use burn_cuda::{Cuda, CudaDevice}; #[cfg(feature = "rocm")] pub use burn_rocm::{Rocm, RocmDevice}; #[cfg(wgpu_metal)] pub use burn_wgpu::Metal; #[cfg(wgpu_vulkan)] pub use burn_wgpu::Vulkan; #[cfg(wgpu_webgpu)] pub use burn_wgpu::WebGpu; #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))] pub use burn_wgpu::WgpuDevice; #[cfg(feature = "ndarray")] pub use burn_ndarray::{NdArray, NdArrayDevice}; #[cfg(feature = "tch")] pub use burn_tch::{LibTorch, LibTorchDevice}; } ================================================ FILE: crates/burn-dispatch/src/macros.rs ================================================ /// Supplies a list of all supported backends and their corresponding feature flags /// to a callback macro. This centralizes the backend registry. macro_rules! backend_list { ($callback:ident, $($extra:tt)*) => { $callback! { $($extra)*; [Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, wgpu_metal], [Rocm, feature = "rocm"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"] } }; } /// Supplies a matrix of cross-backend combinations. Used for operations where the source and destination backends may differ. macro_rules! backend_matrix { ($callback:ident, $($extra:tt)*) => { $callback! { $($extra)*; [Cpu, feature = "cpu"] => [[Cuda, feature = "cuda"], [Metal, wgpu_metal], [Rocm, feature = "rocm"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [Cuda, feature = "cuda"] => [[Cpu, feature = "cpu"], [Metal, wgpu_metal], [Rocm, feature = "rocm"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [Metal, wgpu_metal] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [Rocm, feature = "rocm"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, wgpu_metal], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [Vulkan, wgpu_vulkan] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [WebGpu, wgpu_webgpu] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [NdArray, feature = "ndarray"], [LibTorch, feature = "tch"]]; [NdArray, feature = "ndarray"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, wgpu_metal], [Rocm, feature = "rocm"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [LibTorch, feature = "tch"]]; [LibTorch, feature = "tch"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, wgpu_metal], [Rocm, feature = "rocm"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = "ndarray"]] } }; } /// Helper to map the runtime strategy to the compile-time Autodiff generic. macro_rules! with_autodiff_backend { ($Backend:ident, $checkpointing:expr, |$B:ident| $body:expr) => { match $checkpointing { $crate::CheckpointingStrategy::Balanced => { type $B = Autodiff< $Backend, burn_autodiff::checkpoint::strategy::BalancedCheckpointing, >; $body } $crate::CheckpointingStrategy::None => { type $B = Autodiff<$Backend, burn_autodiff::checkpoint::strategy::NoCheckpointing>; $body } } }; } /// Match arm generator for `dispatch_device`. /// Maps each backend variant to a block where the specific backend type is bound to `B`. macro_rules! dispatch_device_arms { ( $device:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => { match $device { // Autodiff arm first #[cfg(feature = "autodiff")] $crate::DispatchDevice::Autodiff(inner) => { // Recursively dispatch on inner dispatch_device_arms!( @autodiff &**inner, |$inner| $body; $([$Backend, $cfg]),* ) }, $( #[cfg($cfg)] $crate::DispatchDevice::$Backend($inner) => { type B = $Backend; $body } )* } }; ( @autodiff $device:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => { match $device { $( #[cfg($cfg)] $crate::DispatchDevice::$Backend($inner) => { type B = Autodiff<$Backend>; $body } )* $crate::DispatchDevice::Autodiff(_) => panic!("Autodiff should not wrap an autodiff device.") } }; } /// Dispatches an operation body based on the provided device. macro_rules! dispatch_device { ($device:expr, |$inner:ident| $body:expr) => { backend_list!(dispatch_device_arms, $device, |$inner| $body) }; } /// Match arm generator for `to_device`. /// Handles the logic for same-backend transfers (fast path) and cross-backend /// transfers by generating a grid of all device combinations provided via `backend_matrix`. macro_rules! to_device_arms { ( $kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr; $( [$B1:ident, $src_cfg:meta] => [ $( [$B2:ident, $dst_cfg:meta] ),+ ] );* ) => { match ($tensor.kind, $device) { // --- Same backend to_device --- $( #[cfg($src_cfg)] ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B1(d)) => { $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$B1($crate::BackendTensor::$kind( $B1::::$to_device(tensor.$inner_fn(), d) )), #[cfg(feature = "autodiff")] checkpointing: $tensor.checkpointing, } } )* // --- Cross backend arms --- // This loop generates the grid of combinations $( $( #[cfg(all($src_cfg, $dst_cfg))] ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B2($device_ident)) => { type B1 = $B1; type B2 = $B2; let $inner = tensor.$inner_fn(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$B2( $crate::BackendTensor::$kind($body) ), #[cfg(feature = "autodiff")] checkpointing: $tensor.checkpointing, } } )+ )* #[cfg(feature = "autodiff")] (_, $crate::DispatchDevice::Autodiff(_)) | ($crate::DispatchTensorKind::Autodiff(..), _) => panic!("Operation not marked for autodiff.") } }; } /// Handles tensor movement between devices, supporting both same-backend transfers /// and cross-backend dispatches. macro_rules! to_device { ($kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr) => { backend_matrix!( to_device_arms, $kind, $inner_fn, $tensor, $device, $to_device, |$inner, $device_ident| $body ) }; } /// Match arm generator for `float_to_device`. /// /// Similar to `to_device_arms`, but float tensors are checked for autodiff support. macro_rules! float_to_device_arms { ( $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr; $( [$B1:ident, $src_cfg:meta] => [ $( [$B2:ident, $dst_cfg:meta] ),+ ] );* ) => { match ($tensor.kind, $device) { #[cfg(feature = "autodiff")] ($crate::DispatchTensorKind::Autodiff(kind), $crate::DispatchDevice::Autodiff(device)) => { let ckp = $tensor.checkpointing; float_to_device_arms!( @autodiff *kind, &**device, ckp, $to_device; $([$B1, $src_cfg]);* ) } // --- Same backend to_device --- $( #[cfg($src_cfg)] ($crate::DispatchTensorKind::$B1(kind), $crate::DispatchDevice::$B1(d)) => { $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$B1($crate::BackendTensor::Float( $B1::::$to_device(kind.float(), d) )), #[cfg(feature = "autodiff")] checkpointing: $tensor.checkpointing, } } )* // --- Cross backend arms --- // This loop generates the grid of combinations $( $( #[cfg(all($src_cfg, $dst_cfg))] ($crate::DispatchTensorKind::$B1(kind), $crate::DispatchDevice::$B2($device_ident)) => { type B1 = $B1; type B2 = $B2; let $inner = kind.float(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$B2($crate::BackendTensor::Float($body)), #[cfg(feature = "autodiff")] checkpointing: $tensor.checkpointing, } } )+ )* #[cfg(feature = "autodiff")] ($crate::DispatchTensorKind::Autodiff(..), _) | (_, $crate::DispatchDevice::Autodiff(_)) => panic!("Cannot move between autodiff and non-autodiff instances.") } }; // Autodiff(DispatchTensor) ( @autodiff $tensor:expr, $device:expr, $ckp:expr, $to_device:ident; $( [$B1:ident, $src_cfg:meta] );* ) => {{ match ($tensor, $device) { // --- Same backend to_device --- $( #[cfg($src_cfg)] ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B1(d)) => { let kind = $crate::DispatchTensorKind::Autodiff(Box::new($crate::DispatchTensorKind::$B1($crate::BackendTensor::Autodiff( with_autodiff_backend!($B1, $ckp, |B| { B::$to_device(tensor.autodiff(), d) }) )))); $crate::DispatchTensor {kind, checkpointing: $ckp} } )* (_, _) => unimplemented!("Autodiff tensor cannot be moved between backends.") } }}; } /// Handles float tensor movement between devices (that might support autodiff). macro_rules! float_to_device { ($kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr) => { backend_matrix!( float_to_device_arms, $tensor, $device, $to_device, |$inner, $device_ident| $body ) }; } /// Dispatches a tensor creation operation (e.g., zeros, ones) to the correct backend /// based on the provided device. macro_rules! creation_op { ($kind:ident, $device:expr, |$inner:ident| $body:expr) => { backend_list!(creation_op_arms, $kind, $device, |$inner| $body) }; } /// Match arm generator for `creation_float`. /// /// Similar to `creation_op_arms`, but float tensors are checked for autodiff support. macro_rules! creation_op_arms { ( $kind:ident, $device:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match $device { // Autodiff arm first #[cfg(feature = "autodiff")] $crate::DispatchDevice::Autodiff(inner) => { // Recursively dispatch on inner creation_op_arms!( @autodiff $kind, &**inner, inner.checkpointing, |$inner| $body; $([$Backend, $cfg]),* ) }, $( #[cfg($cfg)] $crate::DispatchDevice::$Backend($inner) => { type B = $Backend; $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend( $crate::BackendTensor::$kind($body) ), // TODO: hmmm should devices also carry the checkpointing all the time? #[cfg(feature = "autodiff")] checkpointing: $crate::CheckpointingStrategy::None, } } )* } }}; ( @autodiff $kind:ident, $device:expr, $ckp:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match $device { $( #[cfg($cfg)] $crate::DispatchDevice::$Backend($inner) => { with_autodiff_backend!($Backend, $ckp, |B| { wrap_float!(@wrap_autodiff $kind, $Backend, $ckp, { $body }) }) } )* $crate::DispatchDevice::Autodiff(_) => panic!("Autodiff should not wrap an autodiff device.") } }}; } /// Wrap the result in the backend tensor kind, handling float -> autodiff. #[cfg(feature = "autodiff")] macro_rules! wrap_float { ( @wrap_autodiff Float, $Backend:ident, $ckp:expr, $expr:expr ) => { $crate::DispatchTensor { kind: $crate::DispatchTensorKind::Autodiff(Box::new( $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Autodiff($expr)), )), checkpointing: $ckp, } }; ( @wrap_autodiff $other:ident, $Backend:ident, $ckp:expr, $expr:expr ) => { $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$other($expr)), checkpointing: $ckp, } }; } /// Match arm generator for `unary_op`. /// Unwraps the inner tensor primitive (e.g., `inner.float()`) and provides the backend type `B` /// for the operation. /// /// When the return kind is provided, the result is wrapped in the corresponding `DispatchTensor` variant. macro_rules! unary_op_arms { ( $kind:ident, $inner_kind:ident, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = $tensor.checkpointing; match $tensor.kind { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { type B = $Backend; let $inner = $inner.$inner_kind(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(..) => panic!("Operation not marked for autodiff.") } }}; // Operations that do not return a tensor kind ( $inner_kind:ident, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match $tensor.kind { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { type B = $Backend; let $inner = $inner.$inner_kind(); $body } )* #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(..) => panic!("Operation not marked for autodiff.") } }}; } /// Backend dispatch for unary operations. /// /// When the return `=> Kind` is not provided, the operation output is not wrapped in a dispatch tensor (e.g., `into_data(..)`) macro_rules! unary_op { ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => { backend_list!(unary_op_arms, $kind, $inner_kind, $tensor, |$inner| { $body }) }; ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => { backend_list!(unary_op_arms, $inner_kind, $tensor, |$inner| { $body }) }; } /// Match arm generator for `unary_float`. /// /// Similar to `unary_op_arms`, but float tensors are checked for autodiff support. macro_rules! unary_float_arms { ( $mode:ident, // `owned` or `ref` $kind:ident, $inner_kind:ident, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = $tensor.checkpointing; match $tensor.kind { #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(inner) => { unary_float_arms!( @autodiff $mode, checkpointing, $kind, { if_mode!($mode, &**inner, *inner) }, |$inner| $body; $([$Backend, $cfg]),* ) }, $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { type B = $Backend; let $inner = unary_float_arms!(@unwrap $mode, $inner, $inner_kind); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend( $crate::BackendTensor::$kind($body) ), #[cfg(feature = "autodiff")] checkpointing, } } )* } }}; // --- Autodiff recursive arm --- ( @autodiff $mode:ident, $ckp:expr, $kind:ident, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match $tensor { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { with_autodiff_backend!($Backend, $ckp, |B| { let $inner = unary_float_arms!(@unwrap_ad $mode, $inner); wrap_float!(@wrap_autodiff $kind, $Backend, $ckp, { $body }) }) } )* $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor.") } }}; // --- Non-wrapping arms (operations not returning a tensor) --- ( $mode:ident, $inner_kind:ident, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = &$tensor.checkpointing; match { if_mode!($mode, &$tensor.kind, $tensor.kind) } { #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(inner) => { unary_float_arms!( @autodiff $mode, checkpointing, { if_mode!($mode, &**inner, *inner) }, |$inner| $body; $([$Backend, $cfg]),* ) }, $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { type B = $Backend; let $inner = unary_float_arms!(@unwrap $mode, $inner, $inner_kind); $body } )* } }}; ( @autodiff $mode:ident, $ckp:expr, $tensor:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match $tensor { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend($inner) => { with_autodiff_backend!($Backend, $ckp, |B| { let $inner = unary_float_arms!(@unwrap_ad $mode, $inner); $body }) } )* $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor.") } }}; // --- Helpers to unwarp the tensor based on owned/ref --- (@unwrap owned, $inner:ident, $inner_kind:ident) => { $inner.$inner_kind() }; (@unwrap ref, $inner:ident, $inner_kind:ident) => { paste::paste! { $inner.[< as_ $inner_kind >]() } }; (@unwrap_ad owned, $inner:ident) => { $inner.autodiff() }; (@unwrap_ad ref, $inner:ident) => { $inner.as_autodiff() }; } #[cfg(feature = "autodiff")] /// Utility to pick a token based on mode macro_rules! if_mode { (ref, $if_ref:expr, $if_owned:expr) => { $if_ref }; (owned, $if_ref:expr, $if_owned:expr) => { $if_owned }; } /// Backend dispatch for float unary operations (that might support autodiff). /// /// When the return `=> Kind` is not provided, the operation output is not wrapped in a dispatch tensor (e.g., `into_data(..)`) macro_rules! unary_float { // Owned with return kind ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => { backend_list!( unary_float_arms, owned, $kind, $inner_kind, $tensor, |$inner| { $body } ) }; // Owned without return kind ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => { backend_list!(unary_float_arms, owned, $inner_kind, $tensor, |$inner| { $body }) }; // Reference without return kind (ref $tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => { backend_list!(unary_float_arms, ref, $inner_kind, $tensor, |$inner| { $body }) }; } /// Match arm generator for `binary_op`. /// Matches two tensors to ensure they share the same backend before unwrapping them for the operation. macro_rules! binary_op_arms { ( $kind:ident, ($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing); match ($lhs.kind, $rhs.kind) { $( #[cfg($cfg)] ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { type B = $Backend; let $lhs_inner = $lhs_inner.$lhs_kind(); let $rhs_inner = $rhs_inner.$rhs_kind(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* #[allow(unreachable_patterns)] (lhs, rhs) => { panic!( "The provided tensors are not on the same backend. Got backends {:?} and {:?}.", lhs, rhs ); } } }}; } /// Backend dispatch for binary operations. /// Automatically verifies that both tensors reside on the same backend. macro_rules! binary_op { (($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr => $kind:ident) => { backend_list!( binary_op_arms, $kind, ($lhs, $lhs_kind), ($rhs, $rhs_kind), |$lhs_inner, $rhs_inner| { $body } ) }; } /// Match arm generator for `binary_float`. /// Matches two tensors to ensure they share the same backend before unwrapping them for the operation. macro_rules! binary_float_arms { // (float, float) binary op ( $kind:ident, ($lhs:expr, float), ($rhs:expr, float), |$lhs_inner:ident, $rhs_inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing); match ($lhs.kind, $rhs.kind) { // Autodiff arms first #[cfg(feature = "autodiff")] ($crate::DispatchTensorKind::Autodiff(lhs_inner), $crate::DispatchTensorKind::Autodiff(rhs_inner)) => { // Recursively dispatch on inner binary_float_arms!( @autodiff $kind, (*lhs_inner, autodiff, checkpointing), (*rhs_inner, autodiff, checkpointing), |$lhs_inner, $rhs_inner| $body; $([$Backend, $cfg]),* ) }, $( #[cfg($cfg)] ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { type B = $Backend; let $lhs_inner = $lhs_inner.float(); let $rhs_inner = $rhs_inner.float(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* #[allow(unreachable_patterns)] (lhs, rhs) => { panic!( "The provided tensors are not on the same backend. Got backends {:?} and {:?}.", lhs, rhs ); } } }}; // (float, any) binary op ( $kind:ident, ($lhs:expr, float), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ #[cfg(feature = "autodiff")] let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing); match ($lhs.kind, $rhs.kind) { $( // Autodiff arms first #[cfg(all(feature = "autodiff", $cfg))] ($crate::DispatchTensorKind::Autodiff(lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { // Match on inner match *lhs_inner { $crate::DispatchTensorKind::$Backend($lhs_inner) => { with_autodiff_backend!($Backend, checkpointing, |B| { let $lhs_inner = $lhs_inner.autodiff(); let $rhs_inner = $rhs_inner.$rhs_kind(); wrap_float!( @wrap_autodiff $kind, $Backend, checkpointing, { $body } ) }) } $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor."), #[allow(unreachable_patterns)] _ => panic!("The provided tensors are not on the same backend.") } }, #[cfg($cfg)] ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { type B = $Backend; let $lhs_inner = $lhs_inner.float(); let $rhs_inner = $rhs_inner.$rhs_kind(); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* #[allow(unreachable_patterns)] (lhs, rhs) => { panic!( "The provided tensors are not on the same backend. Got backends {:?} and {:?}.", lhs, rhs ); } } }}; ( $kind:ident, ($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match ($lhs, $rhs) { $( #[cfg($cfg)] ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { type B = $Backend; let $lhs_inner = $lhs_inner.$lhs_kind(); let $rhs_inner = $rhs_inner.$rhs_kind(); $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)) } )* (lhs, rhs) => { panic!( "The provided tensors are not on the same backend. Got backends {:?} and {:?}.", lhs, rhs ); } } }}; // Autodiff (lhs, rhs) tensors ( @autodiff $kind:ident, ($lhs:expr, $lhs_kind:ident, $ckp_lhs:expr), ($rhs:expr, $rhs_kind:ident, $ckp_rhs:expr), |$lhs_inner:ident, $rhs_inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),* ) => {{ match ($lhs, $rhs) { $( #[cfg($cfg)] ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => { with_autodiff_backend!($Backend, $ckp_lhs, |B| { let $lhs_inner = $lhs_inner.$lhs_kind(); let $rhs_inner = $rhs_inner.$rhs_kind(); wrap_float!( @wrap_autodiff $kind, $Backend, $ckp_lhs, { $body } ) }) } )* #[cfg(feature = "autodiff")] ($crate::DispatchTensorKind::Autodiff(..), _) | (_, $crate::DispatchTensorKind::Autodiff(..)) => panic!("Autodiff should not wrap an autodiff tensor."), #[allow(unreachable_patterns)] (lhs, rhs) => { panic!( "The provided tensors are not on the same backend. Got backends {:?} and {:?}.", lhs, rhs ); } } }}; } /// Backend dispatch for binary operations. /// Automatically verifies that both tensors reside on the same backend. macro_rules! binary_float { (($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr => $kind:ident) => { backend_list!( binary_float_arms, $kind, ($lhs, $lhs_kind), ($rhs, $rhs_kind), |$lhs_inner, $rhs_inner| { $body } ) }; } /// The core logic for a single backend in a `multi_op`. /// Handles the manual unwrapping of required/optional inputs and the /// re-wrapping of multiple required/optional output tensors. macro_rules! multi_op_arm { ( $Backend:ident, $ckp:ident, [ $( ($x:ident, $x_kind:ident) ),+ ], [ $( ($opt_in:ident, $opt_kind:ident) ),* ], [ $( ($out:ident, $out_kind:ident) ),+ ], [ $( $opt_out:ident ),* ], $body:expr ) => {{ type B = $Backend; // Required inputs $( let $x = match $x.kind { $crate::DispatchTensorKind::$Backend(inner) => inner.$x_kind(), #[allow(unreachable_patterns)] _ => panic!("Input tensor {} is on the wrong device", stringify!($x)), }; )+ // Optional inputs $( let $opt_in = $opt_in.map(|o| match o.kind { $crate::DispatchTensorKind::$Backend(inner) => inner.$opt_kind(), #[allow(unreachable_patterns)] _ => panic!("Optional tensor {} is on the wrong device", stringify!($opt_in)), }); )* let ($($out),+, $($opt_out),*) = $body; // Outputs and optional outputs ( $( $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$out_kind($out)), #[cfg(feature = "autodiff")] checkpointing: $ckp, } ),+, $( $opt_out.map(|t| $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Float(t)), #[cfg(feature = "autodiff")] checkpointing: $ckp, } ) ),* ) }}; } #[cfg(feature = "autodiff")] macro_rules! wrap_input_autodiff { ($Backend:ident, $inner:expr, int) => { $inner.int() }; ($Backend:ident, $inner:expr, bool) => { $inner.bool() }; // Float tensors: wrap with autodiff ($Backend:ident, $inner:expr, float) => { $inner.autodiff() }; } #[cfg(feature = "autodiff")] // DispatchTensorKind::Autodiff(DispatchTensorKind::$Backend(BackendTensor::Autodiff())) macro_rules! multi_op_arm_autodiff { ( $Backend:ident, $ckp:ident, [ $( ($x:ident, $x_kind:ident) ),+ ], [ $( ($opt_in:ident, $opt_kind:ident) ),* ], [ $( ($out:ident, $out_kind:ident) ),+ ], [ $( $opt_out:ident ),* ], $body:expr ) => {{ // type B = Autodiff<$Backend>; with_autodiff_backend!($Backend, $ckp, |B| { // Required inputs $( let $x = match $x.kind { $crate::DispatchTensorKind::Autodiff(inner) => { match *inner { $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $x_kind), _ => panic!("Input tensor {} is on the wrong device", stringify!($x)), } }, // Unreachable, except when input is int $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $x_kind), #[allow(unreachable_patterns)] _ => panic!("Input tensor {} is on the wrong device", stringify!($x)), }; )+ // Optional inputs (always assumed to be float / autodiff) $( let $opt_in = $opt_in.map(|o| match o.kind { $crate::DispatchTensorKind::Autodiff(inner) => { match *inner { $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $opt_kind), _ => panic!("Input tensor {} is on the wrong device", stringify!($opt_in)), } }, _ => panic!("Optional tensor {} is on the wrong device", stringify!($opt_in)), }); )* let ($($out),+, $($opt_out),*) = $body; // Outputs and optional outputs ( $( wrap_float!(@wrap_autodiff $out_kind, $Backend, $ckp, $out) ),+, $( $opt_out.map(|t| wrap_float!(@wrap_autodiff Float, $Backend, $ckp, t)) ),* ) }) }}; } /// Helper to extract the first identifier from an input list. /// Used to determine the device/backend for dispatching multi-tensor operations. macro_rules! first_input { ([ ($x:ident, $kind:ident) $(, $rest:tt)* ]) => { $x }; } /// Match arm generator for `multi_op`. /// Determines the backend based on the first input and delegates to `multi_op_arm` /// to handle the repetition-heavy unwrapping and wrapping logic. macro_rules! multi_op_arms_autodiff { ( $inputs:tt, $opt_inputs:tt, $outputs:tt, $opt_outputs:tt, $body:expr; $( [$Backend:ident, $cfg:meta] ),* ) => {{ let first_input = &first_input!($inputs); #[cfg(feature = "autodiff")] let checkpointing = first_input.checkpointing; match &first_input.kind { // Autodiff first #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(inner) => { match **inner { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { multi_op_arm_autodiff!( $Backend, checkpointing, $inputs, $opt_inputs, $outputs, $opt_outputs, $body ) } )* $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor.") } }, $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { multi_op_arm!( $Backend, checkpointing, $inputs, $opt_inputs, $outputs, $opt_outputs, $body ) } )* } }}; } /// Match arm generator for `multi_op`. /// /// Similar to `multi_op_arms`, but skips autodiff checks. macro_rules! multi_op_arms { ( $inputs:tt, $opt_inputs:tt, $outputs:tt, $opt_outputs:tt, $body:expr; $( [$Backend:ident, $cfg:meta] ),* ) => {{ let first_input = &first_input!($inputs); let checkpointing = if cfg!(feature = "autodiff") { first_input.checkpointing } else { $crate::CheckpointingStrategy::None }; match first_input.kind { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { multi_op_arm!( $Backend, checkpointing, $inputs, $opt_inputs, $outputs, $opt_outputs, $body ) } )* #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(..) => panic!("Operation not marked for autodiff.") } }}; } /// High-level macro for complex module operations (e.g., conv2d) and multi-tensor operations. /// Handles variable numbers of required/optional inputs and wraps multiple outputs. /// /// Usage: /// ```ignore /// multi_op!( /// inputs[(x, float), (weight, float)], /// opt_inputs[(bias, float)], /// => Float, /// B::conv2d(x, weight, bias, options) /// ) /// ``` macro_rules! multi_op { // --- Single output shorthands --- // Automatically wraps body in tuple and extracts .0 ( inputs[$( ($x:ident, $kind:ident) ),+], => Float, $body:expr ) => { multi_op!( inputs[$( ($x, $kind) ),+], opt_inputs[], outputs[(out, Float)], opt_outputs[], { ($body,) } ) .0 }; ( inputs[$( ($x:ident, $kind:ident) ),+], opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ], => $out_kind:ident, $body:expr ) => { multi_op!( inputs[$( ($x, $kind) ),+], opt_inputs[ $(($opt_in, $opt_kind)),* ], outputs[(out, $out_kind)], opt_outputs[], { ($body,) } ) .0 }; // Int/Bool op specialization (not marked for autodiff) ( inputs[$( ($x:ident, $kind:ident) ),+], => $out_kind:ident, $body:expr ) => { backend_list!( multi_op_arms, [ $(($x, $kind)),+ ], [], [ (out, $out_kind) ], [], { ($body,) } ).0 }; // --- Required + optional for both inputs and outputs --- ( inputs[ $(($x:ident, $kind:ident)),+ ], opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ], outputs[ $( ($out:ident, $out_kind:ident) ),+ ], opt_outputs[ $($opt_out:ident),* ], $body:expr ) => { backend_list!( multi_op_arms_autodiff, [ $(($x, $kind)),+ ], [ $(($opt_in, $opt_kind)),* ], [ $(($out, $out_kind)),+ ], [ $($opt_out),* ], $body ) }; ( inputs[ $(($x:ident, $kind:ident)),+ ], opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ], outputs[ $($out:ident),+ ], $body:expr ) => { multi_op!( inputs[ $(($x, $kind)),+ ], opt_inputs[ $(($opt_in, $opt_kind)),* ], outputs[ $(($out, Float)),+ ], opt_outputs[], $body ) }; ( inputs[ $(($x:ident, $kind:ident)),+ ], outputs[ $( ($out:ident, $out_kind:ident) ),+ ], $body:expr ) => { multi_op!( inputs[ $(($x, $kind)),+ ], opt_inputs[], outputs[ $(($out, $out_kind)),+ ], opt_outputs[], $body ) }; } /// Unwraps a `Vec` for a known backend. macro_rules! unwrap_vec { ($Backend:ident, $vec:expr, $kind:ident) => { $vec.into_iter() .map(|t| match t.kind { $crate::DispatchTensorKind::$Backend(inner) => inner.$kind(), #[allow(unreachable_patterns)] _ => panic!( "Tensor is on the wrong backend (expected {}).", stringify!($Backend) ), }) .collect::>() }; // Autodiff-wrapped backend (@autodiff $Backend:ident, $vec:expr, $kind:ident) => { $vec.into_iter() .map(|t| match t.kind { $crate::DispatchTensorKind::Autodiff(inner) => match *inner { $crate::DispatchTensorKind::$Backend(inner) => inner.$kind(), _ => panic!( "Autodiff float tensor is on the wrong backend (expected {}).", stringify!($Backend) ), }, _ => panic!( "Expected autodiff-wrapped float tensor for backend {}.", stringify!($Backend) ), }) .collect::>() }; } /// Match arm generator for `vec_op`. macro_rules! vec_op_arms { (Float, $inner_kind:ident, $tensors:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),*) => {{ let first = &$tensors[0]; #[cfg(feature = "autodiff")] let checkpointing = first.checkpointing; match &first.kind { // Autodiff arm first #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(inner) => { // Recursively dispatch on inner match **inner { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { with_autodiff_backend!($Backend, checkpointing, |B| { let $inner = unwrap_vec!(@autodiff $Backend, $tensors, autodiff); wrap_float!( @wrap_autodiff Float, $Backend, checkpointing, { $body } ) }) } )* $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor.") } }, $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { type B = $Backend; let $inner = unwrap_vec!($Backend, $tensors, $inner_kind); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Float($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* } }}; ($kind:ident, $inner_kind:ident, $tensors:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),*) => {{ let first = &$tensors[0]; #[cfg(feature = "autodiff")] let checkpointing = first.checkpointing; match first.kind { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { type B = $Backend; let $inner = unwrap_vec!($Backend, $tensors, $inner_kind); $crate::DispatchTensor { kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)), #[cfg(feature = "autodiff")] checkpointing, } } )* #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(..) => panic!("Operation not marked for autodiff.") } }}; } /// Backend dispatch for operations on multiple inputs (vec). /// Automatically verifies that tensors reside on the first backend. macro_rules! vec_op { ($tensors:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => { backend_list!(vec_op_arms, $kind, $inner_kind, $tensors, |$inner| { $body }) }; } /// Match arm generator for `transaction_op`. macro_rules! transaction_op_arms { ($tx:ident, $first:expr; $([$Backend:ident, $cfg:meta]),*) => {{ match &$first.kind { // Autodiff arm first #[cfg(feature = "autodiff")] $crate::DispatchTensorKind::Autodiff(inner) => { // Recursively dispatch on inner match **inner { $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { type B = $Backend; // Unwrap vec let floats = unwrap_vec!(@autodiff $Backend, $tx.read_floats, autodiff_inner); let ints = unwrap_vec!($Backend, $tx.read_ints, int); let bools = unwrap_vec!($Backend, $tx.read_bools, bool); // Not supported let qfloats = $tx.read_qfloats.into_iter().map(|_t| todo!("Quantization not supported yet")).collect(); B::tr_execute(TransactionPrimitive::new(floats, qfloats, ints, bools)).await } )* $crate::DispatchTensorKind::Autodiff(..) => panic!("Autodiff should not wrap an autodiff tensor.") } }, $( #[cfg($cfg)] $crate::DispatchTensorKind::$Backend(_) => { type B = $Backend; // Unwrap vec let floats = unwrap_vec!($Backend, $tx.read_floats, float); let ints = unwrap_vec!($Backend, $tx.read_ints, int); let bools = unwrap_vec!($Backend, $tx.read_bools, bool); // Not supported let qfloats = $tx.read_qfloats.into_iter().map(|_t| todo!("Quantization not supported yet")).collect(); B::tr_execute(TransactionPrimitive::new(floats, qfloats, ints, bools)).await } )* } }}; } /// Helper to dispatch a transaction based on the first available tensor. macro_rules! transaction_op { ($tx:ident, $first:expr) => { backend_list!(transaction_op_arms, $tx, $first) }; } ================================================ FILE: crates/burn-dispatch/src/ops/activation.rs ================================================ use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor}; use crate::Dispatch; use crate::backends::*; impl ActivationOps for Dispatch { fn leaky_relu(tensor: FloatTensor, negative_slope: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float) } fn relu(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::relu(tensor) => Float) } fn relu_backward(output: FloatTensor, grad: FloatTensor) -> FloatTensor { binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float) } fn gelu(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float) } fn prelu(tensor: FloatTensor, alpha: FloatTensor) -> FloatTensor { binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float) } fn gelu_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor { binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float) } fn sigmoid(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float) } fn sigmoid_backward(output: FloatTensor, grad: FloatTensor) -> FloatTensor { binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float) } fn hard_sigmoid(tensor: FloatTensor, alpha: Scalar, beta: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float) } fn log_sigmoid(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float) } fn log_sigmoid_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor { binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float) } } ================================================ FILE: crates/burn-dispatch/src/ops/bool_tensor.rs ================================================ use burn_backend::{ ExecutionError, Scalar, TensorData, ops::BoolTensorOps, tensor::{BoolTensor, FloatTensor, IntTensor}, }; use burn_std::{Shape, Slice}; use crate::backends::*; use crate::{Dispatch, DispatchDevice}; impl BoolTensorOps for Dispatch { fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor { creation_op!(Bool, device, |device| B::bool_empty(shape, device)) } fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor { creation_op!(Bool, device, |device| B::bool_zeros(shape, device)) } fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor { creation_op!(Bool, device, |device| B::bool_ones(shape, device)) } async fn bool_into_data(tensor: BoolTensor) -> Result { unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await) } fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor { creation_op!(Bool, device, |device| B::bool_from_data(data, device)) } fn bool_into_int(tensor: BoolTensor) -> IntTensor { unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int) } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float) } fn bool_device(tensor: &BoolTensor) -> DispatchDevice { tensor.device() } fn bool_to_device(tensor: BoolTensor, device: &DispatchDevice) -> BoolTensor { to_device!( Bool, bool, tensor, device, bool_to_device, |inner, device| { let data = burn_backend::read_sync(B1::bool_into_data(inner)).expect("Should read data"); B2::bool_from_data(data, device) } ) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool) } fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool) } fn bool_slice_assign( tensor: BoolTensor, slices: &[Slice], value: BoolTensor, ) -> BoolTensor { binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool) } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { multi_op!( inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool, B::bool_mask_where(tensor, mask, value) ) } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool) } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { multi_op!( inputs[(tensor, bool), (indices, int), (value, bool)], => Bool, B::bool_scatter_or(dim, tensor, indices, value) ) } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool) } fn bool_not(tensor: BoolTensor) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool) } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool) } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool) } fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool) } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool) } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool) } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool) } fn bool_unfold( tensor: BoolTensor, dim: usize, size: usize, step: usize, ) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool) } fn bool_select( tensor: BoolTensor, dim: usize, indices: IntTensor, ) -> BoolTensor { binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool) } fn bool_select_or( tensor: BoolTensor, dim: usize, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { multi_op!( inputs[(tensor, bool), (indices, int), (value, bool)], => Bool, B::bool_select_or(tensor, dim, indices, value) ) } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool) } fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool) } fn bool_not_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool) } fn bool_not_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool) } fn bool_xor(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool) } fn bool_transpose(tensor: BoolTensor) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool) } fn bool_any(tensor: BoolTensor) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool) } fn bool_any_dim(tensor: BoolTensor, dim: usize) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool) } fn bool_all(tensor: BoolTensor) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool) } fn bool_all_dim(tensor: BoolTensor, dim: usize) -> BoolTensor { unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool) } async fn bool_argwhere(tensor: BoolTensor) -> IntTensor { unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int) } } ================================================ FILE: crates/burn-dispatch/src/ops/int_tensor.rs ================================================ use burn_backend::{ ExecutionError, Scalar, TensorData, ops::IntTensorOps, tensor::{BoolTensor, FloatTensor, IntTensor}, }; use burn_std::{IntDType, Shape, Slice}; use crate::backends::*; use crate::{Dispatch, DispatchDevice}; impl IntTensorOps for Dispatch { fn int_empty(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor { creation_op!(Int, device, |device| B::int_empty(shape, device, dtype)) } async fn int_into_data(tensor: IntTensor) -> Result { unary_op!(tensor, int, |tensor| B::int_into_data(tensor).await) } fn int_from_data(data: TensorData, device: &DispatchDevice) -> IntTensor { creation_op!(Int, device, |device| B::int_from_data(data, device)) } fn int_device(tensor: &IntTensor) -> DispatchDevice { tensor.device() } fn int_to_device(tensor: IntTensor, device: &DispatchDevice) -> IntTensor { to_device!(Int, int, tensor, device, int_to_device, |inner, device| { let data = burn_backend::read_sync(B1::int_into_data(inner)).expect("Should read data"); B2::int_from_data(data, device) }) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_reshape(tensor, shape) => Int) } fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_slice(tensor, slices) => Int) } fn int_slice_assign( tensor: IntTensor, slices: &[Slice], value: IntTensor, ) -> IntTensor { binary_op!((tensor, int), (value, int), |tensor, value| B::int_slice_assign(tensor, slices, value) => Int) } fn int_into_float(tensor: IntTensor) -> FloatTensor { unary_op!(tensor, int, |tensor| B::int_into_float(tensor) => Float) } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> IntTensor { multi_op!( inputs[(tensor, int), (mask, bool), (value, int)], => Int, B::int_mask_where(tensor, mask, value) ) } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> IntTensor { binary_op!((tensor, int), (mask, bool), |tensor, mask| B::int_mask_fill(tensor, mask, value) => Int) } fn int_gather( dim: usize, tensor: IntTensor, indices: IntTensor, ) -> IntTensor { binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_gather(dim, tensor, indices) => Int) } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { multi_op!( inputs[(tensor, int), (indices, int), (value, int)], => Int, B::int_scatter_add(dim, tensor, indices, value) ) } fn int_select( tensor: IntTensor, dim: usize, indices: IntTensor, ) -> IntTensor { binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_select(tensor, dim, indices) => Int) } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { multi_op!( inputs[(tensor, int), (indices, int), (value, int)], => Int, B::int_select_add(tensor, dim, indices, value) ) } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_equal(lhs, rhs) => Bool) } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_equal_elem(lhs, rhs) => Bool) } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater(lhs, rhs) => Bool) } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_greater_elem(lhs, rhs) => Bool) } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater_equal(lhs, rhs) => Bool) } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_greater_equal_elem(lhs, rhs) => Bool) } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower(lhs, rhs) => Bool) } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_lower_elem(lhs, rhs) => Bool) } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower_equal(lhs, rhs) => Bool) } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_lower_equal_elem(lhs, rhs) => Bool) } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_add(lhs, rhs) => Int) } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_add_scalar(lhs, rhs) => Int) } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_sub(lhs, rhs) => Int) } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_sub_scalar(lhs, rhs) => Int) } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_mul(lhs, rhs) => Int) } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_mul_scalar(lhs, rhs) => Int) } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_div(lhs, rhs) => Int) } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_div_scalar(lhs, rhs) => Int) } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_remainder(lhs, rhs) => Int) } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_remainder_scalar(lhs, rhs) => Int) } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_matmul(lhs, rhs) => Int) } fn int_sum(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_sum(tensor) => Int) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_sum_dim(tensor, dim) => Int) } fn int_prod(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_prod(tensor) => Int) } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_prod_dim(tensor, dim) => Int) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_mean_dim(tensor, dim) => Int) } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_cumsum(tensor, dim) => Int) } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_cumprod(tensor, dim) => Int) } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_cummin(tensor, dim) => Int) } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_cummax(tensor, dim) => Int) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_argmax(tensor, dim) => Int) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_argmin(tensor, dim) => Int) } fn int_abs(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_abs(tensor) => Int) } fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_swap_dims(tensor, dim1, dim2) => Int) } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_permute(tensor, axes) => Int) } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_flip(tensor, axes) => Int) } fn int_random( shape: Shape, distribution: burn_backend::Distribution, device: &DispatchDevice, ) -> IntTensor { creation_op!(Int, device, |device| { B::int_random(shape, distribution, device) }) } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_expand(tensor, shape) => Int) } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_and(lhs, rhs) => Int) } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::bitwise_and_scalar(lhs, rhs) => Int) } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_or(lhs, rhs) => Int) } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::bitwise_or_scalar(lhs, rhs) => Int) } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_xor(lhs, rhs) => Int) } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::bitwise_xor_scalar(lhs, rhs) => Int) } fn bitwise_not(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::bitwise_not(tensor) => Int) } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_left_shift(lhs, rhs) => Int) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::bitwise_left_shift_scalar(lhs, rhs) => Int) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_right_shift(lhs, rhs) => Int) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::bitwise_right_shift_scalar(lhs, rhs) => Int) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_cast(tensor, dtype) => Int) } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_unfold(tensor, dim, size, step) => Int) } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_repeat_dim(tensor, dim, times) => Int) } fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { vec_op!(tensors, int, |tensors| B::int_cat(tensors, dim) => Int) } fn int_not_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_not_equal(lhs, rhs) => Bool) } fn int_not_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { unary_op!(lhs, int, |lhs| B::int_not_equal_elem(lhs, rhs) => Bool) } fn int_powi(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powi(lhs, rhs) => Int) } fn int_powi_scalar_impl(lhs: IntTensor, rhs: Scalar) -> IntTensor { unary_op!(lhs, int, |lhs| B::int_powi_scalar_impl(lhs, rhs) => Int) } fn int_clamp_min(tensor: IntTensor, min: Scalar) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_clamp_min(tensor, min) => Int) } fn int_clamp_max(tensor: IntTensor, max: Scalar) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_clamp_max(tensor, max) => Int) } fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_clamp(tensor, min, max) => Int) } fn int_neg(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_neg(tensor) => Int) } fn int_zeros(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor { creation_op!(Int, device, |device| B::int_zeros(shape, device, dtype)) } fn int_ones(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor { creation_op!(Int, device, |device| B::int_ones(shape, device, dtype)) } fn int_full( shape: Shape, fill_value: Scalar, device: &DispatchDevice, dtype: IntDType, ) -> IntTensor { creation_op!(Int, device, |device| B::int_full( shape, fill_value, device, dtype )) } fn int_mean(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_mean(tensor) => Int) } fn int_max(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_max(tensor) => Int) } fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_max_dim(tensor, dim) => Int) } fn int_max_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { multi_op!( inputs[(tensor, int)], outputs[(out, Int), (indices, Int)], B::int_max_dim_with_indices(tensor, dim) ) } fn int_max_abs(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_max_abs(tensor) => Int) } fn int_max_abs_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_max_abs_dim(tensor, dim) => Int) } fn int_min(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_min(tensor) => Int) } fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_min_dim(tensor, dim) => Int) } fn int_min_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { multi_op!( inputs[(tensor, int)], outputs[(out, Int), (indices, Int)], B::int_min_dim_with_indices(tensor, dim) ) } fn int_transpose(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_transpose(tensor) => Int) } fn int_arange_step( range: std::ops::Range, step: usize, device: &DispatchDevice, ) -> IntTensor { creation_op!(Int, device, |device| B::int_arange_step( range, step, device )) } fn int_arange(range: std::ops::Range, device: &DispatchDevice) -> IntTensor { creation_op!(Int, device, |device| B::int_arange(range, device)) } fn int_any(tensor: IntTensor) -> BoolTensor { unary_op!(tensor, int, |tensor| B::int_any(tensor) => Bool) } fn int_any_dim(tensor: IntTensor, dim: usize) -> BoolTensor { unary_op!(tensor, int, |tensor| B::int_any_dim(tensor, dim) => Bool) } fn int_all(tensor: IntTensor) -> BoolTensor { unary_op!(tensor, int, |tensor| B::int_all(tensor) => Bool) } fn int_all_dim(tensor: IntTensor, dim: usize) -> BoolTensor { unary_op!(tensor, int, |tensor| B::int_all_dim(tensor, dim) => Bool) } fn int_sign(tensor: IntTensor) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_sign(tensor) => Int) } fn int_sort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_sort(tensor, dim, descending) => Int) } fn int_sort_with_indices( tensor: IntTensor, dim: usize, descending: bool, ) -> (IntTensor, IntTensor) { multi_op!( inputs[(tensor, int)], outputs[(out, Int), (indices, Int)], B::int_sort_with_indices(tensor, dim, descending) ) } fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { unary_op!(tensor, int, |tensor| B::int_argsort(tensor, dim, descending) => Int) } } ================================================ FILE: crates/burn-dispatch/src/ops/mod.rs ================================================ mod activation; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; ================================================ FILE: crates/burn-dispatch/src/ops/module.rs ================================================ use burn_backend::{ ops::{ DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }, tensor::{FloatTensor, IntTensor}, }; use crate::Dispatch; use crate::backends::*; impl ModuleOps for Dispatch { fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv2d(x, weight, bias, options) ) } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: burn_backend::ops::DeformConvOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (offset, float), (weight, float)], opt_inputs[(mask, float), (bias, float)], => Float, B::deform_conv2d(x, offset, weight, mask, bias, options) ) } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: burn_backend::ops::DeformConvOptions<2>, ) -> DeformConv2dBackward { let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!( inputs[(x, float), (offset, float), (weight, float), (output_grad, float)], opt_inputs[(mask, float), (bias, float)], outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)], opt_outputs[mask_grad, bias_grad], { let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options); (res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad) } ); DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad) } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvOptions<3>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv3d(x, weight, bias, options) ) } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvTransposeOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv_transpose2d(x, weight, bias, options) ) } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvTransposeOptions<3>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv_transpose3d(x, weight, bias, options) ) } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { multi_op!(inputs[(x, float)], => Float, B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode) ) } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { multi_op!( inputs[(x, float), (grad, float)], => Float, B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode) ) } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { multi_op!( inputs[(x, float)], => Float, B::adaptive_avg_pool2d(x, output_size) ) } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (grad, float)], => Float, B::adaptive_avg_pool2d_backward(x, grad) ) } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { multi_op!( inputs[(x, float)], => Float, B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode) ) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { let (out, indices) = multi_op!( inputs[(x, float)], outputs[(out, Float), (indices, Int)], { let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode); (res.output, res.indices) } ); MaxPool2dWithIndices::new(out, indices) } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward { let x_grad = multi_op!( inputs[(x, float), (output_grad, float), (indices, int)], => Float, { let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices); res.x_grad } ); MaxPool2dBackward::new(x_grad) } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: burn_backend::ops::InterpolateOptions, ) -> FloatTensor { multi_op!( inputs[(x, float)], => Float, B::interpolate(x, output_size, options) ) } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: burn_backend::ops::InterpolateOptions, ) -> FloatTensor { multi_op!( inputs[(x, float), (grad, float)], => Float, B::interpolate_backward(x, grad, output_size, options) ) } fn embedding(weights: FloatTensor, indices: IntTensor) -> FloatTensor { multi_op!( inputs[(weights, float), (indices, int)], => Float, B::embedding(weights, indices) ) } fn embedding_backward( weights: FloatTensor, output_grad: FloatTensor, indices: IntTensor, ) -> FloatTensor { multi_op!( inputs[(weights, float), (output_grad, float), (indices, int)], => Float, B::embedding_backward(weights, output_grad, indices) ) } fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvOptions<1>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv1d(x, weight, bias, options) ) } fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<1>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv1d_x_backward(x, weight, output_grad, options) ) } fn conv1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<1>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv1d_weight_backward(x, weight, output_grad, options) ) } fn conv1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv1d_bias_backward(x, bias, output_grad) ) } fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv2d_x_backward(x, weight, output_grad, options) ) } fn conv2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv2d_weight_backward(x, weight, output_grad, options) ) } fn conv2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv2d_bias_backward(x, bias, output_grad) ) } fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<3>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv3d_x_backward(x, weight, output_grad, options) ) } fn conv3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvOptions<3>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv3d_weight_backward(x, weight, output_grad, options) ) } fn conv3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv3d_bias_backward(x, bias, output_grad) ) } fn conv_transpose1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: burn_backend::ops::ConvTransposeOptions<1>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float)], opt_inputs[(bias, float)], => Float, B::conv_transpose1d(x, weight, bias, options) ) } fn conv_transpose1d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<1>, ) -> FloatTensor { multi_op!( inputs[(weight, float), (output_grad, float)], => Float, B::conv_transpose1d_x_backward(weight, output_grad, options) ) } fn conv_transpose1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<1>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv_transpose1d_weight_backward(x, weight, output_grad, options) ) } fn conv_transpose1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv_transpose1d_bias_backward(x, bias, output_grad) ) } fn conv_transpose2d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<2>, ) -> FloatTensor { multi_op!( inputs[(weight, float), (output_grad, float)], => Float, B::conv_transpose2d_x_backward(weight, output_grad, options) ) } fn conv_transpose2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<2>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv_transpose2d_weight_backward(x, weight, output_grad, options) ) } fn conv_transpose2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv_transpose2d_bias_backward(x, bias, output_grad) ) } fn conv_transpose3d_x_backward( weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<3>, ) -> FloatTensor { multi_op!( inputs[(weight, float), (output_grad, float)], => Float, B::conv_transpose3d_x_backward(weight, output_grad, options) ) } fn conv_transpose3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: burn_backend::ops::ConvTransposeOptions<3>, ) -> FloatTensor { multi_op!( inputs[(x, float), (weight, float), (output_grad, float)], => Float, B::conv_transpose3d_weight_backward(x, weight, output_grad, options) ) } fn conv_transpose3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (bias, float), (output_grad, float)], => Float, B::conv_transpose3d_bias_backward(x, bias, output_grad) ) } fn unfold4d( x: FloatTensor, kernel_size: [usize; 2], options: burn_backend::ops::UnfoldOptions, ) -> FloatTensor { multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options)) } fn avg_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { multi_op!(inputs[(x, float)], => Float, B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode) ) } fn avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { multi_op!( inputs[(x, float), (grad, float)], => Float, B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode) ) } fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size)) } fn adaptive_avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(x, float), (grad, float)], => Float, B::adaptive_avg_pool1d_backward(x, grad) ) } fn max_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> FloatTensor { multi_op!(inputs[(x, float)], => Float, B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode)) } fn max_pool1d_with_indices( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { let (out, indices) = multi_op!( inputs[(x, float)], outputs[(out, Float), (indices, Int)], { let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode); (res.output, res.indices) } ); MaxPool1dWithIndices::new(out, indices) } fn max_pool1d_with_indices_backward( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { let x_grad = multi_op!( inputs[(x, float), (output_grad, float), (indices, int)], => Float, { let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices); res.x_grad } ); MaxPool1dBackward::new(x_grad) } fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: burn_backend::ops::AttentionModuleOptions, ) -> FloatTensor { multi_op!( inputs[(query, float), (key, float), (value, float)], opt_inputs[(mask, bool), (attn_bias, float)], => Float, B::attention(query, key, value, mask, attn_bias, options) ) } } ================================================ FILE: crates/burn-dispatch/src/ops/qtensor.rs ================================================ use burn_backend::{ ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive, ops::QTensorOps, quantization::QuantizationParametersPrimitive, tensor::{FloatTensor, IntTensor, QuantizedTensor}, }; use burn_std::{QuantPropagation, Shape, Slice}; use crate::backends::*; use crate::{Dispatch, DispatchDevice}; impl QTensorOps for Dispatch { fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor { creation_op!(Quantized, device, |device| B::q_from_data(data, device)) } fn quantize( tensor: FloatTensor, scheme: &burn_std::QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { binary_op!( (tensor, float), (qparams.scales, float), |tensor, scales| { B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales }) } => Quantized ) } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float) } fn q_device(tensor: &QuantizedTensor) -> DispatchDevice { tensor.device() } fn q_to_device( tensor: QuantizedTensor, device: &DispatchDevice, ) -> QuantizedTensor { to_device!( Quantized, quantized, tensor, device, q_to_device, |inner, device| { let data = burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data"); B2::q_from_data(data, device) } ) } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized) } async fn q_into_data(tensor: QuantizedTensor) -> Result { unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await) } fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized) } fn q_swap_dims( tensor: QuantizedTensor, dim1: usize, dim2: usize, ) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized) } fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized) } fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized) } fn q_select( tensor: QuantizedTensor, dim: usize, indices: IntTensor, ) -> QuantizedTensor { binary_op!( (tensor, quantized), (indices, int), |tensor, indices| B::q_select(tensor, dim, indices) => Quantized ) } fn q_slice(tensor: QuantizedTensor, slices: &[Slice]) -> QuantizedTensor { unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized) } fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { // TODO: this would be much cleaner if we consolidated tensor primitive types match (lhs, rhs) { (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { if matches!(lhs.propagation(), QuantPropagation::Propagate) { let out = binary_op!( (lhs, quantized), (rhs, quantized), |lhs, rhs| { if let TensorPrimitive::QFloat(out) = B::q_matmul( TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs), ) { out } else { unreachable!() } } => Quantized ); TensorPrimitive::QFloat(out) } else { let out = binary_op!( (lhs, quantized), (rhs, quantized), |lhs, rhs| { if let TensorPrimitive::Float(out) = B::q_matmul( TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs), ) { out } else { unreachable!() } } => Float ); TensorPrimitive::Float(out) } } (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => { if matches!(rhs.propagation(), QuantPropagation::Propagate) { let out = binary_op!( (lhs, float), (rhs, quantized), |lhs, rhs| { if let TensorPrimitive::QFloat(out) = B::q_matmul( TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs), ) { out } else { unreachable!() } } => Quantized ); TensorPrimitive::QFloat(out) } else { let out = binary_op!( (lhs, float), (rhs, quantized), |lhs, rhs| { if let TensorPrimitive::Float(out) = B::q_matmul( TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs), ) { out } else { unreachable!() } } => Float ); TensorPrimitive::Float(out) } } (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => { if matches!(lhs.propagation(), QuantPropagation::Propagate) { let out = binary_op!( (lhs, quantized), (rhs, float), |lhs, rhs| { if let TensorPrimitive::QFloat(out) = B::q_matmul( TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs), ) { out } else { unreachable!() } } => Quantized ); TensorPrimitive::QFloat(out) } else { let out = binary_op!( (lhs, quantized), (rhs, float), |lhs, rhs| { if let TensorPrimitive::Float(out) = B::q_matmul( TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs), ) { out } else { unreachable!() } } => Float ); TensorPrimitive::Float(out) } } _ => unreachable!(), } } } ================================================ FILE: crates/burn-dispatch/src/ops/tensor.rs ================================================ use burn_backend::{ ExecutionError, Scalar, TensorData, ops::FloatTensorOps, tensor::{BoolTensor, FloatTensor, IntTensor}, }; use burn_std::{FloatDType, Shape, Slice}; use crate::backends::*; use crate::{Dispatch, DispatchDevice}; // TODO: remove backend default elem type genericsnow that we have per-device defaults // https://github.com/tracel-ai/burn/issues/3642 impl FloatTensorOps for Dispatch { fn float_from_data( data: burn_backend::TensorData, device: &DispatchDevice, ) -> FloatTensor { creation_op!(Float, device, |device| B::float_from_data(data, device)) } fn float_random( shape: Shape, distribution: burn_backend::Distribution, device: &DispatchDevice, ) -> FloatTensor { creation_op!(Float, device, |device| { B::float_random(shape, distribution, device) }) } async fn float_into_data(tensor: FloatTensor) -> Result { unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await) } fn float_device(tensor: &FloatTensor) -> DispatchDevice { tensor.device() } fn float_to_device(tensor: FloatTensor, device: &DispatchDevice) -> FloatTensor { float_to_device!( Float, float, tensor, device, float_to_device, |inner, device| { let data = burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data"); B2::float_from_data(data, device) } ) } fn float_into_int(tensor: FloatTensor) -> IntTensor { unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int) } fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor { creation_op!(Float, device, |device| B::float_empty(shape, device, dtype)) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float) } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float) } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float) } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float) } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float) } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float) } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float) } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float) } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float) } fn float_recip(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float) } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float) } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float) } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float) } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float) } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float) } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(tensor, float), (indices, int), (value, float)], => Float, B::float_scatter_add(dim, tensor, indices, value) ) } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float) } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(tensor, float), (indices, int), (value, float)], => Float, B::float_select_add(tensor, dim, indices, value) ) } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float) } fn float_slice_assign( tensor: FloatTensor, slices: &[Slice], value: FloatTensor, ) -> FloatTensor { binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float) } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { multi_op!( inputs[(tensor, float), (mask, bool), (value, float)], => Float, B::float_mask_where(tensor, mask, value) ) } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float) } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool) } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool) } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool) } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool) } fn float_sum(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float) } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float) } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float) } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float) } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float) } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float) } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float) } fn float_exp(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float) } fn float_log(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float) } fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float) } fn float_abs(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float) } fn float_cos(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float) } fn float_sin(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float) } fn float_tan(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float) } fn float_cosh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float) } fn float_sinh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float) } fn float_acos(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float) } fn float_acosh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float) } fn float_asin(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float) } fn float_asinh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float) } fn float_atan(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float) } fn float_atanh(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float) } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float) } fn float_round(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float) } fn float_floor(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float) } fn float_trunc(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float) } fn float_erf(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int) } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int) } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float) } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { unary_float!(tensor, float, |tensor| { B::float_unfold(tensor, dim, size, step) } => Float) } fn float_detach(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float) } fn float_set_require_grad(tensor: FloatTensor, require_grad: bool) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float) } fn float_is_require_grad(tensor: &FloatTensor) -> bool { unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor)) } // Default implementation fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor { creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype)) } fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor { creation_op!(Float, device, |device| B::float_ones(shape, device, dtype)) } fn float_full( shape: Shape, fill_value: Scalar, device: &DispatchDevice, dtype: FloatDType, ) -> FloatTensor { creation_op!(Float, device, |device| B::float_full( shape, fill_value, device, dtype )) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float) } fn float_clamp_min(tensor: FloatTensor, min: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float) } fn float_clamp_max(tensor: FloatTensor, max: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float) } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float) } fn float_neg(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float) } fn float_transpose(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float) } fn float_not_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool) } fn float_not_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool) } fn float_prod(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float) } fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float) } fn float_mean(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float) } fn float_powi(lhs: FloatTensor, rhs: IntTensor) -> FloatTensor { binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float) } fn float_powi_scalar_impl(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float) } fn float_powf_scalar(tensor: FloatTensor, value: Scalar) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float) } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float) } fn float_max(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float) } fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float) } fn float_max_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { multi_op!( inputs[(tensor, float)], outputs[(out, Float), (indices, Int)], B::float_max_dim_with_indices(tensor, dim) ) } fn float_min(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float) } fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float) } fn float_min_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { multi_op!( inputs[(tensor, float)], outputs[(out, Float), (indices, Int)], B::float_min_dim_with_indices(tensor, dim) ) } fn float_max_abs(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float) } fn float_max_abs_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float) } fn float_any(tensor: FloatTensor) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool) } fn float_any_dim(tensor: FloatTensor, dim: usize) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool) } fn float_all(tensor: FloatTensor) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool) } fn float_all_dim(tensor: FloatTensor, dim: usize) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool) } fn float_sign(tensor: FloatTensor) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float) } fn float_sort(tensor: FloatTensor, dim: usize, descending: bool) -> FloatTensor { unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float) } fn float_sort_with_indices( tensor: FloatTensor, dim: usize, descending: bool, ) -> (FloatTensor, IntTensor) { multi_op!( inputs[(tensor, float)], outputs[(out, Float), (indices, Int)], B::float_sort_with_indices(tensor, dim, descending) ) } fn float_argsort(tensor: FloatTensor, dim: usize, descending: bool) -> IntTensor { unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int) } fn float_grid_sample_2d( tensor: FloatTensor, grid: FloatTensor, options: burn_backend::ops::GridSampleOptions, ) -> FloatTensor { binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float) } fn float_is_nan(tensor: FloatTensor) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool) } fn float_is_inf(tensor: FloatTensor) -> BoolTensor { unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool) } } ================================================ FILE: crates/burn-dispatch/src/ops/transaction.rs ================================================ use burn_backend::{ ExecutionError, ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData}, }; use crate::Dispatch; use crate::backends::*; impl TransactionOps for Dispatch { async fn tr_execute( transaction: TransactionPrimitive, ) -> Result { let first_tensor = transaction .read_floats .first() .or(transaction.read_ints.first()) .or(transaction.read_bools.first()); match first_tensor { Some(tensor) => { transaction_op!(transaction, tensor) } None => Ok(TransactionPrimitiveData::default()), } } } ================================================ FILE: crates/burn-dispatch/src/tensor.rs ================================================ use burn_backend::{Backend, QTensorPrimitive, TensorMetadata}; #[cfg(feature = "autodiff")] use crate::CheckpointingStrategy; use crate::backends::*; #[cfg(feature = "autodiff")] use burn_backend::tensor::FloatTensor; // TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single // `B::TensorPrimitive` we can simplify this. /// Tensor which points to a backend tensor primitive kind. #[derive(Clone, Debug)] pub enum BackendTensor { /// Float tensor handle. Float(B::FloatTensorPrimitive), /// Int tensor handle. Int(B::IntTensorPrimitive), /// Bool tensor handle. Bool(B::BoolTensorPrimitive), /// Quantized tensor handle. Quantized(B::QuantizedTensorPrimitive), #[cfg(feature = "autodiff")] /// Autodiff float tensor handle. Autodiff(FloatTensor>), } impl BackendTensor { /// Returns the inner float tensor primitive. pub(crate) fn float(self) -> B::FloatTensorPrimitive { match self { BackendTensor::Float(tensor) => tensor, BackendTensor::Int(_) => panic!("Should be float, got int"), BackendTensor::Bool(_) => panic!("Should be float, got bool"), BackendTensor::Quantized(_) => panic!("Should be float, got quantized"), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"), } } /// Returns the inner float tensor primitive. pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive { match self { BackendTensor::Float(tensor) => tensor, BackendTensor::Int(_) => panic!("Should be float, got int"), BackendTensor::Bool(_) => panic!("Should be float, got bool"), BackendTensor::Quantized(_) => panic!("Should be float, got quantized"), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(_) => panic!("Should be float, got autodiff"), } } /// Returns the inner int tensor primitive. pub(crate) fn int(self) -> B::IntTensorPrimitive { match self { BackendTensor::Int(tensor) => tensor, BackendTensor::Float(_) => panic!("Should be int, got float"), BackendTensor::Bool(_) => panic!("Should be int, got bool"), BackendTensor::Quantized(_) => panic!("Should be int, got quantized"), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(_) => panic!("Should be int, got autodiff"), } } /// Returns the inner bool tensor primitive. pub(crate) fn bool(self) -> B::BoolTensorPrimitive { match self { BackendTensor::Bool(tensor) => tensor, BackendTensor::Float(_) => panic!("Should be bool, got float"), BackendTensor::Int(_) => panic!("Should be bool, got int"), BackendTensor::Quantized(_) => panic!("Should be bool, got quantized"), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(_) => panic!("Should be bool, got autodiff"), } } /// Returns the inner quantized tensor primitive. pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive { match self { BackendTensor::Quantized(tensor) => tensor, _ => unreachable!(), } } #[cfg(feature = "autodiff")] /// Returns the inner autodiff tensor primitive. pub(crate) fn autodiff(self) -> FloatTensor> { match self { BackendTensor::Autodiff(tensor) => tensor, // NOTE: this is the panicking code reached in tensor.rs:74:18: _ => unreachable!(), } } #[cfg(feature = "autodiff")] /// Returns the inner autodiff tensor primitive. pub(crate) fn as_autodiff(&self) -> &FloatTensor> { match self { BackendTensor::Autodiff(tensor) => tensor, _ => unreachable!(), } } #[cfg(feature = "autodiff")] /// Returns the inner autodiff tensor primitive. pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive { match self { BackendTensor::Autodiff(tensor) => tensor.primitive, _ => unreachable!(), } } /// Returns the backend device. pub(crate) fn device(&self) -> B::Device { match self { BackendTensor::Float(tensor) => B::float_device(tensor), BackendTensor::Int(tensor) => B::int_device(tensor), BackendTensor::Bool(tensor) => B::bool_device(tensor), BackendTensor::Quantized(tensor) => B::q_device(tensor), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive), } } } impl TensorMetadata for BackendTensor { fn dtype(&self) -> burn_std::DType { match self { BackendTensor::Float(tensor) => tensor.dtype(), BackendTensor::Int(tensor) => tensor.dtype(), BackendTensor::Bool(tensor) => tensor.dtype(), BackendTensor::Quantized(tensor) => tensor.dtype(), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(tensor) => tensor.dtype(), } } fn shape(&self) -> burn_std::Shape { match self { BackendTensor::Float(tensor) => tensor.shape(), BackendTensor::Int(tensor) => tensor.shape(), BackendTensor::Bool(tensor) => tensor.shape(), BackendTensor::Quantized(tensor) => tensor.shape(), #[cfg(feature = "autodiff")] BackendTensor::Autodiff(tensor) => tensor.shape(), } } } impl QTensorPrimitive for BackendTensor { fn scheme(&self) -> &burn_std::QuantScheme { match self { BackendTensor::Quantized(tensor) => tensor.scheme(), _ => panic!( "Quantization scheme is not valid for dtype {:?}", self.dtype(), ), } } } /// A tensor that can dispatch operations to any enabled backend at runtime. /// /// When the `autodiff` feature is enabled, tensors may carry a checkpointing /// strategy used to control gradient computation. This is derived from the /// device used to create the tensor. #[derive(Clone, Debug)] pub struct DispatchTensor { /// Tensor kind primitive. pub(crate) kind: DispatchTensorKind, // Technically more of a device property, but device is not a dispatch tensor field. #[cfg(feature = "autodiff")] pub(crate) checkpointing: CheckpointingStrategy, } /// Internal representation of a [`DispatchTensor`]. /// /// This enum contains the concrete backend tensor for each enabled backend. /// It is not intended to be used directly; instead, it is manipulated by /// the dispatch system to route operations to the correct backend. /// /// Each variant corresponds to a specific backend implementation. #[derive(Clone, Debug)] pub enum DispatchTensorKind { /// The [CPU backend](Cpu) tensor. #[cfg(feature = "cpu")] Cpu(BackendTensor), /// The [CUDA backend](Cuda) tensor. #[cfg(feature = "cuda")] Cuda(BackendTensor), /// The [Metal backend](Metal) tensor. #[cfg(wgpu_metal)] Metal(BackendTensor), /// The [ROCm backend](Rocm) tensor. #[cfg(feature = "rocm")] Rocm(BackendTensor), /// The [Vulkan backend](Vulkan) tensor. #[cfg(wgpu_vulkan)] Vulkan(BackendTensor), /// The [WebGPU backend](WebGpu) tensor. #[cfg(wgpu_webgpu)] WebGpu(BackendTensor), /// The [NdArray backend](NdArray) tensor. #[cfg(feature = "ndarray")] NdArray(BackendTensor), /// The [LibTorch backend](LibTorch) tensor. #[cfg(feature = "tch")] LibTorch(BackendTensor), /// The [autodiff enabled backend](Autodiff) tensor. #[cfg(feature = "autodiff")] Autodiff(Box), } impl TensorMetadata for DispatchTensorKind { fn dtype(&self) -> burn_std::DType { match self { #[cfg(feature = "cpu")] Self::Cpu(tensor) => tensor.dtype(), #[cfg(feature = "cuda")] Self::Cuda(tensor) => tensor.dtype(), #[cfg(wgpu_metal)] Self::Metal(tensor) => tensor.dtype(), #[cfg(feature = "rocm")] Self::Rocm(tensor) => tensor.dtype(), #[cfg(wgpu_vulkan)] Self::Vulkan(tensor) => tensor.dtype(), #[cfg(wgpu_webgpu)] Self::WebGpu(tensor) => tensor.dtype(), #[cfg(feature = "ndarray")] Self::NdArray(tensor) => tensor.dtype(), #[cfg(feature = "tch")] Self::LibTorch(tensor) => tensor.dtype(), #[cfg(feature = "autodiff")] Self::Autodiff(tensor) => tensor.dtype(), } } fn shape(&self) -> burn_std::Shape { match self { #[cfg(feature = "cpu")] Self::Cpu(tensor) => tensor.shape(), #[cfg(feature = "cuda")] Self::Cuda(tensor) => tensor.shape(), #[cfg(wgpu_metal)] Self::Metal(tensor) => tensor.shape(), #[cfg(feature = "rocm")] Self::Rocm(tensor) => tensor.shape(), #[cfg(wgpu_vulkan)] Self::Vulkan(tensor) => tensor.shape(), #[cfg(wgpu_webgpu)] Self::WebGpu(tensor) => tensor.shape(), #[cfg(feature = "ndarray")] Self::NdArray(tensor) => tensor.shape(), #[cfg(feature = "tch")] Self::LibTorch(tensor) => tensor.shape(), #[cfg(feature = "autodiff")] Self::Autodiff(tensor) => tensor.shape(), } } } impl QTensorPrimitive for DispatchTensorKind { fn scheme(&self) -> &burn_std::QuantScheme { match self { #[cfg(feature = "cpu")] Self::Cpu(tensor) => tensor.scheme(), #[cfg(feature = "cuda")] Self::Cuda(tensor) => tensor.scheme(), #[cfg(wgpu_metal)] Self::Metal(tensor) => tensor.scheme(), #[cfg(feature = "rocm")] Self::Rocm(tensor) => tensor.scheme(), #[cfg(wgpu_vulkan)] Self::Vulkan(tensor) => tensor.scheme(), #[cfg(wgpu_webgpu)] Self::WebGpu(tensor) => tensor.scheme(), #[cfg(feature = "ndarray")] Self::NdArray(tensor) => tensor.scheme(), #[cfg(feature = "tch")] Self::LibTorch(tensor) => tensor.scheme(), #[cfg(feature = "autodiff")] Self::Autodiff(tensor) => tensor.scheme(), } } } impl TensorMetadata for DispatchTensor { fn dtype(&self) -> burn_std::DType { self.kind.dtype() } fn shape(&self) -> burn_std::Shape { self.kind.shape() } } impl QTensorPrimitive for DispatchTensor { fn scheme(&self) -> &burn_std::QuantScheme { self.kind.scheme() } } ================================================ FILE: crates/burn-fusion/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Kernel fusion backend decorator for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-fusion" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-fusion" documentation = "https://docs.rs/burn-fusion" version.workspace = true [lints] workspace = true [features] default = ["std", "tracing"] std = ["serde/std", "tracing?/std"] doc = ["default"] memory-checks = ["std"] tracing = [ "dep:tracing", "burn-backend/tracing", "burn-ir/tracing", ] [dependencies] burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2" } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2" } tracing = { workspace = true, optional = true, features = ["attributes"] } hashbrown = { workspace = true } derive-new = { workspace = true } spin = { workspace = true } log = { workspace = true } serde = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-fusion/README.md ================================================ # Burn Fusion A kernel fusion backend decorator for Burn. ================================================ FILE: crates/burn-fusion/src/backend.rs ================================================ use crate::{ FusionTensor, client::GlobalFusionClient, stream::{Context, OrderedExecution}, }; use burn_backend::{ Backend, DType, DeviceOps, ExecutionError, tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor}, }; use burn_ir::{BackendIr, OperationIr, TensorHandle}; use serde::{Serialize, de::DeserializeOwned}; use std::marker::PhantomData; /// Get the client for the given device. pub fn get_client(device: &Device) -> Client { GlobalFusionClient::load(device) } /// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend). #[derive(Clone, Debug, Default)] pub struct Fusion { _backend: PhantomData, } impl Backend for Fusion { type Device = B::Device; type FloatTensorPrimitive = FusionTensor; type FloatElem = B::FloatElem; type IntTensorPrimitive = FusionTensor; type IntElem = B::IntElem; type BoolTensorPrimitive = FusionTensor; type BoolElem = B::BoolElem; type QuantizedTensorPrimitive = FusionTensor; fn name(device: &Self::Device) -> String { format!("fusion<{}>", B::name(device)) } fn seed(device: &B::Device, seed: u64) { let client = GlobalFusionClient::::load(device); client.drain(); B::seed(device, seed); } fn sync(device: &Self::Device) -> Result<(), ExecutionError> { let client = GlobalFusionClient::::load(device); client.drain(); B::sync(device) } fn ad_enabled(_device: &Self::Device) -> bool { false } fn memory_persistent_allocations< Output: Send, Input: Send, Func: Fn(Input) -> Output + Send, >( device: &Self::Device, input: Input, func: Func, ) -> Output { B::memory_persistent_allocations(device, input, func) } fn memory_cleanup(device: &Self::Device) { B::memory_cleanup(device) } fn staging<'a, Iter>(data: Iter, device: &Self::Device) where Iter: Iterator, { B::staging(data, device); } fn supports_dtype(device: &Self::Device, dtype: DType) -> bool { B::supports_dtype(device, dtype) } fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { B::dtype_usage(device, dtype) } } /// The status of a [fuser](OperationFuser). #[derive(Clone, Debug, Copy, PartialEq, Eq)] pub enum FuserStatus { /// No more operations can be fused. Closed, /// More operations can be fused. Open, } /// The properties of a [fuser](OperationFuser). #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct FuserProperties { /// The score of the optimization, higher is better. pub score: u64, /// If the operation is ready to be executed. pub ready: bool, } /// The fusion operation abstraction allows implementations to fuse many /// [tensor operations](OperationIr) into one, improving the performance of the backend. /// /// /// # Notes /// /// The implementations are free to execute the registered operations the way they want to improve /// the speed and efficiency of the computational graph. It doesn't mean that all registered /// operations should be fused, but that another way of executing them is more efficient. /// /// Also, it is important to return (FuserStatus::Closed) when no more registered operation can /// improve the performance. pub trait OperationFuser: Send { /// Register a new [tensor operation](OperationIr). fn fuse(&mut self, operation: &OperationIr); /// Finish the optimization and create a fusion operation. fn finish(&mut self) -> O; /// Reset the state. fn reset(&mut self); /// Return the builder [status](FuserStatus). fn status(&self) -> FuserStatus; /// Return the builder [properties](FuserProperties). fn properties(&self) -> FuserProperties; /// The number of operation fused. fn len(&self) -> usize; /// If no operations are fused. fn is_empty(&self) -> bool { self.len() == 0 } /// Clone the optimization builder. fn clone_dyn(&self) -> Box>; } /// The number of operations contained in the data structure. pub trait NumOperations: core::fmt::Debug { /// The number of registered operations. fn len(&self) -> usize; /// If the current optimization is empty. fn is_empty(&self) -> bool { self.len() == 0 } } /// The optimization created from a [fuser](OperationFuser). pub trait Optimization: Send + NumOperations { /// Execute the optimization. fn execute( &mut self, context: &mut Context<'_, R::FusionHandle>, execution: &OrderedExecution, ); /// Returns the state that can be serialized. fn to_state(&self) -> R::OptimizationState; /// Create the optimization from the state. fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self; } /// Type alias for `::FusionDevice`. pub type FusionDevice = ::FusionDevice; /// Type alias for `::FusionHandle`. pub type FusionHandle = ::FusionHandle; /// Client alias. pub type Client = GlobalFusionClient; /// Trait that defines a runtime that will benefits from fused operations. pub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. type Optimization: Optimization; /// Handle used to store tensor dynamically. type FusionHandle: Clone + Send; /// Device used by the runtime. type FusionDevice: DeviceOps; /// The list of fusers that will be used to optimize the computational graph. fn fusers(device: Self::FusionDevice) -> Vec>>; } /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation fuser](crate::OperationFuser). pub trait FusionBackend: BackendIr, Device = FusionDevice> { /// The runtime used for this backend. type FusionRuntime: FusionRuntime; /// Cast a float tensor and returns the resulting handle. fn cast_float(tensor: FloatTensor, dtype: DType) -> Self::Handle; /// Pointer to the full precision fusion backend. type FullPrecisionBackend: FusionBackend; } // Fusion implements `BackendIr` to enable router backend usage. impl BackendIr for Fusion { type Handle = FusionTensor; fn float_tensor(handle: TensorHandle) -> FloatTensor { handle.handle } fn int_tensor(handle: TensorHandle) -> IntTensor { handle.handle } fn bool_tensor(handle: TensorHandle) -> BoolTensor { handle.handle } fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { handle.handle } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { tensor } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { tensor } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { tensor } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { tensor } } // TODO: remove once backends no longer rely on generics for default elem types /// Returns the bool element dtype. pub(crate) fn bool_dtype() -> DType { match BT::dtype() { DType::U32 => DType::Bool(burn_backend::BoolStore::U32), DType::U8 => DType::Bool(burn_backend::BoolStore::U8), other => unimplemented!("Invalid bool dtye {other:?}"), } } ================================================ FILE: crates/burn-fusion/src/client.rs ================================================ use crate::{ FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor, stream::{OperationStreams, StreamId, execution::Operation}, }; use burn_backend::{Device, DeviceHandle, DeviceId, DeviceService}; use burn_backend::{TensorData, backend::ExecutionError}; use burn_ir::{OperationIr, TensorId, TensorIr}; use std::sync::{ Arc, atomic::{AtomicU64, Ordering}, }; /// Use a mutex to communicate with the fusion server. pub struct GlobalFusionClient { server: DeviceHandle>, device: FusionDevice, } impl DeviceService for FusionServer { fn init(device_id: DeviceId) -> Self { let device = FusionDevice::::from_id(device_id); FusionServer::new(device) } fn utilities(&self) -> burn_backend::ServerUtilitiesHandle { Arc::new(()) } } impl Clone for GlobalFusionClient where R: FusionRuntime, { fn clone(&self) -> Self { Self { server: self.server.clone(), device: self.device.clone(), } } } impl GlobalFusionClient where R: FusionRuntime + 'static, { /// Loads the client from the given device. pub fn load(device: &FusionDevice) -> Self { Self { device: device.clone(), server: DeviceHandle::new(device.to_id()), } } } static COUNTER: AtomicU64 = AtomicU64::new(0); impl GlobalFusionClient where R: FusionRuntime + 'static, { /// Create a new client for the given [device](FusionRuntime::FusionDevice). pub fn new(device: FusionDevice) -> Self { Self { device: device.clone(), server: DeviceHandle::new(device.to_id()), } } /// Register a new [tensor operation intermediate representation](OperationIr). /// /// Returns the new (uninitialized) output tensor(s) generated by the registered operation. pub fn register( &self, streams: OperationStreams, repr: OperationIr, operation: O, ) -> Vec> where O: Operation + 'static, { // Create output tensors returned by this operation let outputs = repr .outputs() .map(|output| { FusionTensor::new( output.id, output.shape.clone(), output.dtype, self.clone(), StreamId::current(), ) }) .collect(); self.server.submit(move |server| { server.register(streams, repr, Arc::new(operation)); }); outputs } /// Register all lazy computation. pub fn drain(&self) { let id = StreamId::current(); self.server.submit(move |server| server.drain_stream(id)); } /// Create a new (uninitialized) empty tensor handle and returns its corresponding [tensor id](TensorId). pub fn create_empty_handle(&self) -> TensorId { let value = COUNTER.fetch_add(1, Ordering::Relaxed); TensorId::new(value) } /// Get the current device used by all operations handled by this client. pub fn device(&self) -> &FusionDevice { &self.device } /// Create a tensor with the given handle and returns its corresponding [tensor id](TensorId). pub fn register_tensor_handle(&self, handle: FusionHandle) -> TensorId { let id = self.create_empty_handle(); self.server .submit(move |server| server.handles.register_handle(id, handle)); id } /// Read the values contained by a float tensor. pub fn read_tensor_float( self, tensor: TensorIr, stream: StreamId, ) -> impl Future> + Send where B: FusionBackend, { self.server .submit_blocking(move |server| server.read_float::(tensor, stream)) .unwrap() } /// Read the values contained by an int tensor. pub fn read_tensor_int( self, tensor: TensorIr, stream: StreamId, ) -> impl Future> + Send where B: FusionBackend, { self.server .submit_blocking(move |server| server.read_int::(tensor, stream)) .unwrap() } /// Read the values contained by a bool tensor. pub fn read_tensor_bool( self, tensor: TensorIr, stream: StreamId, ) -> impl Future> + Send where B: FusionBackend, { self.server .submit_blocking(move |server| server.read_bool::(tensor, stream)) .unwrap() } /// Read the values contained by a quantized tensor. pub fn read_tensor_quantized( self, tensor: TensorIr, stream: StreamId, ) -> impl Future> + Send where B: FusionBackend, { self.server .submit_blocking(move |server| server.read_quantized::(tensor, stream)) .unwrap() } /// Change the client of the given float tensor. pub fn change_client_float( &self, tensor: TensorIr, client: Self, stream: StreamId, ) -> FusionTensor where B: FusionBackend, { let dtype = tensor.dtype; let client_cloned = client.clone(); let shape = tensor.shape.clone(); let id = self.create_empty_handle(); self.server.submit(move |server| { server.drain_stream(stream); // TODO: We could improve performance here by not requirering blocking. client .server .clone() .submit_blocking_scoped(move |server_other| { server_other.change_server_float::( &tensor, id, stream, &client.device, server, ) }) }); FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current()) } /// Change the client of the given int tensor. pub fn change_client_int( &self, tensor: TensorIr, client: Self, stream: StreamId, ) -> FusionTensor where B: FusionBackend, { let dtype = tensor.dtype; let client_cloned = client.clone(); let shape = tensor.shape.clone(); let id = self.create_empty_handle(); self.server.submit(move |server| { server.drain_stream(stream); // TODO: We could improve performance here by not requirering blocking. client .server .clone() .submit_blocking_scoped(move |server_other| { server_other.change_server_int::(&tensor, id, stream, &client.device, server) }) }); FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current()) } /// Change the client of the given bool tensor. pub fn change_client_bool( &self, tensor: TensorIr, client: Self, stream: StreamId, ) -> FusionTensor where B: FusionBackend, { let dtype = tensor.dtype; let client_cloned = client.clone(); let shape = tensor.shape.clone(); let id = self.create_empty_handle(); self.server.submit(move |server| { server.drain_stream(stream); // TODO: We could improve performance here by not requirering blocking. client .server .clone() .submit_blocking_scoped(move |server_other| { server_other.change_server_bool::( &tensor, id, stream, &client.device, server, ) }) }); FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current()) } /// Change the client of the given quantized tensor. pub fn change_client_quantized( &self, tensor: TensorIr, client: Self, stream: StreamId, ) -> FusionTensor where B: FusionBackend, { let dtype = tensor.dtype; let client_cloned = client.clone(); let shape = tensor.shape.clone(); let id = self.create_empty_handle(); self.server.submit(move |server| { server.drain_stream(stream); // TODO: We could improve performance here by not requirering blocking. client .server .clone() .submit_blocking_scoped(move |server_other| { server_other.change_server_quantized::(&tensor, id, &client.device, server) }) }); FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current()) } /// Resolve the given float tensor to a primitive tensor. pub fn resolve_tensor_float(&self, tensor: FusionTensor) -> B::FloatTensorPrimitive where B: FusionBackend, { self.server .submit_blocking(move |server| { server.drain_stream(tensor.stream); server.resolve_server_float::(&tensor.into_ir()) }) .unwrap() } /// Resolve the given int tensor to a primitive tensor. pub fn resolve_tensor_int(&self, tensor: FusionTensor) -> B::IntTensorPrimitive where B: FusionBackend, { self.server .submit_blocking(move |server| { server.drain_stream(tensor.stream); server.resolve_server_int::(&tensor.into_ir()) }) .unwrap() } /// Resolve the given bool tensor to a primitive tensor. pub fn resolve_tensor_bool(&self, tensor: FusionTensor) -> B::BoolTensorPrimitive where B: FusionBackend, { self.server .submit_blocking(move |server| { server.drain_stream(tensor.stream); server.resolve_server_bool::(&tensor.into_ir()) }) .unwrap() } } ================================================ FILE: crates/burn-fusion/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! # Burn Fusion //! //! This library is a part of the Burn project. It is a standalone crate that //! can be used to perform automatic operation fusion on backends that support it. #[macro_use] extern crate derive_new; /// Client module exposing types to communicate with the fusion server. pub mod client; /// Stream module exposing all tensor operations that can be optimized. pub mod stream; /// Search module for stream optimizations. pub(crate) mod search; mod backend; mod ops; mod server; mod tensor; pub(crate) use server::*; pub use backend::*; pub use ops::NoOp; pub use tensor::*; ================================================ FILE: crates/burn-fusion/src/ops/activation.rs ================================================ use crate::{Fusion, FusionBackend}; use burn_backend::ops::ActivationOps; impl ActivationOps for Fusion {} ================================================ FILE: crates/burn-fusion/src/ops/base.rs ================================================ use crate::{FusionBackend, stream::Operation}; use burn_ir::HandleContainer; use std::marker::PhantomData; /// A no-operation placeholder for the fusion backend. /// /// `NoOp` is an implementation of [`Operation`] that doesn't execute anything. #[derive(new, Clone, Debug)] pub struct NoOp { _b: PhantomData, } impl Operation for NoOp { fn execute(&self, _handles: &mut HandleContainer) {} } ================================================ FILE: crates/burn-fusion/src/ops/binary.rs ================================================ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_ops { ( $name:ident, $ops:expr ) => { #[derive(Debug)] struct $name { desc: BinaryOpIr, _b: PhantomData, } impl $name { fn new(desc: BinaryOpIr) -> Self { Self { desc, _b: PhantomData, } } } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); handles.register_float_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_cmp_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: BinaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); handles.register_bool_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_cmp_ops { ( $name:ident, $ops:expr ) => { #[derive(Debug)] struct $name { desc: BinaryOpIr, _b: PhantomData, } impl $name { fn new(desc: BinaryOpIr) -> Self { Self { desc, _b: PhantomData, } } } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); handles.register_bool_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: BinaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor::(&self.desc.rhs); let output = $ops(lhs, rhs); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } ================================================ FILE: crates/burn-fusion/src/ops/bool_tensor.rs ================================================ use crate::{ Fusion, FusionBackend, bool_dtype, get_client, stream::{OperationStreams, execution::Operation}, }; use burn_backend::{ Element, ExecutionError, Scalar, Shape, Slice, TensorData, ops::BoolTensorOps, tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor}, }; use burn_ir::{ BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr, GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr, }; use std::marker::PhantomData; use super::NoOp; impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { #[derive(new, Debug)] struct EmptyOps { desc: TensorIr, device: Device, } impl Operation for EmptyOps { fn execute(&self, handles: &mut HandleContainer) { let output = B::bool_empty(self.desc.shape.clone(), &self.device); handles.register_bool_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, bool_dtype::(), || { client.create_empty_handle() }); client .register( OperationStreams::default(), OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())), EmptyOps::::new(desc.out, device.clone()), ) .output() } fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor { #[derive(new, Debug)] struct ZerosOps { desc: TensorIr, device: Device, } impl Operation for ZerosOps { fn execute(&self, handles: &mut HandleContainer) { let output = B::bool_zeros(self.desc.shape.clone(), &self.device); handles.register_bool_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, bool_dtype::(), || { client.create_empty_handle() }); client .register( OperationStreams::default(), OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())), ZerosOps::::new(desc.out, device.clone()), ) .output() } fn bool_ones(shape: Shape, device: &Device) -> BoolTensor { #[derive(new, Debug)] struct OnesOps { desc: TensorIr, device: Device, } impl Operation for OnesOps { fn execute(&self, handles: &mut HandleContainer) { let output = B::bool_ones(self.desc.shape.clone(), &self.device); handles.register_bool_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, bool_dtype::(), || { client.create_empty_handle() }); client .register( OperationStreams::default(), OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())), OnesOps::::new(desc.out, device.clone()), ) .output() } async fn bool_into_data(tensor: BoolTensor) -> Result { tensor.bool_into_data::().await } fn bool_from_data(data: burn_backend::TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); let tensor = B::bool_from_data(data, device); let shape = burn_backend::TensorMetadata::shape(&tensor); let handle = B::bool_tensor_handle(tensor); let desc = InitOperationIr::create(shape, bool_dtype::(), || { client.register_tensor_handle(handle) }); client .register( OperationStreams::default(), OperationIr::Init(desc), NoOp::::new(), ) .output() } fn bool_into_int(tensor: BoolTensor) -> IntTensor { #[derive(new, Debug)] struct IntoIntOps { desc: CastOpIr, _b: PhantomData, } impl Operation for IntoIntOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_int(input); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())), IntoIntOps::::new(desc), ) .output() } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { #[derive(new, Debug)] struct IntoFloatOps { desc: CastOpIr, _b: PhantomData, } impl Operation for IntoFloatOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_float(input); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())), IntoFloatOps::::new(desc), ) .output() } fn bool_device(tensor: &BoolTensor) -> Device { tensor.client.device().clone() } fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { let device_original: &B::Device = tensor.client.device(); if device_original == device { return tensor; } let id = tensor.stream; let client_target = get_client::(device); let client_original = tensor.client.clone(); client_original .clone() .change_client_bool::(tensor.into_ir(), client_target, id) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { if tensor.shape == shape { return tensor; } #[derive(new, Debug)] struct ReshapeDimsOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ReshapeDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_reshape(input, self.desc.out.shape.clone()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())), ReshapeDimsOps::::new(desc), ) .output() } fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor { #[derive(new, Debug)] struct SliceOps { desc: SliceOpIr, _b: PhantomData, } impl Operation for SliceOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_slice(tensor, self.desc.ranges.as_slice()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())), SliceOps::::new(desc), ) .output() } fn bool_slice_assign( tensor: BoolTensor, slices: &[Slice], value: BoolTensor, ) -> BoolTensor { #[derive(new, Debug)] struct SliceAssignOps { desc: SliceAssignOpIr, _b: PhantomData, } impl Operation for SliceAssignOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let value = handles.get_bool_tensor::(&self.desc.value); let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &value]); let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())), SliceAssignOps::::new(desc), ) .output() } fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { #[derive(new, Debug)] struct CatOps { desc: CatOpIr, _b: PhantomData, } impl Operation for CatOps { fn execute(&self, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() .map(|tensor| handles.get_bool_tensor::(tensor)) .collect(); let output = B::bool_cat(tensors, self.desc.dim); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs(&tensors); let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())), CatOps::::new(desc), ) .output() } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { #[derive(new, Debug)] struct EqualOps { desc: BinaryOpIr, _b: PhantomData, } impl Operation for EqualOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let rhs = handles.get_bool_tensor::(&self.desc.rhs); let output = B::bool_equal(lhs, rhs); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())), EqualOps::::new(desc), ) .output() } fn bool_not(tensor: BoolTensor) -> BoolTensor { #[derive(new, Debug)] struct NotOps { desc: UnaryOpIr, _b: PhantomData, } impl Operation for NotOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_not(input); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Bool(BoolOperationIr::Not(desc.clone())), NotOps::::new(desc), ) .output() } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { #[derive(new, Debug)] struct AndOps { desc: BinaryOpIr, _b: PhantomData, } impl Operation for AndOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let rhs = handles.get_bool_tensor::(&self.desc.rhs); let output = B::bool_and(lhs, rhs); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Bool(BoolOperationIr::And(desc.clone())), AndOps::::new(desc), ) .output() } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { #[derive(new, Debug)] struct OrOps { desc: BinaryOpIr, _b: PhantomData, } impl Operation for OrOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let rhs = handles.get_bool_tensor::(&self.desc.rhs); let output = B::bool_or(lhs, rhs); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Bool(BoolOperationIr::Or(desc.clone())), OrOps::::new(desc), ) .output() } fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { #[derive(new, Debug)] struct SwapDimsOps { desc: SwapDimsOpIr, _b: PhantomData, } impl Operation for SwapDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())), SwapDimsOps::::new(desc), ) .output() } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { #[derive(new, Debug)] struct PermuteDimsOps { desc: PermuteOpIr, _b: PhantomData, } impl Operation for PermuteDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_permute(input, self.desc.axes.as_slice()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())), PermuteDimsOps::::new(desc), ) .output() } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { #[derive(new, Debug)] struct ExpandOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ExpandOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_expand(input, self.desc.out.shape.clone()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())), ExpandOps::::new(desc), ) .output() } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { #[derive(new, Debug)] struct FlipOps { desc: FlipOpIr, _b: PhantomData, } impl Operation for FlipOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_flip(input, self.desc.axes.as_slice()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())), FlipOps::::new(desc), ) .output() } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { #[derive(new, Debug)] struct RepeatDimOps { desc: RepeatDimOpIr, _b: PhantomData, } impl Operation for RepeatDimOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())), RepeatDimOps::::new(desc), ) .output() } fn bool_unfold( tensor: BoolTensor, dim: usize, size: usize, step: usize, ) -> BoolTensor { #[derive(new, Debug)] struct UnfoldOps { desc: UnfoldOpIr, _b: PhantomData, } impl Operation for UnfoldOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())), UnfoldOps::::new(desc), ) .output() } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { #[derive(new, Debug)] struct MaskWhereOps { desc: MaskWhereOpIr, _b: PhantomData, } impl Operation for MaskWhereOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let value = handles.get_bool_tensor::(&self.desc.value); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::bool_mask_where(tensor, mask, value); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask, &value]); let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())), MaskWhereOps::::new(desc), ) .output() } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { #[derive(new, Debug)] struct MaskFillOps { desc: MaskFillOpIr, _b: PhantomData, } impl Operation for MaskFillOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::bool_mask_fill(tensor, mask, self.desc.value.into()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask]); let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())), MaskFillOps::::new(desc), ) .output() } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { #[derive(new, Debug)] struct GatherOps { desc: GatherOpIr, _b: PhantomData, } impl Operation for GatherOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::bool_gather(self.desc.dim, tensor, indices); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices]); let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())), GatherOps::::new(desc), ) .output() } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { #[derive(new, Debug)] struct ScatterOps { desc: ScatterOpIr, _b: PhantomData, } impl Operation for ScatterOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let value = handles.get_bool_tensor::(&self.desc.value); let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices, &value]); let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())), ScatterOps::::new(desc), ) .output() } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { #[derive(new, Debug)] struct EqualElemOps { desc: ScalarOpIr, _b: PhantomData, } impl Operation for EqualElemOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let output = B::bool_equal_elem(lhs, self.desc.rhs.into()); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())), EqualElemOps::::new(desc), ) .output() } } ================================================ FILE: crates/burn-fusion/src/ops/int_tensor.rs ================================================ use super::NoOp; use crate::{ Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops, bool_dtype, get_client, reduce_int_ops, scalar_int_cmp_ops, scalar_int_ops, stream::{OperationStreams, execution::Operation}, unary_int_ops, }; use burn_backend::{ Distribution, Element, ExecutionError, IntDType, Scalar, Shape, Slice, TensorData, ops::IntTensorOps, tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntElem, IntTensor}, }; use burn_ir::*; use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { #[derive(new, Debug)] struct EmptyOps { desc: TensorIr, device: Device, } impl Operation for EmptyOps { fn execute(&self, handles: &mut HandleContainer) { let output = B::int_empty( self.desc.shape.clone(), &self.device, self.desc.dtype.into(), ); handles.register_int_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())), EmptyOps::::new(desc.out, device.clone()), ) .output() } async fn int_into_data(tensor: IntTensor) -> Result { tensor.int_into_data::().await } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); let dtype = data.dtype; let tensor = B::int_from_data(data, device); let shape = burn_backend::TensorMetadata::shape(&tensor); let handle = B::int_tensor_handle(tensor); let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle)); client .register( OperationStreams::default(), OperationIr::Init(desc), NoOp::::new(), ) .output() } fn int_device(tensor: &IntTensor) -> Device { tensor.client.device().clone() } fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { let device_original: &B::Device = tensor.client.device(); if device_original == device { return tensor; } let id = tensor.stream; let client_target = get_client::(device); let client_original = tensor.client.clone(); client_original .clone() .change_client_int::(tensor.into_ir(), client_target, id) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { if tensor.shape == shape { return tensor; } #[derive(new, Debug)] struct ReshapeDimsOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ReshapeDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_reshape(input, self.desc.out.shape.clone()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())), ReshapeDimsOps::::new(desc), ) .output() } fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor { #[derive(new, Debug)] struct SliceOps { desc: SliceOpIr, _b: PhantomData, } impl Operation for SliceOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_slice(tensor, self.desc.ranges.as_slice()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())), SliceOps::::new(desc), ) .output() } fn int_slice_assign( tensor: IntTensor, slices: &[burn_backend::Slice], value: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct SliceAssignOps { desc: SliceAssignOpIr, _b: PhantomData, } impl Operation for SliceAssignOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &value]); let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())), SliceAssignOps::::new(desc), ) .output() } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(MatmulOps, B::int_matmul); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())), MatmulOps::::new(desc.into()), ) .output() } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct MaskWhereOps { desc: MaskWhereOpIr, _b: PhantomData, } impl Operation for MaskWhereOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor::(&self.desc.value); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_where(tensor, mask, value); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask, &value]); let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc.clone())), MaskWhereOps::::new(desc), ) .output() } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> IntTensor { #[derive(new, Debug)] struct MaskFillOps { desc: MaskFillOpIr, _b: PhantomData, } impl Operation for MaskFillOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::int_mask_fill(tensor, mask, self.desc.value.into()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask]); let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::MaskFill(desc.clone())), MaskFillOps::::new(desc), ) .output() } fn int_gather( dim: usize, tensor: IntTensor, indices: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct GatherOps { desc: GatherOpIr, _b: PhantomData, } impl Operation for GatherOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_gather(self.desc.dim, tensor, indices); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices]); let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Gather(desc.clone())), GatherOps::::new(desc), ) .output() } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct ScatterOps { desc: ScatterOpIr, _b: PhantomData, } impl Operation for ScatterOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_scatter_add(self.desc.dim, tensor, indices, value); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices, &value]); let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Scatter(desc.clone())), ScatterOps::::new(desc), ) .output() } fn int_select( tensor: IntTensor, dim: usize, indices: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct SelectOps { desc: SelectOpIr, _b: PhantomData, } impl Operation for SelectOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::int_select(tensor, self.desc.dim, indices); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices]); let client = tensor.client.clone(); let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Select(desc.clone())), SelectOps::::new(desc), ) .output() } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { #[derive(new, Debug)] struct SelectAssignOps { desc: SelectAssignOpIr, _b: PhantomData, } impl Operation for SelectAssignOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let value = handles.get_int_tensor::(&self.desc.value); let output = B::int_select_add(tensor, self.desc.dim, indices, value); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices, &value]); let client = tensor.client.clone(); let desc = SelectAssignOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc.clone())), SelectAssignOps::::new(desc), ) .output() } fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { #[derive(new, Debug)] struct CatOps { desc: CatOpIr, _b: PhantomData, } impl Operation for CatOps { fn execute(&self, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() .map(|tensor| handles.get_int_tensor::(tensor)) .collect(); let output = B::int_cat(tensors, self.desc.dim); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs(&tensors); let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())), CatOps::::new(desc), ) .output() } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_int_cmp_ops!(EqualOps, B::int_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())), EqualOps::::new(desc), ) .output() } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::EqualElem(desc.clone())), EqualElemOps::::new(desc), ) .output() } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_int_cmp_ops!(GreaterOps, B::int_greater); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Greater(desc.clone())), GreaterOps::::new(desc), ) .output() } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterElem(desc.clone()), ), GreaterElemOps::::new(desc), ) .output() } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterEqual(desc.clone()), ), GreaterEqualOps::::new(desc), ) .output() } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterEqualElem(desc.clone()), ), GreaterEqualElemOps::::new(desc), ) .output() } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_int_cmp_ops!(LowerOps, B::int_lower); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())), LowerOps::::new(desc), ) .output() } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerElem(desc.clone()), ), LowerElemOps::::new(desc), ) .output() } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerEqual(desc.clone()), ), LowerEqualOps::::new(desc), ) .output() } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerEqualElem(desc.clone()), ), LowerEqualElemOps::::new(desc), ) .output() } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(AddOps, B::int_add); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Add(desc.clone())), AddOps::::new(desc), ) .output() } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(AddOps, B::int_add_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::AddScalar(desc.clone()), ), AddOps::::new(desc), ) .output() } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(SubOps, B::int_sub); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sub(desc.clone())), SubOps::::new(desc), ) .output() } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(SubOps, B::int_sub_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::SubScalar(desc.clone()), ), SubOps::::new(desc), ) .output() } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(MulOps, B::int_mul); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mul(desc.clone())), MulOps::::new(desc), ) .output() } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(MulOps, B::int_mul_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MulScalar(desc.clone()), ), MulOps::::new(desc), ) .output() } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(DivOps, B::int_div); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Div(desc.clone())), DivOps::::new(desc), ) .output() } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(DivOps, B::int_div_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::DivScalar(desc.clone()), ), DivOps::::new(desc), ) .output() } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(ModOps, B::int_remainder); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Rem(desc.clone())), ModOps::::new(desc), ) .output() } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(ModOps, B::int_remainder_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::RemScalar(desc.clone()), ), ModOps::::new(desc), ) .output() } fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { #[derive(new, Debug)] struct ZerosOps { desc: TensorIr, device: Device, } impl Operation for ZerosOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.desc.shape.clone(); let output = B::int_zeros(shape, &self.device, self.desc.dtype.into()); handles.register_int_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseInt(BaseOperationIr::Zeros(desc.clone())), ZerosOps::::new(desc.out, device.clone()), ) .output() } fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { #[derive(new, Debug)] struct OnesOps { desc: TensorIr, device: Device, } impl Operation for OnesOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.desc.shape.clone(); let output = B::int_ones(shape, &self.device, self.desc.dtype.into()); handles.register_int_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseInt(BaseOperationIr::Ones(desc.clone())), OnesOps::::new(desc.out, device.clone()), ) .output() } fn int_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: IntDType, ) -> IntTensor { #[derive(new, Debug)] struct FullOps { out: TensorIr, elem: ScalarIr, device: Device, } impl Operation for FullOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.out.shape.clone(); let output = B::int_full(shape, self.elem.into(), &self.device, self.out.dtype.into()); handles.register_int_tensor::(&self.out.id, output); } } let client = get_client::(device); let dtype = dtype.into(); let value = fill_value.into(); let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Full(desc.clone())), FullOps::::new(desc.out, desc.value, device.clone()), ) .output() } fn int_sum(tensor: IntTensor) -> IntTensor { unary_int_ops!(SumOps, B::int_sum, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sum(desc.clone())), SumOps::::new(desc.into()), ) .output() } fn int_sum_dim(tensor: IntTensor, axis: usize) -> IntTensor { reduce_int_ops!(SumDimOps, B::int_sum_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())), SumDimOps::::new(desc), ) .output() } fn int_prod(tensor: IntTensor) -> IntTensor { unary_int_ops!(ProdOps, B::int_prod, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Prod(desc.clone())), ProdOps::::new(desc.into()), ) .output() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(ProdDimOps, B::int_prod_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ProdDim(desc.clone())), ProdDimOps::::new(desc), ) .output() } fn int_mean(tensor: IntTensor) -> IntTensor { unary_int_ops!(MeanOps, B::int_mean, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mean(desc.clone())), MeanOps::::new(desc.into()), ) .output() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(MeanDimOps, B::int_mean_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MeanDim(desc.clone())), MeanDimOps::::new(desc), ) .output() } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { #[derive(new, Debug)] struct CumsumOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumsumOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_cumsum(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())), CumsumOps::::new(desc), ) .output() } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { #[derive(new, Debug)] struct CumprodOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumprodOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_cumprod(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumProd(desc.clone())), CumprodOps::::new(desc), ) .output() } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { #[derive(new, Debug)] struct CumminOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumminOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_cummin(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())), CumminOps::::new(desc), ) .output() } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { #[derive(new, Debug)] struct CummaxOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CummaxOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_cummax(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())), CummaxOps::::new(desc), ) .output() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(ArgMaxOps, B::int_argmax); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMax(desc.clone())), ArgMaxOps::::new(desc), ) .output() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(ArgMinOps, B::int_argmin); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMin(desc.clone())), ArgMinOps::::new(desc), ) .output() } fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { #[derive(new, Debug)] struct ClampOps { desc: ClampOpIr, _b: PhantomData, } impl Operation for ClampOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_clamp(input, self.desc.min.into(), self.desc.max.into()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let min = min.into(); let max = max.into(); let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Clamp(desc.clone())), ClampOps::::new(desc), ) .output() } fn int_abs(tensor: IntTensor) -> IntTensor { unary_int_ops!(AbsOps, B::int_abs); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Abs(desc.clone())), AbsOps::::new(desc), ) .output() } fn int_into_float(tensor: IntTensor) -> FloatTensor { #[derive(new, Debug)] struct IntoFloatOps { desc: CastOpIr, _b: PhantomData, } impl Operation for IntoFloatOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_into_float(input); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())), IntoFloatOps::::new(desc), ) .output() } fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { #[derive(new, Debug)] struct SwapDimsOps { desc: SwapDimsOpIr, _b: PhantomData, } impl Operation for SwapDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())), SwapDimsOps::::new(desc), ) .output() } fn int_max(tensor: IntTensor) -> IntTensor { unary_int_ops!(MaxOps, B::int_max, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Max(desc.clone())), MaxOps::::new(desc.into()), ) .output() } fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(MaxDimOps, B::int_max_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())), MaxDimOps::::new(desc), ) .output() } fn int_max_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new, Debug)] struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesOpIr, _b: PhantomData, } impl Operation for MaxDimWithIndicesOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim); handles.register_int_tensor::(&self.desc.out.id, output); handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let dtype = tensor.dtype; let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())), MaxDimWithIndicesOps::::new(desc), ) .outputs() .into() } fn int_min(tensor: IntTensor) -> IntTensor { unary_int_ops!(MinOps, B::int_min, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Min(desc.clone())), MinOps::::new(desc.into()), ) .output() } fn int_max_abs(tensor: IntTensor) -> IntTensor { unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())), MaxAbsOps::::new(desc.into()), ) .output() } fn int_max_abs_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(MaxAbsDimOps, B::int_max_abs_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MaxAbsDim(desc.clone()), ), MaxAbsDimOps::::new(desc), ) .output() } fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(MinDimOps, B::int_min_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())), MinDimOps::::new(desc), ) .output() } fn int_min_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { #[derive(new, Debug)] struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesOpIr, _b: PhantomData, } impl Operation for MinDimWithIndicesOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim); handles.register_int_tensor::(&self.desc.out.id, output); handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let dtype = tensor.dtype; let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())), MinDimWithIndicesOps::::new(desc), ) .outputs() .into() } fn int_random( shape: Shape, distribution: Distribution, device: &Device, ) -> IntTensor { #[derive(new, Debug)] struct IntRandomOps { desc: RandomOpIr, device: Device, } impl Operation for IntRandomOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.desc.out.shape.clone(); let output = B::int_random(shape, self.desc.distribution, &self.device); handles.register_int_tensor::(&self.desc.out.id, output); } } let dtype = IntElem::::dtype(); let client = get_client::(device); let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::NumericInt(dtype, NumericOperationIr::IntRandom(desc.clone())), IntRandomOps::::new(desc, device.clone()), ) .output() } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { #[derive(new, Debug)] struct PermuteDimsOps { desc: PermuteOpIr, _b: PhantomData, } impl Operation for PermuteDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_permute(input, self.desc.axes.as_slice()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())), PermuteDimsOps::::new(desc), ) .output() } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { #[derive(new, Debug)] struct ExpandOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ExpandOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_expand(input, self.desc.out.shape.clone()); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())), ExpandOps::::new(desc), ) .output() } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { #[derive(new, Debug)] struct FlipDimsOps { desc: FlipOpIr, _b: PhantomData, } impl Operation for FlipDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let axes = &self.desc.axes; let output = B::int_flip(input, axes); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())), FlipDimsOps::::new(desc), ) .output() } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { #[derive(new, Debug)] struct RepeatDimOps { desc: RepeatDimOpIr, _b: PhantomData, } impl Operation for RepeatDimOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())), RepeatDimOps::::new(desc), ) .output() } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(BitwiseAndOps, B::bitwise_and); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())), BitwiseAndOps::::new(desc), ) .output() } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())), BitwiseAndOps::::new(desc), ) .output() } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(BitwiseOrOps, B::bitwise_or); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())), BitwiseOrOps::::new(desc), ) .output() } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())), BitwiseOrOps::::new(desc), ) .output() } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(BitwiseXorOps, B::bitwise_xor); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())), BitwiseXorOps::::new(desc), ) .output() } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())), BitwiseXorOps::::new(desc), ) .output() } fn bitwise_not(tensor: IntTensor) -> IntTensor { unary_int_ops!(BitwiseNotOps, B::bitwise_not); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())), BitwiseNotOps::::new(desc), ) .output() } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())), BitwiseLeftShiftOps::::new(desc), ) .output() } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())), BitwiseLeftShiftOps::::new(desc), ) .output() } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())), BitwiseRightShiftOps::::new(desc), ) .output() } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())), BitwiseRightShiftOps::::new(desc), ) .output() } fn int_cast(tensor: IntTensor, dtype: burn_backend::IntDType) -> IntTensor { #[derive(new, Debug)] struct CastOps { desc: CastOpIr, dtype: burn_backend::IntDType, _b: PhantomData, } impl Operation for CastOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())), CastOps::::new(desc, dtype), ) .output() } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { #[derive(new, Debug)] struct UnfoldOps { desc: UnfoldOpIr, _b: PhantomData, } impl Operation for UnfoldOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())), UnfoldOps::::new(desc), ) .output() } } ================================================ FILE: crates/burn-fusion/src/ops/mod.rs ================================================ mod activation; mod binary; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; mod unary; mod base; pub use base::NoOp; ================================================ FILE: crates/burn-fusion/src/ops/module.rs ================================================ use crate::{ Fusion, FusionBackend, stream::{OperationStreams, execution::Operation}, }; use burn_backend::{ Element, ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }, tensor::{FloatTensor, IntTensor}, }; use burn_ir::*; use std::marker::PhantomData; macro_rules! make_ops { ($name:ident, $desc:ty, $fn:expr) => { #[derive(new, Debug)] struct $name { desc: $desc, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { #[allow(clippy::redundant_closure_call)] $fn(&self.desc, handles) } } }; } impl ModuleOps> for Fusion { fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr, handles: &mut HandleContainer< B::Handle, >| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv1d(x, weight, bias, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); }); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = Conv1dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())), Conv1dOps::::new(desc), ) .output() } fn conv1d_x_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<1>, ) -> FloatTensor> { make_ops!( Conv1dXBackwardOps, Conv1dXBackwardOpIr, |desc: &Conv1dXBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_x_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv1dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv1dXBackward(desc.clone())), Conv1dXBackwardOps::::new(desc), ) .output() } fn conv1d_weight_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<1>, ) -> FloatTensor> { make_ops!( Conv1dWeightBackwardOps, Conv1dWeightBackwardOpIr, |desc: &Conv1dWeightBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_weight_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv1dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv1dWeightBackward(desc.clone())), Conv1dWeightBackwardOps::::new(desc), ) .output() } fn conv1d_bias_backward( x: FloatTensor>, bias: FloatTensor>, output_grad: FloatTensor>, ) -> FloatTensor> { make_ops!( Conv1dBiasBackwardOps, Conv1dBiasBackwardOpIr, |desc: &Conv1dBiasBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]); let client = x.client.clone(); let desc = Conv1dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(desc.clone())), Conv1dBiasBackwardOps::::new(desc), ) .output() } fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr, handles: &mut HandleContainer< B::Handle, >| { let x = handles.get_float_tensor::(&args.x); let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv2d(x, weight, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); }); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = Conv2dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())), Conv2dOps::::new(desc), ) .output() } fn conv2d_x_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<2>, ) -> FloatTensor> { make_ops!( Conv2dXBackwardOps, Conv2dXBackwardOpIr, |desc: &Conv2dXBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_x_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv2dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv2dXBackward(desc.clone())), Conv2dXBackwardOps::::new(desc), ) .output() } fn conv2d_weight_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<2>, ) -> FloatTensor> { make_ops!( Conv2dWeightBackwardOps, Conv2dWeightBackwardOpIr, |desc: &Conv2dWeightBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_weight_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv2dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv2dWeightBackward(desc.clone())), Conv2dWeightBackwardOps::::new(desc), ) .output() } fn conv2d_bias_backward( x: FloatTensor>, bias: FloatTensor>, output_grad: FloatTensor>, ) -> FloatTensor> { make_ops!( Conv2dBiasBackwardOps, Conv2dBiasBackwardOpIr, |desc: &Conv2dBiasBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]); let client = x.client.clone(); let desc = Conv2dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(desc.clone())), Conv2dBiasBackwardOps::::new(desc), ) .output() } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { make_ops!( DeformConv2dOps, DeformConv2dOpIr, |args: &DeformConv2dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let offset = handles.get_float_tensor::(&args.offset); let weight = handles.get_float_tensor::(&args.weight); let mask = args .mask .as_ref() .map(|mask| handles.get_float_tensor::(mask)); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } if let Some(mask) = mask.as_ref() { streams.tensor(mask) } let client = x.client.clone(); let desc = DeformConv2dOpIr::create( x.into_ir(), offset.into_ir(), weight.into_ir(), mask.map(|mask| mask.into_ir()), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))), DeformConv2dOps::::new(desc), ) .output() } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { make_ops!( DeformConv2dBackwardOps, DeformConv2dBackwardOpIr, |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let offset = handles.get_float_tensor::(&args.offset); let weight = handles.get_float_tensor::(&args.weight); let mask = args .mask .as_ref() .map(|mask| handles.get_float_tensor::(mask)); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output_grad = handles.get_float_tensor::(&args.out_grad); let output = B::deform_conv2d_backward( x, offset, weight, mask, bias, output_grad, args.options.clone().into(), ); handles.register_float_tensor::(&args.input_grad.id, output.x_grad); handles.register_float_tensor::(&args.offset_grad.id, output.offset_grad); handles.register_float_tensor::(&args.weight_grad.id, output.weight_grad); if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) { handles.register_float_tensor::(&field.id, mask_grad); } if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) { handles.register_float_tensor::(&field.id, bias_grad); } } ); let has_bias = bias.is_some(); let has_mask = mask.is_some(); let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]); if let Some(bias) = bias.as_ref() { streams.tensor(bias); } if let Some(mask) = mask.as_ref() { streams.tensor(mask); } let client = x.client.clone(); let desc = DeformConv2dBackwardOpIr::create( x.into_ir(), offset.into_ir(), weight.into_ir(), mask.map(|mask| mask.into_ir()), bias.map(|bias| bias.into_ir()), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); let mut outputs = client .register( streams, OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new( desc.clone(), ))), DeformConv2dBackwardOps::::new(desc), ) .into_iter(); // When the number of outputs is variable, the order is important let input_grad = outputs.next().unwrap(); let offset_grad = outputs.next().unwrap(); let weight_grad = outputs.next().unwrap(); let mask_grad = has_mask.then(|| outputs.next().unwrap()); let bias_grad = has_bias.then(|| outputs.next().unwrap()); DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad) } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr, handles: &mut HandleContainer< B::Handle, >| { let x = handles.get_float_tensor::(&args.x); let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv3d(x, weight, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); }); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = Conv3dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())), Conv3dOps::::new(desc), ) .output() } fn conv3d_x_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<3>, ) -> FloatTensor> { make_ops!( Conv3dXBackwardOps, Conv3dXBackwardOpIr, |desc: &Conv3dXBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_x_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv3dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv3dXBackward(desc.clone())), Conv3dXBackwardOps::::new(desc), ) .output() } fn conv3d_weight_backward( x: FloatTensor>, weight: FloatTensor>, output_grad: FloatTensor>, options: ConvOptions<3>, ) -> FloatTensor> { make_ops!( Conv3dWeightBackwardOps, Conv3dWeightBackwardOpIr, |desc: &Conv3dWeightBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_weight_backward(x, weight, output_grad, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]); let client = x.client.clone(); let desc = Conv3dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv3dWeightBackward(desc.clone())), Conv3dWeightBackwardOps::::new(desc), ) .output() } fn conv3d_bias_backward( x: FloatTensor>, bias: FloatTensor>, output_grad: FloatTensor>, ) -> FloatTensor> { make_ops!( Conv3dBiasBackwardOps, Conv3dBiasBackwardOpIr, |desc: &Conv3dBiasBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]); let client = x.client.clone(); let desc = Conv3dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(desc.clone())), Conv3dBiasBackwardOps::::new(desc), ) .output() } fn conv_transpose1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { make_ops!( ConvTranspose1dOps, ConvTranspose1dOpIr, |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = ConvTranspose1dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())), ConvTranspose1dOps::::new(desc), ) .output() } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { make_ops!( ConvTranspose2dOps, ConvTranspose2dOpIr, |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = ConvTranspose2dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())), ConvTranspose2dOps::::new(desc), ) .output() } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { make_ops!( ConvTranspose3dOps, ConvTranspose3dOpIr, |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let weight = handles.get_float_tensor::(&args.weight); let bias = args .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let mut streams = OperationStreams::with_inputs([&x, &weight]); if let Some(bias) = bias.as_ref() { streams.tensor(bias) } let client = x.client.clone(); let desc = ConvTranspose3dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())), ConvTranspose3dOps::::new(desc), ) .output() } fn avg_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { make_ops!( AvgPool1dOps, AvgPool1dOpIr, |args: &AvgPool1dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool1d( x, args.kernel_size, args.stride, args.padding, args.count_include_pad, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = AvgPool1dOpIr::create( x.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())), AvgPool1dOps::::new(desc), ) .output() } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { make_ops!( AvgPool2dOps, AvgPool2dOpIr, |args: &AvgPool2dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::avg_pool2d( x, args.kernel_size, args.stride, args.padding, args.count_include_pad, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = AvgPool2dOpIr::create( x.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())), AvgPool2dOps::::new(desc), ) .output() } fn avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { make_ops!( AvgPool1dBackwardOps, AvgPool1dBackwardOpIr, |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool1d_backward( x, grad, args.kernel_size, args.stride, args.padding, args.count_include_pad, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &grad]); let client = x.client.clone(); let desc = AvgPool1dBackwardOpIr::create( x.into_ir(), grad.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())), AvgPool1dBackwardOps::::new(desc), ) .output() } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { make_ops!( AvgPool2dBackwardOps, AvgPool2dBackwardOpIr, |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let output = B::avg_pool2d_backward( x, grad, args.kernel_size, args.stride, args.padding, args.count_include_pad, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &grad]); let client = x.client.clone(); let desc = AvgPool2dBackwardOpIr::create( x.into_ir(), grad.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())), AvgPool2dBackwardOps::::new(desc), ) .output() } fn max_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> FloatTensor { make_ops!( MaxPool1dOps, MaxPool1dOpIr, |args: &MaxPool1dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = MaxPool1dOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())), MaxPool1dOps::::new(desc), ) .output() } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { make_ops!( MaxPool2dOps, MaxPool2dOpIr, |args: &MaxPool2dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = MaxPool2dOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())), MaxPool2dOps::::new(desc), ) .output() } fn max_pool1d_with_indices( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { make_ops!( MaxPool1dWithIndicesOps, MaxPool1dWithIndicesOpIr, |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::max_pool1d_with_indices( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output.output); handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = MaxPool1dWithIndicesOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, B::IntElem::dtype(), || client.create_empty_handle(), ); let [out, out_indices] = client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())), MaxPool1dWithIndicesOps::::new(desc), ) .outputs(); MaxPool1dWithIndices::new(out, out_indices) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { make_ops!( MaxPool2dWithIndicesOps, MaxPool2dWithIndicesOpIr, |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::max_pool2d_with_indices( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, ); handles.register_float_tensor::(&args.out.id, output.output); handles.register_int_tensor::(&args.out_indices.id, output.indices); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = MaxPool2dWithIndicesOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, B::IntElem::dtype(), || client.create_empty_handle(), ); let [out, out_indices] = client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())), MaxPool2dWithIndicesOps::::new(desc), ) .outputs(); MaxPool2dWithIndices::new(out, out_indices) } fn max_pool1d_with_indices_backward( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { make_ops!( MaxPool1dWithIndicesBackwardOps, MaxPool1dWithIndicesBackwardOpIr, |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool1d_with_indices_backward( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, grad, indices, ); handles.register_float_tensor::(&args.out.id, output.x_grad); } ); let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]); let client = x.client.clone(); let desc = MaxPool1dWithIndicesBackwardOpIr::create( x.into_ir(), output_grad.into_ir(), indices.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); let out = client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward( desc.clone(), )), MaxPool1dWithIndicesBackwardOps::::new(desc), ) .output(); MaxPool1dBackward::new(out) } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward { make_ops!( MaxPool2dWithIndicesBackwardOps, MaxPool2dWithIndicesBackwardOpIr, |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let indices = handles.get_int_tensor::(&args.indices); let output = B::max_pool2d_with_indices_backward( x, args.kernel_size, args.stride, args.padding, args.dilation, args.ceil_mode, grad, indices, ); handles.register_float_tensor::(&args.out.id, output.x_grad); } ); let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]); let client = x.client.clone(); let desc = MaxPool2dWithIndicesBackwardOpIr::create( x.into_ir(), output_grad.into_ir(), indices.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); let out = client .register( streams, OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward( desc.clone(), )), MaxPool2dWithIndicesBackwardOps::::new(desc), ) .output(); MaxPool2dBackward::new(out) } fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { make_ops!( AdaptiveAvgPool1dOps, AdaptiveAvgPool1dOpIr, |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool1d(x, args.output_size); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || { client.create_empty_handle() }); client .register( streams, OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())), AdaptiveAvgPool1dOps::::new(desc), ) .output() } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { make_ops!( AdaptiveAvgPool2dOps, AdaptiveAvgPool2dOpIr, |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::adaptive_avg_pool2d(x, args.output_size); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || { client.create_empty_handle() }); client .register( streams, OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())), AdaptiveAvgPool2dOps::::new(desc), ) .output() } fn adaptive_avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { make_ops!( AdaptiveAvgPool1dBackwardOps, AdaptiveAvgPool1dBackwardOpIr, |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool1d_backward(x, grad); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &grad]); let client = x.client.clone(); let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())), AdaptiveAvgPool1dBackwardOps::::new(desc), ) .output() } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { make_ops!( AdaptiveAvgPool2dBackwardOps, AdaptiveAvgPool2dBackwardOpIr, |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let output = B::adaptive_avg_pool2d_backward(x, grad); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &grad]); let client = x.client.clone(); let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())), AdaptiveAvgPool2dBackwardOps::::new(desc), ) .output() } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { make_ops!( InterpolateOps, InterpolateOpIr, |args: &InterpolateOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let output = B::interpolate(x, args.output_size, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x]); let client = x.client.clone(); let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())), InterpolateOps::::new(desc), ) .output() } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { make_ops!( InterpolateBackwardOps, InterpolateBackwardOpIr, |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer| { let x = handles.get_float_tensor::(&args.x); let grad = handles.get_float_tensor::(&args.grad); let output = B::interpolate_backward(x, grad, args.output_size, args.options.clone().into()); handles.register_float_tensor::(&args.out.id, output); } ); let streams = OperationStreams::with_inputs([&x, &grad]); let client = x.client.clone(); let desc = InterpolateBackwardOpIr::create( x.into_ir(), grad.into_ir(), output_size, options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())), InterpolateBackwardOps::::new(desc), ) .output() } fn attention( query: FloatTensor>, key: FloatTensor>, value: FloatTensor>, mask: Option>>, attn_bias: Option>>, options: burn_backend::ops::AttentionModuleOptions, ) -> FloatTensor> { make_ops!( AttentionOps, AttentionOpIr, |args: &AttentionOpIr, handles: &mut HandleContainer| { let query = handles.get_float_tensor::(&args.query); let key = handles.get_float_tensor::(&args.key); let value = handles.get_float_tensor::(&args.value); let mask = args.mask.as_ref().map(|m| handles.get_bool_tensor::(m)); let attn_bias = args .attn_bias .as_ref() .map(|ab| handles.get_float_tensor::(ab)); let output = B::attention( query, key, value, mask, attn_bias, args.options.clone().into(), ); handles.register_float_tensor::(&args.out.id, output); } ); let mut streams = OperationStreams::with_inputs([&query, &key, &value]); if let Some(mask) = &mask { streams.tensor(mask); } if let Some(attn_bias) = &attn_bias { streams.tensor(attn_bias); } let client = query.client.clone(); let desc = AttentionOpIr::create( query.into_ir(), key.into_ir(), value.into_ir(), mask.map(|m| m.into_ir()), attn_bias.map(|ab| ab.into_ir()), options.into(), || client.create_empty_handle(), ); client .register( streams, OperationIr::Module(ModuleOperationIr::Attention(desc.clone())), AttentionOps::::new(desc), ) .output() } } ================================================ FILE: crates/burn-fusion/src/ops/qtensor.rs ================================================ use std::marker::PhantomData; use burn_backend::{ DType, Element, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive, ops::QTensorOps, quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive}, tensor::{Device, FloatTensor, IntTensor, QuantizedTensor}, }; use burn_ir::{ BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer, InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr, QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr, }; use crate::{ Fusion, FusionBackend, get_client, stream::{OperationStreams, execution::Operation}, }; use super::NoOp; impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { let client = get_client::(device); let dtype = data.dtype; let tensor = B::q_from_data(data, device); let shape = burn_backend::TensorMetadata::shape(&tensor); let handle = B::quantized_tensor_handle(tensor); let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle)); client .register( OperationStreams::default(), OperationIr::Init(desc), NoOp::::new(), ) .output() } fn quantize( tensor: FloatTensor, scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { #[derive(new, Debug)] struct QuantizeOp { desc: QuantizeOpIr, _b: PhantomData, } impl Operation for QuantizeOp { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let scales = handles.get_float_tensor::(&self.desc.qparams.scales); let qparams = QuantizationParametersPrimitive { scales }; let output = B::quantize(tensor, &self.desc.scheme, qparams); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]); let client = tensor.client.clone(); let qparams = QuantizationParametersIr { scales: qparams.scales.into_ir(), }; let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())), QuantizeOp::::new(desc), ) .output() } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { #[derive(new, Debug)] struct DequantizeOp { desc: DequantizeOpIr, _b: PhantomData, } impl Operation for DequantizeOp { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_quantized_tensor::(&self.desc.input); let output = B::dequantize(tensor); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let dtype = B::FloatElem::dtype(); let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle()); client .register( streams, OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())), DequantizeOp::::new(desc), ) .output() } fn q_device(tensor: &QuantizedTensor) -> Device { tensor.client.device().clone() } fn q_to_device(tensor: QuantizedTensor, device: &Device) -> QuantizedTensor { let device_original: &B::Device = tensor.client.device(); let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; } let id = tensor.stream; let client_target = get_client::(&device_target); let client_original = tensor.client.clone(); client_original.change_client_quantized::(tensor.into_ir(), client_target, id) } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { if tensor.shape == shape { return tensor; } #[derive(new, Debug)] struct ReshapeDimsOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ReshapeDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_quantized_tensor::(&self.desc.input); let output = B::q_reshape(input, self.desc.out.shape.clone()); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())), ReshapeDimsOps::::new(desc), ) .output() } async fn q_into_data(tensor: QuantizedTensor) -> Result { tensor.q_into_data::().await } fn q_swap_dims( tensor: QuantizedTensor, dim1: usize, dim2: usize, ) -> QuantizedTensor { #[derive(new, Debug)] struct SwapDimsOps { desc: SwapDimsOpIr, _b: PhantomData, } impl Operation for SwapDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_quantized_tensor::(&self.desc.input); let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())), SwapDimsOps::::new(desc), ) .output() } fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { #[derive(new, Debug)] struct PermuteDimsOps { desc: PermuteOpIr, _b: PhantomData, } impl Operation for PermuteDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_quantized_tensor::(&self.desc.input); let output = B::q_permute(input, self.desc.axes.as_slice()); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())), PermuteDimsOps::::new(desc), ) .output() } fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { #[derive(new, Debug)] struct FlipOps { desc: FlipOpIr, _b: PhantomData, } impl Operation for FlipOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_quantized_tensor::(&self.desc.input); let output = B::q_flip(input, &self.desc.axes); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())), FlipOps::::new(desc), ) .output() } fn q_gather( dim: usize, tensor: QuantizedTensor, indices: IntTensor, ) -> QuantizedTensor { #[derive(new, Debug)] struct GatherOps { desc: GatherOpIr, _b: PhantomData, } impl Operation for GatherOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_quantized_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::q_gather(self.desc.dim, tensor, indices); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())), GatherOps::::new(desc), ) .output() } fn q_select( tensor: QuantizedTensor, dim: usize, indices: IntTensor, ) -> QuantizedTensor { #[derive(new, Debug)] struct SelectOps { desc: SelectOpIr, _b: PhantomData, } impl Operation for SelectOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_quantized_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::q_select(tensor, self.desc.dim, indices); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())), SelectOps::::new(desc), ) .output() } fn q_slice(tensor: QuantizedTensor, slices: &[Slice]) -> QuantizedTensor { #[derive(new, Debug)] struct SliceOps { desc: SliceOpIr, _b: PhantomData, } impl Operation for SliceOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_quantized_tensor::(&self.desc.tensor); let output = B::q_slice(tensor, self.desc.ranges.as_slice()); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())), SliceOps::::new(desc), ) .output() } fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { #[derive(new, Debug)] struct ExpandOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ExpandOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_quantized_tensor::(&self.desc.input); let output = B::q_expand(input, self.desc.out.shape.clone()); handles.register_quantized_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())), ExpandOps::::new(desc), ) .output() } fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { #[derive(new, Debug)] struct MatmulOps { desc: MatmulOpIr, lhs_quantized: bool, rhs_quantized: bool, _b: PhantomData, } impl Operation for MatmulOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = match self.lhs_quantized { true => { TensorPrimitive::QFloat(handles.get_quantized_tensor::(&self.desc.lhs)) } false => TensorPrimitive::Float(handles.get_float_tensor::(&self.desc.lhs)), }; let rhs = match self.rhs_quantized { true => { TensorPrimitive::QFloat(handles.get_quantized_tensor::(&self.desc.rhs)) } false => TensorPrimitive::Float(handles.get_float_tensor::(&self.desc.rhs)), }; let output = B::q_matmul(lhs, rhs); match output { TensorPrimitive::Float(output) => { handles.register_float_tensor::(&self.desc.out.id, output); } TensorPrimitive::QFloat(output) => { handles.register_quantized_tensor::(&self.desc.out.id, output); } } } } let mut propagation = QuantPropagation::Inhibit; let mut scheme = QuantScheme::default(); let mut streams = OperationStreams::default(); let mut lhs_quantized = false; let mut rhs_quantized = false; match &lhs { TensorPrimitive::QFloat(lhs) => { propagation = lhs.propagation(); scheme = *lhs.scheme(); lhs_quantized = true; streams.tensor(lhs); } TensorPrimitive::Float(lhs) => { streams.tensor(lhs); } } match &rhs { TensorPrimitive::QFloat(rhs) => { propagation = rhs.propagation(); scheme = *rhs.scheme(); rhs_quantized = true; streams.tensor(rhs); } TensorPrimitive::Float(rhs) => { streams.tensor(rhs); } } let dtype = match propagation { QuantPropagation::Propagate => DType::QFloat(scheme), QuantPropagation::Inhibit => B::FloatElem::dtype(), }; let client = match &lhs { TensorPrimitive::Float(lhs) => lhs.client.clone(), TensorPrimitive::QFloat(lhs) => lhs.client.clone(), }; let lhs = match lhs { TensorPrimitive::Float(lhs) => lhs.into_ir(), TensorPrimitive::QFloat(lhs) => lhs.into_ir(), }; let rhs = match rhs { TensorPrimitive::Float(rhs) => rhs.into_ir(), TensorPrimitive::QFloat(rhs) => rhs.into_ir(), }; let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle()); let out = client .register( streams, OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())), MatmulOps::::new(desc, lhs_quantized, rhs_quantized), ) .output(); match propagation { QuantPropagation::Propagate => TensorPrimitive::QFloat(out), QuantPropagation::Inhibit => TensorPrimitive::Float(out), } } } ================================================ FILE: crates/burn-fusion/src/ops/tensor.rs ================================================ use super::NoOp; use crate::{ Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops, bool_dtype, get_client, reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, stream::{OperationStreams, execution::Operation}, unary_float_ops, }; use burn_backend::{ Distribution, Element, ExecutionError, FloatDType, Scalar, Shape, Slice, TensorData, ops::{FloatTensorOps, GridSampleOptions}, tensor::{BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntTensor}, }; use burn_ir::*; use std::marker::PhantomData; impl FloatTensorOps for Fusion { #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(data), fields(?data.shape, ?data.dtype) ))] fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); let dtype = data.dtype; let tensor = B::float_from_data(data, device); let shape = burn_backend::TensorMetadata::shape(&tensor); let handle = B::float_tensor_handle(tensor); let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle)); client .register( OperationStreams::default(), OperationIr::Init(desc), NoOp::::new(), ) .output() } fn float_random( shape: Shape, distribution: Distribution, device: &Device, ) -> FloatTensor { #[derive(new, Debug)] struct RandomOps { desc: RandomOpIr, device: Device, } impl Operation for RandomOps { fn execute(&self, handles: &mut HandleContainer) { let output: B::FloatTensorPrimitive = B::float_random( self.desc.out.shape.clone(), self.desc.distribution, &self.device, ); handles.register_float_tensor::(&self.desc.out.id, output); } } let dtype = FloatElem::::dtype(); let client = get_client::(device); let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::Float(dtype, FloatOperationIr::Random(desc.clone())), RandomOps::::new(desc, device.clone()), ) .output() } fn float_zeros(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { #[derive(new, Debug)] struct ZerosOps { out: TensorIr, device: Device, } impl Operation for ZerosOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.out.shape.clone(); let output = B::float_zeros(shape, &self.device, self.out.dtype.into()); handles.register_float_tensor::(&self.out.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseFloat(BaseOperationIr::Zeros(desc.clone())), ZerosOps::::new(desc.out, device.clone()), ) .output() } fn float_ones(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { #[derive(new, Debug)] struct OnesOps { out: TensorIr, device: Device, } impl Operation for OnesOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.out.shape.clone(); let output = B::float_ones(shape, &self.device, self.out.dtype.into()); handles.register_float_tensor::(&self.out.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseFloat(BaseOperationIr::Ones(desc.clone())), OnesOps::::new(desc.out, device.clone()), ) .output() } fn float_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: FloatDType, ) -> FloatTensor { #[derive(new, Debug)] struct FullOps { out: TensorIr, elem: ScalarIr, device: Device, } impl Operation for FullOps { fn execute(&self, handles: &mut HandleContainer) { let shape = self.out.shape.clone(); let dtype = self.out.dtype.into(); let output: B::FloatTensorPrimitive = B::float_full(shape, self.elem.into(), &self.device, dtype); handles.register_float_tensor::(&self.out.id, output); } } let dtype = dtype.into(); let client = get_client::(device); let value = fill_value.into(); let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())), FullOps::::new(desc.out, desc.value, device.clone()), ) .output() } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields( from = ?tensor.client.device(), shape = ?tensor.shape, dtype = ?tensor.dtype ) ))] async fn float_into_data(tensor: FloatTensor) -> Result { tensor.into_data::().await } fn float_device(tensor: &FloatTensor) -> Device { tensor.client.device().clone() } #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), fields( from = ?tensor.client.device(), shape = ?tensor.shape, dtype = ?tensor.dtype, ) ))] fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { let device_original: &B::Device = tensor.client.device(); if device_original == device { return tensor; } let id = tensor.stream; let client_target = get_client::(device); let client_original = tensor.client.clone(); client_original .clone() .change_client_float::(tensor.into_ir(), client_target, id) } fn float_into_int(tensor: FloatTensor) -> IntTensor { #[derive(new, Debug)] struct IntoIntOps { desc: CastOpIr, _b: PhantomData, } impl Operation for IntoIntOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_into_int(input); handles.register_int_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.input.dtype, FloatOperationIr::IntoInt(desc.clone())), IntoIntOps::::new(desc), ) .output() } fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { #[derive(new, Debug)] struct EmptyOps { desc: TensorIr, device: Device, } impl Operation for EmptyOps { fn execute(&self, handles: &mut HandleContainer) { let output = B::float_empty( self.desc.shape.clone(), &self.device, self.desc.dtype.into(), ); handles.register_float_tensor::(&self.desc.id, output); } } let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register( OperationStreams::default(), OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())), EmptyOps::::new(desc.out, device.clone()), ) .output() } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(AddOps, B::float_add); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Add(desc.clone())), AddOps::::new(desc), ) .output() } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(AddOps, B::float_add_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::AddScalar(desc.clone()), ), AddOps::::new(desc), ) .output() } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { #[derive(new, Debug)] struct ClampOps { desc: ClampOpIr, _b: PhantomData, } impl Operation for ClampOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_clamp(input, self.desc.min.into(), self.desc.max.into()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let min = min.into(); let max = max.into(); let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.tensor.dtype, NumericOperationIr::Clamp(desc.clone()), ), ClampOps::::new(desc), ) .output() } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(SubOps, B::float_sub); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sub(desc.clone())), SubOps::::new(desc), ) .output() } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(SubOps, B::float_sub_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::SubScalar(desc.clone()), ), SubOps::::new(desc), ) .output() } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(MulOps, B::float_mul); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mul(desc.clone())), MulOps::::new(desc), ) .output() } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(MulOps, B::float_mul_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MulScalar(desc.clone()), ), MulOps::::new(desc), ) .output() } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(DivOps, B::float_div); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Div(desc.clone())), DivOps::::new(desc), ) .output() } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(DivOps, B::float_div_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::DivScalar(desc.clone()), ), DivOps::::new(desc), ) .output() } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(ModOps, B::float_remainder); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Rem(desc.clone())), ModOps::::new(desc), ) .output() } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(ModOps, B::float_remainder_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::RemScalar(desc.clone()), ), ModOps::::new(desc), ) .output() } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(MatmulOps, B::float_matmul); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())), MatmulOps::::new(desc.into()), ) .output() } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { #[derive(new, Debug)] struct CrossOps { desc: CrossOpIr, _b: PhantomData, } impl Operation for CrossOps { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor::(&self.desc.rhs); let output = B::float_cross(lhs, rhs, self.desc.dim); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Cross(desc.clone())), CrossOps::::new(desc), ) .output() } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { #[derive(new, Debug)] struct SwapDimsOps { desc: SwapDimsOpIr, _b: PhantomData, } impl Operation for SwapDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())), SwapDimsOps::::new(desc), ) .output() } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { if tensor.shape == shape { return tensor; } #[derive(new, Debug)] struct ReshapeDimsOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ReshapeDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_reshape(input, self.desc.out.shape.clone()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())), ReshapeDimsOps::::new(desc), ) .output() } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { #[derive(new, Debug)] struct GatherOps { desc: GatherOpIr, _b: PhantomData, } impl Operation for GatherOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_gather(self.desc.dim, tensor, indices); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices]); let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())), GatherOps::::new(desc), ) .output() } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { #[derive(new, Debug)] struct ScatterOps { desc: ScatterOpIr, _b: PhantomData, } impl Operation for ScatterOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_scatter_add(self.desc.dim, tensor, indices, value); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices, &value]); let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Scatter(desc.clone())), ScatterOps::::new(desc), ) .output() } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { #[derive(new, Debug)] struct SelectOps { desc: SelectOpIr, _b: PhantomData, } impl Operation for SelectOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let output = B::float_select(tensor, self.desc.dim, indices); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices]); let client = tensor.client.clone(); let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())), SelectOps::::new(desc), ) .output() } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { #[derive(new, Debug)] struct SelectAssignOps { desc: SelectAssignOpIr, _b: PhantomData, } impl Operation for SelectAssignOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor::(&self.desc.indices); let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_select_add(tensor, self.desc.dim, indices, value); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &indices, &value]); let client = tensor.client.clone(); let desc = SelectAssignOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc.clone())), SelectAssignOps::::new(desc), ) .output() } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { #[derive(new, Debug)] struct SliceOps { desc: SliceOpIr, _b: PhantomData, } impl Operation for SliceOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_slice(tensor, self.desc.ranges.as_slice()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())), SliceOps::::new(desc), ) .output() } fn float_slice_assign( tensor: FloatTensor, slices: &[burn_backend::Slice], value: FloatTensor, ) -> FloatTensor { #[derive(new, Debug)] struct SliceAssignOps { desc: SliceAssignOpIr, _b: PhantomData, } impl Operation for SliceAssignOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor::(&self.desc.value); let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &value]); let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())), SliceAssignOps::::new(desc), ) .output() } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { #[derive(new, Debug)] struct MaskWhereOps { desc: MaskWhereOpIr, _b: PhantomData, } impl Operation for MaskWhereOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor::(&self.desc.value); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_where(tensor, mask, value); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask, &value]); let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc.clone())), MaskWhereOps::::new(desc), ) .output() } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { #[derive(new, Debug)] struct MaskFillOps { desc: MaskFillOpIr, _b: PhantomData, } impl Operation for MaskFillOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor::(&self.desc.mask); let output = B::float_mask_fill(tensor, mask, self.desc.value.into()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &mask]); let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc.clone())), MaskFillOps::::new(desc), ) .output() } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float_cmp_ops!(EqualOps, B::float_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())), EqualOps::::new(desc), ) .output() } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc.clone())), EqualElemOps::::new(desc), ) .output() } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float_cmp_ops!(GreaterOps, B::float_greater); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::Greater(desc.clone()), ), GreaterOps::::new(desc), ) .output() } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterElem(desc.clone()), ), GreaterElemOps::::new(desc), ) .output() } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterEqual(desc.clone()), ), GreaterEqualOps::::new(desc), ) .output() } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterEqualElem(desc.clone()), ), GreaterEqualElemOps::::new(desc), ) .output() } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float_cmp_ops!(LowerOps, B::float_lower); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericFloat(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())), LowerOps::::new(desc), ) .output() } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerElem(desc.clone()), ), LowerElemOps::::new(desc), ) .output() } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), bool_dtype::(), || client.create_empty_handle(), ); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerEqual(desc.clone()), ), LowerEqualOps::::new(desc), ) .output() } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerEqualElem(desc.clone()), ), LowerEqualElemOps::::new(desc), ) .output() } fn float_sum(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SumOps, B::float_sum, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sum(desc.clone())), SumOps::::new(desc.into()), ) .output() } fn float_sum_dim(tensor: FloatTensor, axis: usize) -> FloatTensor { reduce_float_ops!(SumDimOps, B::float_sum_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())), SumDimOps::::new(desc), ) .output() } fn float_prod(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ProdOps, B::float_prod, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Prod(desc.clone())), ProdOps::::new(desc.into()), ) .output() } fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce_float_ops!(ProdDimOps, B::float_prod_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::ProdDim(desc.clone()), ), ProdDimOps::::new(desc), ) .output() } fn float_mean(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MeanOps, B::float_mean, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mean(desc.clone())), MeanOps::::new(desc.into()), ) .output() } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce_float_ops!(MeanDimOps, B::float_mean_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MeanDim(desc.clone()), ), MeanDimOps::::new(desc), ) .output() } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CumsumOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumsumOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_cumsum(input, self.desc.axis); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())), CumsumOps::::new(desc), ) .output() } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CumprodOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumprodOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_cumprod(input, self.desc.axis); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::CumProd(desc.clone()), ), CumprodOps::::new(desc), ) .output() } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CumminOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CumminOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_cummin(input, self.desc.axis); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())), CumminOps::::new(desc), ) .output() } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CummaxOps { desc: DimOpIr, _b: PhantomData, } impl Operation for CummaxOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_cummax(input, self.desc.axis); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())), CummaxOps::::new(desc), ) .output() } fn float_exp(lhs: FloatTensor) -> FloatTensor { unary_float_ops!(ExpOps, B::float_exp); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Exp(desc.clone())), ExpOps::::new(desc), ) .output() } fn float_log(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(LogOps, B::float_log); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Log(desc.clone())), LogOps::::new(desc), ) .output() } fn float_log1p(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(Log1pOps, B::float_log1p); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Log1p(desc.clone())), Log1pOps::::new(desc), ) .output() } fn float_powf_scalar_impl(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { scalar_float_ops!(PowfOps, B::float_powf_scalar); let streams = OperationStreams::with_inputs([&lhs]); let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::PowfScalar(desc.clone())), PowfOps::::new(desc), ) .output() } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SqrtOps, B::float_sqrt); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Sqrt(desc.clone())), SqrtOps::::new(desc), ) .output() } fn float_abs(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(AbsOps, B::float_abs); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Abs(desc.clone())), AbsOps::::new(desc), ) .output() } fn float_cos(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(CosOps, B::float_cos); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Cos(desc.clone())), CosOps::::new(desc), ) .output() } fn float_sin(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SinOps, B::float_sin); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Sin(desc.clone())), SinOps::::new(desc), ) .output() } fn float_tan(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanOps, B::float_tan); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Tan(desc.clone())), TanOps::::new(desc), ) .output() } fn float_cosh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(CoshOps, B::float_cosh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Cosh(desc.clone())), CoshOps::::new(desc), ) .output() } fn float_sinh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(SinhOps, B::float_sinh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Sinh(desc.clone())), SinhOps::::new(desc), ) .output() } fn float_tanh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanhOps, B::float_tanh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Tanh(desc.clone())), TanhOps::::new(desc), ) .output() } fn float_acos(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcCosOps, B::float_acos); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCos(desc.clone())), ArcCosOps::::new(desc), ) .output() } fn float_acosh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcCoshOps, B::float_acosh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCosh(desc.clone())), ArcCoshOps::::new(desc), ) .output() } fn float_asin(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcSinOps, B::float_asin); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSin(desc.clone())), ArcSinOps::::new(desc), ) .output() } fn float_asinh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcSinhOps, B::float_asinh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSinh(desc.clone())), ArcSinhOps::::new(desc), ) .output() } fn float_atan(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcTanOps, B::float_atan); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan(desc.clone())), ArcTanOps::::new(desc), ) .output() } fn float_atanh(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(ArcTanhOps, B::float_atanh); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTanh(desc.clone())), ArcTanhOps::::new(desc), ) .output() } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(ArcTan2Ops, B::float_atan2); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan2(desc.clone())), ArcTan2Ops::::new(desc), ) .output() } fn float_recip(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(Recip, B::float_recip); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Recip(desc.clone())), Recip::::new(desc), ) .output() } fn float_erf(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TanhOps, B::float_erf); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Erf(desc.clone())), TanhOps::::new(desc), ) .output() } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { #[derive(new, Debug)] struct CatOps { desc: CatOpIr, _b: PhantomData, } impl Operation for CatOps { fn execute(&self, handles: &mut HandleContainer) { let tensors = self .desc .tensors .iter() .map(|tensor| handles.get_float_tensor::(tensor)) .collect(); let output = B::float_cat(tensors, self.desc.dim); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs(&tensors); let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())), CatOps::::new(desc), ) .output() } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { reduce_float2int_ops!(ArgMaxOps, B::float_argmax); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); // TODO: rename `create_with_dtype` specifically for ARG / indices let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.input.dtype, NumericOperationIr::ArgMax(desc.clone()), ), ArgMaxOps::::new(desc), ) .output() } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { #[derive(new, Debug)] struct RepeatDimOps { desc: RepeatDimOpIr, _b: PhantomData, } impl Operation for RepeatDimOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())), RepeatDimOps::::new(desc), ) .output() } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { reduce_float2int_ops!(ArgMinOps, B::float_argmin); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.input.dtype, NumericOperationIr::ArgMin(desc.clone()), ), ArgMinOps::::new(desc), ) .output() } fn float_max(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MaxOps, B::float_max, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Max(desc.clone())), MaxOps::::new(desc.into()), ) .output() } fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce_float_ops!(MaxDimOps, B::float_max_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())), MaxDimOps::::new(desc), ) .output() } fn float_max_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new, Debug)] struct MaxDimWithIndicesOps { desc: ReduceDimWithIndicesOpIr, _b: PhantomData, } impl Operation for MaxDimWithIndicesOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim); handles.register_float_tensor::(&self.desc.out.id, output); handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.tensor.dtype, NumericOperationIr::MaxDimWithIndices(desc.clone()), ), MaxDimWithIndicesOps::::new(desc), ) .outputs() .into() } fn float_min(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MinOps, B::float_min, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Min(desc.clone())), MinOps::::new(desc.into()), ) .output() } fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce_float_ops!(MinDimOps, B::float_min_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())), MinDimOps::::new(desc), ) .output() } fn float_min_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { #[derive(new, Debug)] struct MinDimWithIndicesOps { desc: ReduceDimWithIndicesOpIr, _b: PhantomData, } impl Operation for MinDimWithIndicesOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim); handles.register_float_tensor::(&self.desc.out.id, output); handles.register_int_tensor::(&self.desc.out_indices.id, indices); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, B::IntElem::dtype(), || { client.create_empty_handle() }); client .register( streams, OperationIr::NumericFloat( desc.tensor.dtype, NumericOperationIr::MinDimWithIndices(desc.clone()), ), MinDimWithIndicesOps::::new(desc), ) .outputs() .into() } fn float_max_abs(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())), MaxAbsOps::::new(desc.into()), ) .output() } fn float_max_abs_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { reduce_float_ops!(MaxAbsDimOps, B::float_max_abs_dim); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register( streams, OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MaxAbsDim(desc.clone()), ), MaxAbsDimOps::::new(desc), ) .output() } // TODO: float_powi w/ burn-cubecl-fusion impl fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { binary_float_ops!(PowOps, B::float_powf); let streams = OperationStreams::with_inputs([&lhs, &rhs]); let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Powf(desc.clone())), PowOps::::new(desc), ) .output() } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { #[derive(new, Debug)] struct PermuteDimsOps { desc: PermuteOpIr, _b: PhantomData, } impl Operation for PermuteDimsOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_permute(input, self.desc.axes.as_slice()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())), PermuteDimsOps::::new(desc), ) .output() } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { #[derive(new, Debug)] struct ExpandOps { desc: ShapeOpIr, _b: PhantomData, } impl Operation for ExpandOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_expand(input, self.desc.out.shape.clone()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())), ExpandOps::::new(desc), ) .output() } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { #[derive(new, Debug)] struct FlipOps { desc: FlipOpIr, _b: PhantomData, } impl Operation for FlipOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_flip(input, &self.desc.axes); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())), FlipOps::::new(desc), ) .output() } fn float_round(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(RoundOps, B::float_round); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Round(desc.clone())), RoundOps::::new(desc), ) .output() } fn float_floor(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(FloorOps, B::float_floor); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Floor(desc.clone())), FloorOps::::new(desc), ) .output() } fn float_ceil(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(CeilOps, B::float_ceil); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Ceil(desc.clone())), CeilOps::::new(desc), ) .output() } fn float_trunc(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(TruncOps, B::float_trunc); let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::Trunc(desc.clone())), TruncOps::::new(desc), ) .output() } fn float_cast(tensor: FloatTensor, dtype: burn_backend::FloatDType) -> FloatTensor { #[derive(new, Debug)] struct CastOps { desc: CastOpIr, dtype: burn_backend::FloatDType, _b: PhantomData, } impl Operation for CastOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())), CastOps::::new(desc, dtype), ) .output() } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { #[derive(new, Debug)] struct UnfoldOps { desc: UnfoldOpIr, _b: PhantomData, } impl Operation for UnfoldOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register( streams, OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())), UnfoldOps::::new(desc), ) .output() } fn float_is_nan(tensor: FloatTensor) -> BoolTensor { #[derive(new, Debug)] struct IsNanOps { desc: UnaryOpIr, _b: PhantomData, } impl Operation for IsNanOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_is_nan(input); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create_comparison(tensor.into_ir(), bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.input.dtype, FloatOperationIr::IsNan(desc.clone())), IsNanOps::::new(desc), ) .output() } fn float_is_inf(tensor: FloatTensor) -> BoolTensor { #[derive(new, Debug)] struct IsInfOps { desc: UnaryOpIr, _b: PhantomData, } impl Operation for IsInfOps { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_is_inf(input); handles.register_bool_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor]); let client = tensor.client.clone(); let desc = UnaryOpIr::create_comparison(tensor.into_ir(), bool_dtype::(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.input.dtype, FloatOperationIr::IsInf(desc.clone())), IsInfOps::::new(desc), ) .output() } fn float_grid_sample_2d( tensor: FloatTensor, grid: FloatTensor, options: GridSampleOptions, ) -> FloatTensor { #[derive(new, Debug)] struct GridSample2dOps { desc: GridSample2dOpIr, _b: PhantomData, } impl Operation for GridSample2dOps { fn execute(&self, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let grid = handles.get_float_tensor::(&self.desc.grid); let output = B::float_grid_sample_2d(tensor, grid, self.desc.options.clone().into()); handles.register_float_tensor::(&self.desc.out.id, output); } } let streams = OperationStreams::with_inputs([&tensor, &grid]); let client = tensor.client.clone(); let desc = GridSample2dOpIr::create(tensor.into_ir(), grid.into_ir(), options.into(), || { client.create_empty_handle() }); client .register( streams, OperationIr::Float(desc.out.dtype, FloatOperationIr::GridSample2d(desc.clone())), GridSample2dOps::::new(desc), ) .output() } } ================================================ FILE: crates/burn-fusion/src/ops/transaction.rs ================================================ use burn_backend::{ backend::ExecutionError, ops::{TransactionOps, TransactionPrimitive}, }; use crate::{Fusion, FusionBackend}; impl TransactionOps> for Fusion { async fn tr_execute( transaction: TransactionPrimitive, ) -> Result { B::tr_execute(TransactionPrimitive::new( transaction .read_floats .into_iter() .map(|t| t.client.clone().resolve_tensor_float::(t)) .collect(), transaction .read_qfloats .into_iter() .map(|_t| todo!("Quantization not supported yet")) .collect(), transaction .read_ints .into_iter() .map(|t| t.client.clone().resolve_tensor_int::(t)) .collect(), transaction .read_bools .into_iter() .map(|t| t.client.clone().resolve_tensor_bool::(t)) .collect(), )) .await } } ================================================ FILE: crates/burn-fusion/src/ops/unary.rs ================================================ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.into()); handles.register_float_tensor::(&self.desc.out.id, output); } } }; ( $name:ident, $ops:expr, noconvert ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); handles.register_float_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_float_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ReduceDimOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input, self.desc.axis); handles.register_float_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_float2int_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ReduceDimOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_int_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ReduceDimOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input, self.desc.axis); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float2int_ops { ( $name:ident, $ops:expr, ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.clone()); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_float_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: UnaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); handles.register_float_tensor::(&self.desc.out.id, output); } } }; ( $name:ident, $ops:expr, reduce ) => { #[derive(new, Debug)] struct $name { desc: UnaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); handles.register_float_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_int_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: UnaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); handles.register_int_tensor::(&self.desc.out.id, output); } } }; ( $name:ident, $ops:expr, reduce ) => { #[derive(new, Debug)] struct $name { desc: UnaryOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_cmp_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.into()); handles.register_bool_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_cmp_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.into()); handles.register_bool_tensor::(&self.desc.out.id, output); } } }; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_ops { ( $name:ident, $ops:expr ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.into()); handles.register_int_tensor::(&self.desc.out.id, output); } } }; ( $name:ident, $ops:expr, noconvert ) => { #[derive(new, Debug)] struct $name { desc: ScalarOpIr, _b: PhantomData, } impl Operation for $name { fn execute(&self, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); handles.register_int_tensor::(&self.desc.out.id, output); } } }; } ================================================ FILE: crates/burn-fusion/src/search/block.rs ================================================ use crate::{FuserStatus, NumOperations, OperationFuser, stream::store::ExecutionStrategy}; use burn_ir::{OperationIr, TensorId, TensorIr}; use std::{collections::HashSet, sync::Arc}; /// A block represents a list of operations, not necessarily in the same order as the execution /// stream. /// /// The start and end position of the relative execution stream are tracked in the block alongside /// the ordering. pub struct Block { builders: Vec>>, operations: Vec, ids: HashSet, ordering: Vec, /// The start position in the relative execution stream. pub start_pos: usize, /// The end position in the relative execution stream. pub end_pos: usize, } /// The result of [registering](Block::register) an [operation](OperationIr). pub enum RegistrationResult { /// If the [operation](OperationIr) is correctly registered. Accepted, /// If the [operation](OperationIr) isn't part of the graph. /// /// In this case the operation isn't registered. NotPartOfTheGraph, } /// The optimization found for a [block](Block). #[derive(Debug, new)] pub struct BlockOptimization { /// The [execution strategy](ExecutionStrategy) to be used to execute the [block](Block). pub strategy: ExecutionStrategy, /// The ordering of each operation in the relative execution stream. pub ordering: Vec, } impl Block { /// Create a new block that will be optimized with the provided [optimization builders](OptimizationBuilder). pub fn new(builders: &[Box>]) -> Self { Self { builders: builders.iter().map(|o| o.clone_dyn()).collect(), operations: Vec::new(), ids: HashSet::new(), ordering: Vec::new(), start_pos: usize::MAX, end_pos: usize::MIN, } } /// Sort the [blocks](Block) based on the start position. pub fn sort(blocks: &mut [Self]) { blocks.sort_by(|a, b| a.start_pos.cmp(&b.start_pos)); } /// Optimize the block. pub fn optimize(mut self) -> BlockOptimization { match find_best_optimization_index(&mut self.builders) { Some(index) => { let opt = self.builders[index].finish(); let opt_len = opt.len(); if opt_len < self.operations.len() { self.ordering.drain(opt_len..); } let strategy = ExecutionStrategy::Optimization { ordering: Arc::new(self.ordering.clone()), opt, }; BlockOptimization::new(strategy, self.ordering) } None => { let strategy = ExecutionStrategy::Operations { ordering: Arc::new(self.ordering.clone()), }; BlockOptimization::new(strategy, self.ordering) } } } /// Returns if the block contains any of the provided [tensors](TensorIr). pub fn contains_tensors(&self, tensors: &[&TensorIr]) -> bool { for node in tensors { if self.ids.contains(&node.id) { return true; } } false } /// Merge the current block with the other one and returns if the operation is successful. /// /// # Warning /// /// This will modify the current block even if the other block isn't correctly merged. pub fn merge(&mut self, other: &Block) -> bool { for (op, pos) in other.operations.iter().zip(&other.ordering) { self.register(op, *pos, true); } // The operation is successful if the current block can still be optimized. self.still_optimizing() } /// Register an [operation](OperationIr) in the current block. /// /// You need to provide the order of the operation as well as a force flag. /// /// When the force flag is true, the builder will always accept the operation, otherwise it /// might refuse it if the operation [isn't part of the graph](RegistrationResult::NotPartOfTheGraph). /// /// Forcing is useful to fuse operations that are part of different graphs, but included /// in the same optimization. pub fn register( &mut self, operation: &OperationIr, order: usize, force: bool, ) -> RegistrationResult { if self.ids.is_empty() { self.register_op(operation, order); return RegistrationResult::Accepted; } let mut contains = false; for node in operation.nodes() { contains = self.ids.contains(&node.id); if contains { break; } } if !contains && !force { return RegistrationResult::NotPartOfTheGraph; } self.register_op(operation, order); RegistrationResult::Accepted } /// If the block can still be optimized further. pub fn still_optimizing(&self) -> bool { let mut num_stopped = 0; for optimization in self.builders.iter() { if let FuserStatus::Closed = optimization.status() { num_stopped += 1 } } num_stopped < self.builders.len() } fn register_op(&mut self, operation: &OperationIr, pos: usize) { self.operations.push(operation.clone()); self.ordering.push(pos); if pos < self.start_pos { self.start_pos = pos; } if pos + 1 > self.end_pos { self.end_pos = pos + 1; } for builder in self.builders.iter_mut() { builder.fuse(operation); } for node in operation.nodes() { self.ids.insert(node.id); } } } impl BlockOptimization { /// Maps the ordering of the current block optimization using the given mapping. pub fn map_ordering(&mut self, mapping: &[usize]) { for i in self.ordering.iter_mut() { *i = mapping[*i]; } self.strategy.map_ordering(mapping); } } impl ExecutionStrategy { /// Maps the ordering of the current execution strategy using the given mapping. pub fn map_ordering(&mut self, mapping: &[usize]) { match self { ExecutionStrategy::Optimization { ordering, .. } => { let mut ordering_mapped = ordering.to_vec(); for o in ordering_mapped.iter_mut() { *o = mapping[*o]; } *ordering = Arc::new(ordering_mapped); } ExecutionStrategy::Operations { ordering } => { let mut ordering_mapped = ordering.to_vec(); for o in ordering_mapped.iter_mut() { *o = mapping[*o]; } *ordering = Arc::new(ordering_mapped); } ExecutionStrategy::Composed(items) => { for item in items.iter_mut() { item.map_ordering(mapping); } } } } } fn find_best_optimization_index( optimizations: &mut [Box>], ) -> Option { let mut best_index = None; let mut best_score = 0; for (i, optimization) in optimizations.iter().enumerate() { let properties = optimization.properties(); // A score of zero is worse than fusing. if properties.ready && properties.score > best_score { best_index = Some(i); best_score = properties.score; } } best_index } impl PartialEq for Block { fn eq(&self, other: &Self) -> bool { // Since the ordering can be seen as operation ids, we can use it to compare // blocks. let mut sorted_a = self.ordering.clone(); let mut sorted_b = other.ordering.clone(); sorted_a.sort(); sorted_b.sort(); sorted_a == sorted_b } } impl core::fmt::Debug for Block { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "Block {{ pos: [{:?}, {:?}; {:?}] }}", self.start_pos, self.end_pos, self.ordering.len(), )) } } impl Clone for Block { fn clone(&self) -> Self { Self { builders: self.builders.iter().map(|b| b.clone_dyn()).collect(), operations: self.operations.clone(), ids: self.ids.clone(), ordering: self.ordering.clone(), start_pos: self.start_pos, end_pos: self.end_pos, } } } ================================================ FILE: crates/burn-fusion/src/search/merging.rs ================================================ use super::Block; use crate::NumOperations; #[derive(Debug, PartialEq)] /// The result of [merging](merge_blocks) [blocks](Block). pub enum MergeBlocksResult { /// All [blocks](Block) merged into one. Full(Block), /// Some [blocks](Block) merged and some failed. Partial { merged: Vec>, failed: Vec>, }, /// All [blocks](Block) failed to merge. Fail, } /// Merge multiple [block](Block) together. /// /// The resulting [blocks](Block) might be sorted if the flag is true, otherwise the order isn't /// guarantee. This is mostly useful for testing. /// /// # Strategy /// /// The merging strategy is in two steps: /// /// 1. The first step is to recursively try to merge adjacent blocks. This has the advantage of /// trying multiple blocks ordering, therefore trying multiple permutation of the blocks. /// However, it has the downside of not trying to merge blocks that are further away in the list /// of blocks. Since trying all combinations possible is exponential, therefore not possible, we /// fallback on the second strategy. /// 2. The second step is to reduce blocks by setting an accumulator block, then sequentially /// trying to merge the remaining blocks. We try some permutations based on the result from /// step1. pub fn merge_blocks(blocks: &[&Block], sorted: bool) -> MergeBlocksResult { if blocks.is_empty() { return MergeBlocksResult::Fail; } if blocks.len() == 1 { return MergeBlocksResult::Full(blocks[0].clone()); } if blocks.len() == 2 { let block0 = blocks[0]; let block1 = blocks[1]; return match merge_two(block0, block1) { Some(result) => MergeBlocksResult::Full(result), None => MergeBlocksResult::Fail, }; } let mut step1 = merge_blocks_step1(blocks); if step1.full.len() == 1 && step1.failed.is_empty() && step1.partial.is_empty() { MergeBlocksResult::Full(step1.full.remove(0)) } else if step1.partial.len() == 1 && step1.failed.is_empty() && step1.full.is_empty() { MergeBlocksResult::Full(step1.partial.remove(0)) } else { let result = merge_blocks_step2(step1); if !sorted { return result; } match result { MergeBlocksResult::Full(block) => MergeBlocksResult::Full(block), MergeBlocksResult::Partial { mut merged, mut failed, } => { Block::sort(&mut merged); Block::sort(&mut failed); MergeBlocksResult::Partial { merged, failed } } MergeBlocksResult::Fail => MergeBlocksResult::Fail, } } } struct MergeBlockStep1 { full: Vec>, partial: Vec>, failed: Vec>, } impl Default for MergeBlockStep1 { fn default() -> Self { Self { full: Default::default(), partial: Default::default(), failed: Default::default(), } } } fn merge_blocks_step1(blocks: &[&Block]) -> MergeBlockStep1 { let step_size = blocks.len() / 2; let num_steps = f32::ceil(blocks.len() as f32 / step_size as f32) as usize; let mut result = MergeBlockStep1::default(); for i in 0..num_steps { let start = i * step_size; let end = usize::min(start + step_size, blocks.len()); match merge_blocks(&blocks[start..end], false) { MergeBlocksResult::Full(block) => { result.full.push(block); } MergeBlocksResult::Partial { mut merged, mut failed, } => { result.partial.append(&mut merged); result.failed.append(&mut failed); } MergeBlocksResult::Fail => { for b in &blocks[start..end] { result.failed.push((*b).clone()); } } } } result } fn merge_blocks_step2(mut step1: MergeBlockStep1) -> MergeBlocksResult { // First let's try to merge partial graphs. if step1.partial.len() > 1 { match merge_accumulator(&step1.partial[0], &step1.partial[1..]) { MergeBlocksResult::Full(block) => { step1.partial = vec![block]; } MergeBlocksResult::Partial { merged, mut failed } => { step1.partial = merged; step1.failed.append(&mut failed); } MergeBlocksResult::Fail => {} } } // Then let's try to merge partial graphs with failed merges. if !step1.failed.is_empty() { step1.partial.append(&mut step1.failed); match merge_accumulator(&step1.partial[0], &step1.partial[1..]) { MergeBlocksResult::Full(block) => { step1.partial = vec![block]; } MergeBlocksResult::Partial { merged, mut failed } => { step1.partial = merged; step1.failed.append(&mut failed); } MergeBlocksResult::Fail => {} } } // Then let's try to merge full graphs. if step1.full.len() > 1 { match merge_accumulator(&step1.full[0], &step1.full[1..]) { MergeBlocksResult::Full(block) => { step1.full = vec![block]; } MergeBlocksResult::Partial { merged, mut failed } => { step1.full = merged; step1.failed.append(&mut failed); } MergeBlocksResult::Fail => {} } } // Then let's try to merge full graphs with failed graphs. if !step1.full.is_empty() { step1.full.append(&mut step1.failed); match merge_accumulator(&step1.full[0], &step1.full[1..]) { MergeBlocksResult::Full(block) => { step1.full = vec![block]; } MergeBlocksResult::Partial { merged, mut failed } => { step1.full = merged; step1.failed.append(&mut failed); } MergeBlocksResult::Fail => {} } } // Then let's try to merge full graphs with partial graphs. if !step1.full.is_empty() || !step1.partial.is_empty() { step1.full.append(&mut step1.partial); match merge_accumulator(&step1.full[0], &step1.full[1..]) { MergeBlocksResult::Full(block) => { step1.full = vec![block]; } MergeBlocksResult::Partial { merged, mut failed } => { step1.full = merged; step1.failed.append(&mut failed); } MergeBlocksResult::Fail => { // We do nothing. } } } if step1.full.is_empty() { MergeBlocksResult::Fail } else if step1.failed.is_empty() { if step1.full.len() == 1 { MergeBlocksResult::Full(step1.full.remove(0)) } else { MergeBlocksResult::Partial { merged: step1.full, failed: vec![], } } } else { MergeBlocksResult::Partial { merged: step1.full, failed: step1.failed, } } } fn merge_accumulator( base: &Block, blocks: &[Block], ) -> MergeBlocksResult { let mut base = base.clone(); let mut merged_failed = Vec::>::new(); let mut merged_success = false; for block in blocks { let mut base_current = base.clone(); match base_current.merge(block) { false => { merged_failed.push((*block).clone()); } true => { merged_success = true; base = base_current; } } } if merged_success { if merged_failed.is_empty() { MergeBlocksResult::Full(base) } else { MergeBlocksResult::Partial { merged: vec![base], failed: merged_failed, } } } else { MergeBlocksResult::Fail } } fn merge_two(a: &Block, b: &Block) -> Option> { let mut base = a.clone(); if base.merge(b) { return Some(base); } let mut base = b.clone(); match base.merge(a) { true => Some(base), false => None, } } #[cfg(test)] mod tests { use super::*; pub use crate::stream::execution::tests::{TestOptimization, TestOptimizationBuilder}; use crate::{ OperationFuser, stream::tests::{operation_1, operation_2, operation_3}, }; #[test] fn test_merge_blocks_no_block() { let actual = merge_blocks::(&[], true); assert_eq!(actual, MergeBlocksResult::Fail); } #[test] fn test_merge_blocks_single() { let builders = builders(); let block = Block::new(&builders); let actual = merge_blocks::(&[&block], true); assert_eq!(actual, MergeBlocksResult::Full(block)); } #[test] fn test_merge_blocks_two_blocks() { let builders = builders(); let mut block1 = Block::new(&builders); let mut block2 = Block::new(&builders); block1.register(&operation_1(), 0, false); block1.register(&operation_1(), 1, false); block2.register(&operation_1(), 2, false); block2.register(&operation_1(), 3, false); let actual = merge_blocks::(&[&block1, &block2], true); let mut expected = Block::new(&builders); expected.register(&operation_1(), 0, false); expected.register(&operation_1(), 1, false); expected.register(&operation_1(), 2, false); expected.register(&operation_1(), 3, false); assert_eq!(actual, MergeBlocksResult::Full(expected)); } #[test] fn test_merge_blocks_three_blocks() { let builders = builders(); let mut block1 = Block::new(&builders); let mut block2 = Block::new(&builders); let mut block3 = Block::new(&builders); block1.register(&operation_1(), 0, false); block2.register(&operation_1(), 1, false); block3.register(&operation_1(), 2, false); let actual = merge_blocks::(&[&block1, &block2, &block3], true); let mut expected = Block::new(&builders); expected.register(&operation_1(), 0, false); expected.register(&operation_1(), 1, false); expected.register(&operation_1(), 2, false); assert_eq!(actual, MergeBlocksResult::Full(expected)); } #[test] fn test_merge_blocks_three_blocks_partial() { let builders = builders(); let mut block1 = Block::new(&builders); let mut block2 = Block::new(&builders); let mut block3 = Block::new(&builders); block1.register(&operation_1(), 0, false); block2.register(&operation_2(), 1, false); block3.register(&operation_1(), 2, false); let actual = merge_blocks::(&[&block1, &block2, &block3], true); let mut expected1 = Block::new(&builders); let mut expected2 = Block::new(&builders); expected1.register(&operation_1(), 0, false); expected1.register(&operation_1(), 2, false); expected2.register(&operation_2(), 1, false); assert_eq!( actual, MergeBlocksResult::Partial { merged: vec![expected1, expected2], failed: vec![] } ); } #[test] fn test_merge_blocks_four_blocks_partial_with_failure() { let builders = builders(); let mut block1 = Block::new(&builders); let mut block2 = Block::new(&builders); let mut block3 = Block::new(&builders); let mut block4 = Block::new(&builders); block1.register(&operation_1(), 0, false); block2.register(&operation_2(), 1, false); block3.register(&operation_1(), 2, false); block4.register(&operation_3(), 3, false); let actual = merge_blocks::(&[&block1, &block2, &block3, &block4], true); let mut expected1 = Block::new(&builders); let mut expected2 = Block::new(&builders); let mut failed = Block::new(&builders); expected1.register(&operation_1(), 0, false); expected1.register(&operation_1(), 2, false); expected2.register(&operation_2(), 1, false); failed.register(&operation_3(), 3, false); assert_eq!( actual, MergeBlocksResult::Partial { merged: vec![expected1], failed: vec![expected2, failed] } ); } #[test] fn test_merge_blocks_five_blocks_partial_with_failure() { let builders = builders(); let mut block1 = Block::new(&builders); let mut block2 = Block::new(&builders); let mut block3 = Block::new(&builders); let mut block4 = Block::new(&builders); let mut block5 = Block::new(&builders); block1.register(&operation_1(), 0, false); block2.register(&operation_2(), 1, false); block3.register(&operation_1(), 2, false); block4.register(&operation_3(), 3, false); block5.register(&operation_2(), 4, false); let actual = merge_blocks::(&[&block1, &block2, &block3, &block4, &block5], true); let mut expected1 = Block::new(&builders); let mut expected2 = Block::new(&builders); let mut failed = Block::new(&builders); expected1.register(&operation_1(), 0, false); expected1.register(&operation_1(), 2, false); expected2.register(&operation_2(), 1, false); expected2.register(&operation_2(), 4, false); failed.register(&operation_3(), 3, false); assert_eq!( actual, MergeBlocksResult::Partial { merged: vec![expected1, expected2], failed: vec![failed] } ); } fn builders() -> Vec>> { let builder_1 = TestOptimizationBuilder::new(0, vec![operation_1(); 10]); let builder_2 = TestOptimizationBuilder::new(1, vec![operation_2(); 10]); vec![Box::new(builder_1), Box::new(builder_2)] } } ================================================ FILE: crates/burn-fusion/src/search/mod.rs ================================================ mod block; mod optimization; pub(super) mod merging; pub(super) use block::*; pub use optimization::*; ================================================ FILE: crates/burn-fusion/src/search/optimization/blocks.rs ================================================ use std::sync::Arc; use crate::{ NumOperations, search::{ Block, BlockOptimization, merging::{MergeBlocksResult, merge_blocks}, }, stream::store::ExecutionStrategy, }; /// Try to optimize a list of [blocks](Block) into a [block optimization](BlockOptimization). /// /// # Notes /// /// What we know here is that every block is independent at that time and can be executed /// in any order. /// /// The contract is that the length of operations executed must include all operations. If we don't /// find an optimization that can be executed with that constraint, we return a /// [BlocksOptimizerResult::WithHoles]. pub struct BlocksOptimizer { blocks: Vec>, resolved: Vec, last_checked: usize, } /// When we can't find a proper optimization for the provided list of [blocks](Block). pub enum BlocksOptimizerResult { /// When an optimization fill the hole stream. Full(BlockOptimization), /// The optimization found with the holes indices. WithHoles { strategies: Vec>>, ordering: Vec, holes: Vec, }, } enum BlockOptimizationStep { Contiguous { strategy: ExecutionStrategy, }, /// Only happen when we fallback on executing a single operation. Operation { strategy: ExecutionStrategy, }, WithHoles { strategy: ExecutionStrategy, holes: Vec, }, Stop, } impl BlocksOptimizer { /// Create a new optimizer with the given blocks. pub fn new(blocks: Vec>) -> Self { let num_ops: usize = blocks.iter().map(|g| g.end_pos).max().unwrap(); Self { blocks, resolved: vec![false; num_ops], last_checked: 0, } } /// Optimizes the blocks. /// /// The strategy is quite simple. We try to merge as much [blocks](Block) together as we can, /// then we iterate over them in order composing optimizations with the remaining blocks, all /// while minimizing fallbacks operations to avoid having holes in the optimization stream. pub fn optimize(mut self) -> BlocksOptimizerResult { self = self.merging_pass(); let mut strategies = Vec::with_capacity(self.blocks.len()); let mut ordering = Vec::new(); let mut blocks = Vec::new(); core::mem::swap(&mut blocks, &mut self.blocks); for block in blocks { match self.optimize_block(block, &mut ordering) { BlockOptimizationStep::Contiguous { strategy } => { strategies.push(Box::new(strategy)); } BlockOptimizationStep::Operation { strategy } => { strategies.push(Box::new(strategy)); break; } BlockOptimizationStep::WithHoles { strategy, holes } => { strategies.push(Box::new(strategy)); return BlocksOptimizerResult::WithHoles { strategies, ordering, holes, }; } BlockOptimizationStep::Stop => { break; } } } let optimization = match strategies.len() > 1 { true => BlockOptimization { strategy: ExecutionStrategy::Composed(strategies), ordering, }, false => BlockOptimization { strategy: *strategies.remove(0), ordering, }, }; BlocksOptimizerResult::Full(optimization) } /// Optimize a single block. fn optimize_block( &mut self, block: Block, ordering: &mut Vec, ) -> BlockOptimizationStep { let last_index = block.end_pos; let mut block_optimization = block.optimize(); let opt_size = block_optimization.ordering.len(); for pos in block_optimization.ordering.iter() { self.update_check(*pos); } if self.last_checked != ordering.len() + opt_size { if !ordering.is_empty() { // Don't include that block and need further exploring. return BlockOptimizationStep::Stop; } return self.optimize_holes(block_optimization, last_index, ordering); } ordering.append(&mut block_optimization.ordering); BlockOptimizationStep::Contiguous { strategy: block_optimization.strategy, } } /// The provided optimization has holes. fn optimize_holes( &mut self, mut optimization: BlockOptimization, last_index: usize, ordering_global: &mut Vec, ) -> BlockOptimizationStep { match optimization.strategy { ExecutionStrategy::Optimization { opt, ordering } => { ordering_global.append(&mut optimization.ordering); let holes = self.find_holes(last_index); if holes.is_empty() { let strategy = ExecutionStrategy::Optimization { opt, ordering }; BlockOptimizationStep::Contiguous { strategy } } else { let strategy = ExecutionStrategy::Optimization { opt, ordering }; BlockOptimizationStep::WithHoles { strategy, holes } } } ExecutionStrategy::Operations { ordering } => { let min = ordering.iter().min().unwrap(); ordering_global.push(*min); let strategy = ExecutionStrategy::Operations { ordering: Arc::new(vec![*min]), }; BlockOptimizationStep::Operation { strategy } } _ => unreachable!(), } } fn update_check(&mut self, pos: usize) { self.resolved[pos] = true; for i in self.last_checked..self.resolved.len() { if self.resolved[i] { self.last_checked += 1; } else { break; } } } fn find_holes(&mut self, last: usize) -> Vec { let mut fallbacks = Vec::new(); for i in self.last_checked..last { if !self.resolved[i] { fallbacks.push(i); self.resolved[i] = true; } self.last_checked += 1; } fallbacks } /// Try to merge blocks together. fn merging_pass(mut self) -> Self { if self.blocks.len() == 1 { return self; } Block::sort(&mut self.blocks); let blocks = self.blocks.iter().collect::>(); match merge_blocks(&blocks, false) { MergeBlocksResult::Full(block) => { self.blocks = vec![block]; } MergeBlocksResult::Partial { mut merged, mut failed, } => { merged.append(&mut failed); self.blocks = merged; Block::sort(&mut self.blocks); } MergeBlocksResult::Fail => {} } self } } ================================================ FILE: crates/burn-fusion/src/search/optimization/mod.rs ================================================ mod blocks; mod stream; pub use stream::*; ================================================ FILE: crates/burn-fusion/src/search/optimization/stream.rs ================================================ use super::blocks::BlocksOptimizer; use crate::{ NumOperations, OperationFuser, search::{ Block, BlockOptimization, RegistrationResult, merging::{MergeBlocksResult, merge_blocks}, optimization::blocks::BlocksOptimizerResult, }, stream::store::ExecutionStrategy, }; use burn_ir::OperationIr; /// Optimize a stream of [operations](OperationIr) using a list of [builders](OptimizationBuilder). pub struct StreamOptimizer { builders: Vec>>, blocks: Vec>, length: usize, stopped: bool, max_blocks: Option, } impl StreamOptimizer { /// Create a new stream optimizer. pub fn new(builders: Vec>>) -> Self { Self { builders, blocks: Vec::new(), length: 0, stopped: false, // Too high and it may breaks the fusion cache always retriggering explorations. max_blocks: Some(5), } } /// Register a new [operation](OperationIr) in the optimizer. /// /// You can use the function [Self::still_optimizing] to know if the operations are actually /// being registered. pub fn register(&mut self, operation: &OperationIr) { if self.stopped { return; } if self.blocks.is_empty() { self.on_new_block(operation); self.length += 1; return; } match self.merge_blocks(operation, false) { MergeBlockStep::Full | MergeBlockStep::NoNeed => {} MergeBlockStep::Fail | MergeBlockStep::Partial => { // With the given operation, blocks are no longer independent. self.stopped = true; return; } } if let Some(max_blocks) = self.max_blocks { if self.register_max_block(operation, max_blocks) { self.length += 1; } else { self.stopped = true; } return; } let added_count = self.register_inner(operation, false); if added_count == 0 { self.on_new_block(operation); } self.length += 1; } /// Optimize the current stream on the given [operations](OperationIr). /// /// # Notes /// /// The operations provided are the same as the ones used in the [register](Self::register) /// method, this simply remove the need for the current type to also keep track of the list of /// operations. pub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization { let result = BlocksOptimizer::new(self.blocks.clone()).optimize(); match result { BlocksOptimizerResult::Full(block_optimization) => block_optimization, BlocksOptimizerResult::WithHoles { mut strategies, mut ordering, mut holes, } => { loop { let mut search = self.new_empty_search(); let mut operations_holes = Vec::with_capacity(holes.len()); for index in holes.iter() { let op = &operations[*index]; operations_holes.push(op.clone()); search.register(op); } let mut optimization_of_holes = search.optimize(&operations_holes); optimization_of_holes.map_ordering(&holes); strategies.push(Box::new(optimization_of_holes.strategy)); holes.drain(0..optimization_of_holes.ordering.len()); ordering.append(&mut optimization_of_holes.ordering); if holes.is_empty() { break; } } BlockOptimization::new(ExecutionStrategy::Composed(strategies), ordering) } } } /// Reset the state of the optimizer. pub fn reset(&mut self) { self.builders.iter_mut().for_each(|b| b.reset()); self.length = 0; self.blocks.clear(); self.stopped = false; } /// Returns if some optimizations are still possible within the stream. pub fn still_optimizing(&self) -> bool { if self.stopped { return false; } if self.blocks.is_empty() { return true; } let mut num_stopped = 0; for block in self.blocks.iter() { if !block.still_optimizing() { num_stopped += 1 } } num_stopped < self.blocks.len() } fn register_max_block(&mut self, operation: &OperationIr, max_blocks: usize) -> bool { if max_blocks == 1 { // Register in the single block with a force. self.register_inner(operation, true); return true; } let added_count = self.register_inner(operation, false); if added_count > 0 { return true; } if added_count == 0 && self.blocks.len() < max_blocks { self.on_new_block(operation); return true; } self.merge_blocks(operation, true); if self.blocks.len() >= max_blocks { self.stopped = true; return false; } let added_count = self.register_inner(operation, false); if added_count == 0 { self.on_new_block(operation); } true } fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize { let mut added_count = 0; for block in self.blocks.iter_mut() { match block.register(operation, self.length, force) { RegistrationResult::Accepted => { added_count += 1; } RegistrationResult::NotPartOfTheGraph => {} } } added_count } fn new_empty_search(&self) -> Self { Self::new( self.builders .iter() .map(|b| { let mut b = b.clone_dyn(); b.reset(); b }) .collect(), ) } fn merge_blocks(&mut self, operation: &OperationIr, all: bool) -> MergeBlockStep { let nodes = operation.nodes(); let mut block_merges = Vec::new(); for (i, block) in self.blocks.iter().enumerate() { if all || block.contains_tensors(&nodes) { block_merges.push(i); } } if block_merges.len() <= 1 { return MergeBlockStep::NoNeed; } let blocks_to_merge = self .blocks .iter() .enumerate() .filter_map(|(i, g)| match block_merges.contains(&i) { true => Some(g), false => None, }) .collect::>(); let merged = merge_blocks(&blocks_to_merge, false); let mut clear_blocks = || { let mut indices = block_merges.to_vec(); indices.sort(); for g in indices.into_iter().rev() { self.blocks.remove(g); } }; match merged { MergeBlocksResult::Full(block) => { clear_blocks(); self.blocks.push(block); Block::sort(&mut self.blocks); MergeBlockStep::Full } MergeBlocksResult::Partial { mut merged, mut failed, } => { clear_blocks(); self.blocks.append(&mut merged); self.blocks.append(&mut failed); Block::sort(&mut self.blocks); MergeBlockStep::Partial } MergeBlocksResult::Fail => MergeBlockStep::Fail, } } fn on_new_block(&mut self, operation: &OperationIr) { let mut block = Block::new(&self.builders); block.register(operation, self.length, true); self.blocks.push(block); } } enum MergeBlockStep { Full, Partial, Fail, NoNeed, } ================================================ FILE: crates/burn-fusion/src/server.rs ================================================ use std::sync::Arc; use crate::{ FusionBackend, FusionRuntime, stream::{MultiStream, OperationStreams, StreamId, execution::Operation}, }; use burn_backend::{TensorData, backend::ExecutionError}; use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr}; pub struct FusionServer { streams: MultiStream, pub(crate) handles: HandleContainer, } impl FusionServer where R: FusionRuntime, { pub fn new(device: R::FusionDevice) -> Self { Self { streams: MultiStream::new(device.clone()), handles: HandleContainer::new(), } } pub fn register( &mut self, streams: OperationStreams, repr: OperationIr, operation: Arc>, ) { self.streams .register(streams, repr, operation, &mut self.handles) } pub fn drain_stream(&mut self, id: StreamId) { self.streams.drain(&mut self.handles, id) } pub fn read_float( &mut self, tensor: TensorIr, id: StreamId, ) -> impl Future> + Send + use where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); let tensor_float = self.handles.get_float_tensor::(&tensor); self.streams.mark_read(id, &tensor, &self.handles); B::float_into_data(tensor_float) } pub fn read_int( &mut self, tensor: TensorIr, id: StreamId, ) -> impl Future> + Send + use where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); let tensor_int = self.handles.get_int_tensor::(&tensor); self.streams.mark_read(id, &tensor, &self.handles); B::int_into_data(tensor_int) } pub fn read_bool( &mut self, tensor: TensorIr, id: StreamId, ) -> impl Future> + Send + use where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); let tensor_bool = self.handles.get_bool_tensor::(&tensor); self.streams.mark_read(id, &tensor, &self.handles); B::bool_into_data(tensor_bool) } pub fn read_quantized( &mut self, tensor: TensorIr, id: StreamId, ) -> impl Future> + Send + use where B: FusionBackend, { // Make sure all registered operations are executed. // The underlying backend can still be async. self.drain_stream(id); let tensor_q = self.handles.get_quantized_tensor::(&tensor); self.streams.mark_read(id, &tensor, &self.handles); B::q_into_data(tensor_q) } pub fn change_server_float( &mut self, tensor: &TensorIr, output_id: TensorId, stream_tensor: StreamId, device: &R::FusionDevice, server_device: &mut Self, ) where B: FusionBackend, { let tensor_float = self.handles.get_float_tensor::(tensor); self.streams.mark_read(stream_tensor, tensor, &self.handles); let tensor = B::float_to_device(tensor_float, device); server_device .handles .register_float_tensor::(&output_id, tensor.clone()); } pub fn resolve_server_float(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive where B: FusionBackend, { self.handles.get_float_tensor::(tensor) } pub fn resolve_server_int(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive where B: FusionBackend, { self.handles.get_int_tensor::(tensor) } pub fn resolve_server_bool(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive where B: FusionBackend, { self.handles.get_bool_tensor::(tensor) } pub fn change_server_int( &mut self, tensor: &TensorIr, output_id: TensorId, stream_tensor: StreamId, device: &R::FusionDevice, server_device: &mut Self, ) where B: FusionBackend, { let tensor_int = self.handles.get_int_tensor::(tensor); self.streams.mark_read(stream_tensor, tensor, &self.handles); let tensor = B::int_to_device(tensor_int, device); server_device .handles .register_int_tensor::(&output_id, tensor.clone()); } pub fn change_server_bool( &mut self, tensor: &TensorIr, output_id: TensorId, stream_tensor: StreamId, device: &R::FusionDevice, server_device: &mut Self, ) where B: FusionBackend, { let tensor_bool = self.handles.get_bool_tensor::(tensor); self.streams.mark_read(stream_tensor, tensor, &self.handles); let tensor = B::bool_to_device(tensor_bool, device); server_device .handles .register_bool_tensor::(&output_id, tensor.clone()); } pub fn change_server_quantized( &mut self, tensor: &TensorIr, output_id: TensorId, device: &R::FusionDevice, server_device: &mut Self, ) where B: FusionBackend, { let tensor = self.handles.get_quantized_tensor::(tensor); let tensor = B::q_to_device(tensor, device); server_device .handles .register_quantized_tensor::(&output_id, tensor); } } ================================================ FILE: crates/burn-fusion/src/stream/base.rs ================================================ pub use burn_backend::StreamId; ================================================ FILE: crates/burn-fusion/src/stream/context.rs ================================================ use burn_backend::{Shape, Slice}; use burn_ir::*; use hashbrown::HashMap; /// The context contains the relative graph tensor mapping so that a relative tensor id can be /// mapped to an existing tensor that can be fetched and updated with the /// [handle container](HandleContainer). /// /// It also contains all scalar values, which can change even for the same graph. They are sorted /// in the order in which they appear in the graph. #[allow(clippy::too_many_arguments)] #[derive(new)] pub struct Context<'a, H> { /// The tensor mapping where local tensor id points to the updated tensor representation. pub tensors: &'a mut HashMap, /// Handle container to retrieve tensors based on their representation. pub handles: &'a mut HandleContainer, /// Scalars found in the graph in the order they appeared. pub scalars: &'a mut HashMap, /// Shape mapping from relative shape ids to global (real) shape ids. pub shapes_relative2global: &'a HashMap, } #[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] /// Scalar unique identifier. pub struct ScalarId { /// The value. pub value: u64, } pub(crate) struct OperationConverter { tensors_relative2global: HashMap, tensors_global2relative: HashMap, shapes_global2relative: HashMap, shapes_relative2global: HashMap, scalars: HashMap, } impl Default for OperationConverter { fn default() -> Self { let mut val = Self { tensors_relative2global: Default::default(), tensors_global2relative: Default::default(), shapes_global2relative: Default::default(), shapes_relative2global: Default::default(), scalars: Default::default(), }; // global 1 is always shape id 0. val.shapes_global2relative.insert(1, 0); val.shapes_relative2global.insert(0, 1); val } } /// Fork of a [context](Context) which owns its data. pub struct ContextOwned { tensors: HashMap, handles: HandleContainer, scalars: HashMap, shapes_relative2global: HashMap, } impl ContextOwned { /// Convert into [context](Context). pub fn as_context(&mut self) -> Context<'_, H> { Context { tensors: &mut self.tensors, handles: &mut self.handles, scalars: &mut self.scalars, shapes_relative2global: &self.shapes_relative2global, } } /// Fork the context again. pub fn fork(&self) -> ContextOwned { ContextOwned { tensors: self.tensors.clone(), handles: self.handles.fork(), scalars: self.scalars.clone(), shapes_relative2global: self.shapes_relative2global.clone(), } } } impl Context<'_, H> { /// Fork the context into an [owned context](ContextOwned). pub fn fork(&self) -> ContextOwned { ContextOwned { tensors: self.tensors.clone(), handles: self.handles.fork(), scalars: self.scalars.clone(), shapes_relative2global: self.shapes_relative2global.clone(), } } } pub(crate) trait RelativeOps { /// Convert (usually an [`OperationIr`]) to a relative form. /// /// The id and the shape of tensors will be computed relative to existing /// operations in the queue. We do this because we want to fuse operations /// that have similar shapes, but we do not care about the exact values. /// /// Similar we do not care about the exact ids of the tensor, but about their /// relative ids (how close they are in the operation queue) fn to_relative(&self, converter: &mut OperationConverter) -> Self; } impl OperationConverter { pub(crate) fn context<'a, H>( &'a mut self, handles: &'a mut HandleContainer, ) -> Context<'a, H> { Context { handles, tensors: &mut self.tensors_relative2global, scalars: &mut self.scalars, shapes_relative2global: &self.shapes_relative2global, } } pub(crate) fn clear(&mut self) { self.tensors_relative2global.clear(); self.tensors_global2relative.clear(); self.shapes_global2relative.clear(); self.shapes_relative2global.clear(); // global 1 is always shape id 0. self.shapes_global2relative.insert(1, 0); self.shapes_relative2global.insert(0, 1); self.scalars.clear(); } } impl RelativeOps for OperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { OperationIr::BaseFloat(ops) => OperationIr::BaseFloat(ops.to_relative(converter)), OperationIr::BaseInt(ops) => OperationIr::BaseInt(ops.to_relative(converter)), OperationIr::BaseBool(ops) => OperationIr::BaseBool(ops.to_relative(converter)), OperationIr::NumericFloat(dtype, ops) => { OperationIr::NumericFloat(*dtype, ops.to_relative(converter)) } OperationIr::NumericInt(dtype, ops) => { OperationIr::NumericInt(*dtype, ops.to_relative(converter)) } OperationIr::Bool(ops) => OperationIr::Bool(ops.to_relative(converter)), OperationIr::Int(ops) => OperationIr::Int(ops.to_relative(converter)), OperationIr::Float(dtype, ops) => { OperationIr::Float(*dtype, ops.to_relative(converter)) } OperationIr::Module(ops) => OperationIr::Module(ops.to_relative(converter)), OperationIr::Custom(ops) => OperationIr::Custom(ops.to_relative(converter)), OperationIr::Init(ops) => OperationIr::Init(ops.to_relative(converter)), OperationIr::Drop(tensor) => OperationIr::Drop(tensor.to_relative(converter)), } } } impl RelativeOps for ModuleOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { ModuleOperationIr::Embedding(desc) => ModuleOperationIr::Embedding(EmbeddingOpIr { weights: desc.weights.to_relative(converter), indices: desc.indices.to_relative(converter), out: desc.out.to_relative(converter), }), ModuleOperationIr::EmbeddingBackward(desc) => { ModuleOperationIr::EmbeddingBackward(EmbeddingBackwardOpIr { weights: desc.weights.to_relative(converter), out_grad: desc.out_grad.to_relative(converter), indices: desc.indices.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv1d(desc) => ModuleOperationIr::Conv1d(Conv1dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }), ModuleOperationIr::Conv1dXBackward(desc) => { ModuleOperationIr::Conv1dXBackward(Conv1dXBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv1dWeightBackward(desc) => { ModuleOperationIr::Conv1dWeightBackward(Conv1dWeightBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv1dBiasBackward(desc) => { ModuleOperationIr::Conv1dBiasBackward(Conv1dBiasBackwardOpIr { x: desc.x.to_relative(converter), bias: desc.bias.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv2d(desc) => ModuleOperationIr::Conv2d(Conv2dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }), ModuleOperationIr::Conv2dXBackward(desc) => { ModuleOperationIr::Conv2dXBackward(Conv2dXBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv2dWeightBackward(desc) => { ModuleOperationIr::Conv2dWeightBackward(Conv2dWeightBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv2dBiasBackward(desc) => { ModuleOperationIr::Conv2dBiasBackward(Conv2dBiasBackwardOpIr { x: desc.x.to_relative(converter), bias: desc.bias.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv3d(desc) => ModuleOperationIr::Conv3d(Conv3dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }), ModuleOperationIr::Conv3dXBackward(desc) => { ModuleOperationIr::Conv3dXBackward(Conv3dXBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv3dWeightBackward(desc) => { ModuleOperationIr::Conv3dWeightBackward(Conv3dWeightBackwardOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Conv3dBiasBackward(desc) => { ModuleOperationIr::Conv3dBiasBackward(Conv3dBiasBackwardOpIr { x: desc.x.to_relative(converter), bias: desc.bias.to_relative(converter), output_grad: desc.output_grad.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::DeformableConv2d(desc) => { ModuleOperationIr::DeformableConv2d(Box::new(DeformConv2dOpIr { x: desc.x.to_relative(converter), offset: desc.offset.to_relative(converter), weight: desc.weight.to_relative(converter), mask: desc.mask.as_ref().map(|t| t.to_relative(converter)), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), })) } ModuleOperationIr::DeformableConv2dBackward(desc) => { ModuleOperationIr::DeformableConv2dBackward(Box::new(DeformConv2dBackwardOpIr { x: desc.x.to_relative(converter), offset: desc.offset.to_relative(converter), weight: desc.weight.to_relative(converter), mask: desc.mask.as_ref().map(|t| t.to_relative(converter)), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), out_grad: desc.out_grad.to_relative(converter), options: desc.options.clone(), input_grad: desc.input_grad.to_relative(converter), offset_grad: desc.offset_grad.to_relative(converter), weight_grad: desc.weight_grad.to_relative(converter), mask_grad: desc.mask_grad.as_ref().map(|t| t.to_relative(converter)), bias_grad: desc.bias_grad.as_ref().map(|t| t.to_relative(converter)), })) } ModuleOperationIr::ConvTranspose1d(desc) => { ModuleOperationIr::ConvTranspose1d(ConvTranspose1dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::ConvTranspose2d(desc) => { ModuleOperationIr::ConvTranspose2d(ConvTranspose2dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::ConvTranspose3d(desc) => { ModuleOperationIr::ConvTranspose3d(ConvTranspose3dOpIr { x: desc.x.to_relative(converter), weight: desc.weight.to_relative(converter), bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::AvgPool1d(desc) => ModuleOperationIr::AvgPool1d(AvgPool1dOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, count_include_pad: desc.count_include_pad, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }), ModuleOperationIr::AvgPool2d(desc) => ModuleOperationIr::AvgPool2d(AvgPool2dOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, count_include_pad: desc.count_include_pad, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }), ModuleOperationIr::AvgPool1dBackward(desc) => { ModuleOperationIr::AvgPool1dBackward(AvgPool1dBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, count_include_pad: desc.count_include_pad, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }) } ModuleOperationIr::AvgPool2dBackward(desc) => { ModuleOperationIr::AvgPool2dBackward(AvgPool2dBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, count_include_pad: desc.count_include_pad, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }) } ModuleOperationIr::AdaptiveAvgPool1d(desc) => { ModuleOperationIr::AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr { x: desc.x.to_relative(converter), output_size: desc.output_size, out: desc.out.to_relative(converter), }) } ModuleOperationIr::AdaptiveAvgPool2d(desc) => { ModuleOperationIr::AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr { x: desc.x.to_relative(converter), output_size: desc.output_size, out: desc.out.to_relative(converter), }) } ModuleOperationIr::AdaptiveAvgPool1dBackward(desc) => { ModuleOperationIr::AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::AdaptiveAvgPool2dBackward(desc) => { ModuleOperationIr::AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), out: desc.out.to_relative(converter), }) } ModuleOperationIr::MaxPool1d(desc) => ModuleOperationIr::MaxPool1d(MaxPool1dOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }), ModuleOperationIr::MaxPool1dWithIndices(desc) => { ModuleOperationIr::MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), out_indices: desc.out_indices.to_relative(converter), }) } ModuleOperationIr::MaxPool1dWithIndicesBackward(desc) => { ModuleOperationIr::MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), indices: desc.indices.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }) } ModuleOperationIr::MaxPool2d(desc) => ModuleOperationIr::MaxPool2d(MaxPool2dOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }), ModuleOperationIr::MaxPool2dWithIndices(desc) => { ModuleOperationIr::MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), out_indices: desc.out_indices.to_relative(converter), }) } ModuleOperationIr::MaxPool2dWithIndicesBackward(desc) => { ModuleOperationIr::MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), indices: desc.indices.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, padding: desc.padding, dilation: desc.dilation, ceil_mode: desc.ceil_mode, out: desc.out.to_relative(converter), }) } ModuleOperationIr::Interpolate(desc) => { ModuleOperationIr::Interpolate(InterpolateOpIr { x: desc.x.to_relative(converter), output_size: desc.output_size, options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::InterpolateBackward(desc) => { ModuleOperationIr::InterpolateBackward(InterpolateBackwardOpIr { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), output_size: desc.output_size, options: desc.options.clone(), out: desc.out.to_relative(converter), }) } ModuleOperationIr::Attention(desc) => ModuleOperationIr::Attention(AttentionOpIr { query: desc.query.to_relative(converter), key: desc.key.to_relative(converter), value: desc.value.to_relative(converter), mask: desc.mask.as_ref().map(|m| m.to_relative(converter)), attn_bias: desc.attn_bias.as_ref().map(|ab| ab.to_relative(converter)), options: desc.options.clone(), out: desc.out.to_relative(converter), }), } } } impl RelativeOps for FloatOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { FloatOperationIr::Exp(desc) => FloatOperationIr::Exp(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Log(desc) => FloatOperationIr::Log(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Log1p(desc) => FloatOperationIr::Log1p(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Erf(desc) => FloatOperationIr::Erf(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Powf(desc) => FloatOperationIr::Powf(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::PowfScalar(desc) => FloatOperationIr::PowfScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Sqrt(desc) => FloatOperationIr::Sqrt(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Cos(desc) => FloatOperationIr::Cos(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Sin(desc) => FloatOperationIr::Sin(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Tanh(desc) => FloatOperationIr::Tanh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Tan(desc) => FloatOperationIr::Tan(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Cosh(desc) => FloatOperationIr::Cosh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Sinh(desc) => FloatOperationIr::Sinh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcCos(desc) => FloatOperationIr::ArcCos(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcCosh(desc) => FloatOperationIr::ArcCosh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcSin(desc) => FloatOperationIr::ArcSin(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcSinh(desc) => FloatOperationIr::ArcSinh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcTan(desc) => FloatOperationIr::ArcTan(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcTanh(desc) => FloatOperationIr::ArcTanh(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::ArcTan2(desc) => FloatOperationIr::ArcTan2(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::IntoInt(desc) => FloatOperationIr::IntoInt(CastOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Matmul(desc) => FloatOperationIr::Matmul(MatmulOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Cross(desc) => FloatOperationIr::Cross(CrossOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), dim: desc.dim, }), FloatOperationIr::Random(desc) => FloatOperationIr::Random(RandomOpIr { out: desc.out.to_relative(converter), distribution: desc.distribution, }), FloatOperationIr::Recip(desc) => FloatOperationIr::Recip(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Quantize(desc) => FloatOperationIr::Quantize(QuantizeOpIr { tensor: desc.tensor.to_relative(converter), qparams: QuantizationParametersIr { scales: desc.qparams.scales.to_relative(converter), }, scheme: desc.scheme, out: desc.out.to_relative(converter), }), FloatOperationIr::Dequantize(desc) => FloatOperationIr::Dequantize(DequantizeOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Round(desc) => FloatOperationIr::Round(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Floor(desc) => FloatOperationIr::Floor(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Ceil(desc) => FloatOperationIr::Ceil(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::Trunc(desc) => FloatOperationIr::Ceil(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::IsNan(desc) => FloatOperationIr::IsNan(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::IsInf(desc) => FloatOperationIr::IsInf(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), FloatOperationIr::GridSample2d(desc) => { FloatOperationIr::GridSample2d(GridSample2dOpIr { tensor: desc.tensor.to_relative(converter), grid: desc.grid.to_relative(converter), options: desc.options.clone(), out: desc.out.to_relative(converter), }) } } } } impl RelativeOps for BoolOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { BoolOperationIr::IntoFloat(desc) => BoolOperationIr::IntoFloat(CastOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BoolOperationIr::IntoInt(desc) => BoolOperationIr::IntoInt(CastOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BoolOperationIr::Not(desc) => BoolOperationIr::Not(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BoolOperationIr::And(desc) => BoolOperationIr::And(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), BoolOperationIr::Or(desc) => BoolOperationIr::Or(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), } } } impl RelativeOps for IntOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { IntOperationIr::IntoFloat(desc) => IntOperationIr::IntoFloat(CastOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::Matmul(desc) => IntOperationIr::Matmul(MatmulOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseAnd(desc) => IntOperationIr::BitwiseAnd(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseAndScalar(desc) => { IntOperationIr::BitwiseAndScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs, out: desc.out.to_relative(converter), }) } IntOperationIr::BitwiseOr(desc) => IntOperationIr::BitwiseOr(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseOrScalar(desc) => IntOperationIr::BitwiseOrScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs, out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseXor(desc) => IntOperationIr::BitwiseXor(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseXorScalar(desc) => { IntOperationIr::BitwiseXorScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs, out: desc.out.to_relative(converter), }) } IntOperationIr::BitwiseNot(desc) => IntOperationIr::BitwiseNot(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), IntOperationIr::BitwiseLeftShift(desc) => { IntOperationIr::BitwiseLeftShift(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } IntOperationIr::BitwiseLeftShiftScalar(desc) => { IntOperationIr::BitwiseLeftShiftScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs, out: desc.out.to_relative(converter), }) } IntOperationIr::BitwiseRightShift(desc) => { IntOperationIr::BitwiseRightShift(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } IntOperationIr::BitwiseRightShiftScalar(desc) => { IntOperationIr::BitwiseRightShiftScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs, out: desc.out.to_relative(converter), }) } } } } impl RelativeOps for CustomOpIr { fn to_relative(&self, converter: &mut OperationConverter) -> CustomOpIr { let id = self.id.clone(); CustomOpIr { id, inputs: self .inputs .iter() .map(|x| x.to_relative(converter)) .collect(), outputs: self .outputs .iter() .map(|x| x.to_relative(converter)) .collect(), } } } impl RelativeOps for NumericOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { NumericOperationIr::Add(desc) => NumericOperationIr::Add(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::AddScalar(desc) => NumericOperationIr::AddScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Sub(desc) => NumericOperationIr::Sub(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::SubScalar(desc) => NumericOperationIr::SubScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Div(desc) => NumericOperationIr::Div(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::DivScalar(desc) => NumericOperationIr::DivScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Rem(desc) => NumericOperationIr::Rem(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::RemScalar(desc) => NumericOperationIr::RemScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Mul(desc) => NumericOperationIr::Mul(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::MulScalar(desc) => NumericOperationIr::MulScalar(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Abs(desc) => NumericOperationIr::Abs(UnaryOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Full(desc) => NumericOperationIr::Full(FullOpIr { out: desc.out.to_relative(converter), value: desc.value.to_relative(converter), }), NumericOperationIr::MeanDim(desc) => NumericOperationIr::MeanDim(ReduceDimOpIr { input: desc.input.to_relative(converter), axis: desc.axis, out: desc.out.to_relative(converter), }), NumericOperationIr::Mean(desc) => NumericOperationIr::Mean(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::Sum(desc) => NumericOperationIr::Sum(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::SumDim(desc) => { NumericOperationIr::SumDim(ReduceDimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, // Axis should stay the same. }) } NumericOperationIr::Prod(desc) => NumericOperationIr::Prod(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::ProdDim(desc) => NumericOperationIr::ProdDim(ReduceDimOpIr { input: desc.input.to_relative(converter), axis: desc.axis, out: desc.out.to_relative(converter), }), NumericOperationIr::Greater(desc) => NumericOperationIr::Greater(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::GreaterElem(desc) => NumericOperationIr::GreaterElem(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::GreaterEqual(desc) => { NumericOperationIr::GreaterEqual(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } NumericOperationIr::GreaterEqualElem(desc) => { NumericOperationIr::GreaterEqualElem(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } NumericOperationIr::Lower(desc) => NumericOperationIr::Lower(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::LowerElem(desc) => NumericOperationIr::LowerElem(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::LowerEqual(desc) => NumericOperationIr::LowerEqual(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::LowerEqualElem(desc) => { NumericOperationIr::LowerEqualElem(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } NumericOperationIr::ArgMax(desc) => NumericOperationIr::ArgMax(ReduceDimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, // Axis should stay the same. }), NumericOperationIr::ArgMin(desc) => NumericOperationIr::ArgMin(ReduceDimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, // Axis should stay the same. }), NumericOperationIr::Max(desc) => NumericOperationIr::Max(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::MaxDimWithIndices(desc) => { NumericOperationIr::MaxDimWithIndices(ReduceDimWithIndicesOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, out: desc.out.to_relative(converter), out_indices: desc.out_indices.to_relative(converter), }) } NumericOperationIr::MinDimWithIndices(desc) => { NumericOperationIr::MinDimWithIndices(ReduceDimWithIndicesOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, out: desc.out.to_relative(converter), out_indices: desc.out_indices.to_relative(converter), }) } NumericOperationIr::Min(desc) => NumericOperationIr::Min(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::MaxDim(desc) => NumericOperationIr::MaxDim(ReduceDimOpIr { input: desc.input.to_relative(converter), axis: desc.axis, out: desc.out.to_relative(converter), }), NumericOperationIr::MinDim(desc) => NumericOperationIr::MinDim(ReduceDimOpIr { input: desc.input.to_relative(converter), axis: desc.axis, out: desc.out.to_relative(converter), }), NumericOperationIr::MaxAbs(desc) => NumericOperationIr::MaxAbs(ReduceOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::MaxAbsDim(desc) => NumericOperationIr::MaxAbsDim(ReduceDimOpIr { input: desc.input.to_relative(converter), axis: desc.axis, out: desc.out.to_relative(converter), }), NumericOperationIr::Clamp(desc) => NumericOperationIr::Clamp(ClampOpIr { tensor: desc.tensor.to_relative(converter), min: desc.min.to_relative(converter), max: desc.max.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::IntRandom(desc) => NumericOperationIr::IntRandom(RandomOpIr { out: desc.out.to_relative(converter), distribution: desc.distribution, }), NumericOperationIr::Powi(desc) => NumericOperationIr::Powi(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), NumericOperationIr::CumSum(desc) => NumericOperationIr::CumSum(DimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, }), NumericOperationIr::CumProd(desc) => NumericOperationIr::CumProd(DimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, }), NumericOperationIr::CumMin(desc) => NumericOperationIr::CumMin(DimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, }), NumericOperationIr::CumMax(desc) => NumericOperationIr::CumMax(DimOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axis: desc.axis, }), } } } impl RelativeOps for BaseOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { BaseOperationIr::Reshape(desc) => BaseOperationIr::Reshape(ShapeOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::SwapDims(desc) => BaseOperationIr::SwapDims(SwapDimsOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), dim1: desc.dim1, dim2: desc.dim2, }), BaseOperationIr::Permute(desc) => BaseOperationIr::Permute(PermuteOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axes: desc.axes.clone(), }), BaseOperationIr::Expand(desc) => BaseOperationIr::Expand(ShapeOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::Unfold(desc) => BaseOperationIr::Unfold(UnfoldOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), dim: desc.dim, size: desc.size, step: desc.step, }), BaseOperationIr::Flip(desc) => BaseOperationIr::Flip(FlipOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), axes: desc.axes.clone(), }), BaseOperationIr::Slice(desc) => BaseOperationIr::Slice(SliceOpIr { tensor: desc.tensor.to_relative(converter), ranges: desc.ranges.iter().map(|_info| Slice::from(0..1)).collect(), out: desc.out.to_relative(converter), }), BaseOperationIr::SliceAssign(desc) => BaseOperationIr::SliceAssign(SliceAssignOpIr { tensor: desc.tensor.to_relative(converter), ranges: desc.ranges.iter().map(|_range| Slice::from(0..1)).collect(), value: desc.value.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::Gather(desc) => BaseOperationIr::Gather(GatherOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, indices: desc.indices.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::Scatter(desc) => BaseOperationIr::Scatter(ScatterOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, indices: desc.indices.to_relative(converter), value: desc.value.to_relative(converter), update: desc.update, out: desc.out.to_relative(converter), }), BaseOperationIr::Select(desc) => BaseOperationIr::Select(SelectOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, indices: desc.indices.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::SelectAssign(desc) => { BaseOperationIr::SelectAssign(SelectAssignOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, indices: desc.indices.to_relative(converter), value: desc.value.to_relative(converter), update: desc.update, out: desc.out.to_relative(converter), }) } BaseOperationIr::MaskWhere(desc) => BaseOperationIr::MaskWhere(MaskWhereOpIr { tensor: desc.tensor.to_relative(converter), mask: desc.mask.to_relative(converter), value: desc.value.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::MaskFill(desc) => BaseOperationIr::MaskFill(MaskFillOpIr { tensor: desc.tensor.to_relative(converter), mask: desc.mask.to_relative(converter), value: desc.value.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::Equal(desc) => BaseOperationIr::Equal(BinaryOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::EqualElem(desc) => BaseOperationIr::EqualElem(ScalarOpIr { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::RepeatDim(desc) => BaseOperationIr::RepeatDim(RepeatDimOpIr { tensor: desc.tensor.to_relative(converter), dim: desc.dim, times: desc.times, out: desc.out.to_relative(converter), }), BaseOperationIr::Cat(desc) => BaseOperationIr::Cat(CatOpIr { tensors: desc .tensors .iter() .map(|tensor| tensor.to_relative(converter)) .collect(), dim: desc.dim, out: desc.out.to_relative(converter), }), BaseOperationIr::Cast(desc) => BaseOperationIr::Cast(CastOpIr { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }), BaseOperationIr::Empty(desc) => BaseOperationIr::Empty(desc.to_relative(converter)), BaseOperationIr::Ones(desc) => BaseOperationIr::Ones(desc.to_relative(converter)), BaseOperationIr::Zeros(desc) => BaseOperationIr::Zeros(desc.to_relative(converter)), } } } impl RelativeOps for InitOperationIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { Self { out: self.out.to_relative(converter), } } } impl RelativeOps for CreationOpIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { Self { out: self.out.to_relative(converter), } } } impl RelativeOps for TensorIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { let relative_id = self.id.to_relative(converter); // We can create relative shapes by mapping each shape found to an ID, which is a `usize`. let mut relative_shape = Vec::with_capacity(self.shape.rank()); for dim in self.shape.iter() { if let Some(dim_id) = converter.shapes_global2relative.get(dim) { // We already saw that dim value before, so we retrieve its ID. relative_shape.push(*dim_id); } else { // We never saw this dim value before, therefore we create a new ID. let dim_id = converter.shapes_global2relative.len(); relative_shape.push(dim_id); converter.shapes_global2relative.insert(*dim, dim_id); converter.shapes_relative2global.insert(dim_id, *dim); } } // We create the relative tensor. let relative_tensor = TensorIr { id: relative_id, shape: Shape::from(relative_shape), status: self.status, dtype: self.dtype, }; // We update both mappings. converter .tensors_relative2global .insert(relative_id, self.clone()); converter .tensors_global2relative .insert(self.id, relative_tensor.clone()); relative_tensor } } impl RelativeOps for TensorId { fn to_relative(&self, converter: &mut OperationConverter) -> Self { if let Some(value) = converter.tensors_global2relative.get(self) { // If we already have the same tensor registered, we have to update its value, but not // its id. value.id } else { // We create a new relative id since we never seen this tensor in the graph before. TensorId::new(converter.tensors_relative2global.len() as u64) } } } impl RelativeOps for ScalarIr { fn to_relative(&self, converter: &mut OperationConverter) -> Self { if matches!(self, ScalarIr::Bool(_)) { todo!("Unsupported dtype ({self:?}) for scalar") } let id = ScalarId { value: converter.scalars.len() as u64, }; converter.scalars.insert(id, *self); ScalarIr::UInt(id.value) } } #[cfg(test)] mod tests { use super::*; use burn_backend::DType; use burn_ir::{TensorId, TensorIr, TensorStatus}; #[test] fn tensor_description_to_relative() { let tensor1 = TensorIr { id: TensorId::new(500), shape: Shape::new([512, 32, 2048]), status: TensorStatus::ReadOnly, dtype: DType::F32, }; let tensor2 = TensorIr { id: TensorId::new(501), shape: Shape::new([512, 128, 2048]), status: TensorStatus::ReadOnly, dtype: DType::F32, }; let mut converter = OperationConverter::default(); let tensor1_local = tensor1.to_relative(&mut converter); let tensor2_local = tensor2.to_relative(&mut converter); assert_eq!( tensor1_local, TensorIr { id: TensorId::new(0), shape: Shape::new([1, 2, 3]), status: TensorStatus::ReadOnly, dtype: DType::F32 } ); assert_eq!( tensor2_local, TensorIr { id: TensorId::new(1), shape: Shape::new([1, 4, 3]), status: TensorStatus::ReadOnly, dtype: DType::F32 } ); } #[test] fn scalar_ir_to_relative() { let scalar1 = ScalarIr::Float(1.0); let scalar2 = ScalarIr::UInt(1); let mut converter = OperationConverter::default(); let scalar1_local = scalar1.to_relative(&mut converter); let scalar2_local = scalar2.to_relative(&mut converter); assert_eq!(scalar1_local, ScalarIr::UInt(0)); assert_eq!(scalar2_local, ScalarIr::UInt(1)); } } ================================================ FILE: crates/burn-fusion/src/stream/execution/base.rs ================================================ use burn_ir::HandleContainer; use crate::FusionRuntime; /// The mode in which the execution is done. #[derive(Clone, Copy, Debug)] pub(crate) enum ExecutionMode { Lazy, Sync, } /// General trait to abstract how a single operation is executed. pub trait Operation: Send + Sync + core::fmt::Debug { /// Execute the operation. fn execute(&self, handles: &mut HandleContainer); } ================================================ FILE: crates/burn-fusion/src/stream/execution/explorer.rs ================================================ use burn_ir::OperationIr; use super::ExecutionMode; use crate::{ NumOperations, OperationFuser, search::{BlockOptimization, StreamOptimizer}, }; /// Explore and create new optimization. pub struct Explorer { optimizer: StreamOptimizer, num_deferred: usize, num_explored: usize, is_still_optimizing: bool, } /// The result of an exploration done by the [explorer](Explorer). pub enum ExplorationAction { /// Found a new optimization. Completed(BlockOptimization), /// We should continue exploring before arriving at a conclusion. Continue, } impl Explorer { /// Create a new explorer. pub(crate) fn new(optimizations: Vec>>) -> Self { Self { optimizer: StreamOptimizer::new(optimizations), num_deferred: 0, num_explored: 0, is_still_optimizing: true, } } /// Indicate that a new operation is added. pub(crate) fn on_new_operation(&mut self) { self.num_deferred += 1; } /// If the explorer is up to date. pub(crate) fn is_up_to_date(&self) -> bool { self.num_deferred == 0 } /// Explore the provided operations. pub(crate) fn explore( &mut self, operations: &[OperationIr], mode: ExecutionMode, ) -> ExplorationAction { self.update(operations); // Can only continue exploration when not sync. if let ExecutionMode::Lazy = mode && self.is_still_optimizing { return ExplorationAction::Continue; } let optimization = self.optimizer.optimize(operations); ExplorationAction::Completed(optimization) } /// Reset the state of the explorer to the provided list of operations. pub(crate) fn reset(&mut self, operations: &[OperationIr]) { self.optimizer.reset(); self.num_explored = 0; self.num_deferred = operations.len(); self.is_still_optimizing = true; } /// Register any operations that we had deferred fn update(&mut self, operations: &[OperationIr]) { for i in (0..self.num_deferred).rev() { if !self.is_still_optimizing { break; } let index = operations.len() - 1 - i; let relative = &operations[index]; self.optimizer.register(relative); self.num_explored += 1; self.is_still_optimizing = self.optimizer.still_optimizing(); } self.num_deferred = 0; } } ================================================ FILE: crates/burn-fusion/src/stream/execution/mod.rs ================================================ pub(crate) mod validator; mod base; mod explorer; mod ordering; mod policy; mod processor; pub use base::*; pub use ordering::*; pub(crate) use explorer::*; pub(crate) use policy::*; pub(crate) use processor::*; #[cfg(test)] pub(crate) mod tests; ================================================ FILE: crates/burn-fusion/src/stream/execution/ordering.rs ================================================ use std::sync::Arc; use burn_ir::HandleContainer; use crate::{FusionRuntime, NumOperations, Optimization, stream::Context}; use super::Operation; /// Manage the execution of potentially multiple optimizations and operations out of order. pub struct OrderedExecution { operations: Vec>>, num_executed: usize, ordering: Option>>, } impl OrderedExecution { /// Returns the operation that can be executed without impacting the state of the execution. /// /// This is useful to implement fallback for optimizations. #[allow(clippy::borrowed_box)] pub fn operation_within_optimization(&self, index: usize) -> Arc> { match &self.ordering { Some(val) => { let index = val[index]; self.operations[index].clone() } None => panic!("No ordering provided"), } } pub(crate) fn new(operations: Vec>>) -> Self { Self { operations, num_executed: 0, ordering: None, } } pub(crate) fn finish(mut self) -> (Vec>>, usize) { self.operations.drain(0..self.num_executed); (self.operations, self.num_executed) } pub(crate) fn execute_optimization( &mut self, optimization: &mut R::Optimization, context: &mut Context<'_, R::FusionHandle>, ordering: Arc>, ) { if ordering.len() > self.operations.len() { panic!("Ordering is bigger than operations"); } self.ordering = Some(ordering); let num_drained = optimization.len(); optimization.execute(context, self); self.num_executed += num_drained; } pub(crate) fn execute_operations( &mut self, handles: &mut HandleContainer, ordering: &[usize], ) { self.num_executed += ordering.len(); for id in ordering { let op = &self.operations[*id]; op.execute(handles); } } } ================================================ FILE: crates/burn-fusion/src/stream/execution/policy.rs ================================================ use burn_ir::OperationIr; use super::ExecutionMode; use super::validator::{ ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator, ValidatorState, }; use crate::stream::execution::validator::OperationsValidator; use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery}; use std::marker::PhantomData; /// The policy keeps track of all possible execution plans for the current operations. /// /// # Details /// /// We keep track of each new operation added and invalidate potential execution plans /// when we see a different operation is added. /// /// Therefore, the overhead is very minimal, since the time-complexity of checking for existing /// execution plans scales with the number of concurrent potential plans for the current operations, /// which isn't supposed to be big at any time. pub(crate) struct Policy { /// List of potential execution plans that are compatible with current stream segment candidates: Vec>, /// List of candidate execution plans that have been found; we can still keep searching /// to potentially find a better one. availables: Vec, /// The found execution plan that should be executed, along with the number of operations /// in the plan. found: Option<(ExecutionPlanId, usize)>, /// The number of operations that have been analyzed num_operations: usize, _item_type: PhantomData, } #[derive(new)] struct AvailableItem { id: ExecutionPlanId, size: usize, triggers: Vec, } /// Action to be made depending on the stream. #[derive(PartialEq, Eq, Debug)] pub enum Action { /// Continue exploring using the [builder](crate::OptimizationBuilder). Explore, /// The current policy indicates that an exploration may be possible in the future, so the /// best action is to defer any execution. /// /// Sometimes, it can be a false positive and a new exploration should be built from scratch. /// Therefore it's important to keep the previous operations to rebuild the state if it /// happens. Defer, /// An exploration has been found, and the best action is to execute it! Execute(ExecutionPlanId), } impl Policy { /// Create a new policy. pub(crate) fn new() -> Self { Self { candidates: Vec::new(), availables: Vec::new(), found: None, num_operations: 0, _item_type: PhantomData, } } /// Returns the [action](Action) that should be taken given the state of the policy. pub fn action( &self, store: &ExecutionPlanStore, operations: &[OperationIr], mode: ExecutionMode, ) -> Action { if self.num_operations < operations.len() { panic!( "Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed." ); } if let Some((id, _length)) = self.found { return Action::Execute(id); } match mode { ExecutionMode::Lazy => self.action_lazy(operations), ExecutionMode::Sync => self.action_sync(operations, store), } } /// Update the policy state. pub fn update(&mut self, store: &ExecutionPlanStore, operation: &OperationIr) { // reset the candidates to contain all execution plans starting with the operation. if self.num_operations == 0 { self.candidates = store .find(SearchQuery::PlansStartingWith(operation)) .into_iter() .map(OperationsValidator::new) .collect(); } self.update_candidates(store, operation); self.check_candidates(store); self.update_availables(store, operation); self.check_availables(); self.num_operations += 1; } // Reset the state of the policy. pub fn reset(&mut self) { self.candidates.clear(); self.availables.clear(); self.num_operations = 0; self.found = None; } /// Check which candidates can be removed, and which one can go from /// 'candidate' to 'available' fn check_candidates(&mut self, store: &ExecutionPlanStore) { let mut candidates_to_remove = Vec::new(); for candidate in self.candidates.iter() { match candidate.state { ValidatorState::Found { size } => { let item = store.get_unchecked(candidate.id); let mut triggers = Vec::with_capacity(item.triggers.len()); for (index, trigger) in item.triggers.iter().enumerate() { triggers.push(match trigger { ExecutionTrigger::OnOperations(_) => TriggerValidator::OnOperations { matching: OperationsValidator::new(index), progress: TriggerProgress::NotInit, }, ExecutionTrigger::OnSync => TriggerValidator::OnSync, ExecutionTrigger::Always => TriggerValidator::Always, }); } self.availables .push(AvailableItem::new(candidate.id, size, triggers)); candidates_to_remove.push(candidate.id); } ValidatorState::Invalidated => { candidates_to_remove.push(candidate.id); } ValidatorState::Validating => {} }; } let mut updated_candidates = Vec::new(); core::mem::swap(&mut updated_candidates, &mut self.candidates); self.candidates = updated_candidates .into_iter() .filter(|candidate| !candidates_to_remove.iter().any(|id| id == &candidate.id)) .collect(); } fn check_availables(&mut self) { for available in self.availables.iter() { for trigger in available.triggers.iter() { match trigger { TriggerValidator::OnOperations { matching, progress: _, } => { if let ValidatorState::Found { size: _size_of_trigger, } = matching.state { self.found = Some((available.id, available.size)); return; } } TriggerValidator::Always => { self.found = Some((available.id, available.size)); return; } TriggerValidator::OnSync => { // Does nothing during an update. } } } } } fn update_candidates(&mut self, store: &ExecutionPlanStore, operation: &OperationIr) { let main_store = ExecutionPlanOperationsStore::new(store); self.candidates .iter_mut() .for_each(|candidate| candidate.update(operation, self.num_operations, &main_store)); } fn update_availables(&mut self, store: &ExecutionPlanStore, operation: &OperationIr) { self.availables.iter_mut().for_each(|available| { let store_trigger = TriggerOperationsStore::new(available.id, store); available.triggers.iter_mut().for_each(|trigger| { if let TriggerValidator::OnOperations { matching, progress } = trigger { match progress { TriggerProgress::NotInit => { *progress = TriggerProgress::NumChecked(0); } TriggerProgress::NumChecked(num_check) => { matching.update(operation, *num_check, &store_trigger); *num_check += 1; } } } }); }); } fn action_lazy(&self, operations: &[OperationIr]) -> Action { if !self.candidates.is_empty() { return Action::Defer; } for available in self.availables.iter() { if available.size == operations.len() { return Action::Defer; } for trigger in available.triggers.iter() { if let TriggerValidator::OnOperations { matching, progress: _, } = trigger && let ValidatorState::Validating = matching.state { return Action::Defer; } } } Action::Explore } fn action_sync(&self, operations: &[OperationIr], store: &ExecutionPlanStore) -> Action { for available in self.availables.iter() { if available.size == operations.len() { return Action::Execute(available.id); } } for candidate in self.candidates.iter() { let item = store.get_unchecked(candidate.id); if item.operations.len() == operations.len() { return Action::Execute(candidate.id); } } Action::Explore } } #[cfg(test)] mod tests { use burn_backend::{DType, Shape}; use burn_ir::{FloatOperationIr, TensorId, TensorIr, TensorStatus, UnaryOpIr}; use super::*; use crate::{ search::BlockOptimization, stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger}, }; use std::ops::Range; #[test] fn given_no_optimization_should_explore() { let store = ExecutionPlanStore::default(); let mut policy = Policy::new(); let stream = TestStream::new(3); stream.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(0..3), Action::Explore, ); } #[test] fn given_existing_optimizations_when_sync_should_execute_one_when_available() { let mut store = ExecutionPlanStore::default(); let mut policy = Policy::new(); let stream = TestStream::new(3); let id_1 = store.add(ExecutionPlan { operations: stream.operations[0..2].to_vec(), triggers: Vec::new(), optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()), }); let _id_2 = store.add(ExecutionPlan { operations: stream.operations[0..3].to_vec(), triggers: Vec::new(), optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()), }); stream.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(0..2), Action::Defer, ); let action = policy.action(&store, &stream.operations[0..2], ExecutionMode::Sync); assert_eq!(action, Action::Execute(id_1)); } #[test] fn given_existing_plan_when_found_trigger_should_execute_plan() { let mut store = ExecutionPlanStore::default(); let mut policy = Policy::new(); let stream = TestStream::new(3); let id = store.add(ExecutionPlan { operations: stream.operations[0..2].to_vec(), triggers: stream.operations[2..3] .iter() .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()])) .collect(), optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()), }); stream.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(0..2), Action::Defer, ); stream.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(2..3), Action::Execute(id), ); } #[test] fn should_support_multiple_triggers() { let mut store = ExecutionPlanStore::default(); let mut policy_1 = Policy::new(); let mut policy_2 = Policy::new(); let mut stream_1 = TestStream::new(2); let mut stream_2 = TestStream::new(2); // Create different end operation for each stream. let trigger_id_1 = 5; let trigger_id_2 = 6; stream_1.new_ops(trigger_id_1); stream_2.new_ops(trigger_id_2); let id = store.add(ExecutionPlan { operations: stream_1.operations[0..2].to_vec(), triggers: vec![ ExecutionTrigger::OnOperations(vec![stream_1.operations[2].clone()]), ExecutionTrigger::OnOperations(vec![stream_2.operations[2].clone()]), ], optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()), }); stream_1.assert_updates( &store, &mut policy_1, AssertUpdatesOptions::OperationsIndex(0..2), Action::Defer, ); stream_2.assert_updates( &store, &mut policy_2, AssertUpdatesOptions::OperationsIndex(0..2), Action::Defer, ); stream_1.assert_updates( &store, &mut policy_1, AssertUpdatesOptions::OperationsIndex(2..3), // First trigger. Action::Execute(id), ); stream_2.assert_updates( &store, &mut policy_2, AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger. Action::Execute(id), ); } #[test] fn should_select_right_optimization() { let mut store = ExecutionPlanStore::default(); let mut policy_1 = Policy::new(); let mut policy_2 = Policy::new(); let mut stream_1 = TestStream::new(2); let mut stream_2 = TestStream::new(2); // Create different streams after op 2. stream_1.new_ops(4); stream_1.new_ops(5); stream_2.new_ops(5); stream_2.new_ops(6); let optimization_stream_1 = store.add(ExecutionPlan { operations: stream_1.operations[0..3].to_vec(), triggers: stream_1.operations[3..4] .iter() .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()])) .collect(), optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()), }); let optimization_stream_2 = store.add(ExecutionPlan { operations: stream_2.operations[0..3].to_vec(), triggers: stream_2.operations[3..4] .iter() .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()])) .collect(), optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()), }); assert_ne!(optimization_stream_1, optimization_stream_2); stream_1.assert_updates( &store, &mut policy_1, AssertUpdatesOptions::OperationsIndex(0..3), Action::Defer, ); stream_2.assert_updates( &store, &mut policy_2, AssertUpdatesOptions::OperationsIndex(0..3), Action::Defer, ); stream_1.assert_updates( &store, &mut policy_1, AssertUpdatesOptions::OperationsIndex(3..4), Action::Execute(optimization_stream_1), ); stream_2.assert_updates( &store, &mut policy_2, AssertUpdatesOptions::OperationsIndex(3..4), Action::Execute(optimization_stream_2), ); } #[test] fn should_invalidate_wrong_optimizations() { let mut store = ExecutionPlanStore::default(); let stream_1 = TestStream::new(4); let mut stream_2 = TestStream::new(2); stream_2.new_ops(6); stream_2.new_ops(7); store.add(ExecutionPlan { operations: stream_1.operations[0..3].to_vec(), triggers: stream_1.operations[3..4] .iter() .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()])) .collect(), optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()), }); let mut policy = Policy::new(); // Same path as stream 1 stream_2.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(0..2), Action::Defer, ); // But is different. stream_2.assert_updates( &store, &mut policy, AssertUpdatesOptions::OperationsIndex(2..4), Action::Explore, ); } #[derive(Default, Debug)] struct TestStream { tensors: Vec, operations: Vec, } #[derive(Debug)] enum AssertUpdatesOptions { OperationsIndex(Range), } impl TestStream { /// Create a new test stream with `num_ops` operations registered. pub fn new(num_ops: usize) -> Self { let mut stream = Self::default(); for id in 0..num_ops { stream.new_ops(id as u64 + 1); } stream } /// The first follow should only be cache miss. pub fn assert_updates( &self, optimizations: &ExecutionPlanStore<()>, policy: &mut Policy<()>, options: AssertUpdatesOptions, action: Action, ) { match options { AssertUpdatesOptions::OperationsIndex(range) => { for i in range { let stream = &self.operations[0..i]; let next_ops = &self.operations[i]; policy.update(optimizations, next_ops); let result = policy.action(optimizations, stream, ExecutionMode::Lazy); assert_eq!(result, action); } } } } /// Add a simple operation to the stream. pub fn new_ops(&mut self, out_id: u64) { if self.tensors.is_empty() { // Root node. self.new_empty_node(0); } // Out node. self.new_empty_node(out_id); self.operations.push(OperationIr::Float( DType::F32, FloatOperationIr::Log(self.unary_description()), )); } fn new_empty_node(&mut self, id: u64) { self.tensors.push(TensorIr { id: TensorId::new(id), shape: Shape::new([32, 32, 1]), status: TensorStatus::NotInit, dtype: DType::F32, }); } fn unary_description(&self) -> UnaryOpIr { let size = self.tensors.len(); UnaryOpIr { input: self.tensors[size - 2].clone(), out: self.tensors[size - 1].clone(), } } } } ================================================ FILE: crates/burn-fusion/src/stream/execution/processor.rs ================================================ use burn_ir::OperationIr; use super::{ExecutionMode, ExplorationAction, Explorer}; use crate::search::BlockOptimization; use crate::stream::execution::{Action, Policy}; use crate::stream::store::{ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger}; use crate::{NumOperations, OperationFuser}; /// Process a [stream segment](StreamSegment) following a [policy](Policy). pub(crate) struct Processor { policy: Policy, explorer: Explorer, } /// A part of a stream that can be executed partially using [execution plan](ExecutionPlan). pub(crate) trait StreamSegment { /// The operations in the segment. fn operations(&self) -> &[OperationIr]; /// Execute part of the segment using the given plan id. fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore); } impl Processor { /// Create a new stream processor. pub fn new(optimizations: Vec>>) -> Self { Self { policy: Policy::new(), explorer: Explorer::new(optimizations), } } /// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode). pub fn process( &mut self, mut segment: Segment, store: &mut ExecutionPlanStore, mode: ExecutionMode, ) where Segment: StreamSegment, { // We assume that we always register a new operation in lazy mode. if let ExecutionMode::Lazy = mode { self.on_new_operation(&segment, store); } loop { if segment.operations().is_empty() { break; } let action = self.policy.action(store, segment.operations(), mode); match action { Action::Explore => { self.explore(&mut segment, store, mode); if self.explorer.is_up_to_date() { break; } } Action::Defer => { match mode { ExecutionMode::Lazy => break, ExecutionMode::Sync => panic!("Can't defer while sync"), }; } Action::Execute(id) => { if let ExecutionMode::Sync = mode { store.add_trigger(id, ExecutionTrigger::OnSync); } segment.execute(id, store); self.reset(store, segment.operations()); } }; } } fn on_new_operation(&mut self, segment: &Segment, store: &mut ExecutionPlanStore) where Segment: StreamSegment, { self.policy.update( store, segment .operations() .last() .expect("At least one operation in the operation list."), ); self.explorer.on_new_operation(); } fn explore>( &mut self, item: &mut Item, store: &mut ExecutionPlanStore, mode: ExecutionMode, ) { match self.explorer.explore(item.operations(), mode) { ExplorationAction::Completed(optim) => { let id = Self::on_exploration_completed( &self.policy, item.operations(), store, optim, mode, ); item.execute(id, store); self.reset(store, item.operations()); } ExplorationAction::Continue => { if let ExecutionMode::Sync = mode { panic!("Can't continue exploring when sync.") } } } } fn reset(&mut self, store: &mut ExecutionPlanStore, operations: &[OperationIr]) { self.explorer.reset(operations); self.policy.reset(); // Reset the policy state with the remaining operations for operation in operations.iter() { self.policy.update(store, operation); } } /// We found an optimization (i.e. a new execution plan). /// Cache it in the store. fn on_exploration_completed( policy: &Policy, operations: &[OperationIr], store: &mut ExecutionPlanStore, optimization: BlockOptimization, mode: ExecutionMode, ) -> ExecutionPlanId { let num_optimized = optimization.ordering.len(); let relative = &operations[0..num_optimized]; match mode { ExecutionMode::Lazy => { let next_ops = &operations[num_optimized..operations.len()]; let trigger = if next_ops.is_empty() { // Happens if the next ops is included in the fused operation, and there is no // way the builder can still continue fusing. ExecutionTrigger::Always } else { ExecutionTrigger::OnOperations(next_ops.to_vec()) }; match policy.action(store, relative, ExecutionMode::Sync) { Action::Execute(id) => { store.add_trigger(id, trigger); id } _ => { let plan = ExecutionPlan { operations: relative.to_vec(), triggers: vec![trigger], optimization, }; store.add(plan) } } } ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) { Action::Execute(id) => { store.add_trigger(id, ExecutionTrigger::OnSync); id } _ => { let plan = ExecutionPlan { operations: relative.to_vec(), triggers: vec![ExecutionTrigger::OnSync], optimization, }; store.add(plan) } }, } } } ================================================ FILE: crates/burn-fusion/src/stream/execution/tests.rs ================================================ //! A testing module that ensures the correctness of the explorer, policy, and processor. //! //! The primary focus is on validating the seamless interaction between these three components to //! execute and optimize a stream of operations accurately. //! //! To test these components effectively, we create mock types for the stream, optimization, //! optimization builder, and stream segment. These mock types aid in comprehensively //! understanding the process of optimizing streams. use std::sync::Arc; use burn_backend::{DType, Shape}; use burn_ir::{ BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarIr, ScalarOpIr, TensorId, TensorIr, TensorStatus, UnaryOpIr, }; use crate::{ FuserProperties, FuserStatus, NumOperations, OperationFuser, search::BlockOptimization, stream::store::{ ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger, }, }; use super::*; /// A fake stream of operations for testing purpose. pub struct TestStream { processor: Processor, store: ExecutionPlanStore, executed: Vec, operations: Vec, } /// A fake [optimization builder](OptimizationBuilder) for testing purpose. /// /// The optimizer tries to fuse only the `expected_operations` if they appear /// in the operations queue #[derive(Clone)] pub struct TestOptimizationBuilder { builder_id: usize, expected_operations: Vec, actual: Vec, } /// A fake optimization for testing purpose. #[derive(new, Debug, PartialEq)] pub struct TestOptimization { builder_id: usize, size: usize, } impl NumOperations for TestOptimization { fn len(&self) -> usize { self.size } } /// A fake [stream segment](StreamSegment) for testing purpose. #[derive(new)] pub struct TestSegment<'i> { operations: &'i mut Vec, executed: &'i mut Vec, } impl ExecutionStrategy { /// Create an ordered execution strategy with the given size. pub fn operations(size: usize) -> Self { Self::Operations { ordering: Arc::new((0..size).collect()), } } } impl ExecutionStrategy { /// Only use it for testing, to easily create ordered strategies. pub fn optimization(opt: TestOptimization) -> Self { let ordering = Arc::new((0..opt.size).collect()); Self::Optimization { opt, ordering } } } /// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions. /// /// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is /// crucial to verify that the stream's state is correctly updated when various cases occur consecutively. #[test] fn should_support_complex_stream() { // We have 2 different optimization builders in this test case. let builder_id_1 = 0; let builder_id_2 = 1; // We will have a total of 3 execution plans to execute. let plan_id_1 = 0; let plan_id_2 = 1; let plan_id_3 = 2; let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]); let builder_2 = TestOptimizationBuilder::new(builder_id_2, vec![operation_2(), operation_2()]); let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]); // builder_1 is still waiting to see next op is operation_2 // builder_2 is closed because it's not the right operation stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(0); // No optimization found for the first two operations. stream.add(operation_1()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(1); stream.assert_last_executed(plan_id_1); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_1(), operation_1()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()), }, ); // Nothing to execute. stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(1); // Now we should trigger the first optimization builder. stream.add(operation_2()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(2); stream.assert_last_executed(plan_id_2); stream.assert_plan( plan_id_2, ExecutionPlan { operations: vec![operation_1(), operation_2()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization::new( ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)), vec![0, 1], ), }, ); // Nothing to execute. stream.add(operation_2()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(2); // Now we should trigger the second optimization builder. stream.add(operation_2()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(3); stream.assert_last_executed(plan_id_3); stream.assert_plan( plan_id_3, ExecutionPlan { operations: vec![operation_2(), operation_2()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_2, 2)), ordering: vec![0, 1], }, }, ); // Nothing to execute. stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(3); // Now we should trigger the first optimization builder (second plan). stream.add(operation_2()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(4); stream.assert_last_executed(plan_id_2); stream.assert_plan( plan_id_2, ExecutionPlan { operations: vec![operation_1(), operation_2()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)), ordering: vec![0, 1], }, }, ); // Nothing to execute. stream.add(operation_2()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(4); // Now we should trigger the first optimization builder (third plan). stream.add(operation_2()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(5); stream.assert_last_executed(plan_id_3); } /// In this scenario we will never use an optimization, but we check that we reuse the execution plan stored. #[test] fn should_reuse_basic_operations() { let builder_id_1 = 0; let plan_id_1 = 0; let plan_id_2 = 1; let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]); let mut stream = TestStream::new(vec![Box::new(builder_1)]); stream.add(operation_3()); stream.assert_last_executed(plan_id_1); stream.assert_number_of_operations(0); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_3()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(1), ordering: vec![0], }, }, ); stream.add(operation_3()); stream.assert_last_executed(plan_id_1); stream.assert_number_of_operations(0); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_3()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(1), ordering: vec![0], }, }, ); // Lazy try to build optimization 1. stream.add(operation_1()); // But not possible. stream.add(operation_3()); // Creates a new plan with both operations. stream.assert_plan( plan_id_2, ExecutionPlan { operations: vec![operation_1(), operation_3()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(2), ordering: vec![0], }, }, ); stream.assert_number_of_operations(0); stream.assert_last_executed(plan_id_2); } // In this scenario we validate that we support multiple optimization builders with overlapping // operations. // // This is a very long scenario that validates a lot of things. #[test] fn should_support_overlapping_optimizations() { // We have 2 different optimization builders in this test case. let builder_id_1 = 0; let builder_id_2 = 0; // We will have a total of 5 execution plans to execute. let plan_id_1 = 0; let plan_id_2 = 1; let plan_id_3 = 2; let plan_id_4 = 3; let plan_id_5 = 4; let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]); let builder_2 = TestOptimizationBuilder::new( builder_id_2, vec![operation_1(), operation_2(), operation_1(), operation_1()], ); let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]); stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(0); stream.add(operation_2()); stream.assert_number_of_operations(2); stream.assert_number_of_executions(0); stream.add(operation_1()); stream.assert_number_of_operations(3); stream.assert_number_of_executions(0); stream.add(operation_2()); stream.assert_number_of_operations(2); stream.assert_number_of_executions(1); stream.assert_last_executed(plan_id_1); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_1(), operation_2()], triggers: vec![ExecutionTrigger::OnOperations(vec![ operation_1(), operation_2(), ])], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)), ordering: vec![0, 1], }, }, ); stream.add(operation_2()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(3); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_1(), operation_2()], triggers: vec![ ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]), ExecutionTrigger::OnOperations(vec![operation_2()]), ], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)), ordering: vec![0, 1], }, }, ); stream.assert_plan( plan_id_2, ExecutionPlan { operations: vec![operation_2()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(1), ordering: vec![0], }, }, ); stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(3); stream.add(operation_2()); stream.assert_number_of_operations(2); stream.assert_number_of_executions(3); stream.add(operation_1()); stream.assert_number_of_operations(3); stream.assert_number_of_executions(3); stream.add(operation_1()); stream.assert_number_of_operations(0); stream.assert_number_of_executions(4); stream.assert_plan( plan_id_3, ExecutionPlan { operations: vec![operation_1(), operation_2(), operation_1(), operation_1()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 4)), ordering: vec![0], }, }, ); stream.add(operation_1()); stream.assert_number_of_operations(1); stream.assert_number_of_executions(4); stream.add(operation_2()); stream.assert_number_of_operations(2); stream.assert_number_of_executions(4); stream.add(operation_1()); stream.assert_number_of_operations(3); stream.assert_number_of_executions(4); stream.sync(); stream.assert_number_of_operations(0); stream.assert_number_of_executions(6); stream.assert_plan( plan_id_1, ExecutionPlan { operations: vec![operation_1(), operation_2()], triggers: vec![ ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]), ExecutionTrigger::OnOperations(vec![operation_2()]), ExecutionTrigger::OnSync, ], optimization: BlockOptimization { strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)), ordering: vec![0, 1], }, }, ); stream.assert_plan( plan_id_4, ExecutionPlan { operations: vec![operation_1()], triggers: vec![ExecutionTrigger::OnSync], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(1), ordering: vec![0], }, }, ); stream.add(operation_3()); stream.assert_last_executed(plan_id_5); stream.assert_plan( plan_id_5, ExecutionPlan { operations: vec![operation_3()], triggers: vec![ExecutionTrigger::Always], optimization: BlockOptimization { strategy: ExecutionStrategy::operations(1), ordering: vec![0], }, }, ); stream.add(operation_3()); stream.assert_last_executed(plan_id_5); } impl TestStream { /// Create a new stream with the given optimization builders. fn new(optimizations: Vec>>) -> Self { Self { processor: Processor::::new(optimizations), store: ExecutionPlanStore::::new(), executed: Vec::new(), operations: Vec::new(), } } /// Add an operation to the stream. fn add(&mut self, operation: OperationIr) { self.operations.push(operation); self.processor.process( TestSegment::new(&mut self.operations, &mut self.executed), &mut self.store, ExecutionMode::Lazy, ); } /// Sync the stream. fn sync(&mut self) { self.processor.process( TestSegment::new(&mut self.operations, &mut self.executed), &mut self.store, ExecutionMode::Sync, ); } /// Assert that the plan has been executed as provided. fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan) { let actual = self.store.get_unchecked(id); assert_eq!(actual.operations, expected.operations, "Same operations"); assert_eq!(actual.triggers, expected.triggers, "Same triggers"); } /// Assert that the given plan id has been the last executed. fn assert_last_executed(&self, id: ExecutionPlanId) { match self.executed.last() { Some(last_id) => assert_eq!(*last_id, id), None => panic!("No plan has been executed"), } } /// Assert the number of executions since the start of the stream. fn assert_number_of_executions(&self, number: usize) { assert_eq!(self.executed.len(), number, "Number of execution match"); } /// Assert the number of operations queued. fn assert_number_of_operations(&self, number: usize) { assert_eq!(self.operations.len(), number); } } impl TestOptimizationBuilder { /// Create a new optimization builder that follows a pattern with a trigger. pub fn new(builder_id: usize, operations: Vec) -> Self { Self { builder_id, expected_operations: operations, actual: Vec::new(), } } } impl OperationFuser for TestOptimizationBuilder { /// Register a new operation. fn fuse(&mut self, operation: &OperationIr) { self.actual.push(operation.clone()); } /// Build the optimization. fn finish(&mut self) -> TestOptimization { TestOptimization::new(self.builder_id, self.len()) } /// Reset the state. fn reset(&mut self) { self.actual.clear(); } /// Return the optimization status. fn status(&self) -> FuserStatus { if self.actual.len() < self.expected_operations.len() { let operations = &self.expected_operations[0..self.actual.len()]; return match self.actual == operations { // Still optimizing. true => FuserStatus::Open, // Never gonna be possible on that stream. false => FuserStatus::Closed, }; } FuserStatus::Closed } /// Return the properties of this optimization. fn properties(&self) -> FuserProperties { if self.actual.len() < self.expected_operations.len() { // Optimization not possible. return FuserProperties { score: 0, ready: false, }; } let stream_is_ok = self.actual[0..self.expected_operations.len()] == self.expected_operations; if !stream_is_ok { // Optimization not possible. return FuserProperties { score: 0, ready: false, }; } // Optimization possible. FuserProperties { score: self.expected_operations.len() as u64, ready: true, } } // The number of operations that should be handle by the optimization. fn len(&self) -> usize { self.expected_operations.len() } fn clone_dyn(&self) -> Box> { Box::new(self.clone()) } } impl StreamSegment for TestSegment<'_> { // The operations in the process. fn operations(&self) -> &[OperationIr] { self.operations } // Execute the process. fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { let execution_plan = store.get_unchecked(id); self.execute_strategy(&execution_plan.optimization.strategy); self.executed.push(id); } } impl TestSegment<'_> { fn execute_strategy(&mut self, strategy: &ExecutionStrategy) { match strategy { ExecutionStrategy::Optimization { opt, .. } => { self.operations.drain(0..opt.size); } ExecutionStrategy::Operations { ordering } => { self.operations.drain(0..ordering.len()); } ExecutionStrategy::Composed(strategies) => { for strategy in strategies { self.execute_strategy(strategy); } } } } } /// Just a simple operation. pub fn operation_1() -> OperationIr { OperationIr::NumericFloat( DType::F32, NumericOperationIr::Add(BinaryOpIr { lhs: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, rhs: TensorIr { id: TensorId::new(1), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, out: TensorIr { id: TensorId::new(2), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } /// Just a simple operation. pub fn operation_2() -> OperationIr { OperationIr::NumericFloat( DType::F32, NumericOperationIr::AddScalar(ScalarOpIr { lhs: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, rhs: ScalarIr::Float(5.0), out: TensorIr { id: TensorId::new(2), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } /// Just a simple operation. pub fn operation_3() -> OperationIr { OperationIr::Float( DType::F32, FloatOperationIr::Log(UnaryOpIr { input: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, out: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } ================================================ FILE: crates/burn-fusion/src/stream/execution/validator.rs ================================================ use burn_ir::OperationIr; use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger}; /// Compare each operation in the list of operations provided by the [store](OperationsStore) /// to verify if the newly added operations match the original list. /// /// It is used by the [policy](crate::stream::execution::Policy) to check each candidate as well /// as to verify if a list of operations is optimal to execute based on their triggers. #[derive(Debug)] pub(crate) struct OperationsValidator { /// The ID used to retrieve the operation list. pub(crate) id: ID, /// The current [state](MatchingState). pub(crate) state: ValidatorState, } /// The state of the validator. #[derive(Debug)] pub(crate) enum ValidatorState { /// A matching operation list has been found. Found { size: usize }, /// No matching operation list has been found. Invalidated, /// Potentially going to find a matching operation list when more operations are added. Validating, } /// Provides a list of operations based on an Id. pub(crate) trait OperationsStore { /// The type used for the identifier. type Id: Copy; /// retrieve the list of operations corresponding on the provided id. fn get(&self, id: Self::Id) -> &[OperationIr]; } impl OperationsValidator { /// Create a new validator. pub(crate) fn new(id: ID) -> Self { Self { id, state: ValidatorState::Validating, } } /// Update the state of the validator based on the newly added operation. pub(crate) fn update(&mut self, added: &OperationIr, added_position: usize, store: &S) where S: OperationsStore, ID: PartialEq + Copy, { match &self.state { ValidatorState::Found { size: _ } => return, ValidatorState::Invalidated => return, ValidatorState::Validating => {} }; let item = store.get(self.id); let operation_candidate = match item.get(added_position) { Some(val) => val, None => { self.state = ValidatorState::Invalidated; return; } }; if operation_candidate != added { self.state = ValidatorState::Invalidated; return; } // Finished if item.len() == added_position + 1 { self.state = ValidatorState::Found { size: item.len() }; } } } /// [Operations store](OperationsStore) used to retrieve the list of operations for a trigger. #[derive(new)] pub(crate) struct TriggerOperationsStore<'a, O> { id: ExecutionPlanId, store: &'a ExecutionPlanStore, } /// Validates when operations match a trigger. #[derive(Debug)] pub(crate) enum TriggerValidator { OnOperations { matching: OperationsValidator, progress: TriggerProgress, }, Always, OnSync, } /// The progress made into the trigger validation process. #[derive(Debug)] pub(crate) enum TriggerProgress { /// When the validation hasn't started. NotInit, /// The number of operations that have been checked. NumChecked(usize), } /// An execution plan can have many triggers, so we use the position in the list to identify a /// trigger. pub(crate) type TriggerId = usize; impl OperationsStore for TriggerOperationsStore<'_, O> { type Id = TriggerId; fn get(&self, id: Self::Id) -> &[OperationIr] { match &self.store.get_unchecked(self.id).triggers[id] { ExecutionTrigger::OnOperations(operations) => operations, ExecutionTrigger::OnSync => &[], ExecutionTrigger::Always => &[], } } } /// [Operations store](OperationsStore) used to retrieve the list of operations for an /// [execution plan](crate::stream::store::ExecutionPlan). #[derive(new)] pub(crate) struct ExecutionPlanOperationsStore<'a, O> { store: &'a ExecutionPlanStore, } impl OperationsStore for ExecutionPlanOperationsStore<'_, O> { type Id = ExecutionPlanId; fn get(&self, id: Self::Id) -> &[OperationIr] { &self.store.get_unchecked(id).operations } } ================================================ FILE: crates/burn-fusion/src/stream/memory_checks.rs ================================================ use hashbrown::HashMap; use std::{ fmt::Display, sync::{ Arc, atomic::{AtomicU64, Ordering}, mpsc::SyncSender, }, thread::JoinHandle, time::Duration, }; use burn_ir::{HandleContainer, TensorId, TensorStatus}; use burn_std::id::StreamId; use crate::FusionRuntime; use super::Stream; /// Memory checks struct to validate there is no memory leak with the fusion runtime. #[derive(Clone)] pub(crate) struct MemoryChecks { sender: SyncSender, num_queued: Arc, // Keeps track of its thread. _handle: Arc>, } enum Message { Register(StreamAnalyses), Check(SyncSender), } enum MemoryReport { Success, NotReady, NotStarted, Fail(String), } #[derive(Default)] struct StreamAnalyses { streams: HashMap, num_handles: usize, } impl Display for StreamAnalyses { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("\n==== Fusion Memory Report ====\n")?; f.write_fmt(format_args!(" - Handles: {}\n", self.num_handles))?; f.write_fmt(format_args!(" - Streams: {}\n", self.streams.len()))?; for (id, analysis) in self.streams.iter() { f.write_fmt(format_args!( " - {} => operations: {} cursor: {}\n", id, analysis.num_operations, analysis.cursor ))?; for (tid, (origin, status)) in analysis.variables.iter() { f.write_fmt(format_args!( " - {tid} => origin: {origin} status: {status:?}\n", ))?; } } f.write_str("==============================\n") } } #[derive(Default, Debug)] struct Analysis { variables: HashMap, num_operations: usize, cursor: u64, } #[macro_export] /// Export memory checks tests. macro_rules! memory_checks { () => { #[cfg(test)] mod memory_checks { #[test] fn test_memory_leaks() { burn_fusion::stream::memory_checks::check_memory_leaks(); } } }; } static INSTANCE: spin::Mutex> = spin::Mutex::new(None); /// Performs memory checks and panics if a leak is discovered. pub fn check_memory_leaks() { let mut num_try_uninit = 0; let max_try = 25; loop { let report = fetch_memory_report(); match report { MemoryReport::Success => return, MemoryReport::NotReady => { num_try_uninit = 0; std::thread::sleep(Duration::from_millis(100)) } MemoryReport::NotStarted => { if num_try_uninit >= max_try { // Nothing is running on the fusion runtime. return; } num_try_uninit += 1; std::thread::sleep(Duration::from_millis(100)) } MemoryReport::Fail(msg) => panic!("{msg}"), } } } fn fetch_memory_report() -> MemoryReport { let report = INSTANCE.lock(); let report = match report.as_ref() { Some(client) => client, None => return MemoryReport::NotStarted, }; let (sender, rec) = std::sync::mpsc::sync_channel(1); match report.sender.send(Message::Check(sender)) { Ok(_) => {} Err(err) => { panic!("Channel closed can't send the check call: {err:?}") } }; match rec.recv() { Ok(report) => report, Err(err) => panic!("Received an error from fetching check results: {err}"), } } impl Default for MemoryChecks { fn default() -> Self { let mut instance = INSTANCE.lock(); let result = match instance.as_mut() { Some(client) => client.clone(), None => { let this = Self::spawn_new(); *instance = Some(this.clone()); this } }; core::mem::drop(instance); result } } impl MemoryChecks { pub(crate) fn check( &mut self, streams: &HashMap>, handles: &HandleContainer, ) { let mut analyses = StreamAnalyses { num_handles: handles.num_handles(), streams: Default::default(), }; for (id, s) in streams.iter() { let analysis = Analysis { variables: s.queue.variables.clone(), num_operations: s.queue.global.len(), cursor: s.cursor, }; analyses.streams.insert(*id, analysis); } self.num_queued.fetch_add(1, Ordering::Relaxed); match self.sender.send(Message::Register(analyses)) { Ok(..) => {} Err(err) => { panic!("Can't register memory checks analysis: {err:?}") } } } fn spawn_new() -> Self { let (sender, rec) = std::sync::mpsc::sync_channel(100); let num_queued = Arc::new(AtomicU64::new(0)); let num_queued_moved = num_queued.clone(); let handle = std::thread::spawn(move || { let mut last_analyses = None; loop { let payload = match rec.recv() { Err(_err) => { // A client has panic, safe to skip as it may be normal. continue; } Ok(payload) => payload, }; match payload { Message::Register(payload) => { last_analyses = Some(payload); num_queued_moved.fetch_sub(1, Ordering::Relaxed); } Message::Check(callback) => { if num_queued_moved.load(Ordering::Relaxed) > 1 { callback.send(MemoryReport::NotReady).unwrap(); continue; } // We assume that if nothing has been registered in the last second // while being at a count of 1, it's the end. std::thread::sleep(Duration::from_secs(5)); if num_queued_moved.load(Ordering::Relaxed) <= 1 { match last_analyses.take() { Some(val) => { callback.send(Self::final_check(val)).unwrap(); } None => { callback .send(MemoryReport::Fail("No analyses".into())) .unwrap(); } } } else { callback.send(MemoryReport::NotReady).unwrap(); } } } } }); Self { sender, num_queued, _handle: Arc::new(handle), } } fn final_check(analyses: StreamAnalyses) -> MemoryReport { if !analyses.streams.is_empty() || analyses.num_handles > 0 { return MemoryReport::Fail(format!("{analyses}")); } MemoryReport::Success } } ================================================ FILE: crates/burn-fusion/src/stream/mod.rs ================================================ pub(crate) mod execution; pub(crate) mod queue; pub(crate) mod shared_tensors; pub(crate) mod store; #[cfg(feature = "memory-checks")] /// Memory checks module. pub mod memory_checks; #[cfg(not(feature = "memory-checks"))] #[macro_export] /// Export memory checks tests. macro_rules! memory_checks { () => { #[cfg(test)] mod memory_checks { #[ignore = "'memory-checks' disabled"] #[test] fn test_memory_leaks() { // } } }; } mod base; mod context; mod multi; pub use base::*; pub use context::*; pub use execution::*; pub use multi::*; ================================================ FILE: crates/burn-fusion/src/stream/multi.rs ================================================ use std::sync::Arc; use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus}; use hashbrown::{HashMap, HashSet}; use super::{ StreamId, execution::{ExecutionMode, Operation, Processor, StreamSegment}, queue::OperationQueue, shared_tensors::SharedTensors, store::{ExecutionPlanId, ExecutionPlanStore}, }; use crate::{ DropOp, FusionRuntime, stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction}, }; /// Keep track of multiple concurrent lazy streams of operations. pub struct MultiStream { streams: HashMap>, optimizations: ExecutionPlanStore, shared_tensors: SharedTensors, device: R::FusionDevice, #[cfg(feature = "memory-checks")] memory_checks: super::memory_checks::MemoryChecks, } #[derive(Debug)] enum DropAction { SkipSharedTensor, ForceSharedTensor(Vec, TensorId), ContinueDrop, } impl MultiStream { pub(crate) fn new(device: R::FusionDevice) -> Self { Self { streams: HashMap::new(), optimizations: ExecutionPlanStore::new(), shared_tensors: SharedTensors::default(), device, #[cfg(feature = "memory-checks")] memory_checks: super::memory_checks::MemoryChecks::default(), } } /// Register a new tensor operation. pub(crate) fn register( &mut self, streams: OperationStreams, mut repr: OperationIr, operation: Arc>, handles: &mut HandleContainer, ) { let id = self.resolve_streams(&streams, handles, &mut repr); let drop_action = match &mut repr { OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)), _ => None, }; let sync = match drop_action { Some(DropAction::SkipSharedTensor) => return, Some(DropAction::ContinueDrop) => true, Some(DropAction::ForceSharedTensor(stream_ids, tid)) => { for stream_id in stream_ids { if let Some(stream) = self.streams.get_mut(&stream_id) { stream.queue.variables.remove(&tid); if stream.queue.variables.is_empty() { self.streams.remove(&stream_id); } } } true } None => false, }; let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles); if num_executed > 0 && let Some(stream) = self.streams.get_mut(&id) { let cleared = self.shared_tensors.on_executed_ops(id, stream); self.clear_shared_tensors(&cleared, id); let to_drop = self.shared_tensors.clear_tensors(cleared); self.drop_shared_tensors(to_drop, handles, id); } let stream = match self.streams.get(&id) { Some(val) => val, None => { #[cfg(feature = "memory-checks")] self.memory_checks.check(&self.streams, handles); return; } }; if !stream.queue.variables.is_empty() && sync { // Not draining the queue can cause a memory leak when a stream is closing. self.drain(handles, id); } #[cfg(feature = "memory-checks")] self.memory_checks.check(&self.streams, handles); } /// Checks if the current operation is a drop. /// /// When a tensor is shared across multiple concurrent streams, dropping a tensor might cause a /// problem when the same tensor is registered lazily on another stream, but not yet executed. fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction { match !matches!(tensor_ir.status, TensorStatus::ReadWrite) { true => { let stream = self.streams.get(&id); let on_drop = self .shared_tensors .on_drop(id, tensor_ir.id, stream.is_none()); match on_drop { SharedTensorDropAction::ForceDrop(streams) => { tensor_ir.status = TensorStatus::ReadWrite; DropAction::ForceSharedTensor(streams, tensor_ir.id) } SharedTensorDropAction::Skip => DropAction::SkipSharedTensor, } } false => DropAction::ContinueDrop, } } /// Enqueue an operation on the queue. fn enqueue_operation( &mut self, id: StreamId, repr: OperationIr, streams: &OperationStreams, operation: Arc>, handles: &mut HandleContainer, ) -> usize { let stream = match self.streams.get_mut(&id) { Some(stream) => stream, None => { let stream = Stream::new(self.device.clone()); self.streams.insert(id, stream); self.streams .get_mut(&id) .expect("Just added, so should be included in the hashmap.") } }; stream.queue.add(repr, operation, streams, id); let len_before = stream.queue.global.len(); stream.processor.process( Segment::new(&mut stream.queue, handles), &mut self.optimizations, ExecutionMode::Lazy, ); let len_after = stream.queue.global.len(); let num_executed = len_before - len_after; stream.cursor += num_executed as u64; num_executed } /// Mark a tensor as read. #[allow(unused_variables)] pub fn mark_read( &mut self, id: StreamId, ir: &TensorIr, handles: &HandleContainer, ) { if !matches!(ir.status, TensorStatus::ReadWrite) { return; }; let stream = match self.streams.get_mut(&id) { Some(val) => val, None => return, }; stream.queue.variables.remove(&ir.id); if stream.queue.variables.is_empty() { self.streams.remove(&id); } #[cfg(feature = "memory-checks")] self.memory_checks.check(&self.streams, handles); } /// Drain a stream pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { if let Some(stream) = self.streams.get_mut(&id) { let old = unsafe { StreamId::swap(id) }; let num_executed = stream.queue.global.len(); stream.processor.process( Segment::new(&mut stream.queue, handles), &mut self.optimizations, ExecutionMode::Sync, ); stream.cursor += num_executed as u64; let cleared = self.shared_tensors.on_executed_ops(id, stream); self.clear_shared_tensors(&cleared, id); let to_drop = self.shared_tensors.clear_tensors(cleared); self.drop_shared_tensors(to_drop, handles, id); unsafe { StreamId::swap(old); }; } } /// When one of the provided streams is different from the current stream, we drain them. /// /// Returns the selected stream id. fn resolve_streams( &mut self, streams: &OperationStreams, handles: &mut HandleContainer, op: &mut OperationIr, ) -> StreamId { let current = streams.current; let nodes = op.nodes(); let analysis = self.analyse_shared_tensors(&nodes, streams, current); self.merge_streams_timelines(handles, &analysis, current, &nodes); self.register_shared_tensors_drop(&analysis, op); current } /// Drain the stream only if one of the tensor in the given nodes is also included in the /// stream queue. fn resolve_stream( &mut self, handles: &mut HandleContainer, id: StreamId, nodes: &[&TensorIr], ) { if let Some(stream) = self.streams.get(&id) { for node in nodes { if stream.queue.variables.contains_key(&node.id) { self.drain(handles, id); return; } } } } fn analyse_shared_tensors( &mut self, nodes: &[&TensorIr], streams: &OperationStreams, current: StreamId, ) -> MultiSharedTensorAnalysis { let mut shared_analysis = MultiSharedTensorAnalysis::default(); for node in nodes.iter() { let analysis = self .shared_tensors .analyse(current, node, streams, &self.streams); match analysis { SharedTensorAnalysis::SharedFromCurrentStream => { shared_analysis.current.push((node.id, node.status)); } SharedTensorAnalysis::NotShared => {} SharedTensorAnalysis::SharedFromExistingStream { stream_id, original_cursor, } => { shared_analysis .existing .push((node.id, stream_id, original_cursor)); } SharedTensorAnalysis::SharedFromNewStream { stream_id } => { shared_analysis.new.push((node.id, stream_id)); } } } shared_analysis } fn merge_streams_timelines( &mut self, handles: &mut HandleContainer, analysis: &MultiSharedTensorAnalysis, current: StreamId, nodes: &[&TensorIr], ) { // If we only have current tensors that are shared, we're safe to not sync the timelines. if analysis.new.is_empty() && analysis.existing.is_empty() && analysis.current.is_empty() { return; } let mut streams_to_sync = HashSet::new(); for (_tensor_id, stream_id) in analysis.new.iter() { streams_to_sync.insert(*stream_id); } for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() { if let Some(stream) = self.streams.get(stream_id) { // We only have to sync a stream when the stream isn't up to date with // the original cursor of the current operation. if stream.cursor <= *original_cursor && *stream_id != current { streams_to_sync.insert(*stream_id); } } } for (tensor_id, status) in analysis.current.iter() { if let TensorStatus::ReadWrite = status { for stream in self.shared_tensors.streams_of(tensor_id) { streams_to_sync.insert(stream); } } } for id in streams_to_sync.drain() { log::trace!("Drain stream {id} for use in current {current}"); self.resolve_stream(handles, id, nodes); } } fn register_shared_tensors_drop( &mut self, analysis: &MultiSharedTensorAnalysis, op: &mut OperationIr, ) { let mut readonly_tensors = Vec::new(); for (tensor_id, _stream_id) in analysis.new.iter() { readonly_tensors.push(*tensor_id); } for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() { readonly_tensors.push(*tensor_id); } for (tensor_id, status) in analysis.current.iter() { if let TensorStatus::ReadOnly = status { readonly_tensors.push(*tensor_id); } } self.shared_tensors .tag_manual_drop(op.mark_read_only(&readonly_tensors)); } fn drop_shared_tensors( &mut self, tensors: Vec, handles: &mut HandleContainer, current: StreamId, ) { for (stream_id, s) in self.streams.iter_mut() { for tensor in tensors.iter() { if let Some((original, _status)) = s.queue.variables.get(&tensor.id) && original != stream_id { s.queue.variables.remove(&tensor.id); } } } for tensor in tensors { let streams = OperationStreams { streams: HashMap::new(), current, }; let op = Arc::new(DropOp { id: tensor.id }); self.register(streams, OperationIr::Drop(tensor), op, handles); } } fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) { let mut to_remove = Vec::new(); for (stream_id, s) in self.streams.iter_mut() { for tensor in tensors.iter() { s.queue.variables.remove(tensor); } if s.queue.variables.is_empty() && current != *stream_id { to_remove.push(*stream_id); } } for s in to_remove { self.streams.remove(&s); } } } pub(crate) struct Stream { pub(crate) queue: OperationQueue, processor: Processor, pub(crate) cursor: u64, } #[derive(new)] struct Segment<'a, R: FusionRuntime> { queue: &'a mut OperationQueue, handles: &'a mut HandleContainer, } impl StreamSegment for Segment<'_, R> { fn operations(&self) -> &[OperationIr] { &self.queue.relative } fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore) { self.queue.execute(id, self.handles, store) } } impl Stream { fn new(device: R::FusionDevice) -> Self { Self { processor: Processor::new(R::fusers(device)), queue: OperationQueue::new(), cursor: 0, } } } #[derive(Debug)] /// Manage the streams used for the current [operation](OperationIr). pub struct OperationStreams { pub(crate) streams: HashMap, pub(crate) current: StreamId, } impl Default for OperationStreams { fn default() -> Self { Self { streams: HashMap::new(), current: StreamId::current(), } } } impl OperationStreams { /// Register a tensor in the list of streams used for the current [operation](OperationIr). /// /// You only need to register input tensors, not the outputs. /// So init tensor operations should have no streams registered. pub fn tensor(&mut self, tensor: &crate::FusionTensor) { self.streams.insert(tensor.id, tensor.stream); } pub(crate) fn get(&self, id: TensorId) -> Option { self.streams.get(&id).cloned() } /// Create new operation streams with the given inputs. /// /// The inputs are automatically registered. pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self where I: IntoIterator>, { let mut streams = OperationStreams::default(); for tensor in tensors.into_iter() { streams.tensor(tensor) } streams } } #[derive(Default, Debug)] struct MultiSharedTensorAnalysis { /// Tensors that are shared with other streams, but we're currently executing on the same stream /// the tensor was originally created. current: Vec<(TensorId, TensorStatus)>, /// Tensors that are shared with new streams. new: Vec<(TensorId, StreamId)>, /// Tensors that are shared with existing streams. existing: Vec<(TensorId, StreamId, u64)>, } ================================================ FILE: crates/burn-fusion/src/stream/queue/base.rs ================================================ use std::sync::Arc; use crate::FusionRuntime; use crate::stream::{OperationConverter, OperationStreams, RelativeOps, execution::Operation}; use burn_backend::StreamId; use burn_ir::{OperationIr, TensorId, TensorStatus}; use hashbrown::HashMap; /// A growing list of [tensor operation descriptions](OperationIr). pub struct OperationQueue { /// List of operation descriptions. These contain the exact tensor IDs /// and shapes so that kernels can be run correctly. /// /// The length of this list is the same as the length of the `operations` list. pub(crate) global: Vec, /// List of operation descriptions. The tensor IDs and shapes are relative /// because we don't need to know the exact values, but they are sufficient to /// determine which operations can be fused. pub(crate) relative: Vec, pub(crate) converter: OperationConverter, pub(crate) operations: Vec>>, pub(crate) variables: HashMap, } impl Default for OperationQueue { fn default() -> Self { Self::new() } } impl OperationQueue { /// Create a new empty queue. pub fn new() -> Self { Self { global: Vec::new(), relative: Vec::new(), converter: OperationConverter::default(), operations: Vec::new(), variables: HashMap::new(), } } /// Add a new tensor operation to the queue. /// /// The new [operation intermediate representation](OperationIr) will be converted to a local /// representation that can be reused when the same pattern emerge in different but similar /// scenario, so that the same optimization can be used. pub fn add( &mut self, global: OperationIr, operation: Arc>, streams: &OperationStreams, current: StreamId, ) { for node in global.nodes() { if let Some(stream_id) = streams.get(node.id) { self.variables.insert(node.id, (stream_id, node.status)); } else { self.variables.insert(node.id, (current, node.status)); } } let relative = global.to_relative(&mut self.converter); self.relative.push(relative); self.global.push(global); self.operations.push(operation); } } #[cfg(all(test, feature = "std"))] mod tests { use super::*; #[test] fn stream_id_from_different_threads() { let current = StreamId::current(); let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current())); let thread2 = std::thread::spawn(StreamId::current); let (stream_1, stream_11) = thread1.join().unwrap(); let stream_2 = thread2.join().unwrap(); assert_ne!(current, stream_1, "Should be different from thread 1"); assert_ne!(current, stream_2, "Should be different from thread 2"); assert_ne!( stream_1, stream_2, "Should be different from different threads" ); assert_eq!( stream_1, stream_11, "Should be the same, since same thread." ); } } ================================================ FILE: crates/burn-fusion/src/stream/queue/execution.rs ================================================ use std::sync::Arc; use burn_ir::{HandleContainer, TensorStatus}; use crate::{ FusionRuntime, search::BlockOptimization, stream::{ Context, Operation, OperationConverter, OrderedExecution, RelativeOps, store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy}, }, }; use super::OperationQueue; impl OperationQueue { /// Execute the queue partially following the execution strategy from the plan. pub(crate) fn execute( &mut self, id: ExecutionPlanId, handles: &mut HandleContainer, store: &mut ExecutionPlanStore, ) { let plan = store.get_mut_unchecked(id); self.execute_block_optimization(&mut plan.optimization, handles); } fn execute_block_optimization( &mut self, step: &mut BlockOptimization, handles: &mut HandleContainer, ) { let mut operations = Vec::new(); core::mem::swap(&mut operations, &mut self.operations); let (operations, num_drained) = QueueExecution::run(step, &mut self.converter, handles, operations); self.operations = operations; self.drain_queue(num_drained, handles); } /// Bookkeeping after executing `num_drained` operations from the queue. fn drain_queue(&mut self, num_drained: usize, handles: &mut HandleContainer) { self.global[0..num_drained] .iter() .flat_map(|desc| desc.nodes()) .for_each(|tensor| { if tensor.status == TensorStatus::ReadWrite { self.variables.remove(&tensor.id); }; handles.free(tensor) }); self.global.drain(0..num_drained); self.reset_relative(); } fn reset_relative(&mut self) { self.relative.clear(); self.converter.clear(); for node in self.global.iter() { let relative = node.to_relative(&mut self.converter); self.relative.push(relative); } } } /// A queue execution has the responsibility to run the provided /// [optimization](FusionRuntime::Optimization) without holes. enum QueueExecution<'a, R: FusionRuntime> { Single { handles: &'a mut HandleContainer, converter: &'a mut OperationConverter, execution: OrderedExecution, }, Multiple { context: &'a mut Context<'a, R::FusionHandle>, execution: OrderedExecution, }, } impl<'a, R: FusionRuntime> QueueExecution<'a, R> { fn run( optimization: &mut BlockOptimization, converter: &'a mut OperationConverter, handles: &'a mut HandleContainer, operations: Vec>>, ) -> (Vec>>, usize) { let execution = OrderedExecution::new(operations); if matches!(&optimization.strategy, ExecutionStrategy::Composed(..)) { let mut context = converter.context(handles); let mut this = QueueExecution::Multiple { context: &mut context, execution, }; this = this.execute_strategy(&mut optimization.strategy); match this { QueueExecution::Multiple { execution, .. } => execution.finish(), _ => unreachable!(), } } else { let mut this = QueueExecution::Single { handles, converter, execution, }; this = this.execute_strategy(&mut optimization.strategy); match this { QueueExecution::Single { execution, .. } => execution.finish(), _ => unreachable!(), } } } fn execute_strategy(mut self, strategy: &mut ExecutionStrategy) -> Self { match &mut self { QueueExecution::Single { handles, converter, execution, } => match strategy { ExecutionStrategy::Optimization { ordering, opt } => { let mut context = converter.context(handles); execution.execute_optimization(opt, &mut context, ordering.clone()) } ExecutionStrategy::Operations { ordering } => { execution.execute_operations(handles, ordering) } ExecutionStrategy::Composed(_) => unreachable!(), }, QueueExecution::Multiple { context, execution } => match strategy { ExecutionStrategy::Optimization { opt, ordering } => { execution.execute_optimization(opt, context, ordering.clone()); } ExecutionStrategy::Operations { ordering } => { execution.execute_operations(context.handles, ordering); } ExecutionStrategy::Composed(items) => { for item in items.iter_mut() { self = self.execute_strategy(item); } } }, }; self } } ================================================ FILE: crates/burn-fusion/src/stream/queue/mod.rs ================================================ mod base; mod execution; pub use base::*; ================================================ FILE: crates/burn-fusion/src/stream/shared_tensors.rs ================================================ use burn_backend::StreamId; use burn_ir::{TensorId, TensorIr}; use hashbrown::HashMap; use super::{OperationStreams, Stream}; use crate::FusionRuntime; #[derive(Default)] /// Manages tensors that are shared between multiple streams. pub struct SharedTensors { shared_tensors: HashMap, shared_tensors_manual_drop: HashMap, } #[derive(Default, Debug)] /// A tensor that is shared between multiple streams. struct SharedTensor { streams: HashMap, } #[derive(Debug)] struct SharedTensorState { cursor_current: u64, cursor_origin: u64, } #[derive(Debug)] /// What do to when a tensor is dropped. pub enum SharedTensorDropAction { /// Performs the drop and removes the shared tensor from the provided list of /// stream ids. ForceDrop(Vec), /// Skip the drop. Skip, } #[derive(Debug)] /// Information about a shared tensor. pub enum SharedTensorAnalysis { /// The tensor is not shared. NotShared, /// The tensor is shared, but its original stream is the current one. SharedFromCurrentStream, /// The tensor is shared, and its original stream is an existing stream. SharedFromExistingStream { /// The stream id of the existing stream. stream_id: StreamId, /// The position of execution in the existing stream where the tensor was created. original_cursor: u64, }, /// The tensor is shared, and its original stream is a new one without any operation /// executed. SharedFromNewStream { /// The stream id of the new stream. stream_id: StreamId, }, } impl SharedTensors { /// Function to call when a drop operation is registered on the given stream and tensor. pub fn on_drop( &mut self, stream_id: StreamId, tensor_id: TensorId, stream_completed: bool, ) -> SharedTensorDropAction { let mut execute_still = false; if let Some(shared) = self.shared_tensors.get_mut(&tensor_id) { if stream_completed { shared.drop(stream_id); execute_still = shared.streams.is_empty(); } } else { execute_still = true; } if execute_still { let state = self.shared_tensors.remove(&tensor_id); self.shared_tensors_manual_drop.remove(&tensor_id); return match state { Some(val) => { let streams = val.streams.keys().copied().collect(); SharedTensorDropAction::ForceDrop(streams) } None => SharedTensorDropAction::ForceDrop(Vec::new()), }; } SharedTensorDropAction::Skip } /// Function to call when one or many operations were executed on the stream. /// /// Returns the tensor id that can be cleared with [Self::clear_tensors] pub fn on_executed_ops( &mut self, id: StreamId, stream: &mut Stream, ) -> Vec { let mut cleared = Vec::new(); for (tensor_id, state) in self.shared_tensors.iter_mut() { match state.update(id, stream) { SharedTensorUpdate::RemovedFromStream(no_more_stream) => { stream.queue.variables.remove(tensor_id); if no_more_stream { cleared.push(*tensor_id); } } SharedTensorUpdate::ReadyForCleanup => { cleared.push(*tensor_id); } SharedTensorUpdate::NoChange => {} } } cleared } /// Clear the provided tensors and returns the list of tensors that can be manually dropped. pub fn clear_tensors(&mut self, tensors: Vec) -> Vec { let mut to_drop = Vec::new(); for id in tensors { self.shared_tensors.remove(&id); if let Some(tensor) = self.shared_tensors_manual_drop.remove(&id) { to_drop.push(tensor); } } self.register_manual_drop(to_drop) } pub fn streams_of(&mut self, tensor: &TensorId) -> Vec { let mut streams = Vec::new(); if let Some(value) = self.shared_tensors.get(tensor) { for s in value.streams.keys() { streams.push(*s); } } streams } /// Analyses the current tensor and updates its state. pub fn analyse( &mut self, id: StreamId, node: &TensorIr, streams_op: &OperationStreams, streams: &HashMap>, ) -> SharedTensorAnalysis { let stream_id = match streams_op.streams.get(&node.id) { Some(val) => val, None => { return match self.shared_tensors.contains_key(&node.id) { true => SharedTensorAnalysis::SharedFromCurrentStream, false => SharedTensorAnalysis::NotShared, }; } }; if stream_id == &id { return match self.shared_tensors.contains_key(&node.id) { true => SharedTensorAnalysis::SharedFromCurrentStream, false => SharedTensorAnalysis::NotShared, }; } // Here the node is tagged as newly shared. let stream_current = streams.get(&id); let stream = streams.get(stream_id); let state = match self.shared_tensors.get_mut(&node.id) { Some(state) => state, None => { self.shared_tensors.insert(node.id, SharedTensor::default()); self.shared_tensors.get_mut(&node.id).unwrap() } }; state.register_new_stream(id, stream_current); match state.register_new_stream(*stream_id, stream) { Some(origin) => SharedTensorAnalysis::SharedFromExistingStream { stream_id: *stream_id, original_cursor: origin, }, None => SharedTensorAnalysis::SharedFromNewStream { stream_id: *stream_id, }, } } /// Tag the provided tensors as manually dropped. pub fn tag_manual_drop(&mut self, dropped: Vec) { for tensor in dropped { self.shared_tensors_manual_drop.insert(tensor.id, tensor); } } fn register_manual_drop(&mut self, mut tensors: Vec) -> Vec { if self.shared_tensors_manual_drop.is_empty() { return tensors; } let mut to_drop = Vec::new(); for id in self.shared_tensors_manual_drop.keys() { if !self.shared_tensors.contains_key(id) { to_drop.push(*id); } } for id in to_drop { let entry = self.shared_tensors_manual_drop.remove(&id).unwrap(); tensors.push(entry); } tensors } } /// The result from a [SharedTensor::update]. pub enum SharedTensorUpdate { /// The tensor is removed from the current stream. /// /// Also contains if the current stream is empty. RemovedFromStream(bool), /// If the tensor is shared across zero streams. ReadyForCleanup, /// If nothing has been done from the update. NoChange, } impl SharedTensor { /// Register the tensor as also part of the given stream. /// /// The stream might not exist yet when the current tensor is part of the first operation in /// the newly created stream. fn register_new_stream( &mut self, id: StreamId, stream: Option<&Stream>, ) -> Option { let cursor_current = match stream { Some(stream) => stream.cursor + stream.queue.global.len() as u64, None => 1, }; match self.streams.get_mut(&id) { Some(s) => { s.cursor_current = cursor_current; Some(s.cursor_origin) } None => { let state = SharedTensorState { cursor_current, cursor_origin: cursor_current, }; self.streams.insert(id, state); None } } } /// Update the current shared tensor state on the given stream. /// /// If the shared tensor is no longer needed on the stream, we will remove it from the list of /// shared streams. fn update(&mut self, id: StreamId, stream: &Stream) -> SharedTensorUpdate { let entry = match self.streams.remove(&id) { Some(val) => val, None => { return if self.streams.is_empty() { SharedTensorUpdate::ReadyForCleanup } else { SharedTensorUpdate::NoChange }; } }; // We can only free the shared tensor if the latest cursor is executed. if entry.cursor_current <= stream.cursor { SharedTensorUpdate::RemovedFromStream(self.streams.is_empty()) } else { self.streams.insert(id, entry); SharedTensorUpdate::NoChange } } fn drop(&mut self, id: StreamId) { self.streams.remove(&id); } } impl core::fmt::Debug for SharedTensors { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("\n==== Shared Tensors ====\n")?; for sh in self.shared_tensors.iter() { f.write_fmt(format_args!(" - Shared {}", sh.0))?; for (id, state) in sh.1.streams.iter() { f.write_fmt(format_args!( " [{}, cursor={}..{}] ", id, state.cursor_origin, state.cursor_current ))?; } f.write_str("\n")?; } for sh in self.shared_tensors_manual_drop.iter() { f.write_fmt(format_args!(" - Manual Drop {}", sh.0))?; f.write_str("\n")?; } f.write_str("========================\n") } } ================================================ FILE: crates/burn-fusion/src/stream/store/base.rs ================================================ use std::sync::Arc; use crate::search::BlockOptimization; use super::{ExecutionPlanIndex, InsertQuery, SearchQuery}; use burn_ir::OperationIr; use serde::{Deserialize, Serialize}; /// The store that contains all explorations done on a device. #[derive(Default)] pub(crate) struct ExecutionPlanStore { plans: Vec>, index: ExecutionPlanIndex, } /// How a list of operations should be executed. #[derive(PartialEq, Debug, Clone)] pub(crate) enum ExecutionStrategy { /// An optimization was found, and therefore should be executed. Optimization { opt: O, ordering: Arc> }, /// No optimization was found, each operation should be executed individually. Operations { ordering: Arc> }, /// A composition of multiple execution strategies. Composed(Vec>), } /// The trigger that indicates when to stop exploring. #[derive(Debug, PartialEq, Serialize, Deserialize)] pub(crate) enum ExecutionTrigger { OnOperations(Vec), OnSync, Always, } /// The unique identifier for an exploration that was executed. pub(crate) type ExecutionPlanId = usize; /// The outcome of an exploration that can be stored. #[derive(Debug)] pub(crate) struct ExecutionPlan { /// The operations on which the exploration is related to. pub(crate) operations: Vec, /// The criteria that signal when this plan should be executed. Only one trigger is necessary. pub(crate) triggers: Vec, /// The optimization that should be used when executing this plan. pub(crate) optimization: BlockOptimization, } impl ExecutionPlanStore { pub fn new() -> Self { Self { plans: Vec::new(), index: ExecutionPlanIndex::default(), } } pub fn find(&self, query: SearchQuery<'_>) -> Vec { self.index.find(query) } pub fn add(&mut self, exploration: ExecutionPlan) -> ExecutionPlanId { if exploration.operations.is_empty() { panic!("Can't add an empty optimization."); } let id = self.plans.len(); self.index.insert(InsertQuery::NewPlan { operations: &exploration.operations, id, }); self.plans.push(exploration); id } pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan { &mut self.plans[id] } pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan { &self.plans[id] } /// Add a new end condition for an optimization. pub fn add_trigger(&mut self, id: ExecutionPlanId, trigger: ExecutionTrigger) { let criteria = &mut self.plans[id].triggers; if !criteria.contains(&trigger) { criteria.push(trigger); } } } ================================================ FILE: crates/burn-fusion/src/stream/store/index.rs ================================================ use crate::stream::store::ExecutionPlanId; use burn_ir::OperationIr; use serde::{Deserialize, Serialize}; use std::{ collections::{HashMap, hash_map::DefaultHasher}, hash::{Hash, Hasher}, }; /// Index used to search optimizations. #[derive(Default, Serialize, Deserialize, Clone)] pub struct ExecutionPlanIndex { /// We can't use `HashMap>` since `OperationIr` /// doesn't implement [`Eq`](core::cmp::Eq). /// /// `OperationIr` can't implement `Eq` since float types don't implement it. /// /// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions. /// This is OK because we use `relative` operations where any scalar values are set to zeros, /// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter). /// /// Map from the hash of the `OperationIr` to a list of `(OperationIr, index)` pairs, /// where `index` is the index of all the execution plans that start with the `OperationIr` /// in the `starters` list. mapping: HashMap>, starters: Vec>, } pub enum SearchQuery<'a> { PlansStartingWith(&'a OperationIr), } pub enum InsertQuery<'a> { NewPlan { operations: &'a [OperationIr], id: ExecutionPlanId, }, } impl ExecutionPlanIndex { /// Search optimizations with the given [query](SearchQuery). pub fn find(&self, query: SearchQuery<'_>) -> Vec { match query { SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops), } } /// Register a new optimization with the given [query](InsertQuery). pub fn insert(&mut self, query: InsertQuery<'_>) { match query { InsertQuery::NewPlan { operations, id } => { if let Some(operation) = operations.first() { self.insert_new_operation(operation, id) } } } } /// Find execution plans starting with the `OperationIr` fn find_starting_with(&self, operation: &OperationIr) -> Vec { let key = self.operation_key(operation); let values = match self.mapping.get(&key) { Some(val) => val, None => return Vec::new(), }; if values.is_empty() { return Vec::new(); } let (_, index) = match values.iter().find(|value| &value.0 == operation) { Some(val) => val, None => return Vec::new(), }; match self.starters.get(*index) { Some(value) => value.clone(), None => Vec::new(), } } /// Update the index for an execution plan starting with operation `ops` fn insert_new_operation(&mut self, ops: &OperationIr, new_id: ExecutionPlanId) { let key = self.operation_key(ops); let values = match self.mapping.get_mut(&key) { Some(val) => val, None => { // New starter ops. let index = self.starters.len(); self.starters.push(vec![new_id]); self.mapping.insert(key, vec![(ops.clone(), index)]); return; } }; let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) { Some(val) => val, None => { // New with hash collision. let index = self.starters.len(); self.starters.push(vec![new_id]); values.push((ops.clone(), index)); return; } }; // New optimization for an existing starter. self.starters .get_mut(*index) .expect("Should exist") .push(new_id); } // Hash the value of the first operation in a list. fn operation_key(&self, ops: &OperationIr) -> u64 { let mut hasher = DefaultHasher::new(); ops.hash(&mut hasher); hasher.finish() } } #[cfg(test)] mod tests { use burn_backend::{DType, Shape}; use burn_ir::{ BinaryOpIr, NumericOperationIr, ScalarIr, ScalarOpIr, TensorId, TensorIr, TensorStatus, }; use super::*; #[test] fn should_find_optimization_id_based_on_tensor_ops() { let mut index = ExecutionPlanIndex::default(); let stream_1 = [ops_1()]; let optimization_id_1 = 0; index.insert(InsertQuery::NewPlan { operations: &stream_1, id: optimization_id_1, }); let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0])); assert_eq!(found, vec![optimization_id_1]); } #[test] fn should_support_multiple_optimization_ids_with_same_starting_ops() { let mut index = ExecutionPlanIndex::default(); let stream_1 = [ops_1(), ops_2(), ops_1()]; let stream_2 = [ops_1(), ops_1(), ops_2()]; let optimization_id_1 = 0; let optimization_id_2 = 1; index.insert(InsertQuery::NewPlan { operations: &stream_1, id: optimization_id_1, }); index.insert(InsertQuery::NewPlan { operations: &stream_2, id: optimization_id_2, }); let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0])); assert_eq!(found, vec![optimization_id_1, optimization_id_2]); } #[test] fn should_only_find_optimization_with_correct_starting_ops() { let mut index = ExecutionPlanIndex::default(); let stream_1 = [ops_1(), ops_1()]; let stream_2 = [ops_2(), ops_1()]; let optimization_id_1 = 0; let optimization_id_2 = 1; index.insert(InsertQuery::NewPlan { operations: &stream_1, id: optimization_id_1, }); index.insert(InsertQuery::NewPlan { operations: &stream_2, id: optimization_id_2, }); let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0])); assert_eq!(found, vec![optimization_id_1]); } #[test] fn should_handle_hash_collisions() { let mut index = ExecutionPlanIndex::default(); let stream_1 = [ops_1(), ops_1()]; let stream_2 = [ops_3(), ops_1()]; let optimization_id_1 = 0; let optimization_id_2 = 1; let stream_1_key = index.operation_key(&stream_1[0]); let stream_2_key = index.operation_key(&stream_2[0]); assert_ne!( stream_1_key, stream_2_key, "Ops 1 and Ops 3 should not have the same hash" ); // ops 1 and 3 have different variants, so the hash differs assert_ne!(stream_1[0], stream_2[0], "Ops 1 and Ops 3 are different."); index.insert(InsertQuery::NewPlan { operations: &stream_1, id: optimization_id_1, }); index.insert(InsertQuery::NewPlan { operations: &stream_2, id: optimization_id_2, }); let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0])); assert_eq!(found, vec![optimization_id_1]); } fn ops_1() -> OperationIr { OperationIr::NumericFloat( DType::F32, NumericOperationIr::Add(BinaryOpIr { lhs: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, rhs: TensorIr { id: TensorId::new(1), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, out: TensorIr { id: TensorId::new(2), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } fn ops_2() -> OperationIr { OperationIr::NumericFloat( DType::F32, NumericOperationIr::AddScalar(ScalarOpIr { lhs: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, rhs: ScalarIr::Float(5.0), out: TensorIr { id: TensorId::new(2), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } fn ops_3() -> OperationIr { OperationIr::NumericFloat( DType::F32, NumericOperationIr::Sub(BinaryOpIr { lhs: TensorIr { id: TensorId::new(0), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, rhs: TensorIr { id: TensorId::new(1), shape: Shape::new([32, 32]), status: TensorStatus::ReadOnly, dtype: DType::F32, }, out: TensorIr { id: TensorId::new(2), shape: Shape::new([32, 32]), status: TensorStatus::NotInit, dtype: DType::F32, }, }), ) } } ================================================ FILE: crates/burn-fusion/src/stream/store/mod.rs ================================================ mod base; mod index; pub(crate) use base::*; pub(super) use index::*; ================================================ FILE: crates/burn-fusion/src/tensor.rs ================================================ use crate::{ Client, FusionBackend, FusionRuntime, stream::{Operation, OperationStreams, StreamId}, }; use burn_backend::{ DType, ExecutionError, QTensorPrimitive, Shape, TensorData, TensorMetadata, quantization::QuantScheme, }; use burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus}; use std::sync::{ Arc, atomic::{AtomicU32, Ordering}, }; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. pub struct FusionTensor { /// Tensor id. pub id: TensorId, /// The shape of the tensor. pub shape: Shape, /// The fusion client. pub client: Client, /// The datatype of the tensor. pub dtype: DType, /// The current stream id this tensor is on. pub stream: StreamId, pub(crate) count: Arc, } impl Clone for FusionTensor { fn clone(&self) -> Self { self.count.fetch_add(1, Ordering::Acquire); Self { id: self.id, shape: self.shape.clone(), client: self.client.clone(), dtype: self.dtype, stream: self.stream, count: self.count.clone(), } } } impl core::fmt::Debug for FusionTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str( format!( "{{ id: {:?}, shape: {:?}, device: {:?} }}", self.id, self.shape, self.client.device().clone(), ) .as_str(), ) } } impl TensorMetadata for FusionTensor { fn dtype(&self) -> DType { self.dtype } fn shape(&self) -> Shape { self.shape.clone() } fn rank(&self) -> usize { self.shape.num_dims() } } impl FusionTensor { pub(crate) fn new( id: TensorId, shape: Shape, dtype: DType, client: Client, stream: StreamId, ) -> Self { Self { id, shape, client, dtype, stream, count: Arc::new(AtomicU32::new(1)), } } fn status(&self, count: u32) -> TensorStatus { if count <= 1 { TensorStatus::ReadWrite } else { TensorStatus::ReadOnly } } /// Intermediate representation to be used when using an uninitialized tensor as output. pub fn to_ir_out(&self) -> TensorIr { TensorIr { status: TensorStatus::NotInit, shape: self.shape.clone(), id: self.id, dtype: self.dtype, } } /// Intermediate representation to be used when using an initialized tensor used as input. pub fn into_ir(mut self) -> TensorIr { let count = self.count.load(Ordering::Acquire); let status = self.status(count); let mut shape_out = Shape::from(Vec::::new()); core::mem::swap(&mut self.shape, &mut shape_out); if let TensorStatus::ReadWrite = status { // Avoids an unwanted drop on the same thread. // // Since `drop` is called after `into_ir`, we must not register a drop if the tensor // was consumed with a `ReadWrite` status. self.count.fetch_add(1, Ordering::Acquire); } TensorIr { status, shape: shape_out, id: self.id, dtype: self.dtype, } } pub(crate) async fn into_data(self) -> Result where B: FusionBackend, { let id = self.stream; let client = self.client.clone(); let desc = self.into_ir(); client.read_tensor_float::(desc, id).await } pub(crate) async fn q_into_data(self) -> Result where B: FusionBackend, { if let DType::QFloat(_scheme) = self.dtype { let id = self.stream; let client = self.client.clone(); let desc = self.into_ir(); client.read_tensor_quantized::(desc, id).await } else { panic!("Expected quantized float dtype, got {:?}", self.dtype) } } pub(crate) async fn int_into_data(self) -> Result where B: FusionBackend, { let id = self.stream; let client = self.client.clone(); let desc = self.into_ir(); client.read_tensor_int::(desc, id).await } pub(crate) async fn bool_into_data(self) -> Result where B: FusionBackend, { let id = self.stream; let client = self.client.clone(); let desc = self.into_ir(); client.read_tensor_bool::(desc, id).await } } #[derive(new, Debug)] pub(crate) struct DropOp { pub(crate) id: TensorId, } impl Operation for DropOp { fn execute(&self, handles: &mut burn_ir::HandleContainer) { handles.remove_handle(self.id); } } impl Drop for FusionTensor { fn drop(&mut self) { let count = self.count.fetch_sub(1, Ordering::Acquire); // Workaround to prevent segfaults when an operation panics if std::thread::panicking() { return; } match self.status(count) { TensorStatus::ReadWrite => { let mut shape = Shape::from(Vec::::new()); core::mem::swap(&mut shape, &mut self.shape); let ir = TensorIr { id: self.id, shape, status: TensorStatus::ReadWrite, dtype: self.dtype, }; let mut streams = OperationStreams::default(); streams.tensor(self); self.client .register(streams, OperationIr::Drop(ir), DropOp { id: self.id }); } TensorStatus::ReadOnly => {} TensorStatus::NotInit => {} } } } impl QTensorPrimitive for FusionTensor { fn scheme(&self) -> &QuantScheme { if let DType::QFloat(scheme) = &self.dtype { scheme } else { panic!( "Quantization scheme is not valid for dtype {:?}", self.dtype, ) } } } ================================================ FILE: crates/burn-ir/Cargo.toml ================================================ [package] authors = ["laggui ", "nathanielsimard "] categories = ["science"] description = "Intermediate representation for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor"] license.workspace = true name = "burn-ir" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-ir" documentation = "https://docs.rs/burn-ir" version.workspace = true [lints] workspace = true [features] default = ["std"] std = ["burn-backend/std"] doc = ["default"] tracing = [ "burn-backend/tracing", ] [dependencies] serde = { workspace = true } hashbrown = { workspace = true } # no_std compatible burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-ir/README.md ================================================ # Burn Intermediate Representation Defines an Intermediate Representation (IR) used to represent tensors and operations. The abstraction over computation allows execution across different targets (e.g., remote backend). It also enables optimization and transformation of tensor computations before execution (e.g., operator fusion). ================================================ FILE: crates/burn-ir/src/backend.rs ================================================ use burn_backend::{ Backend, Shape, tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, }; /// A tensor representation containing a reference to a tensor resource with a given shape. #[derive(Clone)] pub struct TensorHandle { /// The type that can be used to point to a tensor of any kind. pub handle: H, /// The shape associated to the tensor. pub shape: Shape, } /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor /// intermediate representation for compilation purpose or other... pub trait BackendIr: Backend { /// The type that can be used to point to a tensor of any kind. type Handle: Sync + Send + Clone; /// Convert a [handle](BackendIr::Handle) to a [float tensor](Backend::FloatTensorPrimitive). fn float_tensor(handle: TensorHandle) -> FloatTensor; /// Convert a [handle](BackendIr::Handle) to an [int tensor](Backend::IntTensorPrimitive). fn int_tensor(handle: TensorHandle) -> IntTensor; /// Convert a [handle](BackendIr::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](BackendIr::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor; /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](BackendIr::Handle). fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](BackendIr::Handle). fn int_tensor_handle(tensor: IntTensor) -> Self::Handle; /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](BackendIr::Handle). fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](BackendIr::Handle). fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle; } /// Handle which points to a backend tensor primitive kind. #[derive(Clone, Debug)] pub enum HandleKind { /// Float tensor handle. Float(B::FloatTensorPrimitive), /// Int tensor handle. Int(B::IntTensorPrimitive), /// Bool tensor handle. Bool(B::BoolTensorPrimitive), /// Quantized tensor handle. Quantized(B::QuantizedTensorPrimitive), } impl HandleKind { /// Returns the handle kind name. pub fn name(&self) -> &str { match self { HandleKind::Float(_) => "float", HandleKind::Int(_) => "int", HandleKind::Bool(_) => "bool", HandleKind::Quantized(_) => "quantized", } } } ================================================ FILE: crates/burn-ir/src/builder.rs ================================================ #![allow(missing_docs)] use alloc::vec::Vec; use burn_backend::{ DType, Distribution, Shape, Slice, SliceOps, calculate_matmul_output, ops::{ conv::{ calculate_conv_output_shape, calculate_conv_transpose_output_shape, calculate_pool_output_shape, }, unfold::calculate_unfold_shape, }, quantization::QuantScheme, tensor::IndexingUpdateOp, }; use crate::{ScalarIr, TensorId, TensorIr}; use super::operation::*; impl CreationOpIr { pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self { let out = TensorIr::uninit(new_id(), shape, dtype); CreationOpIr { out } } } impl InitOperationIr { pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self { let out = TensorIr::uninit(new_id(), shape, dtype); InitOperationIr { out } } } impl RandomOpIr { pub fn create( shape: Shape, dtype: DType, distribution: Distribution, new_id: impl FnOnce() -> TensorId, ) -> Self { let out = TensorIr::uninit(new_id(), shape, dtype); RandomOpIr { out, distribution } } } impl FullOpIr { pub fn create( shape: Shape, dtype: DType, value: ScalarIr, new_id: impl FnOnce() -> TensorId, ) -> Self { // TODO: check that ScalarIr dtype matches dtype? let out = TensorIr::uninit(new_id(), shape, dtype); FullOpIr { out, value } } } impl CastOpIr { pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self { let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype); CastOpIr { input, out } } } impl ShapeOpIr { pub fn expand(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self { let shape = input.shape.expand(shape).unwrap(); Self::create(input, shape, new_id) } pub fn reshape(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self { let shape = input.shape.reshape(shape).unwrap(); Self::create(input, shape, new_id) } fn create(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self { let out = TensorIr::uninit(new_id(), shape, input.dtype); ShapeOpIr { input, out } } } // "Lower" specific operations into a binary or unary op representation. // Useful when collecting inputs and outputs and don't care about the other semantics. impl From for BinaryOpIr { fn from(value: MatmulOpIr) -> Self { Self { lhs: value.lhs, rhs: value.rhs, out: value.out, } } } impl From for UnaryOpIr { fn from(value: ReduceOpIr) -> Self { Self { input: value.input, out: value.out, } } } #[derive(Debug)] #[allow(missing_docs)] pub enum IrError { DTypeMismatch, } fn dtype_compat(lhs: &DType, rhs: &DType) -> bool { let lhs_qfloat = matches!(lhs, DType::QFloat(_)); let rhs_qfloat = matches!(rhs, DType::QFloat(_)); if lhs_qfloat && (rhs_qfloat || rhs.is_float()) || lhs.is_float() && (rhs_qfloat || rhs.is_float()) { true } else { lhs == rhs } } fn output_check<'a, I>(inputs: I, compat: impl Fn(&DType, &DType) -> bool) -> Result where I: IntoIterator, { let mut iter = inputs.into_iter(); let first = iter.next().unwrap(); for d in iter { if !compat(first, d) { return Err(IrError::DTypeMismatch); } } Ok(*first) } fn output_dtype<'a, I: IntoIterator>(inputs: I) -> Result { output_check(inputs, |a, b| a == b) } fn output_dtype_mixed<'a, I: IntoIterator>(inputs: I) -> Result { output_check(inputs, dtype_compat) } /// Macro to implement `create` constructors for operations with a single output. /// /// Supports shape and dtype validation. macro_rules! impl_ir_create { (@create_fn $op:ident { $( $field:ident : $ty:ty ),* $(,)? } , $shape:expr, $dtype:expr) => { #[doc = "Create a new operation IR from the given inputs."] #[doc = "`new_id` should generate a unique `TensorId` for the uninitialized output tensor."] #[allow(clippy::too_many_arguments)] pub fn create($( $field : $ty ),*, new_id: impl FnOnce() -> crate::TensorId) -> $op { let shape = $shape; let dtype = $dtype; let out = TensorIr::uninit(new_id(), shape, dtype); $op { $( $field ),*, out } } }; // Case: simple op, single `create` ( $op:ident { $( $field:ident : $ty:ty ),* $(,)? }, shape = $shape:expr, dtype = $dtype:expr ) => { impl $op { impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype); } }; // Case: op with one additional constructor that accepts an explicit output dtype ( $op:ident { $( $field:ident : $ty:ty ),* $(,)? }, shape = $shape:expr, dtype = $dtype:expr, $fn_name:ident ( $extra:ident : $extra_ty:ty ) ) => { impl $op { impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype); #[doc = "Create a new operation IR from the given inputs and the given output dtype."] #[allow(clippy::too_many_arguments)] pub fn $fn_name($( $field : $ty ),*, $extra: $extra_ty, new_id: impl FnOnce() -> crate::TensorId) -> Self { let shape = $shape; let _ = $dtype; // still validates dtype if needed let out = TensorIr::uninit(new_id(), shape, $extra); $op { $( $field ),*, out } } } }; } impl_ir_create!( UnaryOpIr { input: TensorIr }, shape = input.shape.clone(), dtype = input.dtype, // Additional constructor for unary comparisons create_comparison(bool_dtype: DType) ); impl_ir_create!( BinaryOpIr { lhs: TensorIr, rhs: TensorIr }, shape = lhs.shape.broadcast(&rhs.shape).unwrap(), dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap(), // Additional constructor for binary comparisons create_comparison(bool_dtype: DType) ); impl_ir_create!( ScalarOpIr { lhs: TensorIr, rhs: ScalarIr }, shape = lhs.shape.clone(), dtype = lhs.dtype, // Additional constructor for scalar comparisons create_comparison(bool_dtype: DType) ); impl_ir_create!( MatmulOpIr { lhs: TensorIr, rhs: TensorIr }, shape = calculate_matmul_output(&lhs.shape, &rhs.shape).unwrap(), dtype = output_dtype_mixed([&lhs.dtype, &rhs.dtype]).unwrap(), // Additional constructor for mixed dtypes create_mixed(out_dtype: DType) ); impl_ir_create!( SwapDimsOpIr { input: TensorIr, dim1: usize, dim2: usize }, shape = input.shape.clone().swapped(dim1, dim2).unwrap(), dtype = input.dtype ); impl_ir_create!( PermuteOpIr { input: TensorIr, axes: Vec }, shape = input.shape.clone().permuted(&axes).unwrap(), dtype = input.dtype ); impl_ir_create!( RepeatDimOpIr { tensor: TensorIr, dim: usize, times: usize }, shape = tensor.shape.clone().repeat(dim, times).unwrap(), dtype = tensor.dtype ); impl_ir_create!( FlipOpIr { input: TensorIr, axes: Vec }, shape = input.shape.clone(), // TODO: check if axes are within the tensor dimensions dtype = input.dtype ); impl_ir_create!( CatOpIr { tensors: Vec, dim: usize }, shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap(), dtype = output_dtype(tensors.iter().map(|t| &t.dtype)).unwrap() ); impl_ir_create!( GatherOpIr { tensor: TensorIr, dim: usize, indices: TensorIr }, shape = indices.shape.clone(), // TODO: check dims compat between tensor and indices dtype = tensor.dtype ); impl_ir_create!( ScatterOpIr { tensor: TensorIr, dim: usize, indices: TensorIr, value: TensorIr, update: IndexingUpdateOp }, shape = tensor.shape.clone(), // TODO: check dims compat between tensor and indices dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap() ); impl_ir_create!( ReduceOpIr { input: TensorIr }, shape = [1].into(), dtype = input.dtype ); impl_ir_create!( ReduceDimOpIr { input: TensorIr, axis: usize }, shape = input.shape.clone().reduce(axis).unwrap(), dtype = input.dtype, // Additional constructor for argument reduction create_arg(ind_dtype: DType) ); impl_ir_create!( DimOpIr { input: TensorIr, axis: usize }, shape = input.shape.clone(), // TODO: check dims within rank dtype = input.dtype ); impl_ir_create!( SelectOpIr { tensor: TensorIr, dim: usize, indices: TensorIr }, // TODO: shape.select? shape = { let mut s = tensor.shape.clone(); s[dim] = indices.shape[0]; s }, dtype = tensor.dtype ); impl_ir_create!( SelectAssignOpIr { tensor: TensorIr, dim: usize, indices: TensorIr, value: TensorIr, update: IndexingUpdateOp }, // TODO: check value and indices shape match for dim shape = tensor.shape.clone(), dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap() ); impl_ir_create!( SliceOpIr { tensor: TensorIr, ranges: Vec, }, shape = tensor.shape.clone().slice(&ranges).unwrap(), dtype = tensor.dtype ); impl_ir_create!( SliceAssignOpIr { tensor: TensorIr, ranges: Vec, value: TensorIr }, // TODO: check slice and value number of elements match shape = tensor.shape.clone(), dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap() ); impl_ir_create!( MaskWhereOpIr { tensor: TensorIr, mask: TensorIr, value: TensorIr }, shape = Shape::broadcast_many([&tensor.shape, &mask.shape, &value.shape]).unwrap(), dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap() ); impl_ir_create!( MaskFillOpIr { tensor: TensorIr, mask: TensorIr, value: ScalarIr }, shape = tensor.shape.broadcast(&mask.shape).unwrap(), dtype = tensor.dtype ); impl_ir_create!( ClampOpIr { tensor: TensorIr, min: ScalarIr, max: ScalarIr }, shape = tensor.shape.clone(), dtype = tensor.dtype ); impl_ir_create!( AvgPool1dOpIr { x: TensorIr, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool }, shape = calculate_pool_output_shape( &x.shape, &[kernel_size], &[stride], &[padding], &[1], ceil_mode ) .unwrap(), dtype = x.dtype ); impl_ir_create!( AvgPool1dBackwardOpIr { x: TensorIr, grad: TensorIr, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( AvgPool2dOpIr { x: TensorIr, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool }, shape = calculate_pool_output_shape( &x.shape, &kernel_size, &stride, &padding, &[1, 1], ceil_mode ) .unwrap(), dtype = x.dtype ); impl_ir_create!( AvgPool2dBackwardOpIr { x: TensorIr, grad: TensorIr, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( MaxPool1dOpIr { x: TensorIr, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool }, shape = calculate_pool_output_shape( &x.shape, &[kernel_size], &[stride], &[padding], &[dilation], ceil_mode ) .unwrap(), dtype = x.dtype ); impl_ir_create!( MaxPool2dOpIr { x: TensorIr, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool }, shape = calculate_pool_output_shape( &x.shape, &kernel_size, &stride, &padding, &dilation, ceil_mode ) .unwrap(), dtype = x.dtype ); impl_ir_create!( MaxPool1dWithIndicesBackwardOpIr { x: TensorIr, grad: TensorIr, indices: TensorIr, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( MaxPool2dWithIndicesBackwardOpIr { x: TensorIr, grad: TensorIr, indices: TensorIr, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( AdaptiveAvgPool1dOpIr { x: TensorIr, output_size: usize }, shape = Shape::new([x.shape[0], x.shape[1], output_size]), dtype = x.dtype ); impl_ir_create!( AdaptiveAvgPool2dOpIr { x: TensorIr, output_size: [usize; 2] }, shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]), dtype = x.dtype ); impl_ir_create!( AdaptiveAvgPool1dBackwardOpIr { x: TensorIr, grad: TensorIr, }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( AdaptiveAvgPool2dBackwardOpIr { x: TensorIr, grad: TensorIr, }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( InterpolateOpIr { x: TensorIr, output_size: [usize; 2], options: InterpolateOptionsIr }, shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]), dtype = x.dtype ); impl_ir_create!( InterpolateBackwardOpIr { x: TensorIr, grad: TensorIr, output_size: [usize; 2], options: InterpolateOptionsIr }, shape = x.shape.clone(), dtype = x.dtype ); impl_ir_create!( GridSample2dOpIr { tensor: TensorIr, grid: TensorIr, options: GridSampleOptionsIr }, // Input tensor: [N, C, H_in, W_in] // Grid: [N, H_out, W_out, 2] // Output: [N, C, H_out, W_out] shape = Shape::new([ tensor.shape[0], tensor.shape[1], grid.shape[1], grid.shape[2] ]), dtype = tensor.dtype ); impl_ir_create!( Conv1dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: Conv1dOptionsIr }, shape = calculate_conv_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.dilation, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( Conv1dXBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv1dOptionsIr }, shape = x.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv1dWeightBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv1dOptionsIr }, shape = weight.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv1dBiasBackwardOpIr { x: TensorIr, bias: TensorIr, output_grad: TensorIr, }, shape = bias.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv2dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: Conv2dOptionsIr }, shape = calculate_conv_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.dilation, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( Conv2dXBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv2dOptionsIr }, shape = x.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv2dWeightBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv2dOptionsIr }, shape = weight.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv2dBiasBackwardOpIr { x: TensorIr, bias: TensorIr, output_grad: TensorIr, }, shape = bias.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv3dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: Conv3dOptionsIr }, shape = calculate_conv_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.dilation, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( Conv3dXBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv3dOptionsIr }, shape = x.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv3dWeightBackwardOpIr { x: TensorIr, weight: TensorIr, output_grad: TensorIr, options: Conv3dOptionsIr }, shape = weight.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( Conv3dBiasBackwardOpIr { x: TensorIr, bias: TensorIr, output_grad: TensorIr, }, shape = bias.shape.clone(), dtype = output_grad.dtype ); impl_ir_create!( DeformConv2dOpIr { x: TensorIr, offset: TensorIr, weight: TensorIr, mask: Option, bias: Option, options: DeformableConv2dOptionsIr }, shape = calculate_conv_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.dilation, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&offset.dtype), Some(&weight.dtype), mask.as_ref().map(|m| &m.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( ConvTranspose1dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: ConvTranspose1dOptionsIr }, shape = calculate_conv_transpose_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.padding_out, &options.dilation, options.groups, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( ConvTranspose2dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: ConvTranspose2dOptionsIr }, shape = calculate_conv_transpose_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.padding_out, &options.dilation, options.groups, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( ConvTranspose3dOpIr { x: TensorIr, weight: TensorIr, bias: Option, options: ConvTranspose3dOptionsIr }, shape = calculate_conv_transpose_output_shape( &x.shape, &weight.shape, &options.stride, &options.padding, &options.padding_out, &options.dilation, options.groups, ) .unwrap(), dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap() ); impl_ir_create!( UnfoldOpIr { input: TensorIr, dim: usize, size: usize, step: usize }, shape = calculate_unfold_shape(input.shape.clone(), dim, size, step), dtype = input.dtype ); impl_ir_create!( CrossOpIr { lhs: TensorIr, rhs: TensorIr, dim: usize }, shape = lhs.shape.broadcast(&rhs.shape).unwrap(), dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap() ); impl_ir_create!( QuantizeOpIr { tensor: TensorIr, qparams: QuantizationParametersIr, scheme: QuantScheme }, shape = tensor.shape.clone(), dtype = DType::QFloat(scheme) ); impl_ir_create!( AttentionOpIr { query: TensorIr, key: TensorIr, value: TensorIr, mask: Option, attn_bias: Option, options: AttentionOptionsIr, }, shape = Shape::new([query.shape[0], query.shape[1], query.shape[2], value.shape[3]]), dtype = query.dtype ); impl DequantizeOpIr { pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self { let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype); DequantizeOpIr { input, out } } } // Operations with multiple outputs impl ReduceDimWithIndicesOpIr { pub fn create( tensor: TensorIr, dim: usize, dtype_indices: DType, mut new_id: impl FnMut() -> TensorId, ) -> Self { let mut shape = tensor.shape.clone(); shape[dim] = 1; let out = TensorIr::uninit(new_id(), shape.clone(), tensor.dtype); let out_indices = TensorIr::uninit(new_id(), shape.clone(), dtype_indices); ReduceDimWithIndicesOpIr { tensor, dim, out, out_indices, } } } impl DeformConv2dBackwardOpIr { #[allow(clippy::too_many_arguments)] pub fn create( x: TensorIr, offset: TensorIr, weight: TensorIr, mask: Option, bias: Option, out_grad: TensorIr, options: DeformableConv2dOptionsIr, mut new_id: impl FnMut() -> TensorId, ) -> Self { let dtype = output_dtype( [ Some(&x.dtype), Some(&weight.dtype), mask.as_ref().map(|m| &m.dtype), bias.as_ref().map(|b| &b.dtype), ] .iter() .filter_map(|&d| d), ) .unwrap(); let input_grad = TensorIr::uninit(new_id(), x.shape.clone(), dtype); let offset_grad = TensorIr::uninit(new_id(), offset.shape.clone(), dtype); let weight_grad = TensorIr::uninit(new_id(), weight.shape.clone(), dtype); let mask_grad = mask .as_ref() .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype)); let bias_grad = bias .as_ref() .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype)); DeformConv2dBackwardOpIr { x, offset, weight, mask, bias, out_grad, options, input_grad, offset_grad, weight_grad, mask_grad, bias_grad, } } } impl MaxPool1dWithIndicesOpIr { #[allow(clippy::too_many_arguments)] pub fn create( x: TensorIr, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, dtype_indices: DType, mut new_id: impl FnMut() -> TensorId, ) -> Self { let shape = calculate_pool_output_shape( &x.shape, &[kernel_size], &[stride], &[padding], &[dilation], ceil_mode, ) .unwrap(); let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype); let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices); MaxPool1dWithIndicesOpIr { x, kernel_size, stride, padding, dilation, ceil_mode, out, out_indices, } } } impl MaxPool2dWithIndicesOpIr { #[allow(clippy::too_many_arguments)] pub fn create( x: TensorIr, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, dtype_indices: DType, mut new_id: impl FnMut() -> TensorId, ) -> Self { let shape = calculate_pool_output_shape( &x.shape, &kernel_size, &stride, &padding, &dilation, ceil_mode, ) .unwrap(); let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype); let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices); MaxPool2dWithIndicesOpIr { x, kernel_size, stride, padding, dilation, ceil_mode, out, out_indices, } } } ================================================ FILE: crates/burn-ir/src/handle.rs ================================================ use hashbrown::HashMap; use crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus}; /// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources /// are used optimally. #[derive(Default)] pub struct HandleContainer { handles: HashMap>, counter: u64, } impl HandleContainer { /// Fork the container, useful for autotune. pub fn fork(&self) -> Self { let mut handles = HashMap::with_capacity(self.handles.len()); for (id, handle) in self.handles.iter() { handles.insert(*id, handle.clone()); } Self { handles, counter: self.counter, } } } impl core::fmt::Debug for HandleContainer { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("HandleContainer") .field("handles", &self.handles.keys()) // only care about the IDs when debugging .field("counter", &self.counter) .finish() } } /// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state #[derive(Clone)] pub enum Handle { /// No [tensor handle](BackendIr::Handle) has been created yet NotInit, /// A [tensor handle](BackendIr::Handle) has been created Existing(H), } impl HandleContainer { /// Create a new HandleContainer pub fn new() -> Self { Self { handles: HashMap::new(), counter: 0, } } /// Register a handle for the given [tensor id](TensorId). pub fn register_handle(&mut self, id: TensorId, handle: H) { self.handles.insert(id, Handle::Existing(handle)); } /// Whether an handle exists. pub fn has_handle(&mut self, id: &TensorId) -> bool { self.handles.contains_key(id) } /// Get the reference to a handle. pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> { self.handles .get(id) .filter(|h| !matches!(h, Handle::NotInit)) .map(|h| match h { Handle::Existing(handle) => handle, Handle::NotInit => unreachable!(), }) } /// Get the handle for the given [tensor id](TensorId). The status is used to determine if the /// tensor should be popped out of the current tensor map, necessary for inplace operations. /// /// # Warnings /// /// Make sure the status corresponds to the operation you want to execute the handle on, /// otherwise you might remove a tensor handle that will be required in the future. pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H { let (id, handle) = self .handles .remove_entry(id) .unwrap_or_else(|| panic!("Should have handle for tensor {id:?}")); match handle { Handle::Existing(handle) => match status { TensorStatus::ReadOnly => { self.handles.insert(id, Handle::Existing(handle.clone())); handle } TensorStatus::ReadWrite => handle, TensorStatus::NotInit => panic!( "Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status" ), }, Handle::NotInit => panic!("Cannot get uninitialized handle {id:?}."), } } /// Get the tensor handle for the given [tensor intermediate representation](TensorIr). pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle { TensorHandle { handle: self.get_handle(&tensor.id, &tensor.status), shape: tensor.shape.clone(), } } /// Get the [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) corresponding to the /// given [tensor intermediate representation](TensorIr). pub fn get_float_tensor(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive where B: BackendIr, { B::float_tensor(self.get_tensor_handle(tensor)) } /// Get the [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) corresponding to the /// given [tensor intermediate representation](TensorIr). pub fn get_int_tensor(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive where B: BackendIr, { B::int_tensor(self.get_tensor_handle(tensor)) } /// Get the [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) corresponding to the /// given [tensor intermediate representation](TensorIr). pub fn get_bool_tensor(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive where B: BackendIr, { B::bool_tensor(self.get_tensor_handle(tensor)) } /// Get the [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) corresponding to the /// given [tensor intermediate representation](TensorIr). pub fn get_quantized_tensor(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive where B: BackendIr, { B::quantized_tensor(self.get_tensor_handle(tensor)) } /// Register a new [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_float_tensor(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive) where B: BackendIr, { let handle = B::float_tensor_handle(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). pub fn register_quantized_tensor( &mut self, id: &TensorId, tensor: B::QuantizedTensorPrimitive, ) where B: BackendIr, { let handle = B::quantized_tensor_handle(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_int_tensor(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive) where B: BackendIr, { let handle = B::int_tensor_handle(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Register a new [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_bool_tensor(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive) where B: BackendIr, { let handle = B::bool_tensor_handle(tensor); self.handles.insert(*id, Handle::Existing(handle)); } /// Remove tensor handle from container. pub fn remove_handle(&mut self, id: TensorId) -> Option> { self.handles.remove(&id) } /// Remove tensor handle from container if writable pub fn free(&mut self, tensor: &TensorIr) { match tensor.status { TensorStatus::ReadOnly => (), TensorStatus::NotInit => (), TensorStatus::ReadWrite => { self.handles.remove(&tensor.id); } }; } /// Returns the number of handles. pub fn num_handles(&self) -> usize { self.handles.len() } } ================================================ FILE: crates/burn-ir/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! Burn intermediate representation. extern crate alloc; mod backend; mod builder; mod handle; mod operation; mod scalar; mod tensor; pub use backend::*; pub use builder::*; pub use handle::*; pub use operation::*; pub use scalar::*; pub use tensor::*; ================================================ FILE: crates/burn-ir/src/operation.rs ================================================ use burn_backend::ops::AttentionModuleOptions; use burn_backend::tensor::IndexingUpdateOp; use core::hash::Hash; use serde::{Deserialize, Serialize}; use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec::Vec}; use burn_backend::{ DType, Distribution, Slice, ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions, GridSamplePaddingMode, InterpolateMode, InterpolateOptions, }, quantization::QuantScheme, }; use crate::{ScalarIr, TensorId, TensorIr, TensorStatus}; /// Custom operation in fusion stream, declaring its inputs and outputs. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct CustomOpIr { /// Unique identifier of the operation. pub id: String, /// Input tensors used in the custom operation. pub inputs: Vec, /// Output tensors used in the custom operation. pub outputs: Vec, } impl CustomOpIr { /// Create a new custom operation intermediate representation. pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self { Self { id: id.to_owned(), inputs: inputs.to_vec(), outputs: outputs.to_vec(), } } /// Cast the intermediate representation, and get the in and output tensors. pub fn as_fixed( &self, ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) { ( self.inputs.as_slice().try_into().expect( "Wrong number of inputs expected (expected {D}, is {}), check your implementation", ), self.outputs.as_slice().try_into().expect( "Wrong number of outputs expected (expected {D}, is {}), check your implementation", ), ) } fn inputs(&self) -> Box + '_> { Box::new(self.inputs.iter()) } fn outputs(&self) -> Box + '_> { Box::new(self.outputs.iter()) } } /// Describe all tensor operations possible. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(clippy::large_enum_variant)] pub enum OperationIr { /// Basic operation on a float tensor. BaseFloat(BaseOperationIr), /// Basic operation on an int tensor. BaseInt(BaseOperationIr), /// Basic operation on a bool tensor. BaseBool(BaseOperationIr), /// Numeric operation on a float tensor. NumericFloat(DType, NumericOperationIr), /// Numeric operation on an int tensor. NumericInt(DType, NumericOperationIr), /// Operation specific to a bool tensor. Bool(BoolOperationIr), /// Operation specific to an int tensor. Int(IntOperationIr), /// Operation specific to a float tensor. Float(DType, FloatOperationIr), /// Module operation. Module(ModuleOperationIr), /// Initialize operation. Init(InitOperationIr), /// A custom operation. Custom(CustomOpIr), /// A tensor is dropped. Drop(TensorIr), } /// Operation intermediate representation specific to a float tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum FloatOperationIr { /// Operation corresponding to [exp](burn_backend::ops::FloatTensorOps::float_exp). Exp(UnaryOpIr), /// Operation corresponding to [log](burn_backend::ops::FloatTensorOps::float_log). Log(UnaryOpIr), /// Operation corresponding to [log1p](burn_backend::ops::FloatTensorOps::float_log1p). Log1p(UnaryOpIr), /// Operation corresponding to [erf](burn_backend::ops::FloatTensorOps::float_erf). Erf(UnaryOpIr), /// Operation corresponding to [powf_scalar](burn_backend::ops::FloatTensorOps::float_powf_scalar). PowfScalar(ScalarOpIr), /// Operation corresponding to [sqrt](burn_backend::ops::FloatTensorOps::float_sqrt). Sqrt(UnaryOpIr), /// Operation corresponding to [cos](burn_backend::ops::FloatTensorOps::float_cos). Cos(UnaryOpIr), /// Operation corresponding to [cosh](burn_backend::ops::FloatTensorOps::float_cosh). Cosh(UnaryOpIr), /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sin). Sin(UnaryOpIr), /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sinh). Sinh(UnaryOpIr), /// Operation corresponding to [tan](burn_backend::ops::FloatTensorOps::float_tan). Tan(UnaryOpIr), /// Operation corresponding to [tanh](burn_backend::ops::FloatTensorOps::float_tanh). Tanh(UnaryOpIr), /// Operation corresponding to [acos](burn_backend::ops::FloatTensorOps::float_acos). ArcCos(UnaryOpIr), /// Operation corresponding to [acosh](burn_backend::ops::FloatTensorOps::float_acosh). ArcCosh(UnaryOpIr), /// Operation corresponding to [asin](burn_backend::ops::FloatTensorOps::float_asin). ArcSin(UnaryOpIr), /// Operation corresponding to [asinh](burn_backend::ops::FloatTensorOps::float_asinh). ArcSinh(UnaryOpIr), /// Operation corresponding to [atan](burn_backend::ops::FloatTensorOps::float_atan). ArcTan(UnaryOpIr), /// Operation corresponding to [atanh](burn_backend::ops::FloatTensorOps::float_atanh). ArcTanh(UnaryOpIr), /// Operation corresponding to [atan2](burn_backend::ops::FloatTensorOps::float_atan2). ArcTan2(BinaryOpIr), /// Operation corresponding to [round](burn_backend::ops::FloatTensorOps::float_round). Round(UnaryOpIr), /// Operation corresponding to [floor](burn_backend::ops::FloatTensorOps::float_floor). Floor(UnaryOpIr), /// Operation corresponding to [ceil](burn_backend::ops::FloatTensorOps::float_ceil). Ceil(UnaryOpIr), /// Operation corresponding to [trunc](burn_backend::ops::FloatTensorOps::float_trunc). Trunc(UnaryOpIr), /// Operation corresponding to [into_int](burn_backend::ops::FloatTensorOps::float_into_int). IntoInt(CastOpIr), /// Operation corresponding to [matmul](burn_backend::ops::FloatTensorOps::float_matmul). Matmul(MatmulOpIr), /// Operation corresponding to [cross](burn_backend::ops::FloatTensorOps::float_cross). Cross(CrossOpIr), /// Operation corresponding to [random](burn_backend::ops::FloatTensorOps::float_random). Random(RandomOpIr), /// Operation corresponding to [recip](burn_backend::ops::FloatTensorOps::float_recip). Recip(UnaryOpIr), /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_nan). IsNan(UnaryOpIr), /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_inf). IsInf(UnaryOpIr), /// Operation corresponding to [quantize](burn_backend::ops::QTensorOps::quantize). Quantize(QuantizeOpIr), /// Operation corresponding to [dequantize](burn_backend::ops::QTensorOps::dequantize). Dequantize(DequantizeOpIr), /// Operation corresponding to [grid_sample_2d](burn_backend::ops::FloatTensorOps::float_grid_sample_2d). GridSample2d(GridSample2dOpIr), /// Operation corresponding to [powf](burn_backend::ops::FloatTensorOps::float_powi). Powf(BinaryOpIr), } /// Operation intermediate representation specific to module. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum ModuleOperationIr { /// Operation corresponding to [embedding](burn_backend::ops::ModuleOps::embedding). Embedding(EmbeddingOpIr), /// Operation corresponding to [embedding_backward](burn_backend::ops::ModuleOps::embedding_backward). EmbeddingBackward(EmbeddingBackwardOpIr), /// Operation corresponding to [conv1d](burn_backend::ops::ModuleOps::conv1d). Conv1d(Conv1dOpIr), /// Operation corresponding to [conv1d_x_backward](burn_backend::ops::ModuleOps::conv1d_x_backward). Conv1dXBackward(Conv1dXBackwardOpIr), /// Operation corresponding to [conv1d_weight_backward](burn_backend::ops::ModuleOps::conv1d_weight_backward). Conv1dWeightBackward(Conv1dWeightBackwardOpIr), /// Operation corresponding to [conv1d_bias_backward](burn_backend::ops::ModuleOps::conv1d_bias_backward). Conv1dBiasBackward(Conv1dBiasBackwardOpIr), /// Operation corresponding to [conv2d](burn_backend::ops::ModuleOps::conv2d). Conv2d(Conv2dOpIr), /// Operation corresponding to [conv2d_x_backward](burn_backend::ops::ModuleOps::conv2d_x_backward). Conv2dXBackward(Conv2dXBackwardOpIr), /// Operation corresponding to [conv2d_weight_backward](burn_backend::ops::ModuleOps::conv2d_weight_backward). Conv2dWeightBackward(Conv2dWeightBackwardOpIr), /// Operation corresponding to [conv2d_bias_backward](burn_backend::ops::ModuleOps::conv2d_bias_backward). Conv2dBiasBackward(Conv2dBiasBackwardOpIr), /// Operation corresponding to [conv3d](burn_backend::ops::ModuleOps::conv3d). Conv3d(Conv3dOpIr), /// Operation corresponding to [conv3d_x_backward](burn_backend::ops::ModuleOps::conv3d_x_backward). Conv3dXBackward(Conv3dXBackwardOpIr), /// Operation corresponding to [conv3d_weight_backward](burn_backend::ops::ModuleOps::conv3d_weight_backward). Conv3dWeightBackward(Conv3dWeightBackwardOpIr), /// Operation corresponding to [conv3d_bias_backward](burn_backend::ops::ModuleOps::conv3d_bias_backward). Conv3dBiasBackward(Conv3dBiasBackwardOpIr), /// Operation corresponding to [deform_conv2d](burn_backend::ops::ModuleOps::deform_conv2d) DeformableConv2d(Box), /// Operation corresponding to [deform_conv2d_backward](burn_backend::ops::ModuleOps::deform_conv2d_backward) DeformableConv2dBackward(Box), /// Operation corresponding to [conv transpose 1d](burn_backend::ops::ModuleOps::conv_transpose1d). ConvTranspose1d(ConvTranspose1dOpIr), /// Operation corresponding to [conv transpose 2d](burn_backend::ops::ModuleOps::conv_transpose2d). ConvTranspose2d(ConvTranspose2dOpIr), /// Operation corresponding to [conv transpose 3d](burn_backend::ops::ModuleOps::conv_transpose3d). ConvTranspose3d(ConvTranspose3dOpIr), /// Operation corresponding to [avg pool 1d](burn_backend::ops::ModuleOps::avg_pool1d). AvgPool1d(AvgPool1dOpIr), /// Operation corresponding to [avg pool 2d](burn_backend::ops::ModuleOps::avg_pool2d). AvgPool2d(AvgPool2dOpIr), /// Operation corresponding to /// [avg pool 1d backward](burn_backend::ops::ModuleOps::avg_pool1d_backward). AvgPool1dBackward(AvgPool1dBackwardOpIr), /// Operation corresponding to /// [avg pool 2d backward](burn_backend::ops::ModuleOps::avg_pool2d_backward). AvgPool2dBackward(AvgPool2dBackwardOpIr), /// Operation corresponding to /// [adaptive avg pool 1d](burn_backend::ops::ModuleOps::adaptive_avg_pool1d). AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr), /// Operation corresponding to /// [adaptive avg pool 2d](burn_backend::ops::ModuleOps::adaptive_avg_pool2d). AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr), /// Operation corresponding to /// [adaptive avg pool 1d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool1d_backward). AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr), /// Operation corresponding to /// [adaptive avg pool 2d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool2d_backward). AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr), /// Operation corresponding to /// [max pool 1d](burn_backend::ops::ModuleOps::max_pool1d). MaxPool1d(MaxPool1dOpIr), /// Operation corresponding to /// [max pool 1d with indices](burn_backend::ops::ModuleOps::max_pool1d_with_indices). MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr), /// Operation corresponding to /// [max pool 1d with indices backward](burn_backend::ops::ModuleOps::max_pool1d_with_indices_backward). MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr), /// Operation corresponding to /// [max pool 2d](burn_backend::ops::ModuleOps::max_pool1d). MaxPool2d(MaxPool2dOpIr), /// Operation corresponding to /// [max pool 2d with indices](burn_backend::ops::ModuleOps::max_pool2d_with_indices). MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr), /// Operation corresponding to /// [max pool 2d with indices backward](burn_backend::ops::ModuleOps::max_pool2d_with_indices_backward). MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr), /// Operation corresponding to [interpolate](burn_backend::ops::ModuleOps::interpolate). Interpolate(InterpolateOpIr), /// Operation corresponding to [interpolate backward](burn_backend::ops::ModuleOps::interpolate_backward). InterpolateBackward(InterpolateBackwardOpIr), /// Operation corresponding to [attention](burn_backend::ops::ModuleOps::attention). Attention(AttentionOpIr), } /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationIr { /// Operation corresponding to: /// /// Float => [reshape](burn_backend::ops::FloatTensorOps::float_reshape). /// Int => [reshape](burn_backend::ops::IntTensorOps::int_reshape). /// Bool => [reshape](burn_backend::ops::BoolTensorOps::bool_reshape). Reshape(ShapeOpIr), /// Operation corresponding to: /// /// Float => [swap_dims](burn_backend::ops::FloatTensorOps::float_swap_dims). /// Int => [swap_dims](burn_backend::ops::IntTensorOps::int_swap_dims). /// Bool => [swap_dims](burn_backend::ops::BoolTensorOps::bool_swap_dims). SwapDims(SwapDimsOpIr), /// Operation corresponding to: /// /// Float => [permute](burn_backend::ops::FloatTensorOps::float_permute). /// Int => [permute](burn_backend::ops::IntTensorOps::int_permute). /// Bool => [permute](burn_backend::ops::BoolTensorOps::bool_permute). Permute(PermuteOpIr), /// Operation corresponding to: /// Float => [flip](burn_backend::ops::FloatTensorOps::float_flip). /// Int => [flip](burn_backend::ops::IntTensorOps::int_flip). /// Bool => [flip](burn_backend::ops::BoolTensorOps::bool_flip). Flip(FlipOpIr), /// Operation corresponding to: /// /// Float => [expand](burn_backend::ops::FloatTensorOps::float_expand). /// Int => [expand](burn_backend::ops::IntTensorOps::int_expand). /// Bool => [expand](burn_backend::ops::BoolTensorOps::bool_expand). Expand(ShapeOpIr), /// Unfold windows along an axis. /// Unfold(UnfoldOpIr), /// Operation corresponding to: /// /// Float => [slice](burn_backend::ops::FloatTensorOps::float_slice). /// Int => [slice](burn_backend::ops::IntTensorOps::int_slice). /// Bool => [slice](burn_backend::ops::BoolTensorOps::bool_slice). Slice(SliceOpIr), /// Operation corresponding to: /// /// Float => [slice assign](burn_backend::ops::FloatTensorOps::float_slice_assign). /// Int => [slice assign](burn_backend::ops::IntTensorOps::int_slice_assign). /// Bool => [slice assign](burn_backend::ops::BoolTensorOps::bool_slice_assign). SliceAssign(SliceAssignOpIr), /// Operation corresponding to: /// /// Float => [select](burn_backend::ops::FloatTensorOps::float_select). /// Int => [select](burn_backend::ops::IntTensorOps::int_select). /// Bool => [select](burn_backend::ops::BoolTensorOps::bool_select). Select(SelectOpIr), /// Operation corresponding to: /// /// Float => [select assign](burn_backend::ops::FloatTensorOps::float_select_add). /// Int => [select assign](burn_backend::ops::IntTensorOps::int_select_add). /// Bool => [select assign](burn_backend::ops::BoolTensorOps::bool_select_or). SelectAssign(SelectAssignOpIr), /// Operation corresponding to: /// /// Float => [mask where](burn_backend::ops::FloatTensorOps::float_mask_where). /// Int => [mask where](burn_backend::ops::IntTensorOps::int_mask_where). /// Bool => [mask where](burn_backend::ops::BoolTensorOps::bool_mask_where). MaskWhere(MaskWhereOpIr), /// Operation corresponding to: /// /// Float => [mask fill](burn_backend::ops::FloatTensorOps::float_mask_fill). /// Int => [mask fill](burn_backend::ops::IntTensorOps::int_mask_fill). /// Bool => [mask fill](burn_backend::ops::BoolTensorOps::bool_mask_fill). MaskFill(MaskFillOpIr), /// Operation corresponding to: /// /// Float => [gather](burn_backend::ops::FloatTensorOps::float_gather). /// Int => [gather](burn_backend::ops::IntTensorOps::int_gather). /// Bool => [gather](burn_backend::ops::BoolTensorOps::bool_gather). Gather(GatherOpIr), /// Operation corresponding to: /// /// Float => [scatter](burn_backend::ops::FloatTensorOps::float_scatter_add). /// Int => [scatter](burn_backend::ops::IntTensorOps::int_scatter_add). /// Bool => [scatter](burn_backend::ops::BoolTensorOps::bool_scatter_or). Scatter(ScatterOpIr), /// Operation corresponding to: /// /// Float => [equal](burn_backend::ops::FloatTensorOps::float_equal). /// Int => [equal](burn_backend::ops::IntTensorOps::int_equal). /// Bool => [equal](burn_backend::ops::BoolTensorOps::bool_equal). Equal(BinaryOpIr), /// Operation corresponding to: /// /// Float => [equal elem](burn_backend::ops::FloatTensorOps::float_equal_elem). /// Int => [equal elem](burn_backend::ops::IntTensorOps::int_equal_elem). /// Bool => [equal elem](burn_backend::ops::BoolTensorOps::bool_equal_elem). EqualElem(ScalarOpIr), /// Operation corresponding to: /// /// Float => [repeat dim](burn_backend::ops::FloatTensorOps::float_repeat_dim). /// Int => [repeat dim](burn_backend::ops::IntTensorOps::int_repeat_dim). /// Bool => [repeat dim](burn_backend::ops::BoolTensorOps::bool_repeat_dim). RepeatDim(RepeatDimOpIr), /// Operation corresponding to: /// /// Float => [cat](burn_backend::ops::FloatTensorOps::float_cat). /// Int => [cat](burn_backend::ops::IntTensorOps::int_cat). /// Bool => [cat](burn_backend::ops::BoolTensorOps::bool_cat). Cat(CatOpIr), /// Cast operation, no direct operation and should be supported by fusion backend. Cast(CastOpIr), /// Operation corresponding to: /// /// Float => [empty](burn_backend::ops::FloatTensorOps::float_empty). /// Int => [empty](burn_backend::ops::IntTensorOps::int_empty). /// Bool => [empty](burn_backend::ops::BoolTensorOps::bool_empty). Empty(CreationOpIr), /// Operation corresponding to: /// /// Float => [ones](burn_backend::ops::FloatTensorOps::float_ones). /// Int => [ones](burn_backend::ops::IntTensorOps::int_ones). /// Bool => [ones](burn_backend::ops::BoolTensorOps::bool_ones). Ones(CreationOpIr), /// Operation corresponding to: /// /// Float => [zeros](burn_backend::ops::FloatTensorOps::float_zeros). /// Int => [zeros](burn_backend::ops::IntTensorOps::int_zeros). /// Bool => [zeros](burn_backend::ops::BoolTensorOps::bool_zeros). Zeros(CreationOpIr), } /// Numeric operations on int and float tensors. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum NumericOperationIr { /// Operation corresponding to: /// /// Float => [add](burn_backend::ops::FloatTensorOps::float_add). /// Int => [add](burn_backend::ops::IntTensorOps::int_add). Add(BinaryOpIr), /// Operation corresponding to: /// /// Float => [add scalar](burn_backend::ops::FloatTensorOps::float_add_scalar). /// Int => [add scalar](burn_backend::ops::IntTensorOps::int_add_scalar). AddScalar(ScalarOpIr), /// Operation corresponding to: /// /// Float => [sub](burn_backend::ops::FloatTensorOps::float_sub). /// Int => [sub](burn_backend::ops::IntTensorOps::int_sub). Sub(BinaryOpIr), /// Operation corresponding to: /// /// Float => [sub scalar](burn_backend::ops::FloatTensorOps::float_sub_scalar). /// Int => [sub scalar](burn_backend::ops::IntTensorOps::int_sub_scalar). SubScalar(ScalarOpIr), /// Operation corresponding to: /// /// Float => [div](burn_backend::ops::FloatTensorOps::float_div). /// Int => [div](burn_backend::ops::IntTensorOps::int_div). Div(BinaryOpIr), /// Operation corresponding to: /// /// Float => [div scalar](burn_backend::ops::FloatTensorOps::float_div_scalar). /// Int => [div scalar](burn_backend::ops::IntTensorOps::int_div_scalar). DivScalar(ScalarOpIr), /// Operation corresponding to: /// /// Float => [rem](burn_backend::ops::FloatTensorOps::float_remainder). /// Int => [rem](burn_backend::ops::IntTensorOps::int_remainder). Rem(BinaryOpIr), /// Operation corresponding to: /// /// Float => [rem scalar](burn_backend::ops::FloatTensorOps::float_remainder_scalar). /// Int => [rem scalar](burn_backend::ops::IntTensorOps::int_remainder_scalar). RemScalar(ScalarOpIr), /// Operation corresponding to: /// /// Float => [mul](burn_backend::ops::FloatTensorOps::float_mul). /// Int => [mul](burn_backend::ops::IntTensorOps::int_mul). Mul(BinaryOpIr), /// Operation corresponding to: /// /// Float => [mul scalar](burn_backend::ops::FloatTensorOps::float_mul_scalar). /// Int => [mul scalar](burn_backend::ops::IntTensorOps::int_mul_scalar). MulScalar(ScalarOpIr), /// Operation corresponding to: /// /// Float => [abs](burn_backend::ops::FloatTensorOps::float_abs). /// Int => [abs](burn_backend::ops::IntTensorOps::int_abs). Abs(UnaryOpIr), /// Operation corresponding to: /// /// Float => [full](burn_backend::ops::FloatTensorOps::float_full). /// Int => [full](burn_backend::ops::IntTensorOps::int_full). Full(FullOpIr), /// Operation corresponding to: /// /// Float => [mean dim](burn_backend::ops::FloatTensorOps::float_mean_dim). /// Int => [mean dim](burn_backend::ops::IntTensorOps::int_mean_dim). MeanDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [mean](burn_backend::ops::FloatTensorOps::float_mean). /// Int => [mean](burn_backend::ops::IntTensorOps::int_mean). Mean(ReduceOpIr), /// Operation corresponding to: /// /// Float => [sum](burn_backend::ops::FloatTensorOps::float_sum). /// Int => [sum](burn_backend::ops::IntTensorOps::int_sum). Sum(ReduceOpIr), /// Operation corresponding to: /// /// Float => [sum dim](burn_backend::ops::FloatTensorOps::float_sum_dim). /// Int => [sum dim](burn_backend::ops::IntTensorOps::int_sum_dim). SumDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [prod](burn_backend::ops::FloatTensorOps::float_prod). /// Int => [prod](burn_backend::ops::IntTensorOps::int_prod). Prod(ReduceOpIr), /// Operation corresponding to: /// /// Float => [prod dim](burn_backend::ops::FloatTensorOps::float_prod_dim). /// Int => [prod dim](burn_backend::ops::IntTensorOps::int_prod_dim). ProdDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [greater](burn_backend::ops::FloatTensorOps::float_greater). /// Int => [greater](burn_backend::ops::IntTensorOps::int_greater). Greater(BinaryOpIr), /// Operation corresponding to: /// /// Float => [greater elem](burn_backend::ops::FloatTensorOps::float_greater_elem). /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem). GreaterElem(ScalarOpIr), /// Operation corresponding to: /// /// Float => [greater equal](burn_backend::ops::FloatTensorOps::float_greater_elem). /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem). GreaterEqual(BinaryOpIr), /// Operation corresponding to: /// /// Float => [greater equal elem](burn_backend::ops::FloatTensorOps::float_greater_equal_elem). /// Int => [greater equal elem](burn_backend::ops::IntTensorOps::int_greater_equal_elem). GreaterEqualElem(ScalarOpIr), /// Operation corresponding to: /// /// Float => [lower](burn_backend::ops::FloatTensorOps::float_lower). /// Int => [lower](burn_backend::ops::IntTensorOps::int_lower). Lower(BinaryOpIr), /// Operation corresponding to: /// /// Float => [lower elem](burn_backend::ops::FloatTensorOps::float_lower_elem). /// Int => [lower elem](burn_backend::ops::IntTensorOps::int_lower_elem). LowerElem(ScalarOpIr), /// Operation corresponding to: /// /// Float => [lower equal](burn_backend::ops::FloatTensorOps::float_lower_equal). /// Int => [lower equal](burn_backend::ops::IntTensorOps::int_lower_equal). LowerEqual(BinaryOpIr), /// Operation corresponding to: /// /// Float => [lower equal elem](burn_backend::ops::FloatTensorOps::float_lower_equal_elem). /// Int => [lower equal elem](burn_backend::ops::IntTensorOps::int_lower_equal_elem). LowerEqualElem(ScalarOpIr), /// Operation corresponding to: /// /// Float => [argmax](burn_backend::ops::FloatTensorOps::float_argmax). /// Int => [argmax](burn_backend::ops::IntTensorOps::int_argmax). ArgMax(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [argmin](burn_backend::ops::FloatTensorOps::float_argmin). /// Int => [argmin](burn_backend::ops::IntTensorOps::int_argmin). ArgMin(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [max](burn_backend::ops::FloatTensorOps::float_max). /// Int => [max](burn_backend::ops::IntTensorOps::int_max). Max(ReduceOpIr), /// Operation corresponding to: /// /// Float => [max dim with indices](burn_backend::ops::FloatTensorOps::float_max_dim_with_indices). /// Int => [max dim with indices](burn_backend::ops::IntTensorOps::int_max_dim_with_indices). MaxDimWithIndices(ReduceDimWithIndicesOpIr), /// Operation corresponding to: /// /// Float => [min dim with indices](burn_backend::ops::FloatTensorOps::float_min_dim_with_indices). /// Int => [min dim with indices](burn_backend::ops::IntTensorOps::int_min_dim_with_indices). MinDimWithIndices(ReduceDimWithIndicesOpIr), /// Operation corresponding to: /// /// Float => [min](burn_backend::ops::FloatTensorOps::float_min). /// Int => [min](burn_backend::ops::IntTensorOps::int_min). Min(ReduceOpIr), /// Operation corresponding to: /// /// Float => [max dim](burn_backend::ops::FloatTensorOps::float_max_dim). /// Int => [max dim](burn_backend::ops::IntTensorOps::int_max_dim). MaxDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [min dim](burn_backend::ops::FloatTensorOps::float_min_dim). /// Int => [min dim](burn_backend::ops::IntTensorOps::int_min_dim). MinDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [max_abs](burn_backend::ops::FloatTensorOps::float_max_abs). /// Int => [max_abs](burn_backend::ops::IntTensorOps::int_max_abs). MaxAbs(ReduceOpIr), /// Operation corresponding to: /// /// Float => [max_abs dim](burn_backend::ops::FloatTensorOps::float_max_abs_dim). /// Int => [max_abs dim](burn_backend::ops::IntTensorOps::int_max_abs_dim). MaxAbsDim(ReduceDimOpIr), /// Operation corresponding to: /// /// Float => [clamp](burn_backend::ops::FloatTensorOps::float_clamp). /// Int => [clamp](burn_backend::ops::IntTensorOps::int_clamp). Clamp(ClampOpIr), /// Operation corresponding to: /// /// Int => [random](burn_backend::ops::IntTensorOps::int_random). IntRandom(RandomOpIr), /// Operation corresponding to: /// /// Float => [powf](burn_backend::ops::FloatTensorOps::float_powi). /// Int => [powf](burn_backend::ops::IntTensorOps::int_powi). Powi(BinaryOpIr), /// Operation corresponding to: /// /// Float => [cumsum](burn_backend::ops::FloatTensorOps::float_cumsum). /// Int => [cumsum](burn_backend::ops::IntTensorOps::int_cumsum). CumSum(DimOpIr), /// Operation corresponding to: /// /// Float => [cumprod](burn_backend::ops::FloatTensorOps::float_cumprod). /// Int => [cumprod](burn_backend::ops::IntTensorOps::int_cumprod). CumProd(DimOpIr), /// Operation corresponding to: /// /// Float => [cummin](burn_backend::ops::FloatTensorOps::float_cummin). /// Int => [cummin](burn_backend::ops::IntTensorOps::int_cummin). CumMin(DimOpIr), /// Operation corresponding to: /// /// Float => [cummax](burn_backend::ops::FloatTensorOps::float_cummax). /// Int => [cummax](burn_backend::ops::IntTensorOps::int_cummax). CumMax(DimOpIr), } /// Operation intermediate representation specific to an int tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum IntOperationIr { /// Operation corresponding to [into float](burn_backend::ops::IntTensorOps::int_into_float). IntoFloat(CastOpIr), /// Operation corresponding to: /// /// Int => [bitwise and](burn_backend::ops::IntTensorOps::bitwise_and). BitwiseAnd(BinaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise and scalar](burn_backend::ops::IntTensorOps::bitwise_and_scalar). BitwiseAndScalar(ScalarOpIr), /// Operation corresponding to: /// /// Int => [bitwise or](burn_backend::ops::IntTensorOps::bitwise_or). BitwiseOr(BinaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise or scalar](burn_backend::ops::IntTensorOps::bitwise_or_scalar). BitwiseOrScalar(ScalarOpIr), /// Operation corresponding to: /// /// Int => [bitwise xor](burn_backend::ops::IntTensorOps::bitwise_xor). BitwiseXor(BinaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise xor scalar](burn_backend::ops::IntTensorOps::bitwise_xor_scalar). BitwiseXorScalar(ScalarOpIr), /// Operation corresponding to: /// /// Int => [bitwise not](burn_backend::ops::IntTensorOps::bitwise_not). BitwiseNot(UnaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise left shift](burn_backend::ops::IntTensorOps::bitwise_left_shift). BitwiseLeftShift(BinaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise left shift scalar](burn_backend::ops::IntTensorOps::bitwise_left_shift_scalar). BitwiseLeftShiftScalar(ScalarOpIr), /// Operation corresponding to: /// /// Int => [bitwise right shift](burn_backend::ops::IntTensorOps::bitwise_right_shift). BitwiseRightShift(BinaryOpIr), /// Operation corresponding to: /// /// Int => [bitwise right shift scalar](burn_backend::ops::IntTensorOps::bitwise_right_shift_scalar). BitwiseRightShiftScalar(ScalarOpIr), /// Operation corresponding to [matmul](burn_backend::ops::IntTensorOps::int_matmul). Matmul(MatmulOpIr), } /// Operation intermediate representation specific to a bool tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BoolOperationIr { /// Operation corresponding to [into float](burn_backend::ops::BoolTensorOps::bool_into_float). IntoFloat(CastOpIr), /// Operation corresponding to [into int](burn_backend::ops::BoolTensorOps::bool_into_int). IntoInt(CastOpIr), /// Operation corresponding to [not](burn_backend::ops::BoolTensorOps::bool_not). Not(UnaryOpIr), /// Operation corresponding to [and](burn_backend::ops::BoolTensorOps::bool_and). And(BinaryOpIr), /// Operation corresponding to [or](burn_backend::ops::BoolTensorOps::bool_or). Or(BinaryOpIr), } /// Swap dim operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct SwapDimsOpIr { /// Input tensor intermediate representation. pub input: TensorIr, /// Output tensor intermediate representation. pub out: TensorIr, /// The first dim to swap. pub dim1: usize, /// The second dim to swap. pub dim2: usize, } /// Permute operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct PermuteOpIr { /// Input tensor intermediate representation. pub input: TensorIr, /// Output tensor intermediate representation. pub out: TensorIr, /// The new order of the dimensions. pub axes: Vec, } /// Shape operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct ShapeOpIr { /// Input tensor intermediate representation. pub input: TensorIr, /// Output tensor intermediate representation with the new shape. pub out: TensorIr, } /// Unfold operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct UnfoldOpIr { /// Input tensor intermediate representation. pub input: TensorIr, /// Output tensor intermediate representation. pub out: TensorIr, /// The selected dim. pub dim: usize, /// The window size. pub size: usize, /// The window step along dim. pub step: usize, } /// Flip operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct FlipOpIr { /// Input tensor intermediate representation. pub input: TensorIr, /// Output tensor intermediate representation. pub out: TensorIr, /// The dimensions to flip. pub axes: Vec, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct RandomOpIr { pub out: TensorIr, pub distribution: Distribution, } /// Creation operation intermediate representation. /// As opposed to [InitOperationIr], creation operations are lazy initialized. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct CreationOpIr { /// Output tensor intermediate representation. pub out: TensorIr, } /// Full operation intermediate representation. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct FullOpIr { /// Output tensor intermediate representation. pub out: TensorIr, /// Fill value. pub value: ScalarIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] /// Declares a tensor has been initialized. /// /// It is necessary to register for proper orphan detection and avoid memory leak. pub struct InitOperationIr { /// The initialized tensor. pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct BinaryOpIr { pub lhs: TensorIr, pub rhs: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MatmulOpIr { pub lhs: TensorIr, pub rhs: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct CrossOpIr { pub lhs: TensorIr, pub rhs: TensorIr, pub out: TensorIr, pub dim: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct UnaryOpIr { pub input: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ScalarOpIr { pub lhs: TensorIr, // TODO: Make that an enum with `Value` and `Id` variants for relative/global // conversion. pub rhs: ScalarIr, pub out: TensorIr, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub struct ReduceOpIr { pub input: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub struct ReduceDimOpIr { pub input: TensorIr, pub out: TensorIr, pub axis: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct CastOpIr { pub input: TensorIr, pub out: TensorIr, } /// IR for operations that operate along a dimension without reducing it. /// Unlike `ReduceDimOpIr`, the output shape is the same as the input shape. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub struct DimOpIr { pub input: TensorIr, pub out: TensorIr, pub axis: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct GatherOpIr { pub tensor: TensorIr, pub dim: usize, pub indices: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ScatterOpIr { pub tensor: TensorIr, pub dim: usize, pub indices: TensorIr, pub value: TensorIr, pub update: IndexingUpdateOp, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct SelectOpIr { pub tensor: TensorIr, pub dim: usize, pub indices: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct SelectAssignOpIr { pub tensor: TensorIr, pub dim: usize, pub indices: TensorIr, pub value: TensorIr, pub update: IndexingUpdateOp, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct SliceOpIr { pub tensor: TensorIr, pub ranges: Vec, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct SliceAssignOpIr { pub tensor: TensorIr, pub ranges: Vec, pub value: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaskWhereOpIr { pub tensor: TensorIr, pub mask: TensorIr, pub value: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaskFillOpIr { pub tensor: TensorIr, pub mask: TensorIr, pub value: ScalarIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ClampOpIr { pub tensor: TensorIr, pub min: ScalarIr, pub max: ScalarIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct RepeatDimOpIr { pub tensor: TensorIr, pub dim: usize, pub times: usize, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct CatOpIr { pub tensors: Vec, pub dim: usize, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ReduceDimWithIndicesOpIr { pub tensor: TensorIr, pub dim: usize, pub out: TensorIr, pub out_indices: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct EmbeddingOpIr { pub weights: TensorIr, pub indices: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct EmbeddingBackwardOpIr { pub weights: TensorIr, pub out_grad: TensorIr, pub indices: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv1dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: Conv1dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv1dXBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv1dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv1dWeightBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv1dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv1dBiasBackwardOpIr { pub x: TensorIr, pub bias: TensorIr, pub output_grad: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv2dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: Conv2dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv2dXBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv2dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv2dWeightBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv2dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv2dBiasBackwardOpIr { pub x: TensorIr, pub bias: TensorIr, pub output_grad: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DeformConv2dOpIr { pub x: TensorIr, pub offset: TensorIr, pub weight: TensorIr, pub mask: Option, pub bias: Option, pub options: DeformableConv2dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DeformConv2dBackwardOpIr { pub x: TensorIr, pub offset: TensorIr, pub weight: TensorIr, pub mask: Option, pub bias: Option, pub out_grad: TensorIr, pub options: DeformableConv2dOptionsIr, pub input_grad: TensorIr, pub offset_grad: TensorIr, pub weight_grad: TensorIr, pub mask_grad: Option, pub bias_grad: Option, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: Conv3dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dXBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv3dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dWeightBackwardOpIr { pub x: TensorIr, pub weight: TensorIr, pub output_grad: TensorIr, pub options: Conv3dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dBiasBackwardOpIr { pub x: TensorIr, pub bias: TensorIr, pub output_grad: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose1dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: ConvTranspose1dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose2dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: ConvTranspose2dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose3dOpIr { pub x: TensorIr, pub weight: TensorIr, pub bias: Option, pub options: ConvTranspose3dOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv1dOptionsIr { pub stride: [usize; 1], pub padding: [usize; 1], pub dilation: [usize; 1], pub groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv2dOptionsIr { pub stride: [usize; 2], pub padding: [usize; 2], pub dilation: [usize; 2], pub groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DeformableConv2dOptionsIr { pub stride: [usize; 2], pub padding: [usize; 2], pub dilation: [usize; 2], pub weight_groups: usize, pub offset_groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dOptionsIr { pub stride: [usize; 3], pub padding: [usize; 3], pub dilation: [usize; 3], pub groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose1dOptionsIr { pub stride: [usize; 1], pub padding: [usize; 1], pub padding_out: [usize; 1], pub dilation: [usize; 1], pub groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose2dOptionsIr { pub stride: [usize; 2], pub padding: [usize; 2], pub padding_out: [usize; 2], pub dilation: [usize; 2], pub groups: usize, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ConvTranspose3dOptionsIr { pub stride: [usize; 3], pub padding: [usize; 3], pub padding_out: [usize; 3], pub dilation: [usize; 3], pub groups: usize, } /// Quantization parameters intermediate representation. #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct QuantizationParametersIr { /// The scaling factor. pub scales: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct QuantizeOpIr { pub tensor: TensorIr, pub qparams: QuantizationParametersIr, pub scheme: QuantScheme, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct DequantizeOpIr { pub input: TensorIr, pub out: TensorIr, } impl From> for Conv1dOptionsIr { fn from(value: ConvOptions<1>) -> Self { Self { stride: value.stride, padding: value.padding, dilation: value.dilation, groups: value.groups, } } } impl From> for Conv2dOptionsIr { fn from(value: ConvOptions<2>) -> Self { Self { stride: value.stride, padding: value.padding, dilation: value.dilation, groups: value.groups, } } } impl From> for Conv3dOptionsIr { fn from(value: ConvOptions<3>) -> Self { Self { stride: value.stride, padding: value.padding, dilation: value.dilation, groups: value.groups, } } } impl From> for DeformableConv2dOptionsIr { fn from(value: DeformConvOptions<2>) -> Self { Self { stride: value.stride, padding: value.padding, dilation: value.dilation, weight_groups: value.weight_groups, offset_groups: value.offset_groups, } } } impl From> for ConvTranspose1dOptionsIr { fn from(value: ConvTransposeOptions<1>) -> Self { Self { stride: value.stride, padding: value.padding, padding_out: value.padding_out, dilation: value.dilation, groups: value.groups, } } } impl From> for ConvTranspose2dOptionsIr { fn from(value: ConvTransposeOptions<2>) -> Self { Self { stride: value.stride, padding: value.padding, padding_out: value.padding_out, dilation: value.dilation, groups: value.groups, } } } impl From> for ConvTranspose3dOptionsIr { fn from(value: ConvTransposeOptions<3>) -> Self { Self { stride: value.stride, padding: value.padding, padding_out: value.padding_out, dilation: value.dilation, groups: value.groups, } } } impl From for ConvOptions<1> { fn from(val: Conv1dOptionsIr) -> Self { ConvOptions { stride: val.stride, padding: val.padding, dilation: val.dilation, groups: val.groups, } } } impl From for ConvOptions<2> { fn from(val: Conv2dOptionsIr) -> Self { ConvOptions { stride: val.stride, padding: val.padding, dilation: val.dilation, groups: val.groups, } } } impl From for ConvOptions<3> { fn from(val: Conv3dOptionsIr) -> Self { ConvOptions { stride: val.stride, padding: val.padding, dilation: val.dilation, groups: val.groups, } } } impl From for DeformConvOptions<2> { fn from(value: DeformableConv2dOptionsIr) -> Self { DeformConvOptions { stride: value.stride, padding: value.padding, dilation: value.dilation, weight_groups: value.weight_groups, offset_groups: value.offset_groups, } } } impl From for ConvTransposeOptions<1> { fn from(val: ConvTranspose1dOptionsIr) -> Self { ConvTransposeOptions { stride: val.stride, padding: val.padding, padding_out: val.padding_out, dilation: val.dilation, groups: val.groups, } } } impl From for ConvTransposeOptions<2> { fn from(val: ConvTranspose2dOptionsIr) -> Self { ConvTransposeOptions { stride: val.stride, padding: val.padding, padding_out: val.padding_out, dilation: val.dilation, groups: val.groups, } } } impl From for ConvTransposeOptions<3> { fn from(val: ConvTranspose3dOptionsIr) -> Self { ConvTransposeOptions { stride: val.stride, padding: val.padding, padding_out: val.padding_out, dilation: val.dilation, groups: val.groups, } } } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AvgPool1dOpIr { pub x: TensorIr, pub kernel_size: usize, pub stride: usize, pub padding: usize, pub count_include_pad: bool, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AvgPool2dOpIr { pub x: TensorIr, pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub count_include_pad: bool, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AvgPool1dBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub kernel_size: usize, pub stride: usize, pub padding: usize, pub count_include_pad: bool, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AvgPool2dBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub count_include_pad: bool, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dOpIr { pub x: TensorIr, pub output_size: usize, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dOpIr { pub x: TensorIr, pub output_size: [usize; 2], pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AdaptiveAvgPool1dBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AdaptiveAvgPool2dBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaxPool1dOpIr { pub x: TensorIr, pub kernel_size: usize, pub stride: usize, pub padding: usize, pub dilation: usize, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesOpIr { pub x: TensorIr, pub kernel_size: usize, pub stride: usize, pub padding: usize, pub dilation: usize, pub ceil_mode: bool, pub out: TensorIr, pub out_indices: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaxPool1dWithIndicesBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub indices: TensorIr, pub kernel_size: usize, pub stride: usize, pub padding: usize, pub dilation: usize, pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaxPool2dOpIr { pub x: TensorIr, pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub dilation: [usize; 2], pub ceil_mode: bool, pub out: TensorIr, } #[allow(missing_docs)] #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub struct MaxPool2dWithIndicesOpIr { pub x: TensorIr, pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub dilation: [usize; 2], pub ceil_mode: bool, pub out: TensorIr, pub out_indices: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct MaxPool2dWithIndicesBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub indices: TensorIr, pub kernel_size: [usize; 2], pub stride: [usize; 2], pub padding: [usize; 2], pub dilation: [usize; 2], pub ceil_mode: bool, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub enum InterpolateModeIr { Nearest, Bilinear, Bicubic, Lanczos3, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct InterpolateOptionsIr { pub mode: InterpolateModeIr, pub align_corners: bool, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct InterpolateOpIr { pub x: TensorIr, pub output_size: [usize; 2], pub options: InterpolateOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AttentionOptionsIr { pub scale: Option, pub softcap: Option, pub is_causal: bool, } impl From for AttentionModuleOptions { fn from(ir: AttentionOptionsIr) -> Self { AttentionModuleOptions { scale: ir.scale.map(|s| s.elem()), softcap: ir.softcap.map(|s| s.elem()), is_causal: ir.is_causal, } } } impl From for AttentionOptionsIr { fn from(ir: AttentionModuleOptions) -> Self { AttentionOptionsIr { scale: ir.scale.map(ScalarIr::Float), softcap: ir.softcap.map(ScalarIr::Float), is_causal: ir.is_causal, } } } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct AttentionOpIr { pub query: TensorIr, pub key: TensorIr, pub value: TensorIr, pub mask: Option, pub attn_bias: Option, pub options: AttentionOptionsIr, pub out: TensorIr, } impl From for InterpolateMode { fn from(val: InterpolateModeIr) -> Self { match val { InterpolateModeIr::Nearest => Self::Nearest, InterpolateModeIr::Bilinear => Self::Bilinear, InterpolateModeIr::Bicubic => Self::Bicubic, InterpolateModeIr::Lanczos3 => Self::Lanczos3, } } } impl From for InterpolateOptions { fn from(val: InterpolateOptionsIr) -> Self { Self::new(val.mode.into()).with_align_corners(val.align_corners) } } impl From for InterpolateModeIr { fn from(val: InterpolateMode) -> Self { match val { InterpolateMode::Nearest => Self::Nearest, InterpolateMode::Bilinear => Self::Bilinear, InterpolateMode::Bicubic => Self::Bicubic, InterpolateMode::Lanczos3 => Self::Lanczos3, } } } impl From for InterpolateOptionsIr { fn from(val: InterpolateOptions) -> Self { Self { mode: val.mode.into(), align_corners: val.align_corners, } } } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct InterpolateBackwardOpIr { pub x: TensorIr, pub grad: TensorIr, pub output_size: [usize; 2], pub options: InterpolateOptionsIr, pub out: TensorIr, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub enum GridSamplePaddingModeIr { Zeros, Border, Reflection, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct GridSampleOptionsIr { pub mode: InterpolateModeIr, pub padding_mode: GridSamplePaddingModeIr, pub align_corners: bool, } #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct GridSample2dOpIr { pub tensor: TensorIr, pub grid: TensorIr, pub options: GridSampleOptionsIr, pub out: TensorIr, } impl From for GridSamplePaddingMode { fn from(val: GridSamplePaddingModeIr) -> Self { match val { GridSamplePaddingModeIr::Zeros => Self::Zeros, GridSamplePaddingModeIr::Border => Self::Border, GridSamplePaddingModeIr::Reflection => Self::Reflection, } } } impl From for GridSamplePaddingModeIr { fn from(val: GridSamplePaddingMode) -> Self { match val { GridSamplePaddingMode::Zeros => Self::Zeros, GridSamplePaddingMode::Border => Self::Border, GridSamplePaddingMode::Reflection => Self::Reflection, } } } impl From for GridSampleOptions { fn from(val: GridSampleOptionsIr) -> Self { Self { mode: val.mode.into(), padding_mode: val.padding_mode.into(), align_corners: val.align_corners, } } } impl From for GridSampleOptionsIr { fn from(val: GridSampleOptions) -> Self { Self { mode: val.mode.into(), padding_mode: val.padding_mode.into(), align_corners: val.align_corners, } } } impl OperationIr { /// Get all input [tensors](TensorIr) involved with the current operation. pub fn inputs(&self) -> impl Iterator { match self { OperationIr::BaseFloat(repr) => repr.inputs(), OperationIr::BaseInt(repr) => repr.inputs(), OperationIr::BaseBool(repr) => repr.inputs(), OperationIr::NumericFloat(_dtype, repr) => repr.inputs(), OperationIr::NumericInt(_dtype, repr) => repr.inputs(), OperationIr::Bool(repr) => repr.inputs(), OperationIr::Int(repr) => repr.inputs(), OperationIr::Float(_dtype, repr) => repr.inputs(), OperationIr::Module(repr) => repr.inputs(), OperationIr::Init(repr) => repr.inputs(), OperationIr::Custom(repr) => repr.inputs(), OperationIr::Drop(repr) => Box::new([repr].into_iter()), } } /// Get all output [tensors](TensorIr) involved with the current operation. pub fn outputs(&self) -> impl Iterator { match self { OperationIr::BaseFloat(repr) => repr.outputs(), OperationIr::BaseInt(repr) => repr.outputs(), OperationIr::BaseBool(repr) => repr.outputs(), OperationIr::NumericFloat(_dtype, repr) => repr.outputs(), OperationIr::NumericInt(_dtype, repr) => repr.outputs(), OperationIr::Bool(repr) => repr.outputs(), OperationIr::Int(repr) => repr.outputs(), OperationIr::Float(_dtype, repr) => repr.outputs(), OperationIr::Module(repr) => repr.outputs(), OperationIr::Init(repr) => repr.outputs(), OperationIr::Custom(repr) => repr.outputs(), OperationIr::Drop(_repr) => Box::new([].into_iter()), } } /// Get all [tensor](TensorIr) involved with the current operation. pub fn nodes(&self) -> Vec<&TensorIr> { self.inputs().chain(self.outputs()).collect() } /// Set the given nodes that are [read write](super::TensorStatus::ReadWrite) to /// [read only](super::TensorStatus::ReadOnly) in the current operation. /// /// Returns the tensor that were updated with their original representation. pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { match self { OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes), OperationIr::BaseInt(repr) => repr.mark_read_only(nodes), OperationIr::BaseBool(repr) => repr.mark_read_only(nodes), OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes), OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes), OperationIr::Bool(repr) => repr.mark_read_only(nodes), OperationIr::Int(repr) => repr.mark_read_only(nodes), OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes), OperationIr::Module(repr) => repr.mark_read_only(nodes), OperationIr::Init(_) => Vec::new(), OperationIr::Drop(repr) => { let mut output = Vec::new(); repr.mark_read_only(nodes, &mut output); output } OperationIr::Custom(repr) => { let mut output = Vec::new(); for input in repr.inputs.iter_mut() { input.mark_read_only(nodes, &mut output); } output } } } } impl BaseOperationIr { fn inputs(&self) -> Box + '_> { match self { BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()), BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()), BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()), BaseOperationIr::Scatter(repr) => { Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter()) } BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()), BaseOperationIr::SelectAssign(repr) => { Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter()) } BaseOperationIr::MaskWhere(repr) => { Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter()) } BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()), BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()), BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()), BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()), BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()), BaseOperationIr::Empty(_repr) => Box::new([].into_iter()), BaseOperationIr::Ones(_repr) => Box::new([].into_iter()), BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()), } } fn outputs(&self) -> Box + '_> { match self { BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()), BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { BaseOperationIr::Reshape(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::SwapDims(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Permute(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Expand(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Flip(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Slice(repr) => { repr.tensor.mark_read_only(nodes, &mut output); } BaseOperationIr::SliceAssign(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.value.mark_read_only(nodes, &mut output); } BaseOperationIr::Gather(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); } BaseOperationIr::Scatter(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); repr.value.mark_read_only(nodes, &mut output); } BaseOperationIr::Select(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); } BaseOperationIr::SelectAssign(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); repr.value.mark_read_only(nodes, &mut output); } BaseOperationIr::MaskWhere(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.mask.mark_read_only(nodes, &mut output); repr.value.mark_read_only(nodes, &mut output); } BaseOperationIr::MaskFill(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.mask.mark_read_only(nodes, &mut output); } BaseOperationIr::Equal(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } BaseOperationIr::EqualElem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } BaseOperationIr::RepeatDim(repr) => { repr.tensor.mark_read_only(nodes, &mut output); } BaseOperationIr::Cat(repr) => { for t in repr.tensors.iter_mut() { t.mark_read_only(nodes, &mut output); } } BaseOperationIr::Cast(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Unfold(repr) => { repr.input.mark_read_only(nodes, &mut output); } BaseOperationIr::Empty(_) => {} BaseOperationIr::Zeros(_) => {} BaseOperationIr::Ones(_) => {} }; output } } impl NumericOperationIr { fn inputs(&self) -> Box + '_> { match self { NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()), NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()), NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Full(_repr) => Box::new([].into_iter()), NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()), NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()), NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()), NumericOperationIr::Powi(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), NumericOperationIr::CumMin(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::CumMax(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::CumProd(repr) => Box::new([&repr.input].into_iter()), NumericOperationIr::CumSum(repr) => Box::new([&repr.input].into_iter()), } } fn outputs(&self) -> Box + '_> { match self { NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MaxDimWithIndices(repr) => { Box::new([&repr.out, &repr.out_indices].into_iter()) } NumericOperationIr::MinDimWithIndices(repr) => { Box::new([&repr.out, &repr.out_indices].into_iter()) } NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::Powi(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()), NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { NumericOperationIr::Add(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::AddScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Sub(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::SubScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Mul(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::MulScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Div(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::DivScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Rem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::RemScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::GreaterElem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::GreaterEqualElem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::LowerElem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::LowerEqualElem(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Greater(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::GreaterEqual(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::Lower(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::LowerEqual(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::ArgMax(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::ArgMin(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Clamp(repr) => { repr.tensor.mark_read_only(nodes, &mut output); } NumericOperationIr::Abs(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Full(_) => {} NumericOperationIr::MeanDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Mean(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Sum(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::SumDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Prod(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::ProdDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::Max(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::MaxDimWithIndices(repr) => { repr.tensor.mark_read_only(nodes, &mut output); } NumericOperationIr::MinDimWithIndices(repr) => { repr.tensor.mark_read_only(nodes, &mut output); } NumericOperationIr::Min(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::MaxDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::MinDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::MaxAbs(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::MaxAbsDim(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::IntRandom(_) => {} NumericOperationIr::Powi(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } NumericOperationIr::CumSum(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::CumProd(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::CumMin(repr) => { repr.input.mark_read_only(nodes, &mut output); } NumericOperationIr::CumMax(repr) => { repr.input.mark_read_only(nodes, &mut output); } }; output } } impl FloatOperationIr { fn inputs(&self) -> Box + '_> { match self { FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), FloatOperationIr::Random(_repr) => Box::new([].into_iter()), FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()), FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Quantize(repr) => { Box::new([&repr.tensor, &repr.qparams.scales].into_iter()) } FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::GridSample2d(repr) => { Box::new([&repr.tensor, &repr.grid].into_iter()) } FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()), FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), FloatOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), } } fn outputs(&self) -> Box + '_> { match self { FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()), FloatOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { FloatOperationIr::Matmul(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } FloatOperationIr::Cross(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } FloatOperationIr::Random(_) => {} FloatOperationIr::Exp(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Log(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Log1p(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Erf(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Recip(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::PowfScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } FloatOperationIr::Sqrt(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Cos(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Sin(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Tanh(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Round(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Floor(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Ceil(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Trunc(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::Quantize(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.qparams.scales.mark_read_only(nodes, &mut output); } FloatOperationIr::Dequantize(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::IntoInt(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::IsNan(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::IsInf(repr) => { repr.input.mark_read_only(nodes, &mut output); } FloatOperationIr::GridSample2d(repr) => { repr.tensor.mark_read_only(nodes, &mut output); repr.grid.mark_read_only(nodes, &mut output); } FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output), FloatOperationIr::ArcTan2(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } FloatOperationIr::Powf(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } }; output } } impl IntOperationIr { fn inputs(&self) -> Box + '_> { match self { IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()), IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()), IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()), IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()), IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()), IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()), IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()), } } fn outputs(&self) -> Box + '_> { match self { IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()), IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { IntOperationIr::Matmul(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::IntoFloat(repr) => { repr.input.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseAnd(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseAndScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseOr(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseOrScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseXor(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseXorScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseNot(repr) => { repr.input.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseLeftShift(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseLeftShiftScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseRightShift(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } IntOperationIr::BitwiseRightShiftScalar(repr) => { repr.lhs.mark_read_only(nodes, &mut output); } }; output } } impl BoolOperationIr { fn inputs(&self) -> Box + '_> { match self { BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()), BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()), BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()), BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()), } } fn outputs(&self) -> Box + '_> { match self { BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()), BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()), BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()), BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()), BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { BoolOperationIr::IntoFloat(repr) => { repr.input.mark_read_only(nodes, &mut output); } BoolOperationIr::IntoInt(repr) => { repr.input.mark_read_only(nodes, &mut output); } BoolOperationIr::Not(repr) => { repr.input.mark_read_only(nodes, &mut output); } BoolOperationIr::And(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } BoolOperationIr::Or(repr) => { repr.lhs.mark_read_only(nodes, &mut output); repr.rhs.mark_read_only(nodes, &mut output); } }; output } } impl ModuleOperationIr { fn inputs(&self) -> Box + '_> { match self { ModuleOperationIr::Embedding(repr) => { Box::new([&repr.weights, &repr.indices].into_iter()) } ModuleOperationIr::EmbeddingBackward(repr) => { Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter()) } ModuleOperationIr::Conv1d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::Conv1dXBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv1dWeightBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv1dBiasBackward(repr) => { Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv2d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::Conv2dXBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv2dWeightBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv2dBiasBackward(repr) => { Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv3d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::Conv3dXBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv3dWeightBackward(repr) => { Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter()) } ModuleOperationIr::Conv3dBiasBackward(repr) => { Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter()) } ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) { (Some(mask), Some(bias)) => { Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter()) } (Some(mask), None) => { Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter()) } (None, Some(bias)) => { Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter()) } (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()), }, ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) { (Some(mask), Some(bias)) => Box::new( [ &repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask, bias, ] .into_iter(), ), (Some(mask), None) => Box::new( [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(), ), (None, Some(bias)) => Box::new( [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(), ), (None, None) => { Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter()) } }, ModuleOperationIr::ConvTranspose1d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::ConvTranspose2d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::ConvTranspose3d(repr) => { if let Some(bias) = &repr.bias { Box::new([&repr.x, &repr.weight, bias].into_iter()) } else { Box::new([&repr.x, &repr.weight].into_iter()) } } ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::AvgPool1dBackward(repr) => { Box::new([&repr.x, &repr.grad].into_iter()) } ModuleOperationIr::AvgPool2dBackward(repr) => { Box::new([&repr.x, &repr.grad].into_iter()) } ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => { Box::new([&repr.x, &repr.grad].into_iter()) } ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => { Box::new([&repr.x, &repr.grad].into_iter()) } ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => { Box::new([&repr.x, &repr.indices, &repr.grad].into_iter()) } ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => { Box::new([&repr.x, &repr.indices, &repr.grad].into_iter()) } ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()), ModuleOperationIr::InterpolateBackward(repr) => { Box::new([&repr.x, &repr.grad].into_iter()) } ModuleOperationIr::Attention(repr) => { if let Some(mask) = &repr.mask { if let Some(attn_bias) = &repr.attn_bias { Box::new([&repr.query, &repr.key, &repr.value, mask, attn_bias].into_iter()) } else { Box::new([&repr.query, &repr.key, &repr.value, mask].into_iter()) } } else if let Some(attn_bias) = &repr.attn_bias { Box::new([&repr.query, &repr.key, &repr.value, attn_bias].into_iter()) } else { Box::new([&repr.query, &repr.key, &repr.value].into_iter()) } } } } fn outputs(&self) -> Box + '_> { match self { ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv1dXBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv1dWeightBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv1dBiasBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv2dXBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv2dWeightBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv2dBiasBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv3dXBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv3dWeightBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Conv3dBiasBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::DeformableConv2dBackward(repr) => { match (&repr.mask_grad, &repr.bias_grad) { (Some(mask_grad), Some(bias_grad)) => Box::new( [ &repr.input_grad, &repr.offset_grad, &repr.weight_grad, mask_grad, bias_grad, ] .into_iter(), ), (Some(mask_grad), None) => Box::new( [ &repr.input_grad, &repr.offset_grad, &repr.weight_grad, mask_grad, ] .into_iter(), ), (None, Some(bias_grad)) => Box::new( [ &repr.input_grad, &repr.offset_grad, &repr.weight_grad, bias_grad, ] .into_iter(), ), (None, None) => Box::new( [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(), ), } } ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::MaxPool1dWithIndices(repr) => { Box::new([&repr.out, &repr.out_indices].into_iter()) } ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => { Box::new([&repr.out].into_iter()) } ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::MaxPool2dWithIndices(repr) => { Box::new([&repr.out, &repr.out_indices].into_iter()) } ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => { Box::new([&repr.out].into_iter()) } ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()), ModuleOperationIr::Attention(repr) => Box::new([&repr.out].into_iter()), } } fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec { let mut output = Vec::new(); match self { ModuleOperationIr::Embedding(repr) => { repr.weights.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); } ModuleOperationIr::EmbeddingBackward(repr) => { repr.weights.mark_read_only(nodes, &mut output); repr.out_grad.mark_read_only(nodes, &mut output); repr.indices.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv1d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::Conv1dXBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv1dWeightBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv1dBiasBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.bias.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv2d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::Conv2dXBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv2dWeightBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv2dBiasBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.bias.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv3d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::Conv3dXBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv3dWeightBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Conv3dBiasBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.bias.mark_read_only(nodes, &mut output); repr.output_grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::DeformableConv2d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.offset.mark_read_only(nodes, &mut output); match (&mut repr.mask, &mut repr.bias) { (Some(mask), Some(bias)) => { mask.mark_read_only(nodes, &mut output); bias.mark_read_only(nodes, &mut output); } (Some(mask), None) => { mask.mark_read_only(nodes, &mut output); } (None, Some(bias)) => { bias.mark_read_only(nodes, &mut output); } (None, None) => {} }; } ModuleOperationIr::DeformableConv2dBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); repr.offset.mark_read_only(nodes, &mut output); repr.out_grad.mark_read_only(nodes, &mut output); if let Some(mask) = repr.mask.as_mut() { mask.mark_read_only(nodes, &mut output); } if let Some(bias) = repr.bias.as_mut() { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::ConvTranspose1d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::ConvTranspose2d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::ConvTranspose3d(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.weight.mark_read_only(nodes, &mut output); if let Some(bias) = &mut repr.bias { bias.mark_read_only(nodes, &mut output); } } ModuleOperationIr::AvgPool1d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::AvgPool2d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::AvgPool1dBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::AvgPool2dBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::AdaptiveAvgPool1d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::AdaptiveAvgPool2d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool1d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool1dWithIndices(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool2d(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool2dWithIndices(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Interpolate(repr) => { repr.x.mark_read_only(nodes, &mut output); } ModuleOperationIr::InterpolateBackward(repr) => { repr.x.mark_read_only(nodes, &mut output); repr.grad.mark_read_only(nodes, &mut output); } ModuleOperationIr::Attention(repr) => { repr.query.mark_read_only(nodes, &mut output); repr.key.mark_read_only(nodes, &mut output); repr.value.mark_read_only(nodes, &mut output); if let Some(mask) = &mut repr.mask { mask.mark_read_only(nodes, &mut output); } if let Some(attn_bias) = &mut repr.attn_bias { attn_bias.mark_read_only(nodes, &mut output); } } }; output } } impl InitOperationIr { fn inputs(&self) -> Box + '_> { Box::new([].into_iter()) } fn outputs(&self) -> Box + '_> { Box::new([&self.out].into_iter()) } } impl TensorIr { fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec) { if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) { output.push(self.clone()); self.status = TensorStatus::ReadOnly; } } } impl core::hash::Hash for RandomOpIr { fn hash(&self, state: &mut H) { self.out.hash(state); match self.distribution { Distribution::Default => 1u8.hash(state), Distribution::Bernoulli(_) => 2u8.hash(state), Distribution::Uniform(_, _) => 3u8.hash(state), Distribution::Normal(_, _) => 4u8.hash(state), } } } /// Extension trait to extract outputs when registering an operation. pub trait OperationOutput { /// Extract a single output. fn output(self) -> O; /// Extract a fixed number of outputs. fn outputs(self) -> [O; N]; } impl OperationOutput for Vec { fn output(self) -> O { let [tensor] = self.outputs(); tensor } fn outputs(self) -> [O; N] { self.try_into().unwrap() } } ================================================ FILE: crates/burn-ir/src/scalar.rs ================================================ use burn_backend::{DType, Scalar}; use burn_backend::{Element, ElementConversion}; use core::hash::Hash; use serde::{Deserialize, Serialize}; /// A scalar representation. #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub enum ScalarIr { Float(f64), Int(i64), UInt(u64), Bool(bool), } impl Hash for ScalarIr { fn hash(&self, state: &mut H) { match self { ScalarIr::Float(x) => x.to_bits().hash(state), ScalarIr::Int(x) => x.hash(state), ScalarIr::UInt(x) => x.hash(state), ScalarIr::Bool(x) => x.hash(state), } } } impl ScalarIr { /// Creates a scalar with the specified data type. pub fn new(value: E, dtype: &DType) -> Self { if dtype.is_float() { Self::Float(value.elem()) } else if dtype.is_int() { Self::Int(value.elem()) } else if dtype.is_uint() { Self::UInt(value.elem()) } else if dtype.is_bool() { Self::Bool(value.elem()) } else { unimplemented!("Scalar not supported for {dtype:?}") } } /// Converts and returns the converted element. pub fn elem(self) -> E { match self { ScalarIr::Float(x) => x.elem(), ScalarIr::Int(x) => x.elem(), ScalarIr::UInt(x) => x.elem(), ScalarIr::Bool(x) => x.elem(), } } } // The enums are similar, but both types have different roles: // - `Scalar`: runtime literal value // - `ScalarIr`: serializable literal representation (used for IR) impl From for ScalarIr { fn from(value: Scalar) -> Self { match value { Scalar::Float(x) => Self::Float(x), Scalar::Int(x) => Self::Int(x), Scalar::UInt(x) => Self::UInt(x), Scalar::Bool(x) => Self::Bool(x), } } } impl From for Scalar { fn from(value: ScalarIr) -> Self { match value { ScalarIr::Float(x) => Self::Float(x), ScalarIr::Int(x) => Self::Int(x), ScalarIr::UInt(x) => Self::UInt(x), ScalarIr::Bool(x) => Self::Bool(x), } } } ================================================ FILE: crates/burn-ir/src/tensor.rs ================================================ use serde::{Deserialize, Serialize}; use burn_backend::{DType, Shape}; /// The tensor unique identifier. #[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] pub struct TensorId { value: u64, } impl core::fmt::Display for TensorId { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!("TensorId({:?})", self.value)) } } /// The status of the current tensor. #[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum TensorStatus { /// The tensor can be read, but not written. ReadOnly, /// The tensor can be mutated inplace. ReadWrite, /// No handle exists for that tensor. NotInit, } /// A tensor definition represents a snapshot of a tensor when it was used. /// /// # Example /// /// A tensor that is used multiple times has its status updated for each operation. /// /// 1. Status::NotInit /// 2. Status::ReadOnly /// 3. Status::ReadOnly /// 4. Status::ReadWrite #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct TensorIr { /// The [tensor id](TensorId). pub id: TensorId, /// The shape of the tensor. pub shape: Shape, /// The [status](TensorStatus) of the tensor when it was used. pub status: TensorStatus, /// The [type](DType) of the tensor. pub dtype: DType, } impl TensorId { /// Create a new tensor id. pub fn new(value: u64) -> Self { Self { value } } } impl TensorIr { /// Create a new tensor that is not already initialized. pub fn uninit(id: TensorId, shape: Shape, dtype: DType) -> Self { Self { id, status: TensorStatus::NotInit, shape, dtype, } } } ================================================ FILE: crates/burn-ndarray/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Ndarray backend for the Burn framework" documentation = "https://docs.rs/burn-ndarray" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-ndarray" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-ndarray" version.workspace = true [lints] workspace = true [features] blas-accelerate = [ "blas-src/accelerate", # Accelerate framework (macOS only) "ndarray/blas", ] blas-netlib = ["blas-src/netlib", "ndarray/blas"] blas-openblas = ["blas-src/openblas", "ndarray/blas", "openblas-src"] blas-openblas-system = [ "blas-src/openblas", "ndarray/blas", "openblas-src/system", ] default = ["std", "simd", "multi-threads"] doc = ["default"] multi-threads = [ "rayon", "ndarray/rayon", "matrixmultiply/threading", ] simd = ["macerator", "bytemuck", "seq-macro", "itertools"] std = [ "burn-autodiff", "burn-std/std", "burn-backend/std", "burn-ir/std", "ndarray/std", "matrixmultiply/std", "rand/std", "rand/std_rng", "num-traits/std", "macerator/std", ] tracing = [ "burn-autodiff?/tracing", "burn-std/tracing", "burn-backend/tracing", "burn-ir/tracing", ] # Serves as a ref impl for some burn-cubecl kernels export_tests = [] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible const-random = { workspace = true } libm = { workspace = true } matrixmultiply = { workspace = true, default-features = false } ndarray = { workspace = true } num-traits = { workspace = true } openblas-src = { workspace = true, optional = true } paste = { workspace = true } rand = { workspace = true, default-features = false } # SIMD bytemuck = { workspace = true, optional = true } itertools = { version = "0.14", optional = true } macerator = { workspace = true, optional = true } seq-macro = { version = "0.3", optional = true } # Parallel rayon = { workspace = true, optional = true } [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic = { workspace = true } portable-atomic-util = { workspace = true } [dev-dependencies] bytes = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-ndarray/README.md ================================================ # Burn NdArray > [Burn](https://github.com/tracel-ai/burn) ndarray backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-ndarray.svg)](https://crates.io/crates/burn-ndarray) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-ndarray/blob/master/README.md) ## Feature Flags This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the default `std` feature. The following flags support various BLAS options: - `blas-accelerate` - Accelerate framework (macOS only) - `blas-netlib` - Netlib - `blas-openblas` - OpenBLAS static linked - `blas-openblas-system` - OpenBLAS from the system Note: under the `no_std` mode, the seed is fixed if the seed is not initialized by `Backend::seed` method. ### Platform Support | Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM | | :--------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: | | Pure Rust | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | | Accelerate | Yes | No | No | Yes | No | No | Yes | No | | Netlib | Yes | No | Yes | Yes | Yes | No | No | No | | Openblas | Yes | No | Yes | Yes | Yes | Yes | Yes | No | ================================================ FILE: crates/burn-ndarray/build.rs ================================================ fn main() { // https://github.com/rust-ndarray/ndarray/issues/1197 if cfg!(feature = "blas-accelerate") { println!("cargo:rustc-link-lib=framework=Accelerate"); } } ================================================ FILE: crates/burn-ndarray/src/backend.rs ================================================ use crate::rand::NdArrayRng; use crate::{NdArrayQTensor, NdArrayTensor}; use crate::{ SharedArray, element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, }; use alloc::string::String; use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue}; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; use burn_backend::{Backend, DType, DeviceId, DeviceOps}; use burn_ir::{BackendIr, HandleKind, TensorHandle}; use burn_std::BoolStore; use burn_std::stub::Mutex; use core::marker::PhantomData; use rand::SeedableRng; pub(crate) static SEED: Mutex> = Mutex::new(None); /// The device type for the ndarray backend. #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] pub enum NdArrayDevice { /// The CPU device. #[default] Cpu, } impl DeviceOps for NdArrayDevice {} impl burn_backend::Device for NdArrayDevice { fn from_id(_device_id: DeviceId) -> Self { Self::Cpu } fn to_id(&self) -> DeviceId { DeviceId { type_id: 0, index_id: 0, } } fn device_count(_type_id: u16) -> usize { 1 } } /// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. /// /// This backend is compatible with CPUs and can be compiled for almost any platform, including /// `wasm`, `arm`, and `x86`. #[derive(Clone, Copy, Default, Debug)] pub struct NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { _e: PhantomData, _i: PhantomData, _q: PhantomData, } impl Backend for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { type Device = NdArrayDevice; type FloatTensorPrimitive = NdArrayTensor; type FloatElem = E; type IntTensorPrimitive = NdArrayTensor; type IntElem = I; type BoolTensorPrimitive = NdArrayTensor; type BoolElem = bool; type QuantizedTensorPrimitive = NdArrayQTensor; fn ad_enabled(_device: &Self::Device) -> bool { false } fn name(_device: &Self::Device) -> String { String::from("ndarray") } fn seed(_device: &Self::Device, seed: u64) { let rng = NdArrayRng::seed_from_u64(seed); let mut seed = SEED.lock().unwrap(); *seed = Some(rng); } fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { match dtype { DType::F64 | DType::F32 | DType::Flex32 | DType::I64 | DType::I32 | DType::I16 | DType::I8 | DType::U64 | DType::U32 | DType::U16 | DType::U8 | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(), DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(), DType::QFloat(scheme) => { match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, #[cfg(not(feature = "export_tests"))] value: QuantValue::Q8F | QuantValue::Q8S, // For tests, "native" sub-byte quant serves as a reference for value equality. // Values are stored as i8 regardless. #[cfg(feature = "export_tests")] value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S, store: QuantStore::Native, .. } => burn_backend::DTypeUsage::general(), _scheme => burn_backend::DTypeUsageSet::empty(), } } } } } impl BackendIr for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { type Handle = HandleKind; fn float_tensor(handle: TensorHandle) -> FloatTensor { match handle.handle { HandleKind::Float(handle) => handle, _ => panic!("Expected float handle, got {}", handle.handle.name()), } } fn int_tensor(handle: TensorHandle) -> IntTensor { match handle.handle { HandleKind::Int(handle) => handle, _ => panic!("Expected int handle, got {}", handle.handle.name()), } } fn bool_tensor(handle: TensorHandle) -> BoolTensor { match handle.handle { HandleKind::Bool(handle) => handle, _ => panic!("Expected bool handle, got {}", handle.handle.name()), } } fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { match handle.handle { HandleKind::Quantized(handle) => handle, _ => panic!("Expected quantized handle, got {}", handle.handle.name()), } } fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { HandleKind::Float(tensor) } fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { HandleKind::Int(tensor) } fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { HandleKind::Bool(tensor) } fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { HandleKind::Quantized(tensor) } } #[cfg(test)] mod tests { use super::*; use burn_backend::QTensorPrimitive; #[test] fn should_support_dtypes() { type B = NdArray; let device = Default::default(); assert!(B::supports_dtype(&device, DType::F64)); assert!(B::supports_dtype(&device, DType::F32)); assert!(B::supports_dtype(&device, DType::Flex32)); assert!(B::supports_dtype(&device, DType::I64)); assert!(B::supports_dtype(&device, DType::I32)); assert!(B::supports_dtype(&device, DType::I16)); assert!(B::supports_dtype(&device, DType::I8)); assert!(B::supports_dtype(&device, DType::U64)); assert!(B::supports_dtype(&device, DType::U32)); assert!(B::supports_dtype(&device, DType::U16)); assert!(B::supports_dtype(&device, DType::U8)); assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); assert!(B::supports_dtype( &device, DType::QFloat(NdArrayQTensor::default_scheme()) )); assert!(!B::supports_dtype(&device, DType::F16)); assert!(!B::supports_dtype(&device, DType::BF16)); // QuantStore::U32 not supported assert!(!B::supports_dtype( &device, DType::QFloat(QuantScheme::default()) )); } } ================================================ FILE: crates/burn-ndarray/src/element.rs ================================================ use burn_backend::Element; use num_traits::Signed; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use num_traits::Pow; use libm::{log1p, log1pf}; /// A float element for ndarray backend. pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd where Self: Sized, { } /// An int element for ndarray backend. pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd {} /// A general element for ndarray backend. pub trait NdArrayElement: Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + AddAssignElement + num_traits::FromPrimitive + core::ops::AddAssign + core::cmp::PartialEq + core::ops::Rem { } /// A element for ndarray backend that supports exp ops. pub trait ExpElement { /// Exponent fn exp_elem(self) -> Self; /// Log fn log_elem(self) -> Self; /// Log1p fn log1p_elem(self) -> Self; /// Powf fn powf_elem(self, value: f32) -> Self; /// Powi fn powi_elem(self, value: i32) -> Self; /// Sqrt fn sqrt_elem(self) -> Self; /// Abs fn abs_elem(self) -> Self; } /// The addition assignment operator implemented for ndarray elements. pub trait AddAssignElement { /// Performs the addition assignment operation. /// /// For `bool`, this corresponds to logical OR assignment. fn add_assign(&mut self, rhs: Rhs); } impl AddAssignElement for E { fn add_assign(&mut self, rhs: Self) { *self += rhs; } } impl AddAssignElement for bool { fn add_assign(&mut self, rhs: Self) { *self = *self || rhs; // logical OR for bool } } /// A quantized element for the ndarray backend. pub trait QuantElement: NdArrayElement {} impl QuantElement for i8 {} impl FloatNdArrayElement for f64 {} impl FloatNdArrayElement for f32 {} impl IntNdArrayElement for i64 {} impl IntNdArrayElement for i32 {} impl IntNdArrayElement for i16 {} impl IntNdArrayElement for i8 {} impl IntNdArrayElement for u64 {} impl IntNdArrayElement for u32 {} impl IntNdArrayElement for u16 {} impl IntNdArrayElement for u8 {} macro_rules! make_float { ( $ty:ty, $log1p:expr ) => { impl NdArrayElement for $ty {} #[allow(clippy::cast_abs_to_unsigned)] impl ExpElement for $ty { #[inline(always)] fn exp_elem(self) -> Self { self.exp() } #[inline(always)] fn log_elem(self) -> Self { self.ln() } #[inline(always)] fn log1p_elem(self) -> Self { $log1p(self) } #[inline(always)] fn powf_elem(self, value: f32) -> Self { self.pow(value) } #[inline(always)] fn powi_elem(self, value: i32) -> Self { #[cfg(feature = "std")] let val = self.powi(value); #[cfg(not(feature = "std"))] let val = Self::powf_elem(self, value as f32); val } #[inline(always)] fn sqrt_elem(self) -> Self { self.sqrt() } #[inline(always)] fn abs_elem(self) -> Self { self.abs() } } }; } macro_rules! make_int { ( $ty:ty, $abs:expr ) => { impl NdArrayElement for $ty {} #[allow(clippy::cast_abs_to_unsigned)] impl ExpElement for $ty { #[inline(always)] fn exp_elem(self) -> Self { (self as f32).exp() as $ty } #[inline(always)] fn log_elem(self) -> Self { (self as f32).ln() as $ty } #[inline(always)] fn log1p_elem(self) -> Self { log1pf(self as f32) as $ty } #[inline(always)] fn powf_elem(self, value: f32) -> Self { (self as f32).pow(value) as $ty } #[inline(always)] fn powi_elem(self, value: i32) -> Self { #[cfg(feature = "std")] let val = f32::powi(self as f32, value) as $ty; #[cfg(not(feature = "std"))] let val = Self::powf_elem(self, value as f32); val } #[inline(always)] fn sqrt_elem(self) -> Self { (self as f32).sqrt() as $ty } #[inline(always)] fn abs_elem(self) -> Self { $abs(self) } } }; } make_float!(f64, log1p); make_float!(f32, log1pf); make_int!(i64, i64::wrapping_abs); make_int!(i32, i32::wrapping_abs); make_int!(i16, i16::wrapping_abs); make_int!(i8, i8::wrapping_abs); make_int!(u64, |x| x); make_int!(u32, |x| x); make_int!(u16, |x| x); make_int!(u8, |x| x); ================================================ FILE: crates/burn-ndarray/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! Burn ndarray backend. #[cfg(any( feature = "blas-netlib", feature = "blas-openblas", feature = "blas-openblas-system", ))] extern crate blas_src; mod backend; mod element; mod ops; mod parallel; mod rand; mod sharing; mod storage; mod tensor; pub use backend::*; pub use element::*; pub(crate) use sharing::*; pub(crate) use storage::*; pub use tensor::*; extern crate alloc; ================================================ FILE: crates/burn-ndarray/src/ops/activation.rs ================================================ use crate::{ NdArray, NdArrayTensor, SharedArray, element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, execute_with_numeric_dtype, ops::NdArrayMathOps, }; use burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor}; impl ActivationOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn relu(tensor: FloatTensor) -> FloatTensor { execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem())) } } ================================================ FILE: crates/burn-ndarray/src/ops/adaptive_avgpool.rs ================================================ use crate::{ SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, }; use burn_backend::ElementConversion; use ndarray::Array4; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; pub(crate) fn adaptive_avg_pool2d( x: SharedArray, output_size: [usize; 2], ) -> SharedArray { let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap(); let mut output = Array4::from_elem( (batch_size, channels, output_size[0], output_size[1]), 0.elem(), ); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output = unsafe_shared_out.get(); for h in 0..output_size[0] { for w in 0..output_size[1] { let ih_start = start_index(h, output_size[0], input_height); let ih_end = end_index(h, output_size[0], input_height); let iw_start = start_index(w, output_size[1], input_width); let iw_end = end_index(w, output_size[1], input_width); let mut sum_val: E = 0.elem(); for ih in ih_start..ih_end { for iw in iw_start..iw_end { sum_val += x[[b, c, ih, iw]]; } } let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); output[[b, c, h, w]] = sum_val / count.elem(); } } }) }); output.into_dyn().into_shared() } pub(crate) fn adaptive_avg_pool2d_backward( x: SharedArray, grad: SharedArray, ) -> SharedArray { let [_, _, input_height, input_width] = x.shape().try_into().unwrap(); let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap(); let mut output_grad = Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output_grad = unsafe_shared_out.get(); for oh in 0..output_height { for ow in 0..output_width { let ih_start = start_index(oh, output_height, input_height); let ih_end = end_index(oh, output_height, input_height); let iw_start = start_index(ow, output_width, input_width); let iw_end = end_index(ow, output_width, input_width); let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); for ih in ih_start..ih_end { for iw in iw_start..iw_end { output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem(); } } } } }) }); output_grad.into_dyn().into_shared() } fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize } fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize; usize::min(index, input_size) } ================================================ FILE: crates/burn-ndarray/src/ops/avgpool.rs ================================================ use crate::{ SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, }; use burn_backend::ElementConversion; use burn_backend::ops::conv::calculate_pool_output_size; use ndarray::Array4; pub(crate) fn avg_pool2d( x: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> SharedArray { let [kernel_height, kernel_width] = kernel_size; let [padding_height, padding_width] = padding; let [stride_height, stride_width] = stride; let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); let out_height = calculate_pool_output_size( kernel_height, stride_height, padding_height, 1, x_height, ceil_mode, ); let out_width = calculate_pool_output_size( kernel_width, stride_width, padding_width, 1, x_width, ceil_mode, ); // Padded input bounds (for count_include_pad calculation) let padded_height = x_height + 2 * padding_height; let padded_width = x_width + 2 * padding_width; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output = unsafe_shared_out.get(); for oh in 0..out_height { for ow in 0..out_width { let mut sum_val: E = 0.elem(); let mut valid_count = 0usize; let mut padded_count = 0usize; for kh in 0..kernel_height { let ih = oh * stride_height + kh; for kw in 0..kernel_width { let iw = ow * stride_width + kw; // Check if within padded bounds (excludes ceil_mode extensions) if ih < padded_height && iw < padded_width { padded_count += 1; // Check if within valid (non-padding) input bounds if ih >= padding_height && ih < x_height + padding_height && iw >= padding_width && iw < x_width + padding_width { let ih_valid = ih - padding_height; let iw_valid = iw - padding_width; sum_val += x[[b, c, ih_valid, iw_valid]]; valid_count += 1; } } } } // count_include_pad: count positions within padded bounds (not ceil_mode extensions) // !count_include_pad: count only valid (non-padding) positions let count: E = if count_include_pad { (padded_count as i32).elem() } else { (valid_count as i32).elem() }; output[[b, c, oh, ow]] = sum_val / count; } } }) }); output.into_dyn().into_shared() } pub(crate) fn avg_pool2d_backward( x: SharedArray, grad: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, _ceil_mode: bool, ) -> SharedArray { let [kernel_height, kernel_width] = kernel_size; let [stride_height, stride_width] = stride; let [padding_height, padding_width] = padding; let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap(); // Padded input bounds (for count_include_pad calculation) let padded_height = x_height + 2 * padding_height; let padded_width = x_width + 2 * padding_width; let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output_grad = unsafe_shared_grad.get(); for oh in 0..out_height { for ow in 0..out_width { let ih_start_kernel = oh * stride_height; let iw_start_kernel = ow * stride_width; let ih_end_kernel = ih_start_kernel + kernel_height; let iw_end_kernel = iw_start_kernel + kernel_width; // Clip to valid input bounds (for gradient distribution) let ih_start = usize::max(ih_start_kernel, padding_height); let iw_start = usize::max(iw_start_kernel, padding_width); let ih_end = usize::min(ih_end_kernel, x_height + padding_height); let iw_end = usize::min(iw_end_kernel, x_width + padding_width); // Calculate count based on count_include_pad let count = if count_include_pad { // Count positions within padded bounds (not ceil_mode extensions) let ih_start_padded = ih_start_kernel; let iw_start_padded = iw_start_kernel; let ih_end_padded = usize::min(ih_end_kernel, padded_height); let iw_end_padded = usize::min(iw_end_kernel, padded_width); (ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded) } else { // Count only valid (non-padding) positions (ih_end - ih_start) * (iw_end - iw_start) }; for ih in ih_start..ih_end { for iw in iw_start..iw_end { let ih = ih - padding_height; let iw = iw - padding_width; output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / (count as i32).elem(); } } } } }) }); output_grad.into_dyn().into_shared() } ================================================ FILE: crates/burn-ndarray/src/ops/base.rs ================================================ use alloc::{vec, vec::Vec}; use burn_backend::element::{Element, ElementConversion}; #[cfg(feature = "simd")] use burn_backend::{DType, quantization::QuantValue}; use core::fmt::Debug; use core::marker::PhantomData; use ndarray::IntoDimension; use ndarray::SliceInfo; use ndarray::Zip; use ndarray::s; use ndarray::{Array2, ArrayD}; use num_traits::Signed; #[cfg(feature = "simd")] use paste::paste; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; #[cfg(feature = "simd")] use crate::ops::simd::{ binary::try_binary_simd, binary_elemwise::{ VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub, try_binary_scalar_simd, }, cmp::{ VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd, try_cmp_simd, }, unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd}, }; use crate::reshape; use crate::{ IntNdArrayElement, ShapeOps, ops::macros::{ cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim, }, }; use crate::{SharedArray, element::NdArrayElement}; use burn_backend::ops::unfold::calculate_unfold_shape; use burn_backend::{Shape, Slice}; use ndarray::ArrayView; use ndarray::Axis; use ndarray::Dim; use ndarray::IxDyn; use ndarray::SliceInfoElem; pub struct NdArrayOps { e: PhantomData, } pub(crate) struct NdArrayMathOps { e: PhantomData, } impl NdArrayOps where E: Copy + Debug + Element + crate::AddAssignElement, { pub fn slice(tensor: ArrayView, slices: &[Slice]) -> SharedArray { let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); tensor.slice_move(slices.as_slice()).to_shared() } pub fn slice_assign( tensor: SharedArray, slices: &[Slice], value: SharedArray, ) -> SharedArray { let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); let mut array = tensor.into_owned(); array.slice_mut(slices.as_slice()).assign(&value); array.into_shared() } pub fn mask_where( tensor: SharedArray, mask: SharedArray, source: SharedArray, ) -> SharedArray { let tensor = tensor.broadcast(mask.dim()).unwrap(); let source = source.broadcast(mask.dim()).unwrap(); Zip::from(&tensor) .and(&mask) .and(&source) .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x }) .into_shared() } pub fn mask_fill(tensor: SharedArray, mask: SharedArray, value: E) -> SharedArray { // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique let mut output = tensor.into_owned(); let broadcast_mask = mask.broadcast(output.dim()).unwrap(); Zip::from(&mut output) .and(&broadcast_mask) .for_each(|out, &mask_val| { if mask_val { *out = value; } }); output.into_shared() } pub fn gather( dim: usize, mut tensor: SharedArray, mut indices: SharedArray, ) -> SharedArray { let ndims = tensor.shape().num_dims(); if dim != ndims - 1 { tensor.swap_axes(ndims - 1, dim); indices.swap_axes(ndims - 1, dim); } let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape()); let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices[ndims - 1]); let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices); let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); let mut output = Array2::from_elem((batch_size, size_index), 0.elem::()); for b in 0..batch_size { let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { output[[b, i]] = tensor[[b, index.elem::() as usize]]; } } let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices); if dim != ndims - 1 { output.swap_axes(ndims - 1, dim); } output } pub fn scatter( dim: usize, mut tensor: SharedArray, mut indices: SharedArray, mut value: SharedArray, ) -> SharedArray { let ndims = tensor.shape().num_dims(); if dim != ndims - 1 { tensor.swap_axes(ndims - 1, dim); indices.swap_axes(ndims - 1, dim); value.swap_axes(ndims - 1, dim); } let (shape_tensor, shape_indices, shape_value) = (tensor.shape().into_shape(), indices.shape(), value.shape()); let (size_tensor, size_index, size_value) = ( shape_tensor[ndims - 1], shape_indices[ndims - 1], shape_value[ndims - 1], ); let batch_size = Self::gather_batch_size(&shape_tensor, shape_indices); if shape_value != shape_indices { panic!( "Invalid dimension: the shape of the index tensor should be the same as the value \ tensor: Index {:?} value {:?}", shape_indices, shape_value ); } let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])); let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); for b in 0..batch_size { let indices = indices.slice(s!(b, ..)); for (i, index) in indices.iter().enumerate() { let index = index.elem::() as usize; tensor[[b, index]].add_assign(value[[b, i]]); } } let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor); if dim != ndims - 1 { output.swap_axes(ndims - 1, dim); } output } fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize { let ndims = shape_tensor.num_dims(); let mut batch_size = 1; for i in 0..ndims - 1 { if shape_tensor[i] != shape_indices[i] { panic!( "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ {:?}", shape_tensor, shape_indices ); } batch_size *= shape_indices[i]; } batch_size } pub fn reshape(tensor: SharedArray, shape: Shape) -> SharedArray { reshape!( ty E, shape shape, array tensor, d shape.num_dims() ) } pub(crate) fn concatenate( arrays: &[ndarray::ArrayView], dim: usize, ) -> SharedArray { let array = ndarray::concatenate(Axis(dim), arrays) .unwrap() .into_shared(); // Transform column-major layout into row-major (standard) layout. (fix #1053) // Get shape first (via reference), then pass ownership to avoid clone let shape = array.shape().into_shape(); Self::reshape(array, shape) } pub fn cat(tensors: Vec>, dim: usize) -> SharedArray { let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect(); Self::concatenate(&arrays, dim) } #[allow(clippy::wrong_self_convention)] fn to_slice_args_with_steps( burn_slices: &[burn_backend::Slice], ndims: usize, ) -> Vec { let mut slices = vec![SliceInfoElem::NewAxis; ndims]; for i in 0..ndims { slices[i] = if i < burn_slices.len() { let slice = &burn_slices[i]; // Check for empty range (would result in no elements) if let Some(end) = slice.end && slice.start == end { SliceInfoElem::Slice { start: 0, end: Some(0), step: 1, } } else { // Pass slice parameters directly to ndarray // ndarray handles both positive and negative steps correctly: // - Positive step: iterates forward from start // - Negative step: iterates backward from the last element in range SliceInfoElem::Slice { start: slice.start, end: slice.end, step: slice.step, } } } else { // Dimension not specified in slices - use full range SliceInfoElem::Slice { start: 0, end: None, step: 1, } } } slices } pub fn swap_dims(mut tensor: SharedArray, dim1: usize, dim2: usize) -> SharedArray { tensor.swap_axes(dim1, dim2); tensor } pub fn permute(tensor: SharedArray, axes: &[usize]) -> SharedArray { tensor.permuted_axes(axes.into_dimension()) } /// Broadcasts the tensor to the given shape pub(crate) fn expand(tensor: SharedArray, shape: Shape) -> SharedArray { tensor .broadcast(shape.into_dimension()) .expect("The shapes should be broadcastable") // need to convert view to owned array because NdArrayTensor expects owned array // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides) .into_owned() .into_shared() } pub fn flip(tensor: SharedArray, axes: &[usize]) -> SharedArray { let slice_items: Vec<_> = (0..tensor.shape().num_dims()) .map(|i| { if axes.contains(&i) { SliceInfoElem::Slice { start: 0, end: None, step: -1, } } else { SliceInfoElem::Slice { start: 0, end: None, step: 1, } } }) .collect(); let slice_info = SliceInfo::, IxDyn, IxDyn>::try_from(slice_items).unwrap(); tensor.slice(slice_info).into_owned().into_shared() } /// Unfold windows along a dimension. /// /// # Warning /// /// This is a copy impl; `ndarray` doesn't expose the layout machinery /// necessary to build the stride view. /// /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// # Arguments /// /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with shape ``[pre=..., windows, post=..., size]``. #[allow(unused)] pub(crate) fn unfold( tensor: SharedArray, dim: usize, size: usize, step: usize, ) -> SharedArray { let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step); let windows = result_shape[dim]; let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()]; let new_axis = slices.len(); let mut stack = Vec::with_capacity(windows); for widx in 0..windows { let start = widx * step; let end = start + size; slices[dim] = Slice::new(start as isize, Some(end as isize), 1); let mut window_slice = tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice()); window_slice.insert_axis_inplace(Axis(new_axis)); window_slice.swap_axes(dim, new_axis); stack.push(window_slice); } Self::concatenate(&stack, dim) } } #[cfg(feature = "simd")] macro_rules! dispatch_binary_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), _ => Err(($lhs, $rhs)), }, _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_binary_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; } #[cfg(feature = "simd")] macro_rules! dispatch_binary_scalar_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) }, _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_binary_scalar_simd { (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; } #[cfg(feature = "simd")] macro_rules! dispatch_cmp_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs)) }, _ => Err(($lhs, $rhs)), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_cmp_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; } #[cfg(feature = "simd")] macro_rules! dispatch_cmp_scalar_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)* DType::QFloat(strategy) => match strategy.value { QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) }, _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_cmp_scalar_simd { ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; } #[cfg(feature = "simd")] macro_rules! dispatch_unary_simd { ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)* _ => Err($lhs), }; match simd { Ok(out) => return out, Err(args) => args, } } }}; } #[cfg(not(feature = "simd"))] macro_rules! dispatch_unary_simd { ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }}; } // Helper function to broadcast two tensors to a common shape for comparison operations // Returns broadcasted views that can be safely zipped fn broadcast_for_comparison<'a, E: Copy, S1, S2>( lhs: &'a ndarray::ArrayBase, rhs: &'a ndarray::ArrayBase, ) -> ( ndarray::ArrayView<'a, E, ndarray::IxDyn>, ndarray::ArrayView<'a, E, ndarray::IxDyn>, ) where S1: ndarray::Data, S2: ndarray::Data, { // Get shapes let lhs_shape = lhs.shape(); let rhs_shape = rhs.shape(); // Compute broadcast shape using ndarray's broadcast compatibility rules let ndims = lhs_shape.len().max(rhs_shape.len()); let mut broadcast_shape = vec![1; ndims]; for i in 0..ndims { let lhs_dim = if i < lhs_shape.len() { lhs_shape[lhs_shape.len() - 1 - i] } else { 1 }; let rhs_dim = if i < rhs_shape.len() { rhs_shape[rhs_shape.len() - 1 - i] } else { 1 }; if lhs_dim == rhs_dim { broadcast_shape[ndims - 1 - i] = lhs_dim; } else if lhs_dim == 1 { broadcast_shape[ndims - 1 - i] = rhs_dim; } else if rhs_dim == 1 { broadcast_shape[ndims - 1 - i] = lhs_dim; } else { panic!( "Incompatible shapes for broadcasting: {:?} and {:?}", lhs_shape, rhs_shape ); } } // Create IxDyn from broadcast shape let broadcast_dim = ndarray::IxDyn(&broadcast_shape); // Broadcast both arrays let lhs_broadcast = lhs .broadcast(broadcast_dim.clone()) .expect("Failed to broadcast lhs"); let rhs_broadcast = rhs .broadcast(broadcast_dim) .expect("Failed to broadcast rhs"); (lhs_broadcast, rhs_broadcast) } impl NdArrayMathOps where E: Copy + NdArrayElement, { pub fn add(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = &lhs + &rhs; array.into_shared() } pub fn add_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( E, VecAdd, lhs, rhs.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs + rhs; array.into_shared() } pub fn sub(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!( E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs - rhs; array.into_shared() } pub fn sub_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( E, VecSub, lhs, rhs.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); let array = lhs - rhs; array.into_shared() } pub fn mul(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64); let array = lhs * rhs; array.into_shared() } pub fn mul_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( noq, E, VecMul, lhs, rhs.elem(), u16, i16, u32, i32, f32, f64 ); let array = lhs * rhs; array.into_shared() } pub fn div(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64); let array = lhs / rhs; array.into_shared() } pub fn div_scalar(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64); let array = lhs / rhs; array.into_shared() } pub fn remainder(lhs: SharedArray, rhs: SharedArray) -> SharedArray { // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique let mut out = lhs.into_owned(); Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| { // out_elem holds lhs value; read it before overwriting with remainder let a_f = (*out_elem).to_f64(); let b_f = b.to_f64(); let r = a_f - b_f * (a_f / b_f).floor(); *out_elem = r.elem(); }); out.into_shared() } pub fn remainder_scalar(lhs: SharedArray, rhs: E) -> SharedArray where E: core::ops::Rem, { let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs); array.into_shared() } pub fn recip(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32); let array = tensor.map(|x| 1.elem::() / *x); array.into_shared() } /// Sum all elements - zero-copy for borrowed storage. pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let sum = view.sum(); ArrayD::from_elem(IxDyn(&[1]), sum).into_shared() } /// Mean of all elements - zero-copy for borrowed storage. pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let mean = view.mean().unwrap(); ArrayD::from_elem(IxDyn(&[1]), mean).into_shared() } /// Product of all elements - zero-copy for borrowed storage. pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let prod = view.iter().fold(E::one(), |acc, &x| acc * x); ArrayD::from_elem(IxDyn(&[1]), prod).into_shared() } pub fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean), _ => panic!("Dim not supported {ndims}"), } } pub fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum), _ => panic!("Dim not supported {ndims}"), } } pub fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { let ndims = tensor.shape().num_dims(); match ndims { d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod), _ => panic!("Dim not supported {ndims}"), } } pub fn cumsum(tensor: SharedArray, dim: usize) -> SharedArray { cumsum_dim(tensor, dim) } pub fn cumprod(tensor: SharedArray, dim: usize) -> SharedArray { cumprod_dim(tensor, dim) } pub fn select( tensor: SharedArray, dim: usize, indices: SharedArray, ) -> SharedArray { let array = tensor.select( Axis(dim), &indices .into_iter() .map(|i| i.elem::() as usize) .collect::>(), ); array.into_shared() } pub fn select_assign( tensor: SharedArray, dim: usize, indices: SharedArray, value: SharedArray, ) -> SharedArray { let mut output_array = tensor.into_owned(); for (index_value, index) in indices.into_iter().enumerate() { let mut view = output_array.index_axis_mut(Axis(dim), index.elem::() as usize); let value = value.index_axis(Axis(dim), index_value); view.zip_mut_with(&value, |a, b| *a += *b); } output_array.into_shared() } pub(crate) fn elementwise_op( lhs: SharedArray, rhs: SharedArray, var_name: impl FnMut(&E, &OtherE) -> E, ) -> SharedArray { let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view()); let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view()); Zip::from(lhs).and(rhs).map_collect(var_name).into_shared() } pub(crate) fn elementwise_op_scalar( lhs: SharedArray, var_name: impl FnMut(E) -> E, ) -> SharedArray { lhs.mapv(var_name).into_shared() } pub(crate) fn abs(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); tensor.mapv_into(|a| a.abs_elem()).into_shared() } pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs == rhs) .into_shared() } pub(crate) fn equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecEquals, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a == rhs).into_shared() } pub(crate) fn sign_op(tensor: SharedArray) -> SharedArray where E: Signed, { let zero = 0.elem(); let one = 1.elem::(); tensor .mapv(|x| { if x == zero { zero } else { match x.is_positive() { true => one, false => -one, } } }) .into_shared() } } impl NdArrayMathOps where E: Copy + NdArrayElement + PartialOrd, { /// Max of all elements - zero-copy for borrowed storage. pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let max = view .iter() .copied() .reduce(|a, b| if a > b { a } else { b }) .expect("Cannot compute max of empty tensor"); ArrayD::from_elem(IxDyn(&[1]), max).into_shared() } /// Min of all elements - zero-copy for borrowed storage. pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { let min = view .iter() .copied() .reduce(|a, b| if a < b { a } else { b }) .expect("Cannot compute min of empty tensor"); ArrayD::from_elem(IxDyn(&[1]), min).into_shared() } /// Argmax along dimension - zero-copy for borrowed storage. pub fn argmax_view( view: ArrayView<'_, E, IxDyn>, dim: usize, ) -> SharedArray { arg_view(view, dim, CmpType::Max) } /// Argmin along dimension - zero-copy for borrowed storage. pub fn argmin_view( view: ArrayView<'_, E, IxDyn>, dim: usize, ) -> SharedArray { arg_view(view, dim, CmpType::Min) } pub fn cummin(tensor: SharedArray, dim: usize) -> SharedArray { cummin_dim(tensor, dim) } pub fn cummax(tensor: SharedArray, dim: usize) -> SharedArray { cummax_dim(tensor, dim) } pub fn argmax( tensor: SharedArray, dim: usize, ) -> SharedArray { arg(tensor, dim, CmpType::Max) } pub fn argmin( tensor: SharedArray, dim: usize, ) -> SharedArray { arg(tensor, dim, CmpType::Min) } pub fn clamp_min(tensor: SharedArray, min: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecMax, tensor, min.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x < min { true => min, false => x, }); tensor } pub fn clamp_max(tensor: SharedArray, max: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecMin, tensor, max.elem(), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x > max { true => max, false => x, }); tensor } pub fn clamp(tensor: SharedArray, min: E, max: E) -> SharedArray { let mut tensor = dispatch_binary_scalar_simd!( E, VecClamp, tensor, (min.elem(), max.elem()), u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 ); tensor.mapv_inplace(|x| match x < min { true => min, false => match x > max { true => max, false => x, }, }); tensor } pub(crate) fn greater(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs > rhs) .into_shared() } pub(crate) fn greater_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecGreater, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a > rhs).into_shared() } pub(crate) fn greater_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecGreaterEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs >= rhs) .into_shared() } pub(crate) fn greater_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecGreaterEq, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a >= rhs).into_shared() } pub(crate) fn lower_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs <= rhs) .into_shared() } pub(crate) fn lower_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecLowerEq, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a <= rhs).into_shared() } pub(crate) fn lower(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_cmp_simd!( E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs < rhs) .into_shared() } pub(crate) fn lower_elem(lhs: SharedArray, rhs: E) -> SharedArray { let lhs = dispatch_cmp_scalar_simd!( E, VecLower, lhs, rhs.elem(), u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 ); lhs.mapv(|a| a < rhs).into_shared() } } pub struct NdArrayBitOps(PhantomData); impl NdArrayBitOps { pub(crate) fn bitand(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() & (b.elem::())).elem() }) } pub(crate) fn bitand_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitAnd, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() & rhs.elem::()).elem() }) } pub(crate) fn bitor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() | (b.elem::())).elem() }) } pub(crate) fn bitor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitOr, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() | rhs.elem::()).elem() }) } pub(crate) fn bitxor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { let (lhs, rhs) = dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() ^ (b.elem::())).elem() }) } pub(crate) fn bitxor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { let lhs = dispatch_binary_scalar_simd!( I, VecBitXor, lhs, rhs.elem(), i8, u8, i16, u16, i32, u32, i64, u64 ); NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { (a.elem::() ^ rhs.elem::()).elem() }) } pub(crate) fn bitnot(tensor: SharedArray) -> SharedArray { let tensor = dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64); NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) } } pub struct NdArrayBoolOps; // Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would // produce invalid values. impl NdArrayBoolOps { pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_cmp_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs == rhs) .into_shared() } pub(crate) fn equal_elem(lhs: SharedArray, rhs: bool) -> SharedArray { #[cfg(feature = "simd")] let lhs = match try_cmp_scalar_simd::(lhs, rhs.elem()) { Ok(out) => return out, Err(args) => args, }; lhs.mapv(|a| a == rhs).into_shared() } pub(crate) fn and(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs && rhs) .into_shared() } pub(crate) fn or(lhs: SharedArray, rhs: SharedArray) -> SharedArray { #[cfg(feature = "simd")] let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { Ok(out) => return out, Err(args) => args, }; // Use the helper to broadcast both arrays to a common shape let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); // Now we can safely zip and compare Zip::from(&lhs_broadcast) .and(&rhs_broadcast) .map_collect(|&lhs, &rhs| lhs || rhs) .into_shared() } /// Any element is true - zero-copy for borrowed storage. pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool { view.iter().any(|&x| x) } /// All elements are true - zero-copy for borrowed storage. pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool { view.iter().all(|&x| x) } } enum CmpType { Min, Max, } fn arg( tensor: SharedArray, dim: usize, cmp: CmpType, ) -> SharedArray { arg_view(tensor.view(), dim, cmp) } /// View-based argmax/argmin - zero-copy for borrowed storage. fn arg_view( view: ArrayView<'_, E, IxDyn>, dim: usize, cmp: CmpType, ) -> SharedArray { let mut reshape = view.shape().to_vec(); reshape[dim] = 1; let output = view.map_axis(Axis(dim), |arr| { // Find the min/max value in the array, and return its index. let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { let cmp = match cmp { CmpType::Min => e < &acc.0, CmpType::Max => e > &acc.0, }; if cmp { (*e, idx) } else { acc } }); (idx as i64).elem() }); let output = output.to_shape(Dim(reshape.as_slice())).unwrap(); output.into_shared() } #[cfg(test)] mod tests { use burn_backend::TensorData; use crate::NdArrayTensor; use super::*; #[test] fn should_generate_row_major_layout_for_cat() { let expected_shape: &[usize] = &[4, 6, 2]; let expected_strides: &[isize] = &[12, 2, 1]; let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([ [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]], [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]], [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]], [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]], ])) else { panic!() }; let expected_array = expected_storage.into_shared(); let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([ [1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24], ])) else { panic!() }; let tensor = tensor_storage.into_shared(); // unsqueeze dim on the outermost axis let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1])); let NdArrayTensor::I32(zeros_storage) = NdArrayTensor::from_data(TensorData::zeros::([4, 6, 1])) else { panic!() }; let zeros = zeros_storage.into_shared(); // make `ndarray` concatenates array on the outermost axis let array = NdArrayOps::cat([array, zeros].to_vec(), 2); assert!(array.is_standard_layout()); assert_eq!(array.shape(), expected_shape); assert_eq!(array.strides(), expected_strides); assert_eq!( array.into_iter().collect::>(), expected_array.into_iter().collect::>(), ); } } ================================================ FILE: crates/burn-ndarray/src/ops/bool_tensor.rs ================================================ // Language use alloc::vec; use alloc::vec::Vec; use burn_backend::Scalar; use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor}; use burn_backend::{ backend::ExecutionError, ops::BoolTensorOps, tensor::{BoolTensor, IntTensor}, }; use ndarray::IntoDimension; // Current crate use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor}; use crate::{NdArrayDevice, SharedArray, slice}; // Workspace crates use burn_backend::{Shape, TensorData, backend::Backend}; use super::{NdArrayBoolOps, NdArrayOps}; impl BoolTensorOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { if !data.dtype.is_bool() { unimplemented!("Unsupported dtype for `bool_from_data`") } NdArrayTensor::from_data(data) } async fn bool_into_data(tensor: NdArrayTensor) -> Result { Ok(tensor.into_data()) } fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { tensor } fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::reshape(tensor.bool(), shape).into() } fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { slice!(tensor, slices) } fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor { // Use mapv directly instead of collecting to Vec and going through TensorData let int_array: SharedArray = tensor.bool().mapv(|b| b.elem()).into_shared(); int_array.into() } fn bool_device(_tensor: &NdArrayTensor) -> as Backend>::Device { NdArrayDevice::Cpu } fn bool_empty(shape: Shape, _device: & as Backend>::Device) -> NdArrayTensor { Self::bool_zeros(shape, _device) } fn bool_zeros(shape: Shape, _device: & as Backend>::Device) -> NdArrayTensor { let values = vec![false; shape.num_elements()]; NdArrayTensor::from_data(TensorData::new(values, shape)) } fn bool_ones(shape: Shape, _device: & as Backend>::Device) -> NdArrayTensor { let values = vec![true; shape.num_elements()]; NdArrayTensor::from_data(TensorData::new(values, shape)) } fn bool_slice_assign( tensor: NdArrayTensor, slices: &[burn_backend::Slice], value: NdArrayTensor, ) -> NdArrayTensor { NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into() } fn bool_cat(tensors: Vec, dim: usize) -> NdArrayTensor { NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into() } fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into() } fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor { tensor.bool().mapv(|a| !a).into_shared().into() } fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into() } fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into() } fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor { let arr: SharedArray = tensor.bool().mapv(|b| b.elem()).into_shared(); arr.into() } fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into() } fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { tensor.bool().permuted_axes(axes.into_dimension()).into() } fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor.bool(), shape).into() } fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { let tensor_bool = tensor.bool(); let indices_vec: Vec = indices .into_iter() .map(|i| i.elem::() as usize) .collect(); let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec); selected.into_shared().into() }) } fn bool_select_or( tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { let mut output_array = tensor.bool().into_owned(); let value_bool = value.bool(); for (index_value, index) in indices.into_iter().enumerate() { let index_usize = index.elem::() as usize; let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize); let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value); // For boolean tensors, select_assign should use logical OR operation view.zip_mut_with(&value_slice, |a, b| *a = *a || *b); } output_array.into_shared().into() }) } fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { NdArrayOps::flip(tensor.bool(), axes).into() } fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor { NdArrayOps::unfold(tensor.bool(), dim, size, step).into() } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into() } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into() } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { execute_with_int_dtype!(indices, |indices| NdArrayOps::gather( dim, tensor.bool(), indices )) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter( dim, tensor.bool(), indices, value.bool() )) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into() } fn bool_any(tensor: BoolTensor) -> BoolTensor { // Use view() for zero-copy on borrowed storage with short-circuit evaluation let result = NdArrayBoolOps::any_view(tensor.bool().view()); NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) } fn bool_all(tensor: BoolTensor) -> BoolTensor { // Use view() for zero-copy on borrowed storage with short-circuit evaluation let result = NdArrayBoolOps::all_view(tensor.bool().view()); NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) } } ================================================ FILE: crates/burn-ndarray/src/ops/conv.rs ================================================ use burn_backend::{ ElementConversion, ops::{ ConvOptions, ConvTransposeOptions, conv::{calculate_conv_output_size, calculate_conv_transpose_output_size}, }, }; use ndarray::{ Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s, }; use crate::{ NdArrayElement, SharedArray, iter_par, iter_range_par, ops::padding::{apply_padding_4d, apply_padding_5d}, run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; #[inline(always)] fn conv2d_mad_inner( mut output: ArrayViewMut2, x: ArrayView2, k: E, k_xy: (usize, usize), out_xy: (usize, usize), stride: (usize, usize), dilation: (usize, usize), ) { let (kh, kw) = k_xy; let (out_width, out_height) = out_xy; let (stride_width, stride_height) = stride; let (dilation_width, dilation_height) = dilation; for oh in 0..out_height { // Construct a sub-slice view of the input row. // This is done upfront so that rustc does not have to emit bounds checks // in the hot loop below. let ir = x .row(oh * stride_height + kh * dilation_height) .to_slice() .unwrap(); // Ditto. Construct a sub-slice view of the output row, and explicitly specify // the bounds upfront as 0..out_width so that rustc can make the assumption // that all accesses are in-bounds in the below loop. let mut or = output.row_mut(oh); let or = &mut or.as_slice_mut().unwrap()[0..out_width]; #[allow(clippy::needless_range_loop)] for ow in 0..out_width { let iw = ow * stride_width + kw * dilation_width; or[ow] += ir[iw] * k; } } } #[inline(always)] fn conv3d_mad_inner( mut output: ArrayViewMut3, x: ArrayView3, k: E, k_xyz: (usize, usize, usize), out_xyz: (usize, usize, usize), stride: (usize, usize, usize), dilation: (usize, usize, usize), ) { let (kd, kh, kw) = k_xyz; let (out_width, out_height, out_depth) = out_xyz; let (stride_width, stride_height, stride_depth) = stride; let (dilation_width, dilation_height, dilation_depth) = dilation; for od in 0..out_depth { let id = od * stride_depth + kd * dilation_depth; for oh in 0..out_height { let ih = oh * stride_height + kh * dilation_height; // Construct a sub-slice view of the input row. // This is done upfront so that rustc does not have to emit bounds checks // in the hot loop below. let ir = x.slice(s![id, ih, ..]).to_slice().unwrap(); // Ditto. Construct a sub-slice view of the output row, and explicitly specify // the bounds upfront as 0..out_width so that rustc can make the assumption // that all accesses are in-bounds in the below loop. let or = &mut output .slice_mut(s![od, oh, 0..out_width]) .into_slice() .unwrap()[0..out_width]; #[allow(clippy::needless_range_loop)] for ow in 0..out_width { let iw = ow * stride_width + kw * dilation_width; or[ow] += ir[iw] * k; } } } } pub(crate) fn conv2d( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvOptions<2>, ) -> SharedArray where NdArrayTensor: From>, { let [dilation_height, dilation_width] = options.dilation; let [padding_height, padding_width] = options.padding; let [stride_height, stride_width] = options.stride; let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().try_into().unwrap(); let channels_per_group = out_channels / options.groups; let out_height = calculate_conv_output_size( kernel_height, stride_height, padding_height, dilation_height, in_height, ); let out_width = calculate_conv_output_size( kernel_width, stride_width, padding_width, dilation_width, in_width, ); let x = apply_padding_4d::(x, options.padding, 0i32.elem()); // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); let weights = weight.into_dimensionality::().unwrap(); let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); run_par!(|| { iter_par!(output.axis_iter_mut(Axis(0))) .enumerate() .for_each( #[inline(never)] |(k, mut output)| { let b = k / out_channels; let oc = k % out_channels; let g = oc / channels_per_group; for ic in (in_channels * g)..(in_channels * (g + 1)) { let weight_ic = ic - (g * in_channels); let x = x.slice(s![b, ic, .., ..]); let k = weights.slice(s![oc, weight_ic, .., ..]); for kh in 0..kernel_height { for kw in 0..kernel_width { let k = k[[kh, kw]]; // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization // in the case that the stride/dilation is 1. #[allow(clippy::if_same_then_else)] if (1, 1, 1, 1) == ( stride_width, stride_height, dilation_width, dilation_height, ) { conv2d_mad_inner( output.view_mut(), x.view(), k, (kh, kw), (out_width, out_height), (stride_width, stride_height), (dilation_width, dilation_height), ); } else { conv2d_mad_inner( output.view_mut(), x.view(), k, (kh, kw), (out_width, out_height), (stride_width, stride_height), (dilation_width, dilation_height), ); } } } } if let Some(bias) = &bias { let bias = bias[oc]; for oh in 0..out_height { // Get a mutable slice reference to the row we're looping over. // We explicitly define the bounds to 0..out_width so that rustc can make // the assumption that all accesses are in-bounds. let mut or = output.row_mut(oh); let or = &mut or.as_slice_mut().unwrap()[0..out_width]; #[allow(clippy::needless_range_loop)] for ow in 0..out_width { or[ow] += bias; } } } }, ); }); output .to_shape([batch_size, out_channels, out_height, out_width]) .unwrap() .into_dyn() .into_shared() } pub(crate) fn conv_transpose2d( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvTransposeOptions<2>, ) -> SharedArray { let [dilation_height, dilation_width] = options.dilation; let [padding_height, padding_width] = options.padding; let [stride_height, stride_width] = options.stride; let [out_padding_height, out_padding_width] = options.padding_out; let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().try_into().unwrap(); let out_height = calculate_conv_transpose_output_size( kernel_height, stride_height, padding_height, out_padding_height, dilation_height, in_height, ); let out_width = calculate_conv_transpose_output_size( kernel_width, stride_width, padding_width, out_padding_width, dilation_width, in_width, ); let x = x; let mut output = Array4::zeros(Dim([ batch_size, out_channels * options.groups, out_height, out_width, ])); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { let b = k / (out_channels * options.groups); let oc = k % out_channels; let g = (k / out_channels) % options.groups; let output = unsafe_shared_out.get(); let oc_out = oc + (out_channels * g); let ic_start = g * (in_channels / options.groups); let ic_end = ic_start + in_channels / options.groups; for ic in ic_start..ic_end { for ih in 0..in_height { for iw in 0..in_width { for kh in 0..kernel_height { for kw in 0..kernel_width { let oh = ih * stride_height + kh * dilation_height; let ow = iw * stride_width + kw * dilation_width; if oh >= out_height + padding_height || ow >= out_width + padding_width || oh < padding_height || ow < padding_width { continue; } let oh = oh - padding_height; let ow = ow - padding_width; output[[b, oc_out, oh, ow]] += x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]]; } } } } } if let Some(bias) = &bias { for oh in 0..out_height { for ow in 0..out_width { output[[b, oc_out, oh, ow]] += bias[oc_out]; } } } }); }); output.into_dyn().into_shared() } pub(crate) fn conv3d( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvOptions<3>, ) -> SharedArray where NdArrayTensor: From>, { let [dilation_depth, dilation_height, dilation_width] = options.dilation; let [padding_depth, padding_height, padding_width] = options.padding; let [stride_depth, stride_height, stride_width] = options.stride; let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); let [ out_channels, in_channels, kernel_depth, kernel_height, kernel_width, ] = weight.shape().try_into().unwrap(); let out_c_per_group = out_channels / options.groups; let out_depth = calculate_conv_output_size( kernel_depth, stride_depth, padding_depth, dilation_depth, in_depth, ); let out_height = calculate_conv_output_size( kernel_height, stride_height, padding_height, dilation_height, in_height, ); let out_width = calculate_conv_output_size( kernel_width, stride_width, padding_width, dilation_width, in_width, ); let x = apply_padding_5d::(x, options.padding, 0i32.elem()); // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); let weights = weight.into_dimensionality::().unwrap(); let mut output = Array4::zeros(Dim([ batch_size * out_channels, out_depth, out_height, out_width, ])); run_par!(|| { iter_par!(output.axis_iter_mut(Axis(0))) .enumerate() .for_each( #[inline(never)] |(k, mut output)| { let b = k / out_channels; let oc = k % out_channels; let g = oc / out_c_per_group; for ic in (in_channels * g)..(in_channels * (g + 1)) { let weight_ic = ic - (g * in_channels); let x = x.slice(s![b, ic, .., .., ..]); let k = weights.slice(s![oc, weight_ic, .., .., ..]); for kd in 0..kernel_depth { for kh in 0..kernel_height { for kw in 0..kernel_width { let k = k[[kd, kh, kw]]; // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization // in the case that the stride/dilation is 1. #[allow(clippy::if_same_then_else)] if (1, 1, 1, 1, 1, 1) == ( stride_width, stride_height, stride_depth, dilation_width, dilation_height, dilation_depth, ) { conv3d_mad_inner( output.view_mut(), x.view(), k, (kd, kh, kw), (out_width, out_height, out_depth), (stride_width, stride_height, stride_depth), (dilation_width, dilation_height, dilation_depth), ); } else { conv3d_mad_inner( output.view_mut(), x.view(), k, (kd, kh, kw), (out_width, out_height, out_depth), (stride_width, stride_height, stride_depth), (dilation_width, dilation_height, dilation_depth), ); } } } } } if let Some(bias) = &bias { let bias = bias[oc]; // Get a mutable iterator to the row we're looping over. let orows = output.rows_mut(); for mut or in orows { // We explicitly define the bounds to 0..out_width so that rustc can make // the assumption that all accesses are in-bounds. let or = &mut or.as_slice_mut().unwrap()[0..out_width]; #[allow(clippy::needless_range_loop)] for ow in 0..out_width { or[ow] += bias; } } } }, ); }); output .to_shape([batch_size, out_channels, out_depth, out_height, out_width]) .unwrap() .into_dyn() .into_shared() } pub(crate) fn conv_transpose3d( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvTransposeOptions<3>, ) -> SharedArray { let [dilation_depth, dilation_height, dilation_width] = options.dilation; let [padding_depth, padding_height, padding_width] = options.padding; let [stride_depth, stride_height, stride_width] = options.stride; let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out; let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); let [ in_channels, out_channels, kernel_depth, kernel_height, kernel_width, ] = weight.shape().try_into().unwrap(); let out_depth = calculate_conv_transpose_output_size( kernel_depth, stride_depth, padding_depth, out_padding_depth, dilation_depth, in_depth, ); let out_height = calculate_conv_transpose_output_size( kernel_height, stride_height, padding_height, out_padding_height, dilation_height, in_height, ); let out_width = calculate_conv_transpose_output_size( kernel_width, stride_width, padding_width, out_padding_width, dilation_width, in_width, ); let x = x; let mut output = Array5::zeros(Dim([ batch_size, out_channels * options.groups, out_depth, out_height, out_width, ])); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { let b = k / (out_channels * options.groups); let oc = k % out_channels; let g = (k / out_channels) % options.groups; let output = unsafe_shared_out.get(); let oc_out = oc + (out_channels * g); let ic_start = g * (in_channels / options.groups); let ic_end = ic_start + in_channels / options.groups; for ic in ic_start..ic_end { for id in 0..in_depth { for ih in 0..in_height { for iw in 0..in_width { for kd in 0..kernel_depth { for kh in 0..kernel_height { for kw in 0..kernel_width { let od = id * stride_depth + kd * dilation_depth; let oh = ih * stride_height + kh * dilation_height; let ow = iw * stride_width + kw * dilation_width; if od >= out_depth + padding_depth || oh >= out_height + padding_height || ow >= out_width + padding_width || od < padding_depth || oh < padding_height || ow < padding_width { continue; } let od = od - padding_depth; let oh = oh - padding_height; let ow = ow - padding_width; output[[b, oc_out, od, oh, ow]] += x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]]; } } } } } } } if let Some(bias) = &bias { for od in 0..out_depth { for oh in 0..out_height { for ow in 0..out_width { output[[b, oc_out, od, oh, ow]] += bias[oc_out]; } } } } }); }); output.into_dyn().into_shared() } ================================================ FILE: crates/burn-ndarray/src/ops/deform_conv.rs ================================================ use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size}; use core::ops::AddAssign; use ndarray::{ Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4, Zip, s, }; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par}; use super::matmul::matmul; #[inline(always)] #[allow(clippy::too_many_arguments)] fn deform_im2col_kernel( out_y: usize, out_x: usize, input: ArrayView2, offset: ArrayView3, mask: Option>, mut columns: ArrayViewMut2, args: DeformConvOptions<2>, (kernel_h, kernel_w): (usize, usize), ) { // position shape: [in_channels, batch_size, out_h, out_w] // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] let (height, width) = input.dim(); for kernel_y in 0..kernel_h { for kernel_x in 0..kernel_w { let mask_value = mask .map(|it| it[[kernel_y, kernel_x]]) .unwrap_or_else(|| F::from_elem(1.0)); let offset = offset.slice(s![kernel_y, kernel_x, ..]); let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - F::from_elem(args.padding[0]) + offset[0]; let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - F::from_elem(args.padding[1]) + offset[1]; let interpolated = bilinear_interpolate(input, height, width, y, x); columns[[kernel_y, kernel_x]] = mask_value * interpolated; } } } fn bilinear_interpolate( input: ArrayView2, height: usize, width: usize, y: F, x: F, ) -> F { // To simplify code let y = y.to_f32(); let x = x.to_f32(); let mut result = F::from_elem(0.0); if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { let y_low = f32::floor(y); let x_low = f32::floor(x); let y_high = (y_low + 1.) as usize; let x_high = (x_low + 1.) as usize; let zero = F::from_elem(0.0); let v1: F = if y_low >= 0. && x_low >= 0. { input[[y_low as usize, x_low as usize]] } else { zero }; let v2: F = if y_low >= 0. && x_high < width { input[[y_low as usize, x_high]] } else { zero }; let v3: F = if y_high < height && x_low >= 0. { input[[y_high, x_low as usize]] } else { zero }; let v4: F = if y_high < height && x_high < width { input[[y_high, x_high]] } else { zero }; let l_y = y - y_low; let l_x = x - x_low; let h_y = 1.0 - l_y; let h_x = 1.0 - l_x; let w1 = F::from_elem(h_y * h_x); let w2 = F::from_elem(h_y * l_x); let w3 = F::from_elem(l_y * h_x); let w4 = F::from_elem(l_y * l_x); result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; } result } pub(crate) fn deform_conv2d( input: SharedArray, offset: SharedArray, weight: SharedArray, mask: Option>, bias: Option>, args: DeformConvOptions<2>, ) -> SharedArray where NdArrayTensor: From>, { let [batch_size, _, in_height, in_width] = input.shape().dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims(); let groups = args.weight_groups; let weight = weight.as_standard_layout(); let out_h = calculate_conv_output_size( kernel_h, args.stride[0], args.padding[0], args.dilation[0], in_height, ); let out_w = calculate_conv_output_size( kernel_w, args.stride[1], args.padding[1], args.dilation[1], in_width, ); let out_dims = (out_h, out_w); let input = input.into_dimensionality::().unwrap(); let offset = offset.into_dimensionality::().unwrap(); let mask = mask.as_ref().map(|it| { it.to_shape(( batch_size, args.offset_groups, kernel_h, kernel_w, out_h, out_w, )) .unwrap() }); let columns = deform_im2col( input.view(), offset.view(), mask.as_ref().map(|it| it.view()), args, out_dims, (kernel_h, kernel_w), ); let (col_size_0, col_size_1) = columns.dim(); let col_size_0 = col_size_0 / groups; let out_c_per_group = out_channels / groups; let weight = weight .to_shape((groups, out_c_per_group, col_size_0)) .unwrap(); let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); let out = matmul( weight.to_owned().into_dyn().into_shared(), columns.to_owned().into_dyn().into_shared(), ); let mut out = out .into_shape_with_order((out_channels, batch_size, out_h, out_w)) .unwrap(); out.swap_axes(0, 1); if let Some(bias) = bias { let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap(); out.add_assign(&bias); } out.into_dyn().into_shared() } pub(crate) fn deform_im2col( input: ArrayView4, offset: ArrayView4, mask: Option>, args: DeformConvOptions<2>, out_dims: (usize, usize), kernel_dims: (usize, usize), ) -> Array2 { let (batch_size, in_channels, _, _) = input.dim(); let (kernel_h, kernel_w) = kernel_dims; let (out_h, out_w) = out_dims; let channels_per_offset_group = in_channels / args.offset_groups; let mut columns = Array4::zeros(Dim([ in_channels, kernel_h, kernel_w, batch_size * out_h * out_w, ])); let groups = args.offset_groups; run_par!(|| { iter_par!(columns.axis_iter_mut(Axis(3))) .enumerate() .for_each(|(index, mut columns)| { let out_x = index % out_w; let out_y = (index / out_w) % out_h; let batch = (index / (out_w * out_h)) % batch_size; let offset = offset.slice(s![batch, .., out_y, out_x]); let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap(); let mask = mask .as_ref() .map(|it| it.slice(s![batch, .., .., .., out_y, out_x])); columns .axis_iter_mut(Axis(0)) .enumerate() .for_each(|(in_channel, mut columns)| { let group_index = in_channel / channels_per_offset_group; deform_im2col_kernel( out_y, out_x, input.slice(s![batch, in_channel, .., ..]), offset.slice(s![group_index, .., .., ..]), mask.as_ref().map(|it| it.slice(s![group_index, .., ..])), columns.view_mut(), args.clone(), kernel_dims, ); }); }); }); columns // Columns is created here, so we know it's contiguous .into_shape_with_order(( in_channels * kernel_h * kernel_w, batch_size * out_h * out_w, )) .unwrap() } pub mod backward { #[cfg(target_has_atomic = "32")] use core::sync::atomic::Ordering; use atomic_float::AtomicF32; use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; use super::*; pub(crate) type DeformConv2dBackward = ( SharedArray, SharedArray, SharedArray, Option>, Option>, ); /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. pub(crate) fn deform_conv2d_backward( input: SharedArray, offset: SharedArray, weight: SharedArray, mask: Option>, bias: Option>, out_grad: SharedArray, args: DeformConvOptions<2>, ) -> DeformConv2dBackward { let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims(); let [_, _, kernel_h, kernel_w] = weight.shape().dims(); let groups = args.weight_groups; let out_c_per_group = out_channels / groups; let col_shape_1 = batch_size * out_h * out_w; let mut out_grad = out_grad.into_dimensionality::().unwrap(); let gradient_bias = bias.map(|_| { let out_grad = out_grad .clone() .sum_axis(Axis(0)) .sum_axis(Axis(1)) .sum_axis(Axis(1)); out_grad.into_dyn().into_shared() }); out_grad.swap_axes(0, 1); let out_grad = out_grad .to_shape((groups, out_c_per_group, col_shape_1)) .unwrap(); let input = input.into_dimensionality::().unwrap(); let offset = offset.into_dimensionality::().unwrap(); let mask = mask.map(|it| { it.into_shape_with_order(( batch_size, args.offset_groups, kernel_h, kernel_w, out_h, out_w, )) .unwrap() }); let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs( input.view(), weight, offset.view(), mask.as_ref().map(|it| it.view()), out_grad.view(), &args, (kernel_h, kernel_w), ); let weight_grad = compute_weight_grad( input.view(), offset.view(), mask.as_ref().map(|it| it.view()), out_grad.view(), args, (kernel_h, kernel_w), (out_h, out_w), ); ( input_gradient, offset_gradient, weight_grad, mask_gradient, gradient_bias, ) } fn compute_weight_grad( input: ArrayView4, offset: ArrayView4, mask: Option>, out_grad: ArrayView3, options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), ) -> SharedArray { let in_channels = input.dim().1; let (groups, out_c_per_group, _) = out_grad.dim(); let (kernel_h, kernel_w) = kernel_dims; let in_c_per_group = in_channels / groups; let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); let (col_size_0, col_size_1) = columns.dim(); let col_size_0 = col_size_0 / groups; let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); columns.swap_axes(1, 2); let grad_weight = matmul( out_grad.to_owned().into_dyn().into_shared(), columns.to_owned().into_dyn().into_shared(), ); let grad_weight = grad_weight .into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w)) .unwrap(); grad_weight.into_dyn().into_shared() } type InputGradients = (SharedArray, SharedArray, Option>); fn backward_gradient_inputs( image: ArrayView4, weight: SharedArray, offset: ArrayView4, mask: Option>, out_grad: ArrayView3, args: &DeformConvOptions<2>, kernel_dims: (usize, usize), ) -> InputGradients { let input_shape = image.dim(); let in_channels = input_shape.1; let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims(); let (batch_size, _, out_h, out_w) = offset.dim(); let groups = args.weight_groups; let out_c_per_group = out_channels / groups; let col_shape_0 = in_c_per_group * kernel_h * kernel_w; let mut weight = weight .to_shape((groups, out_c_per_group, col_shape_0)) .unwrap(); weight.swap_axes(1, 2); let columns = matmul( weight.to_owned().into_dyn().into_shared(), out_grad.to_owned().into_dyn().into_shared(), ); let columns = columns .to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w)) .unwrap(); let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( columns.view(), image.view(), offset, mask, args, kernel_dims, ); let input_gradient = compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape); (input_gradient, offset_gradient, mask_gradient) } fn compute_offset_and_mask_gradient( columns: ArrayView6, image: ArrayView4, offset: ArrayView4, mask: Option>, args: &DeformConvOptions<2>, kernel_dims: (usize, usize), ) -> (SharedArray, Option>) { let (kernel_h, kernel_w) = kernel_dims; let (_, in_channels, height, width) = image.dim(); let (batch_size, offset_channels, out_h, out_w) = offset.dim(); let offs_groups = args.offset_groups; let channels_per_offset_group = in_channels / args.offset_groups; let mut grad_offset = Array5::zeros(( offs_groups, kernel_h, kernel_w, 2, batch_size * out_h * out_w, )); let mut grad_mask = Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w)); grad_mask .axis_iter_mut(Axis(3)) .zip(grad_offset.axis_iter_mut(Axis(4))) .enumerate() .for_each(|(index, (mut grad_mask, mut grad_offset))| { let out_x = index % out_w; let out_y = (index / out_w) % out_h; let batch = index / (out_w * out_h); let offset = offset.slice(s![batch, .., out_y, out_x]); let offset = offset .to_shape((offs_groups, kernel_h, kernel_w, 2)) .unwrap(); let mask: Option> = mask .as_ref() .map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x])); let columns = columns.slice(s![.., .., .., batch, out_y, out_x]); let image = image.slice(s![batch, .., .., ..]); for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() { let grad_mask: &mut F = grad_mask; let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]); let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]); let columns = columns.slice(s![.., kernel_y, kernel_x]); let group_offset = group * channels_per_offset_group; let image = image.slice(s![group_offset.., .., ..]); let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - F::from_elem(args.padding[0]) + offset[0]; let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - F::from_elem(args.padding[1]) + offset[1]; for (i, grad_offset) in grad_offset.iter_mut().enumerate() { let is_y_direction = i % 2 == 0; let use_mask = mask.is_some(); for channel in 0..channels_per_offset_group { let mask = mask.unwrap_or_else(|| F::one()); let image = image.index_axis(Axis(0), channel); let weight = get_coordinate_weight(image, height, width, y, x, is_y_direction); *grad_offset += mask * weight * columns[channel]; if use_mask && is_y_direction { *grad_mask += columns[channel] * bilinear_interpolate(image, height, width, y, x); } } } } }); let mask_gradient = mask.map(|_| { let mut grad_mask = grad_mask .into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w)) .unwrap(); grad_mask.swap_axes(0, 1); grad_mask.into_dyn().into_shared() }); let mut grad_offset = grad_offset .into_shape_with_order((offset_channels, batch_size, out_h, out_w)) .unwrap(); grad_offset.swap_axes(0, 1); let offset_gradient = grad_offset.into_dyn().into_shared(); (offset_gradient, mask_gradient) } fn get_coordinate_weight( input: ArrayView2, height: usize, width: usize, y: F, x: F, is_y_direction: bool, ) -> F { let y = y.to_f32(); let x = x.to_f32(); let y_low = f32::floor(y); let x_low = f32::floor(x); let y_high = y_low + 1.; let x_high = x_low + 1.; let valid_y_low = y_low >= 0. && y_low < height as f32; let valid_y_high = y_high >= 0. && y_high < height as f32; let valid_x_low = x_low >= 0. && x_low < width as f32; let valid_x_high = x_high >= 0. && x_high < width as f32; let bottom_left = if valid_y_low && valid_x_low { input[[y_low as usize, x_low as usize]] } else { F::zero() }; let bottom_right = if valid_y_low && valid_x_high { input[[y_low as usize, x_high as usize]] } else { F::zero() }; let top_left = if valid_y_high && valid_x_low { input[[y_high as usize, x_low as usize]] } else { F::zero() }; let top_right = if valid_y_high && valid_x_high { input[[y_high as usize, x_high as usize]] } else { F::zero() }; if is_y_direction { let delta_x = F::from_elem(x - x_low); delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left) } else { let delta_y = F::from_elem(y - y_low); delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left) } } fn compute_input_grad( columns: ArrayView6, offset: ArrayView4, mask: Option>, args: &DeformConvOptions<2>, kernel_dims: (usize, usize), input_shape: (usize, usize, usize, usize), ) -> SharedArray { let (batch_size, in_channels, height, width) = input_shape; let (kernel_h, kernel_w) = kernel_dims; let offs_groups = args.offset_groups; let channels_per_offset_group = in_channels / offs_groups; let grad_in = Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || { AtomicF32::new(0.0) }); let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| { let group = in_channel / channels_per_offset_group; let offset = offset.slice(s![batch, .., out_y, out_x]); let offset = offset .to_shape((offs_groups, kernel_h, kernel_w, 2)) .unwrap(); let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); let offset = [offset[0], offset[1]]; let mask = mask .as_ref() .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - F::from_elem(args.padding[0]) + offset[0]; let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - F::from_elem(args.padding[1]) + offset[1]; let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); }; // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise #[cfg(feature = "multi-threads")] run_par!(|| { iter_par!(Zip::indexed(columns)) .for_each(|(args0, args1)| compute_for_each(args0, args1)) }); #[cfg(not(feature = "multi-threads"))] run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) }); let grad_in: Array1 = grad_in .into_iter() .map(|it| F::from_elem(it.into_inner())) .collect(); let grad_in = grad_in .into_shape_with_order((batch_size, in_channels, height, width)) .unwrap(); grad_in.into_dyn().into_shared() } fn deform_col2img_kernel( y: f32, x: f32, mask: Option, col: f32, grad_input: ArrayView2, ) { let (height, width) = grad_input.dim(); let mask_value = mask.unwrap_or(1.0); for dy in -1..=1 { for dx in -1..=1 { let yp = f32::floor(y) + dy as f32; let xp = f32::floor(x) + dx as f32; if yp >= 0.0 && yp < height as f32 && xp >= 0.0 && xp < width as f32 && f32::abs(y - yp) < 1.0 && f32::abs(x - xp) < 1.0 { let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); #[cfg_attr(not(target_has_atomic = "32"), allow(unused))] let value = mask_value * weight * col; #[cfg(target_has_atomic = "32")] grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel); #[cfg(not(target_has_atomic = "32"))] panic!("Can't use deformable convolution backwards pass without atomics"); } } } } } ================================================ FILE: crates/burn-ndarray/src/ops/grid_sample.rs ================================================ use burn_backend::ElementConversion; use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use ndarray::Array4; use crate::SharedArray; use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par}; /// Sample a tensor using grid-based sampling. /// /// # Arguments /// /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in) /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right /// * `options` - Grid sampling options (mode, padding_mode, align_corners) /// /// # Returns /// /// A tensor with shape (N, C, H_out, W_out) pub(crate) fn grid_sample_2d( tensor: SharedArray, grid: SharedArray, options: GridSampleOptions, ) -> SharedArray { match options.mode { InterpolateMode::Bilinear => (), _ => todo!( "grid_sample_2d with {:?} mode is not implemented", options.mode ), } let tensor = tensor.into_dimensionality::().unwrap(); let grid = grid.into_dimensionality::().unwrap(); let (batch_size, channels, height_in, width_in) = tensor.dim(); let (b, height_out, width_out, d) = grid.dim(); assert!(batch_size == b); assert!(2 == d); let mut output = Array4::zeros((batch_size, channels, height_out, width_out)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let sample_count = batch_size * channels * height_out * width_out; let strides = ( channels * height_out * width_out, height_out * width_out, width_out, ); let align = options.align_corners; let pad_mode = options.padding_mode; run_par!(|| { iter_range_par!(0, sample_count).for_each(|id| { let (b, c, y, x) = ( id / strides.0, id % strides.0 / strides.1, id % strides.1 / strides.2, id % strides.2, ); let sample_x = grid[(b, y, x, 0)].elem::(); let sample_y = grid[(b, y, x, 1)].elem::(); // Convert normalized grid coordinates [-1, 1] to pixel coordinates let (px, py) = if align { // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 // Maps -1 to 0 and 1 to width - 1 let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0; let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0; (px, py) } else { // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 // Maps -1 to -0.5 and 1 to width - 0.5 let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5; let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5; (px, py) }; // Bilinear interpolation with the specified padding mode let val = bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align); unsafe { let output = unsafe_shared_out.get(); output[(b, c, y, x)] = val.elem(); } }); }); output.into_dyn().into_shared() } /// Bilinear interpolation at a point with configurable padding mode. #[allow(clippy::too_many_arguments)] fn bilinear_interpolate( source: &ndarray::ArrayBase>, b: usize, c: usize, x: f64, y: f64, width: usize, height: usize, padding_mode: GridSamplePaddingMode, align_corners: bool, ) -> f64 where E: FloatNdArrayElement, S: ndarray::Data, { // Handle inf/nan coordinates if !x.is_finite() || !y.is_finite() { return match padding_mode { GridSamplePaddingMode::Zeros => 0.0, GridSamplePaddingMode::Border => { // Clamp to center of image for inf/nan let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64); let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64); source[(b, c, cy as usize, cx as usize)].elem::() } GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan }; } // Apply padding mode to get actual sampling coordinates let (x, y) = match padding_mode { GridSamplePaddingMode::Border => { // Clamp coordinates to valid range [0, size-1] let x = x.clamp(0.0, (width - 1) as f64); let y = y.clamp(0.0, (height - 1) as f64); (x, y) } GridSamplePaddingMode::Reflection => { // Reflect coordinates at boundaries let x = reflect_coordinate(x, width, align_corners); let y = reflect_coordinate(y, height, align_corners); (x, y) } GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read }; // Get the four corner indices let x0 = x.floor() as i64; let y0 = y.floor() as i64; let x1 = x0.saturating_add(1); let y1 = y0.saturating_add(1); // Compute interpolation weights (fractional part) let x_frac = x - x.floor(); let y_frac = y - y.floor(); // Helper to read a value based on padding mode let read_value = |xi: i64, yi: i64| -> f64 { match padding_mode { GridSamplePaddingMode::Zeros => { // Return 0 for out-of-bounds if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 { source[(b, c, yi as usize, xi as usize)].elem::() } else { 0.0 } } GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => { // Coordinates should already be in valid range after clamping/reflection let xi = xi.clamp(0, (width - 1) as i64) as usize; let yi = yi.clamp(0, (height - 1) as i64) as usize; source[(b, c, yi, xi)].elem::() } } }; // Read the four corners let v00 = read_value(x0, y0); let v01 = read_value(x0, y1); let v10 = read_value(x1, y0); let v11 = read_value(x1, y1); // Bilinear interpolation weights let w00 = (1.0 - x_frac) * (1.0 - y_frac); let w01 = (1.0 - x_frac) * y_frac; let w10 = x_frac * (1.0 - y_frac); let w11 = x_frac * y_frac; v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11 } /// Reflect a coordinate at the boundaries using a triangle wave pattern. /// /// For align_corners=true: reflects within [0, size-1] /// For align_corners=false: reflects within [-0.5, size-0.5] fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 { let size_f = size as f64; let (min_val, max_val) = if align_corners { (0.0, size_f - 1.0) } else { (-0.5, size_f - 0.5) }; let span = max_val - min_val; if span <= 0.0 { return min_val; } // Triangle wave formula: span - |((x mod 2*span) - span)| let period = 2.0 * span; let x = (coord - min_val).abs(); let x_mod = x - (x / period).floor() * period; span - (x_mod - span).abs() + min_val } ================================================ FILE: crates/burn-ndarray/src/ops/int_tensor.rs ================================================ // Language use crate::rand::get_seeded_rng; use alloc::vec::Vec; use burn_backend::backend::ExecutionError; use burn_backend::ops::IntTensorOps; use burn_backend::tensor::{FloatTensor, IntTensor}; use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata}; use burn_backend::ElementConversion; // Current crate use crate::cat_with_dtype; use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor}; use crate::{NdArrayDevice, SEED, slice}; use crate::{SharedArray, element::QuantElement}; use crate::{element::FloatNdArrayElement, ops::matmul::matmul}; use crate::{element::IntNdArrayElement, execute_with_int_dtype}; // Workspace crates use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps}; use burn_backend::{DType, Shape, TensorData, backend::Backend}; impl IntTensorOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { if data.dtype.is_int() || data.dtype.is_uint() { NdArrayTensor::from_data(data) } else { unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype) } } async fn int_into_data(tensor: NdArrayTensor) -> Result { Ok(tensor.into_data()) } fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { tensor } fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape)) } fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { slice!(tensor, slices) } fn int_device(_tensor: &NdArrayTensor) -> as Backend>::Device { NdArrayDevice::Cpu } fn int_empty( shape: Shape, device: & as Backend>::Device, dtype: IntDType, ) -> NdArrayTensor { Self::int_zeros(shape, device, dtype) } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { execute_with_int_dtype!((lhs, rhs), matmul) } fn int_mask_where( tensor: NdArrayTensor, mask: NdArrayTensor, source: NdArrayTensor, ) -> NdArrayTensor { execute_with_int_dtype!((tensor, source), |tensor, source| { NdArrayOps::mask_where(tensor, mask.bool(), source) }) } fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill( array, mask.bool(), value.elem() )) } fn int_slice_assign( tensor: NdArrayTensor, slices: &[burn_backend::Slice], value: NdArrayTensor, ) -> NdArrayTensor { execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign( tensor, slices, value )) } fn int_cat(tensors: Vec, dim: usize) -> NdArrayTensor { cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8]) } fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal) } fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem())) } fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater) } fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem())) } fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal) } fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem( array, rhs.elem() )) } fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower) } fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem())) } fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal) } fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem( array, rhs.elem() )) } fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add) } fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem())) } fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub) } fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem())) } fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul) } fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem())) } fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div) } fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem())) } fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder) } fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar( array, rhs.elem() )) } fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::sum_view( array.view() )) } fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim)) } fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!( tensor, E, |array: SharedArray| NdArrayMathOps::prod_view(array.view()) ) } fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim)) } fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!( tensor, E, |array: SharedArray| NdArrayMathOps::mean_view(array.view()) ) } fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim)) } fn int_max(tensor: NdArrayTensor) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::max_view( array.view() )) } fn int_min(tensor: NdArrayTensor) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::min_view( array.view() )) } fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim)) } fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim)) } fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim)) } fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim)) } fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather( dim, array, idx_array )) }) } fn int_scatter_add( dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { execute_with_int_dtype!(indices, |idx_array| NdArrayOps::::scatter( dim, tensor, idx_array, value )) }) } fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select( array, dim, idx_array )) }) } fn int_select_add( tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::::select_assign( tensor, dim, idx_array, value )) }) } fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!(tensor, E, |array: SharedArray| { NdArrayMathOps::argmax_view::(array.view(), dim) }) } fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_int_dtype!(tensor, E, |array: SharedArray| { NdArrayMathOps::argmin_view::(array.view(), dim) }) } fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem())) } fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem())) } fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp( array, min.elem(), max.elem() )) } fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { match tensor.dtype() { DType::I64 | DType::I32 | DType::I16 | DType::I8 => { execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [ I64 => i64, I32 => i32, I16 => i16, I8 => i8 ]) } // Already unsigned DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor, other => panic!("Unsupported dtype: {other:?}"), } } fn int_into_float(tensor: NdArrayTensor) -> FloatTensor { execute_with_int_dtype!(tensor, IntElem, |array: SharedArray| array .mapv(|a: IntElem| a.elem::()) .into_shared()) } fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2)) } fn int_random( shape: Shape, distribution: Distribution, device: &NdArrayDevice, ) -> NdArrayTensor { let mut seed = SEED.lock().unwrap(); let mut rng = seed.take().unwrap_or_else(get_seeded_rng); let effective_distribution = if distribution == Distribution::Default { Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant } else { distribution }; let tensor = Self::int_from_data( TensorData::random::(shape, effective_distribution, &mut rng), device, ); *seed = Some(rng); tensor } fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op( lhs, rhs, |a: &I, b: &I| { (a.elem::().pow(b.elem::())).elem() } )) } fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes)) } fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes)) } fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { match tensor.dtype() { DType::I64 | DType::I32 | DType::I16 | DType::I8 => { execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [ I64 => i64, I32 => i32, I16 => i16, I8 => i8 ]) } DType::U64 | DType::U32 | DType::U16 | DType::U8 => { Self::int_greater_elem(tensor, 0.into()) } other => panic!("Unsupported dtype: {other:?}"), } } fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape)) } fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand) } fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem())) } fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor) } fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem())) } fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor) } fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem())) } fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot) } fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() << (b.elem::())).elem() }) }) } fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, I, |array| { NdArrayMathOps::elementwise_op_scalar(array, |a: I| { (a.elem::() << rhs.elem::()).elem() }) }) } fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { (a.elem::() >> (b.elem::())).elem() }) }) } fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { execute_with_int_dtype!(lhs, I, |array| { NdArrayMathOps::elementwise_op_scalar(array, |a: I| { (a.elem::() >> rhs.elem::()).elem() }) }) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into())) } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step)) } } ================================================ FILE: crates/burn-ndarray/src/ops/interpolate.rs ================================================ use burn_backend::ElementConversion; use ndarray::{Array4, ArrayBase, DataOwned}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; pub(crate) fn nearest_interpolate( x: SharedArray, output_size: [usize; 2], ) -> SharedArray { let x = x.into_dimensionality::().unwrap(); let (batch_size, channels, in_height, in_width) = x.dim(); let [out_height, out_width] = output_size; let y_ratio = (in_height as f64) / (out_height as f64); let x_ratio = (in_width as f64) / (out_width as f64); let out_element_num = batch_size * channels * out_height * out_width; let strides = ( channels * out_height * out_width, out_height * out_width, out_width, ); let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, out_element_num).for_each(|id| { let (b, c, h, w) = ( id / strides.0, id % strides.0 / strides.1, id % strides.1 / strides.2, id % strides.2, ); let y_in = (y_ratio * h as f64).floor() as usize; let x_in = (x_ratio * w as f64).floor() as usize; unsafe { let output = unsafe_shared_out.get(); output[(b, c, h, w)] = x[(b, c, y_in, x_in)]; } }); }); output.into_dyn().into_shared() } pub(crate) fn nearest_interpolate_backward( x: SharedArray, grad: SharedArray, output_size: [usize; 2], ) -> SharedArray { let [batch_size, channels, input_height, input_width] = x.shape().dims(); let [output_height, output_width] = output_size; let mut output_grad = Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output_grad = unsafe_shared_out.get(); for oh in 0..output_height { for ow in 0..output_width { let ih = start_index(oh, output_height, input_height); let iw = start_index(ow, output_width, input_width); output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] } } }) }); output_grad.into_dyn().into_shared() } fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize } // clamp ceil(frac) to stay within bounds in case of floating-point imprecision pub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 { frac.ceil().min(max as f64) } pub(crate) fn bilinear_interpolate( x: SharedArray, output_size: [usize; 2], align_corners: bool, ) -> SharedArray { let x = x.into_dimensionality::().unwrap(); let (batch_size, channels, in_height, in_width) = x.dim(); let [out_height, out_width] = output_size; let out_element_num = batch_size * channels * out_height * out_width; let strides = ( channels * out_height * out_width, out_height * out_width, out_width, ); let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, out_element_num).for_each(|id| { let (b, c, h, w) = ( id / strides.0, id % strides.0 / strides.1, id % strides.1 / strides.2, id % strides.2, ); let (y_frac, x_frac) = if align_corners { let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); (y_ratio * h as f64, x_ratio * w as f64) } else { let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; ( y_frac.clamp(0.0, (in_height - 1) as f64), x_frac.clamp(0.0, (in_width - 1) as f64), ) }; let val = bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1); unsafe { let output = unsafe_shared_out.get(); output[(b, c, h, w)] = val.elem(); } }); }); output.into_dyn().into_shared() } pub(crate) fn bicubic_interpolate( x: SharedArray, output_size: [usize; 2], align_corners: bool, ) -> SharedArray { fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 { fn cubic_convolution1(x: f64, a: f64) -> f64 { ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0 } fn cubic_convolution2(x: f64, a: f64) -> f64 { ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a } let coeffs = [ cubic_convolution2(t + 1.0, -0.75), cubic_convolution1(t, -0.75), cubic_convolution1(1.0 - t, -0.75), cubic_convolution2(2.0 - t, -0.75), ]; x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3] } let x = x.into_dimensionality::().unwrap(); let (batch_size, channels, in_height, in_width) = x.dim(); let [out_height, out_width] = output_size; let out_element_num = batch_size * channels * out_height * out_width; let strides = ( channels * out_height * out_width, out_height * out_width, out_width, ); let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, out_element_num).for_each(|id| { let (b, c, h, w) = ( id / strides.0, id % strides.0 / strides.1, id % strides.1 / strides.2, id % strides.2, ); let (y_frac, x_frac) = if align_corners { let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); (y_ratio * h as f64, x_ratio * w as f64) } else { let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; (y_frac, x_frac) }; let y0 = y_frac.floor(); let yw = y_frac - y0; let y_in = y0 as isize; let x0 = x_frac.floor(); let xw = x_frac - x0; let x_in = x0 as isize; let max_h = (in_height - 1) as isize; let max_w = (in_width - 1) as isize; let ys_in = [ (y_in - 1).clamp(0, max_h) as usize, y_in.clamp(0, max_h) as usize, (y_in + 1).clamp(0, max_h) as usize, (y_in + 2).clamp(0, max_h) as usize, ]; let xs_in = [ (x_in - 1).clamp(0, max_w) as usize, x_in.clamp(0, max_w) as usize, (x_in + 1).clamp(0, max_w) as usize, (x_in + 2).clamp(0, max_w) as usize, ]; let coefficients = ys_in.map(|y| { cubic_interp1d( x[(b, c, y, xs_in[0])].elem(), x[(b, c, y, xs_in[1])].elem(), x[(b, c, y, xs_in[2])].elem(), x[(b, c, y, xs_in[3])].elem(), xw, ) }); let result = cubic_interp1d( coefficients[0], coefficients[1], coefficients[2], coefficients[3], yw, ) .elem(); unsafe { let output = unsafe_shared_out.get(); output[(b, c, h, w)] = result; } }); }); output.into_dyn().into_shared() } pub(crate) fn lanczos3_interpolate( x: SharedArray, output_size: [usize; 2], align_corners: bool, ) -> SharedArray { fn lanczos3_weight(x: f64) -> f64 { if x == 0.0 { return 1.0; } let abs_x = x.abs(); if abs_x >= 3.0 { return 0.0; } let pi = core::f64::consts::PI; let pi_x = pi * x; let pi_x_over_3 = pi_x / 3.0; (pi_x.sin() * pi_x_over_3.sin()) / (pi_x * pi_x_over_3) } let x = x.into_dimensionality::().unwrap(); let (batch_size, channels, in_height, in_width) = x.dim(); let [out_height, out_width] = output_size; let out_element_num = batch_size * channels * out_height * out_width; let strides = ( channels * out_height * out_width, out_height * out_width, out_width, ); let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, out_element_num).for_each(|id| { let (b, c, h, w) = ( id / strides.0, id % strides.0 / strides.1, id % strides.1 / strides.2, id % strides.2, ); let (y_frac, x_frac) = if align_corners { let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); (y_ratio * h as f64, x_ratio * w as f64) } else { let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; (y_frac, x_frac) }; let y0 = y_frac.floor(); let x0 = x_frac.floor(); let max_h = (in_height - 1) as isize; let max_w = (in_width - 1) as isize; // 6x6 separable Lanczos3 filter (skip out-of-bounds positions) let mut result = 0.0; let mut weight_sum = 0.0; for ky in -2..=3 { let yi = y0 as isize + ky; if yi < 0 || yi > max_h { continue; } let y_idx = yi as usize; let wy = lanczos3_weight(y_frac - (y0 + ky as f64)); for kx in -2..=3 { let xi = x0 as isize + kx; if xi < 0 || xi > max_w { continue; } let x_idx = xi as usize; let wx = lanczos3_weight(x_frac - (x0 + kx as f64)); let w = wy * wx; let pixel: f64 = x[(b, c, y_idx, x_idx)].elem(); result += pixel * w; weight_sum += w; } } if weight_sum != 0.0 { result /= weight_sum; } unsafe { let output = unsafe_shared_out.get(); output[(b, c, h, w)] = result.elem(); } }); }); output.into_dyn().into_shared() } /// Sample an element of the source array with bilinear interpolation /// /// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width) /// * `b` - The batch to read from /// * `c` - The channel to read from /// * `x` - The x position to read in the array /// * `y` - The y position to read in the array /// * `x_max` - The max x position (inclusive) /// * `y_max` - The max y position (inclusive) /// /// # Returns /// /// The interpolated value read from the array pub(crate) fn bilinear_interpolate_single( source: &ArrayBase>, b: usize, c: usize, x: f64, y: f64, x_max: usize, y_max: usize, ) -> f64 where E: FloatNdArrayElement, S: DataOwned, { let y0 = y.floor(); let y1 = ceil_clamp(y, y_max); let yw = y - y0; let x0 = x.floor(); let x1 = ceil_clamp(x, x_max); let xw = x - x0; let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize); let p_a = source[(b, c, y0, x0)].elem::() * (1.0 - xw) * (1.0 - yw); let p_b = source[(b, c, y0, x1)].elem::() * xw * (1.0 - yw); let p_c = source[(b, c, y1, x0)].elem::() * (1.0 - xw) * yw; let p_d = source[(b, c, y1, x1)].elem::() * xw * yw; p_a + p_b + p_c + p_d } ================================================ FILE: crates/burn-ndarray/src/ops/macros.rs ================================================ macro_rules! keepdim { ( $dim:expr, $self:expr, mean ) => {{ // Get shape first (via reference), then pass ownership to avoid clone let mut shape = $self.shape().into_shape(); shape[$dim] = 1; let tensor: SharedArray = mean_dim($self, $dim); NdArrayOps::reshape(tensor, shape) }}; ( $dim:expr, $self:expr, sum ) => {{ // Get shape first (via reference), then pass ownership to avoid clone let mut shape = $self.shape().into_shape(); shape[$dim] = 1; let tensor: SharedArray = sum_dim($self, $dim); NdArrayOps::reshape(tensor, shape) }}; ( $dim:expr, $self:expr, prod ) => {{ // Get shape first (via reference), then pass ownership to avoid clone let mut shape = $self.shape().into_shape(); shape[$dim] = 1; let tensor: SharedArray = prod_dim($self, $dim); NdArrayOps::reshape(tensor, shape) }}; } use burn_backend::ElementConversion; pub(crate) use keepdim; use ndarray::{Axis, Zip}; use crate::{SharedArray, element::NdArrayElement}; pub(crate) fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { tensor.mean_axis(Axis(dim)).unwrap().into_shared() } pub(crate) fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { tensor.sum_axis(Axis(dim)).into_shared() } pub(crate) fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { tensor .fold_axis(Axis(dim), 1.elem::(), |acc, &x| acc.mul(x.elem())) .into_shared() } /// Generic cumulative operation function with closure-based operation. pub(crate) fn cumulative_with_op(tensor: SharedArray, dim: usize, op: F) -> SharedArray where E: NdArrayElement, F: Fn(&mut E, &E), { let axis = Axis(dim); let shape = tensor.shape().to_vec(); // Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique let mut result = tensor.into_owned(); let dim_size = shape[dim]; for i in 1..dim_size { let prev = result.index_axis(axis, i - 1).to_owned(); let mut current = result.index_axis_mut(axis, i); Zip::from(&mut current).and(&prev).for_each(&op); } result.into_shared() } // Define all cumulative operation functions using the generic function pub(crate) fn cumsum_dim(tensor: SharedArray, dim: usize) -> SharedArray { cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem())) } pub(crate) fn cumprod_dim(tensor: SharedArray, dim: usize) -> SharedArray { cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem())) } pub(crate) fn cummin_dim>( tensor: SharedArray, dim: usize, ) -> SharedArray { cumulative_with_op(tensor, dim, |c, &p| { if p < *c { *c = p; } }) } pub(crate) fn cummax_dim>( tensor: SharedArray, dim: usize, ) -> SharedArray { cumulative_with_op(tensor, dim, |c, &p| { if p > *c { *c = p; } }) } ================================================ FILE: crates/burn-ndarray/src/ops/matmul.rs ================================================ use crate::UnsafeSharedRef; use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par}; use alloc::{vec, vec::Vec}; use burn_backend::ElementConversion; use burn_backend::Shape; use ndarray::{IxDyn, s}; pub(crate) fn matmul( lhs: SharedArray, rhs: SharedArray, ) -> SharedArray { let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); let ndims = shape_lhs.num_dims(); let m = shape_lhs[ndims - 2]; // # of left rows let k = shape_rhs[ndims - 2]; // # of left cols and right rows let n = shape_rhs[ndims - 1]; // # of right cols let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs); let l_mat_size = m * k; // size of matrix component of left array let r_mat_size = k * n; // size of matrix component of right array let out_mat_size = m * n; // size of matrix component of output array let num_l_batches = shape_lhs.num_elements() / l_mat_size; let num_r_batches = shape_rhs.num_elements() / r_mat_size; let num_out_batches = out_shape.num_elements() / out_mat_size; let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k])); let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n])); let alpha: E = 1.0.elem(); let beta: E = 0.0.elem(); let out = run_par!(|| { let mut out_array = ndarray::Array3::::zeros((num_out_batches, m, n)); let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); iter_range_par!(0, num_out_batches).for_each(|out_batch| { // Here, we: // 1. Un-flatten the output batch into a component-based batch index. // 2. Use the strides for left and right batch indices to convert it to a flattened // batch for left and right. let out_index = strides_out.unflatten(out_batch); let l_batch = strides_lhs.flatten(&out_index); let r_batch = strides_rhs.flatten(&out_index); let lhs_slice = lhs_array.slice(s!(l_batch, .., ..)); let rhs_slice = rhs_array.slice(s!(r_batch, .., ..)); unsafe { let mut out_slice = unsafe_shared_out_array .get() .slice_mut(s!(out_batch, .., ..)); ndarray::linalg::general_mat_mul( alpha, &lhs_slice, &rhs_slice, beta, &mut out_slice, ) } }); out_array.into_shared().into_dyn() }); NdArrayOps::reshape(out, out_shape) } #[derive(Debug, PartialEq)] struct Strides { strides: Vec, } impl Strides { fn new(strides: Vec) -> Self { Strides { strides } } fn unflatten(&self, linear_index: usize) -> Vec { let mut coord = Vec::with_capacity(self.strides.len()); let mut rem = linear_index; for stride in self.strides.iter() { coord.push(rem / stride); rem %= stride; } coord } fn flatten(&self, index: &Vec) -> usize { assert_eq!(self.strides.len(), index.len()); self.strides .iter() .zip(index) .map(|(stride, index)| stride * index) .sum() } } /// Compute the (broadcasted) output shape of matrix multiplication, along with strides for /// the non-matrix dimensions of all arrays. /// /// # Arguments /// * `lsh`: Shape of the first (left-hand) matrix multiplication argument. /// * `rsh`: Shape of the second (right-hand) matrix multiplication argument. /// /// # Panics /// * If `D` is not at least 2. /// * If the matrix multiplication dimensions (last 2) are incompatible. /// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where /// one dim is equal to 1 is broadcast.) fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) { let ndims = lsh.num_dims(); if ndims < 2 { panic!("Matrix multiplication requires an array with at least 2 dimensions."); } // Fetch matrix dimensions and check compatibility. let l_rows = lsh[ndims - 2]; let l_cols = lsh[ndims - 1]; let r_rows = rsh[ndims - 2]; let r_cols = rsh[ndims - 1]; if l_cols != r_rows { panic!("Dimensions are incompatible for matrix multiplication."); } // Set matrix dimensions of the output shape. let mut osh = vec![0; ndims]; osh[ndims - 2] = l_rows; osh[ndims - 1] = r_cols; // Set other array dimensions, broadcasting as necessary. // Compute the strides inline. let mut cur_l_stride: usize = 1; let mut cur_r_stride: usize = 1; let mut cur_o_stride: usize = 1; let mut l_strides = Vec::with_capacity(ndims - 2); let mut r_strides = Vec::with_capacity(ndims - 2); let mut o_strides = Vec::with_capacity(ndims - 2); for i in (0..ndims - 2).rev() { let l_dim = lsh[i]; let r_dim = rsh[i]; // Compatible dimensions are: // 1. Both dimensions are equal. // 2. One of the dimensions is equal to 1. let o_dim: usize; if l_dim == r_dim { o_dim = l_dim; // both dimensions are equal l_strides.push(cur_l_stride); r_strides.push(cur_r_stride); } else if l_dim == 1 { o_dim = r_dim; // broadcast the left l_strides.push(0); r_strides.push(cur_r_stride); } else if r_dim == 1 { o_dim = l_dim; // broadcast the right l_strides.push(cur_l_stride); r_strides.push(0); } else { panic!("Dimensions differ and cannot be broadcasted."); } osh[i] = o_dim; o_strides.push(cur_o_stride); cur_o_stride *= o_dim; cur_l_stride *= l_dim; cur_r_stride *= r_dim; } l_strides.reverse(); r_strides.reverse(); o_strides.reverse(); ( Shape::from(osh), Strides::new(l_strides), Strides::new(r_strides), Strides::new(o_strides), ) } pub(crate) fn cross( lhs: SharedArray, rhs: SharedArray, dim: usize, ) -> SharedArray { let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); let ndims = shape_lhs.num_dims(); // Broadcast the shapes except along dim let mut broadcast_shape = vec![0; ndims]; for i in 0..ndims { if i == dim { broadcast_shape[i] = shape_lhs[i]; // already checked to be 3 } else { let l = shape_lhs[i]; let r = shape_rhs[i]; if l == r { broadcast_shape[i] = l; } else if l == 1 { broadcast_shape[i] = r; } else if r == 1 { broadcast_shape[i] = l; } else { panic!("Tensors are not broadcastable along dimension {}", i); } } } // Broadcast lhs and rhs let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() { lhs } else { NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone())) }; let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() { rhs } else { NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone())) }; // Now, move dim to the last dimension let mut perm = (0..ndims).collect::>(); perm.remove(dim); perm.push(dim); let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm); let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm); // Reshape to (*, 3) let total_elements = lhs_permuted.shape().num_elements(); let batch_size = total_elements / 3; let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3])); let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3])); // Compute cross product let mut result = ndarray::ArrayD::::zeros(IxDyn(&[batch_size, 3])); for i in 0..batch_size { let a1 = lhs_reshaped[IxDyn(&[i, 0])]; let a2 = lhs_reshaped[IxDyn(&[i, 1])]; let a3 = lhs_reshaped[IxDyn(&[i, 2])]; let b1 = rhs_reshaped[IxDyn(&[i, 0])]; let b2 = rhs_reshaped[IxDyn(&[i, 1])]; let b3 = rhs_reshaped[IxDyn(&[i, 2])]; result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2)); result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3)); result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1)); } let result_shared = result.into_shared(); // Reshape back to the broadcast shape with dim at the end let mut result_shape = broadcast_shape; result_shape.remove(dim); result_shape.push(3); let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape)); // Permute back let mut inv_perm = vec![0; ndims]; for (i, &p) in perm.iter().enumerate() { inv_perm[p] = i; } NdArrayOps::permute(result_reshaped, &inv_perm) } #[cfg(test)] mod tests { use super::*; impl Strides { fn empty() -> Self { Strides { strides: Vec::with_capacity(0), } } } #[test] fn test_output_shape() { // plain matrix multiply assert_eq!( output_shape(&[5, 3], &[3, 7]), ( Shape::from([5, 7]), Strides::empty(), Strides::empty(), Strides::empty() ) ); // matrix multiply with one extra stack dimension assert_eq!( output_shape(&[4, 5, 3], &[4, 3, 7]), ( Shape::from([4, 5, 7]), Strides::new(vec![1]), Strides::new(vec![1]), Strides::new(vec![1]) ) ); // rank 3, broadcast left assert_eq!( output_shape(&[1, 5, 3], &[4, 3, 7]), ( Shape::from([4, 5, 7]), Strides::new(vec![0]), Strides::new(vec![1]), Strides::new(vec![1]) ) ); // rank 3, broadcast right assert_eq!( output_shape(&[4, 5, 3], &[1, 3, 7]), ( Shape::from([4, 5, 7]), Strides::new(vec![1]), Strides::new(vec![0]), Strides::new(vec![1]) ) ); // rank 4, multi broadcast assert_eq!( output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]), ( Shape::from([8, 4, 5, 7]), Strides::new(vec![0, 1]), Strides::new(vec![1, 0]), Strides::new(vec![4, 1]) ) ); // rank 5, multi-broadcast assert_eq!( output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]), ( Shape::from([8, 3, 4, 5, 7]), Strides::new(vec![0, 4, 1]), Strides::new(vec![3, 1, 0]), Strides::new(vec![12, 4, 1]) ) ) } #[test] #[should_panic( expected = "Matrix multiplication requires an array with at least 2 dimensions." )] fn test_output_shape_too_small() { output_shape(&[4], &[4]); } #[test] #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")] fn test_output_shape_bad_matrix_dims() { output_shape(&[5, 3], &[4, 7]); } #[test] #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")] fn test_output_shape_non_broadcast() { output_shape(&[4, 5, 3], &[2, 3, 7]); } } ================================================ FILE: crates/burn-ndarray/src/ops/maxpool.rs ================================================ use crate::{ ShapeOps, SharedArray, element::{FloatNdArrayElement, IntNdArrayElement}, iter_range_par, ops::padding::apply_padding_4d, run_par, sharing::UnsafeSharedRef, }; use burn_backend::ElementConversion; use burn_backend::ops::conv::calculate_pool_output_size; use ndarray::Array4; pub(crate) fn max_pool2d( x: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> SharedArray { let [kernel_height, kernel_width] = kernel_size; let [padding_height, padding_width] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let [batch_size, channels, x_height, x_width] = x.shape().dims(); let inf = (-f32::INFINITY).elem::(); let out_height = calculate_pool_output_size( kernel_height, stride_height, padding_height, dilation_height, x_height, ceil_mode, ); let out_width = calculate_pool_output_size( kernel_width, stride_width, padding_width, dilation_width, x_width, ceil_mode, ); // Calculate extra padding needed for ceil_mode // The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation // This must be < input_size + 2 * total_padding let max_ih = (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; let padded_height = x_height + 2 * padding_height; let padded_width = x_width + 2 * padding_width; let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; let x = apply_padding_4d::(x, total_padding, inf); // Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d) let offset_h = extra_pad_h; let offset_w = extra_pad_w; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output = unsafe_shared_out.get(); for oh in 0..out_height { for ow in 0..out_width { let mut max_val = inf; for kh in 0..kernel_height { let ih = offset_h + oh * stride_height + kh * dilation_height; for kw in 0..kernel_width { let iw = offset_w + ow * stride_width + kw * dilation_width; let val = x[[b, c, ih, iw]]; if val > max_val { max_val = val; } } } output[[b, c, oh, ow]] = max_val; } } }) }); output.into_dyn().into_shared() } pub(crate) fn max_pool2d_with_indices( x: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> (SharedArray, SharedArray) { let [kernel_height, kernel_width] = kernel_size; let [padding_height, padding_width] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let [batch_size, channels, x_height, x_width] = x.shape().dims(); let inf = (-f32::INFINITY).elem::(); let out_height = calculate_pool_output_size( kernel_height, stride_height, padding_height, dilation_height, x_height, ceil_mode, ); let out_width = calculate_pool_output_size( kernel_width, stride_width, padding_width, dilation_width, x_width, ceil_mode, ); // Calculate extra padding needed for ceil_mode let max_ih = (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; let padded_height = x_height + 2 * padding_height; let padded_width = x_width + 2 * padding_width; let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; let x = apply_padding_4d::(x, total_padding, inf); // Offset to account for extra padding let offset_h = extra_pad_h; let offset_w = extra_pad_w; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output = unsafe_shared_out.get(); let indices = unsafe_shared_indices.get(); for oh in 0..out_height { for ow in 0..out_width { let mut max_val = inf; let mut index = 0; for kh in 0..kernel_height { let ih = offset_h + oh * stride_height + kh * dilation_height; for kw in 0..kernel_width { let iw = offset_w + ow * stride_width + kw * dilation_width; let val = x[[b, c, ih, iw]]; if val > max_val { max_val = val; // Calculate index in original (unpadded) input let ih_orig = ih as i64 - (total_padding[0]) as i64; let iw_orig = iw as i64 - (total_padding[1]) as i64; // Clamp to valid range for index calculation let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1); let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1); index = ih_clamped * x_width as i64 + iw_clamped; } } } output[[b, c, oh, ow]] = max_val; indices[[b, c, oh, ow]] = index.elem(); } } }) }); let output = output.into_dyn().into_shared(); let indices = indices.into_dyn().into_shared(); (output, indices) } #[allow(clippy::too_many_arguments)] pub(crate) fn max_pool2d_backward( x: SharedArray, _kernel_size: [usize; 2], _stride: [usize; 2], _padding: [usize; 2], _dilation: [usize; 2], _ceil_mode: bool, output_grad: SharedArray, indices: SharedArray, ) -> SharedArray { let [_batch_size, _channels, height, width] = output_grad.shape().dims(); let [batch_size, channels, height_x, width_x] = x.shape().dims(); let output_grad = output_grad; let indices = indices; let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; let output = unsafe_shared_out.get(); for h in 0..height { for w in 0..width { let index = indices[[b, c, h, w]].elem::(); let grad = output_grad[[b, c, h, w]]; let index_h = index as usize / width_x; let index_w = index as usize % width_x; output[[b, c, index_h, index_w]] += grad; } } }); }); output.into_dyn().into_shared() } ================================================ FILE: crates/burn-ndarray/src/ops/mod.rs ================================================ mod activation; mod base; mod bool_tensor; mod int_tensor; mod module; mod qtensor; #[cfg(feature = "simd")] mod simd; mod tensor; mod transaction; pub(crate) mod adaptive_avgpool; pub(crate) mod avgpool; pub(crate) mod conv; pub(crate) mod deform_conv; pub(crate) mod grid_sample; pub(crate) mod interpolate; pub(crate) mod macros; pub(crate) mod matmul; pub(crate) mod maxpool; pub(crate) mod padding; pub(crate) mod quantization; pub(crate) use base::*; ================================================ FILE: crates/burn-ndarray/src/ops/module.rs ================================================ use super::{ adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, avgpool::{avg_pool2d, avg_pool2d_backward}, conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d}, deform_conv::{backward::deform_conv2d_backward, deform_conv2d}, interpolate::{ bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate, }, maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, }; #[cfg(feature = "simd")] use crate::ops::simd::{ avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd, }; use crate::{ NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype, tensor::NdArrayTensor, }; use crate::{ element::{IntNdArrayElement, QuantElement}, ops::interpolate::nearest_interpolate_backward, }; use burn_backend::{ ElementConversion, TensorMetadata, ops::{attention::attention_fallback, *}, tensor::FloatTensor, }; macro_rules! module_op { // Module op with inputs (inp), optional (opt) and arguments (args). // Converts NdArrayStorage to SharedArray for compatibility with existing operations. (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{ #[allow(unused_parens, unreachable_patterns)] match ($($x),+) { ($(NdArrayTensor::F32($x)),+) => { type $element = f32; $op( $($x.into_shared()),+ $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* ) } ($(NdArrayTensor::F64($x)),+) => { type $element = f64; $op( $($x.into_shared()),+ $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* ) } _ => panic!("Data type mismatch"), } }}; } impl ModuleOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option, options: ConvOptions<2>, ) -> NdArrayTensor { module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { #[cfg(feature = "simd")] let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) { Ok(out) => return out.into(), Err(args) => args, }; conv2d::(x, weight, bias, options).into() }) } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { module_op!( inp(x, offset, weight), opt(mask, bias), E, |x, offset, weight, mask, bias| deform_conv2d::( x, offset, weight, mask, bias, options ) .into() ) } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { module_op!( inp(x, offset, weight, output_grad), opt(mask, bias), E, |x, offset, weight, output_grad, mask, bias| { let (x, offset, weight, mask, bias) = deform_conv2d_backward::( x, offset, weight, mask, bias, output_grad, options, ); DeformConv2dBackward::new( x.into(), offset.into(), weight.into(), mask.map(|m| m.into()), bias.map(|b| b.into()), ) } ) } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { conv_transpose2d::(x, weight, bias, options).into() }) } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { module_op!(inp(x), opt(), E, |x| { #[cfg(feature = "simd")] let x = match if ceil_mode { // SIMD path doesn't support ceil_mode yet, skip it Err(x) } else { try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) } { Ok(out) => return out.into(), Err(x) => x, }; avg_pool2d::( x, kernel_size, stride, padding, count_include_pad, ceil_mode, ) .into() }) } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::( x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode ) .into()) } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { module_op!(inp(x), opt(), E, |x| { #[cfg(feature = "simd")] let x = match if ceil_mode { // SIMD path doesn't support ceil_mode yet, skip it Err(x) } else { try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) } { Ok(out) => return out.into(), Err(x) => x, }; max_pool2d::(x, kernel_size, stride, padding, dilation, ceil_mode).into() }) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices> { module_op!(inp(x), opt(), E, |x| { let (output, indices) = max_pool2d_with_indices::( x, kernel_size, stride, padding, dilation, ceil_mode, ); MaxPool2dWithIndices::new(output.into(), indices.into()) }) } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: NdArrayTensor, ) -> MaxPool2dBackward> { execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray| { // Convert indices from runtime dtype to the expected I type // (pool indices are bounded by tensor dimensions, so conversion is safe) let indices: SharedArray = idx_s.mapv(|x| x.elem()).into_shared(); module_op!(inp(x, output_grad), opt(), E, |x, output_grad| { let output = max_pool2d_backward::( x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices, ); MaxPool2dBackward::new(output.into()) }) }) } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::( x, output_size ) .into()) } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { module_op!(inp(x, grad), opt(), E, |x, grad| { adaptive_avg_pool2d_backward::(x, grad).into() }) } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { match options.mode { InterpolateMode::Nearest => { module_op!(inp(x), opt(), E, |x| nearest_interpolate::( x, output_size ) .into()) } InterpolateMode::Bilinear => { let align_corners = options.align_corners; module_op!(inp(x), opt(), E, |x| bilinear_interpolate::( x, output_size, align_corners ) .into()) } InterpolateMode::Bicubic => { let align_corners = options.align_corners; module_op!(inp(x), opt(), E, |x| bicubic_interpolate::( x, output_size, align_corners ) .into()) } InterpolateMode::Lanczos3 => { let align_corners = options.align_corners; module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::( x, output_size, align_corners ) .into()) } } } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { match options.mode { InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| { nearest_interpolate_backward::(x, grad, output_size).into() }), InterpolateMode::Bilinear => { panic!("bilinear interpolation backward is not supported for ndarray backend") } InterpolateMode::Bicubic => { panic!("bicubic interpolation backward is not supported for ndarray backend") } InterpolateMode::Lanczos3 => { panic!("lanczos3 interpolation backward is not supported for ndarray backend") } } } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::( x, weight, bias, options ) .into()) } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { conv_transpose3d::(x, weight, bias, options).into() }) } fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> FloatTensor { attention_fallback::(query, key, value, mask, attn_bias, options) } } ================================================ FILE: crates/burn-ndarray/src/ops/padding.rs ================================================ use crate::{NdArrayElement, SharedArray}; use ndarray::{Array4, Array5}; use super::NdArrayOps; pub(crate) fn apply_padding_4d( x: SharedArray, padding: [usize; 2], elem: E, ) -> SharedArray { let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap(); let [padding_height, padding_width] = padding; let padded_height = height + 2 * padding_height; let padded_width = width + 2 * padding_width; let x_new = Array4::from_elem( (batch_size, input_channels, padded_height, padded_width), elem, ); let mut x_new = x_new.into_shared().into_dyn(); x_new = NdArrayOps::slice_assign( x_new, &[ burn_backend::Slice::from(0..batch_size), burn_backend::Slice::from(0..input_channels), burn_backend::Slice::from(padding_height..height + padding_height), burn_backend::Slice::from(padding_width..width + padding_width), ], x, ); x_new } pub(crate) fn apply_padding_5d( x: SharedArray, padding: [usize; 3], elem: E, ) -> SharedArray { let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap(); let [padding_depth, padding_height, padding_width] = padding; let padded_depth = depth + 2 * padding_depth; let padded_height = height + 2 * padding_height; let padded_width = width + 2 * padding_width; let x_new = Array5::from_elem( ( batch_size, input_channels, padded_depth, padded_height, padded_width, ), elem, ); let mut x_new = x_new.into_shared().into_dyn(); x_new = NdArrayOps::slice_assign( x_new, &[ burn_backend::Slice::from(0..batch_size), burn_backend::Slice::from(0..input_channels), burn_backend::Slice::from(padding_depth..depth + padding_depth), burn_backend::Slice::from(padding_height..height + padding_height), burn_backend::Slice::from(padding_width..width + padding_width), ], x, ); x_new } ================================================ FILE: crates/burn-ndarray/src/ops/qtensor.rs ================================================ use alloc::{vec, vec::Vec}; use burn_backend::{ DType, ExecutionError, Shape, TensorData, TensorMetadata, ops::QTensorOps, quantization::{ QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue, QuantizationParametersPrimitive, QuantizedBytes, }, tensor::{FloatTensor, IntTensor, QuantizedTensor}, }; use crate::{ FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray, element::{IntNdArrayElement, QuantElement}, execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype, slice, }; use super::quantization::{QuantizationStrategy, SymmetricQuantization}; use super::{NdArrayMathOps, NdArrayOps}; impl QTensorOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor { match data.dtype { DType::QFloat(scheme) => { let shape = data.shape.clone(); let num_elements = data.num_elements(); let q_bytes = QuantizedBytes { bytes: data.into_bytes(), scheme, num_elements, }; match scheme { QuantScheme { level: QuantLevel::Tensor | QuantLevel::Block(_), mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S, .. } => { // We can load QuantStore::U32 w/ QuantizedBytes impl let (values, qparams) = q_bytes.into_vec_i8(); let data = TensorData::new(values, shape); // Overwrite storage let scheme = scheme.with_store(QuantStore::Native); let qparams = qparams .scales .into_iter() .map(|scales| QParams { scales }) .collect(); NdArrayQTensor { qtensor: NdArrayTensor::from_data(data), scheme, qparams, } } QuantScheme { value: QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 | QuantValue::E4M3 | QuantValue::E5M2, .. } => unimplemented!("from_data not supported for scheme {scheme:?}"), } } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", data.dtype ), } } fn quantize( tensor: FloatTensor, scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { let shape = tensor.shape(); let data_f = tensor.into_data(); let scales = qparams.scales.into_data().convert::(); // Implement with ndarray instead of QuantizationStrategy? let (data, qparams) = match scheme { QuantScheme { level: QuantLevel::Tensor, mode: QuantMode::Symmetric, #[cfg(not(feature = "export_tests"))] value: QuantValue::Q8F | QuantValue::Q8S, // For tests, "native" sub-byte quant serves as a reference for value equality. // Values are stored as i8 regardless. #[cfg(feature = "export_tests")] value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S, store: QuantStore::Native, .. } => { let scales = scales.iter().next().unwrap(); let strategy = QuantizationStrategy::PerTensorSymmetric( SymmetricQuantization::init(scales, scheme.value), ); let values = strategy.quantize(data_f.as_slice().unwrap()); ( TensorData::quantized(values, shape.clone(), *scheme, &[scales]), vec![QParams { scales }], ) } QuantScheme { level: QuantLevel::Block(block_size), mode: QuantMode::Symmetric, #[cfg(not(feature = "export_tests"))] value: QuantValue::Q8F | QuantValue::Q8S, #[cfg(feature = "export_tests")] value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S, store: QuantStore::Native, .. } => { let scales = scales.as_slice().unwrap(); let (strategy, qparams) = scales .iter() .map(|&s| { ( SymmetricQuantization::init(s, scheme.value), QParams { scales: s }, ) }) .unzip(); let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size); let values = strategy.quantize(data_f.as_slice().unwrap()); ( TensorData::quantized(values, shape.clone(), *scheme, scales), qparams, ) } scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"), }; let num_elements = data.num_elements(); let q_bytes = QuantizedBytes { bytes: data.into_bytes(), scheme: *scheme, num_elements, }; let (values, _) = q_bytes.into_vec_i8(); let data = TensorData::new(values, shape).convert::(); NdArrayQTensor { qtensor: NdArrayTensor::from_data(data), scheme: *scheme, qparams, } } fn dequantize(tensor: QuantizedTensor) -> FloatTensor { let strategy = tensor.strategy(); let scheme = tensor.scheme; let shape = tensor.shape(); let data = match tensor.qtensor { NdArrayTensor::I8(storage) => { let data = storage.into_shared().into_iter().collect(); dequantize(data, shape, scheme, &strategy) } _ => unreachable!(), }; NdArrayTensor::from_data(data) } fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { NdArrayDevice::Cpu } fn q_to_device( tensor: QuantizedTensor, _device: &NdArrayDevice, ) -> QuantizedTensor { tensor } fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { NdArrayQTensor { qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::reshape(array, shape) }), scheme: tensor.scheme, qparams: tensor.qparams, } } async fn q_into_data(tensor: QuantizedTensor) -> Result { let shape = tensor.qtensor.shape(); let scales = tensor.qparams.iter().map(|q| q.scales).collect::>(); Ok(execute_with_numeric_dtype!( tensor.qtensor, E, |array: SharedArray| { let values = array.into_iter().collect(); TensorData::quantized(values, shape, tensor.scheme, &scales) } )) } fn q_swap_dims( tensor: QuantizedTensor, dim1: usize, dim2: usize, ) -> QuantizedTensor { NdArrayQTensor { qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::swap_dims(array, dim1, dim2) }), scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { NdArrayQTensor { qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::permute(array, axes) }), scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { NdArrayQTensor { qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::flip(array, axes) }), scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_gather( dim: usize, tensor: QuantizedTensor, indices: IntTensor, ) -> QuantizedTensor { let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< IntElem, >| -> NdArrayTensor { execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::gather(dim, array, idx_array) }) }); NdArrayQTensor { qtensor, scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_select( tensor: QuantizedTensor, dim: usize, indices: IntTensor, ) -> QuantizedTensor { let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< IntElem, >| -> NdArrayTensor { execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayMathOps::select(array, dim, idx_array) }) }); NdArrayQTensor { qtensor, scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_slice( tensor: QuantizedTensor, slices: &[burn_backend::Slice], ) -> QuantizedTensor { NdArrayQTensor { qtensor: slice!(tensor.qtensor, slices), scheme: tensor.scheme, qparams: tensor.qparams, } } fn q_argmax(tensor: QuantizedTensor, dim: usize) -> IntTensor { execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayMathOps::argmax::(array, dim) }) } fn q_argmin(tensor: QuantizedTensor, dim: usize) -> IntTensor { execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayMathOps::argmin::(array, dim) }) } fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { NdArrayQTensor { qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { NdArrayOps::expand(array, shape) }), scheme: tensor.scheme, qparams: tensor.qparams, } } } fn dequantize( data: Vec, shape: Shape, scheme: QuantScheme, strategy: &QuantizationStrategy, ) -> TensorData { let qparams = match strategy { QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale], QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => { quant.iter().map(|q| q.scale).collect() } }; let q_bytes = QuantizedBytes::new(data, scheme, &qparams); let (values, _qparams) = q_bytes.into_vec_i8(); TensorData::new(strategy.dequantize(&values), shape) } ================================================ FILE: crates/burn-ndarray/src/ops/quantization.rs ================================================ use alloc::vec::Vec; use num_traits::{Float, PrimInt}; use burn_backend::quantization::{BlockSize, QuantValue}; // NOTE: this mainly serves as a simple reference implementation. // The de/quantization ops should be refactored to use ndarray. /// Quantization strategy. #[derive(Debug, Clone, PartialEq, Eq)] pub enum QuantizationStrategy { /// Per-tensor symmetric quantization. PerTensorSymmetric(SymmetricQuantization), /// Per-block symmetric quantization. PerBlockSymmetric(Vec>, BlockSize), } impl QuantizationStrategy { /// Quantize the values to a lower precision data type. pub fn quantize(&self, values: &[f32]) -> Vec { match self { QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values), QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { let block_elems = block_size.num_elements(); let num_blocks = strategy.len(); let numel = values.len(); assert_eq!( numel / block_elems, num_blocks, "Invalid per-block quantization with num blocks {num_blocks} and {numel} values" ); values .chunks(block_elems) .enumerate() .flat_map(|(block_id, block)| strategy[block_id].quantize(block)) .collect() } } } /// Dequantize the values to a higher precision data type. pub fn dequantize(&self, values: &[i8]) -> Vec { match self { QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values), QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { let block_elems = block_size.num_elements(); let num_blocks = strategy.len(); let numel = values.len(); assert_eq!( numel / block_elems, num_blocks, "Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values" ); values .chunks(block_elems) .enumerate() .flat_map(|(block_id, block)| strategy[block_id].dequantize(block)) .collect() } } } } /// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision /// data type `Q` and vice-versa. pub trait Quantization { /// Returns the quantization range `[a, b]`. fn range(&self) -> (E, E); /// Convert the values to a lower precision data type. fn quantize(&self, values: &[E]) -> Vec; /// Convert a single value to a lower precision data type. fn quantize_one(&self, value: E) -> Q; /// Convert the values back to a higher precision data type. fn dequantize(&self, values: &[Q]) -> Vec; /// Convert a single value back to a higher precision data type. fn dequantize_one(&self, value: Q) -> E; } fn valid_scale(mut scale: E) -> E { // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the // scale to 0.1 to avoid division by zero. if scale.eq(&E::zero()) { scale = E::from(0.1).unwrap(); } scale } /// Symmetric quantization scheme. #[derive(Debug, Clone, Copy)] pub struct SymmetricQuantization { /// The scaling factor. pub scale: E, // The quantization value data type. value: QuantValue, } impl SymmetricQuantization { /// Initialize a symmetric quantization scheme with the given parameters. pub fn init(scale: E, value: QuantValue) -> Self { Self { scale: valid_scale(scale), value, } } #[allow(dead_code)] /// Create a new quantization scheme for an input range `[alpha, beta]`. fn new(alpha: E, beta: E, value: QuantValue) -> Self { let (a, b) = value.range(); let a = E::from(a).unwrap(); let b = E::from(b).unwrap(); // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range let alpha = alpha.abs().max(beta.abs()); let scale = valid_scale((alpha + alpha) / (b - a)); Self { scale, value } } } impl Quantization for SymmetricQuantization { fn quantize(&self, values: &[E]) -> Vec { values.iter().map(|x| self.quantize_one(*x)).collect() } fn dequantize(&self, values: &[Q]) -> Vec { values.iter().map(|x_q| self.dequantize_one(*x_q)).collect() } fn quantize_one(&self, value: E) -> Q { let (a, b) = self.range(); // x_q = clamp(round(x / scale), a, b) Q::from(value.div(self.scale).round().clamp(a, b)).unwrap() } fn dequantize_one(&self, value: Q) -> E { // x = scale * x_q self.scale * E::from(value).unwrap() } fn range(&self) -> (E, E) { let (a, b) = self.value.range(); let a = E::from(a).unwrap(); let b = E::from(b).unwrap(); (a, b) } } impl PartialEq for SymmetricQuantization { fn eq(&self, other: &Self) -> bool { self.scale == other.scale } } impl Eq for SymmetricQuantization {} #[cfg(test)] mod tests { use burn_backend::TensorData; use super::*; use alloc::vec; #[test] fn test_int8_symmetric_quantization() { let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5]; let expected_q = vec![-127, -71, 0, 35]; let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063]; let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); let q: Vec = symmetric.quantize(&x); assert_eq!(q, expected_q); let d = symmetric.dequantize(&expected_q); assert_eq!(d, expected_d); } #[test] fn test_int8_symmetric_quantization_per_block() { let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5]; let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35]; let expected_d = vec![ -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063, ]; let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); let strategy = QuantizationStrategy::PerBlockSymmetric( vec![symmetric, symmetric], BlockSize::new([4]), ); let q: Vec = strategy.quantize(&x); assert_eq!(q, expected_q); let d = symmetric.dequantize(&expected_q); assert_eq!(d, expected_d); } #[test] fn should_support_dequantize() { let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization { scale: 0.1, value: QuantValue::Q8S, }); let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]); let output = TensorData::new(output, [2, 3]); output.assert_approx_eq::( &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]), Default::default(), ); } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/avgpool.rs ================================================ use core::{marker::PhantomData, mem::transmute}; use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; use burn_backend::DType; use burn_backend::{Element, ElementConversion}; use bytemuck::Zeroable; use macerator::{Simd, VAdd, VDiv}; use ndarray::{Array4, s}; use nhwc::avg_pool_nhwc; use super::should_use_simd; #[macerator::with_simd] fn is_accelerated(_x: PhantomData) -> bool { ::is_accelerated::() && ::is_accelerated::() } pub(crate) fn try_avg_pool2d_simd( x: SharedArray, ksize: [usize; 2], stride: [usize; 2], padding: [usize; 2], with_pad: bool, ) -> Result, SharedArray> { // Strides must be unit, dilation isn't supported, rows must be contiguous if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) { return Err(x); } match E::dtype() { DType::F64 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( cast(x), ksize, stride, padding, with_pad, ))), DType::F32 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( cast(x), ksize, stride, padding, with_pad, ))), _ => Err(x), } } fn cast(tensor: SharedArray) -> SharedArray { unsafe { transmute::, SharedArray>(tensor) } } mod nhwc { use itertools::Itertools; use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned}; use ndarray::{ArrayView3, ArrayViewMut3}; use seq_macro::seq; use crate::ops::simd::lanes; use super::*; // Until you can use associated constants as array size, we need to hardcode this. // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. const BLOCK_REGISTERS: usize = 8; pub(crate) fn avg_pool_nhwc( x: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], with_pad: bool, ) -> SharedArray { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); let lanes = lanes::(); let ch_block = lanes * BLOCK_REGISTERS; let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1; let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1; let mut output = unsafe { Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() }; let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let x = x.view(); let x = x.permuted_axes(vec![0, 2, 3, 1]); // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. // An exclusive loop will always have `lanes * blocking factor` elements in bounds. let blocks = channels / ch_block; let blocks_end = blocks * ch_block; // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An // exclusive loop will always have `lanes` elements in bounds. let simd_end = channels / lanes * lanes; let num_simd_unblocked = (simd_end - blocks_end) / lanes; let remainder = channels - simd_end; run_par!(|| { // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { let block = k % blocks; let b = k / blocks; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_blocked(x, out, kernel_size, stride, padding, with_pad, block); }); // SAFETY: See `loop_unblocked` iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe { let ch = (k % num_simd_unblocked) * lanes + blocks_end; let b = k / num_simd_unblocked; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch); }); // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { let ch = (k % remainder) + simd_end; let b = k / remainder; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch); }); }); output = output.permuted_axes([0, 3, 1, 2]); output.into_dyn().into_shared() } /// Execute the blocked (unrolled) portion of the pool. #[allow( clippy::too_many_arguments, clippy::erasing_op, clippy::identity_op, unused_mut )] #[macerator::with_simd] fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>( x: ArrayView3<'a, E>, mut out: ArrayViewMut3<'a, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], with_pad: bool, block: usize, ) where 'a: 'a, { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); let lanes = E::lanes::(); let ch_block = lanes * BLOCK_REGISTERS; // If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds for oh in pad_h..out_height.saturating_sub(pad_h) { for ow in pad_w..out_width.saturating_sub(pad_w) { seq!(N in 0..8 { let mut sum~N: Vector = Zeroable::zeroed(); }); let ch = block * ch_block; let ch_end = ch + ch_block; let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); for kh in 0..kernel_height { let ih = oh * stride_height + kh - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw - pad_w; let x = x.slice(s![ih, iw, ch..ch_end]); seq!(N in 0..8 { // SAFETY: // Load a full vector from x[N * lanes]. This is bounds checked by the // slice above. sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; }); } } let count = kernel_height * kernel_width; let count = (count as u64).elem::(); let count_v = count.splat(); seq!(N in 0..8 { let s~N = sum~N / count_v; // SAFETY: // Store a full vector to out[N * lanes]. This is bounds checked by the // slice above. unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; }); } } // Border pixels need bounds checks if (pad_h, pad_w) != (0, 0) { let v_borders = (0..pad_h) .chain(out_height.saturating_sub(pad_h)..out_height) .cartesian_product(0..out_width); let h_borders = (0..out_height) .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); for (oh, ow) in v_borders.chain(h_borders) { seq!(N in 0..8 { let mut sum~N: Vector = Zeroable::zeroed(); }); let mut count: usize = 0; let ch = block * ch_block; let ch_end = ch + ch_block; let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); for kh in 0..kernel_height { let ih = oh * stride_height + kh; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; count += 1; let x = x.slice(s![ih, iw, ch..ch_end]); seq!(N in 0..8 { // SAFETY: // Load a full vector from x[N * lanes]. This is bounds checked by the // slice above. sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; }); } } if with_pad { count = kernel_height * kernel_width; } let count = (count as u64).elem::(); let count_v = count.splat(); seq!(N in 0..8 { let s~N = sum~N / count_v; // SAFETY: // Store a full vector to out[N * lanes]. This is bounds checked by the // slice above. unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; }); } } } /// Execute the unblocked (not unrolled) portion of the pool. /// /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. #[allow(clippy::too_many_arguments, unused_mut)] #[macerator::with_simd] unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>( x: ArrayView3<'a, E>, mut out: ArrayViewMut3<'a, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], with_pad: bool, ch: usize, ) where 'a: 'a, { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); // If pixels are not within padding range, bounds checks are always true for oh in pad_h..out_height - pad_h { for ow in pad_w..out_width - pad_w { let mut sum: Vector = Zeroable::zeroed(); for kh in 0..kernel_height { let ih = oh * stride_height + kh - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw - pad_w; // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; sum += s0; } } let count = kernel_height * kernel_width; let count: E = (count as u64).elem(); let count_v = count.splat(); let s0 = sum / count_v; // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; } } // Border pixels need bounds checks if (pad_h, pad_w) != (0, 0) { let v_borders = (0..pad_h) .chain(out_height.saturating_sub(pad_h)..out_height) .cartesian_product(0..out_width); let h_borders = (0..out_height) .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); for (oh, ow) in v_borders.chain(h_borders) { let mut sum: Vector = Zeroable::zeroed(); let mut count: usize = 0; for kh in 0..kernel_height { let ih = oh * stride_height + kh; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; count += 1; // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; } } if with_pad { count = kernel_height * kernel_width; } let count = (count as u64).elem::(); let count_v = count.splat(); let s0 = sum / count_v; // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; } } } /// Execute scalar portion of the pooling #[allow(clippy::too_many_arguments)] fn loop_scalar( x: ArrayView3<'_, E>, mut out: ArrayViewMut3<'_, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], with_pad: bool, ch: usize, ) { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); // If pixels are not within padding range, bounds checks are always true for oh in pad_h..out_height.saturating_sub(pad_h) { for ow in pad_w..out_width.saturating_sub(pad_w) { let mut sum: E = Zeroable::zeroed(); for kh in 0..kernel_height { let ih = oh * stride_height + kh - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw - pad_w; sum = sum + x[[ih, iw, ch]]; } } let count = (kernel_height * kernel_width) as u64; out[[oh, ow, ch]] = sum / count.elem(); } } // Border pixels need bounds checks if (pad_h, pad_w) != (0, 0) { let v_borders = (0..pad_h) .chain(out_height.saturating_sub(pad_h)..out_height) .cartesian_product(0..out_width); let h_borders = (0..out_height) .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); for (oh, ow) in v_borders.chain(h_borders) { let mut sum: E = Zeroable::zeroed(); let mut count: usize = 0; for kh in 0..kernel_height { let ih = oh * stride_height + kh; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; count += 1; sum = sum + x[[ih, iw, ch]]; } } if with_pad { count = kernel_height * kernel_width; } out[[oh, ow, ch]] = sum / (count as u64).elem(); } } } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/base.rs ================================================ use core::{marker::PhantomData, mem::MaybeUninit}; use macerator::{Arch, Scalar, Simd}; use ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder}; /// Whether SIMD instructions are worth using #[cfg(all( any( target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32", target_arch = "loongarch64" ), not(test) ))] pub fn should_use_simd(len: usize) -> bool { len >= 32 } /// Whether SIMD instructions are worth using #[cfg(all( not(any( target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32", target_arch = "loongarch64" )), not(test) ))] pub fn should_use_simd(_len: usize) -> bool { false } #[cfg(test)] pub fn should_use_simd(_len: usize) -> bool { true } pub(crate) fn lanes() -> usize { #[allow(non_camel_case_types)] struct lanes<__T0>(__T0); impl ::macerator::WithSimd for lanes> { type Output = usize; #[inline(always)] fn with_simd<__S: ::macerator::Simd>(self) -> ::Output { let Self(__ty) = self; #[allow(unused_unsafe)] unsafe { lanes_simd::<__S, E>(__ty) } } } (Arch::new()).dispatch(lanes(PhantomData::)) } fn lanes_simd(_ty: PhantomData) -> usize { E::lanes::() } pub(crate) fn uninit_array_like(reference: &ArcArray) -> ArrayD { let shape = reference.raw_dim(); let strides = reference.strides(); let strides = strides.iter().map(|it| *it as usize).collect::>(); let shape_strides = shape.strides(IxDyn(&strides)); let size = reference.len(); let mut out_data: Vec> = Vec::with_capacity(size); unsafe { out_data.set_len(size) }; unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() } } pub trait MinMax { fn min(self, other: Self) -> Self; fn max(self, other: Self) -> Self; } macro_rules! impl_minmax { ($ty: ty) => { impl MinMax for $ty { fn min(self, other: Self) -> Self { Ord::min(self, other) } fn max(self, other: Self) -> Self { Ord::max(self, other) } } }; ($($ty: ty),*) => { $(impl_minmax!($ty);)* } } impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64); impl MinMax for f32 { fn min(self, other: Self) -> Self { self.min(other) } fn max(self, other: Self) -> Self { self.max(other) } } impl MinMax for f64 { fn min(self, other: Self) -> Self { self.min(other) } fn max(self, other: Self) -> Self { self.max(other) } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/binary.rs ================================================ use core::{marker::PhantomData, slice}; use burn_backend::Element; use macerator::{ Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned, vstore_unaligned, }; use ndarray::ArrayD; use seq_macro::seq; use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; use super::{ MinMax, binary_elemwise::{ VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub, }, should_use_simd, }; pub trait SimdBinop { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector; fn apply(lhs: T, rhs: T) -> Out; fn is_accelerated() -> bool; } impl SimdBinop for VecAdd { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs + rhs } fn apply(lhs: T, rhs: T) -> T { lhs + rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecDiv { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs / rhs } fn apply(lhs: T, rhs: T) -> T { lhs / rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecMul { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs * rhs } fn apply(lhs: T, rhs: T) -> T { lhs * rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecSub { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs - rhs } fn apply(lhs: T, rhs: T) -> T { lhs - rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecMin { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs.min(rhs) } fn apply(lhs: T, rhs: T) -> T { MinMax::min(lhs, rhs) } fn is_accelerated() -> bool { ::is_min_max_accelerated::() } } impl SimdBinop for VecMax { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs.max(rhs) } fn apply(lhs: T, rhs: T) -> T { MinMax::max(lhs, rhs) } fn is_accelerated() -> bool { ::is_min_max_accelerated::() } } impl SimdBinop for VecBitAnd { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs & rhs } fn apply(lhs: T, rhs: T) -> T { lhs.bitand(rhs) } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecBitOr { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs | rhs } fn apply(lhs: T, rhs: T) -> T { lhs.bitor(rhs) } fn is_accelerated() -> bool { ::is_accelerated::() } } impl SimdBinop for VecBitXor { fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { lhs ^ rhs } fn apply(lhs: T, rhs: T) -> T { lhs.bitxor(rhs) } fn is_accelerated() -> bool { ::is_accelerated::() } } #[macerator::with_simd] fn is_accelerated>( _x: PhantomData<(T, Out, Op)>, ) -> bool { Op::is_accelerated::() } #[allow(clippy::result_large_err)] pub fn try_binary_simd< E: Element, EOut: Element, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdBinop, >( lhs: SharedArray, rhs: SharedArray, ) -> Result, (SharedArray, SharedArray)> { let lhs_len = lhs.len(); let rhs_len = rhs.len(); if !should_use_simd(lhs_len.max(rhs_len)) || !lhs.is_standard_layout() || !rhs.is_standard_layout() || lhs.shape() != rhs.shape() || !is_accelerated::(PhantomData) { return Err((lhs, rhs)); } // Used to assert traits based on the dynamic `DType`. let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; let out = binary_simd_same::(lhs, rhs); // Used to assert traits based on the dynamic `DType`. let out = unsafe { core::mem::transmute::, SharedArray>(out) }; Ok(out) } fn binary_simd_same< T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdBinop, >( lhs: SharedArray, rhs: SharedArray, ) -> SharedArray { let out = if lhs.is_unique() { let mut buf = lhs.into_owned(); let lhs = buf.as_slice_mut().unwrap(); let rhs = rhs.as_slice().unwrap(); let out = unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) }; binary(lhs, rhs, out, PhantomData::); unsafe { core::mem::transmute::, ArrayD>(buf) } } else if rhs.is_unique() { let mut buf = rhs.into_owned(); let lhs = lhs.as_slice().unwrap(); let rhs = buf.as_slice_mut().unwrap(); let out = unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) }; binary(lhs, rhs, out, PhantomData::); unsafe { core::mem::transmute::, ArrayD>(buf) } } else { let mut out = uninit_array_like(&lhs); let lhs = lhs.as_slice().unwrap(); let rhs = rhs.as_slice().unwrap(); let out_slice = out.as_slice_mut().unwrap(); binary(lhs, rhs, out_slice, PhantomData::); out }; out.into_shared() } #[allow(clippy::erasing_op, clippy::identity_op)] #[macerator::with_simd] fn binary< 'a, S: Simd, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdBinop, >( lhs: &'a [T], rhs: &'a [T], out: &'a mut [Out], _op: PhantomData, ) where 'a: 'a, { let lanes = T::lanes::(); let mut chunks_lhs = lhs.chunks_exact(8 * lanes); let mut chunks_rhs = rhs.chunks_exact(8 * lanes); let mut chunks_out = out.chunks_exact_mut(8 * lanes); while let Some(((lhs, rhs), out)) = chunks_lhs .next() .zip(chunks_rhs.next()) .zip(chunks_out.next()) { seq!(N in 0..8 { // Load one full vector from `lhs`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; // Load one full vector from `rhs`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; let s~N = Op::apply_vec(lhs~N, rhs~N); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; }); } let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); while let Some(((lhs, rhs), out)) = chunks_lhs .next() .zip(chunks_rhs.next()) .zip(chunks_out.next()) { // Load one full vector from `lhs`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; // Load one full vector from `rhs`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; let s0 = Op::apply_vec(lhs0, rhs0); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == lanes` unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; } for ((lhs, rhs), out) in chunks_lhs .remainder() .iter() .zip(chunks_rhs.remainder()) .zip(chunks_out.into_remainder()) { *out = Op::apply(*lhs, *rhs) } } /// Unsafely alias a slice to use as an inline argument fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { let ptr = slice.as_mut_ptr(); let len = slice.len(); unsafe { slice::from_raw_parts_mut(ptr, len) } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/binary_elemwise.rs ================================================ use core::marker::PhantomData; use bytemuck::cast; use macerator::{ Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload, vload_unaligned, vstore, vstore_unaligned, }; use ndarray::ArrayD; use seq_macro::seq; use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; use super::{MinMax, should_use_simd}; pub trait ScalarSimdBinop { type Rhs: Copy; type RhsVec: Copy; fn splat(rhs: Self::Rhs) -> Self::RhsVec; fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector; fn apply(lhs: T, rhs: Self::Rhs) -> Out; fn is_accelerated() -> bool; } pub struct VecAdd; pub struct VecDiv; pub struct VecMul; pub struct VecSub; pub struct VecMin; pub struct VecMax; pub struct VecClamp; pub struct VecBitAnd; pub struct VecBitOr; pub struct VecBitXor; impl ScalarSimdBinop for VecAdd { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs + rhs } fn apply(lhs: T, rhs: T) -> T { lhs + rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecDiv { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs / rhs } fn apply(lhs: T, rhs: T) -> T { lhs / rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecMul { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs * rhs } fn apply(lhs: T, rhs: T) -> T { lhs * rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecSub { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs - rhs } fn apply(lhs: T, rhs: T) -> T { lhs - rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecMin { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs.min(rhs) } fn apply(lhs: T, rhs: T) -> T { lhs.min(rhs) } fn is_accelerated() -> bool { ::is_min_max_accelerated::() } } impl ScalarSimdBinop for VecMax { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs.max(rhs) } fn apply(lhs: T, rhs: T) -> T { lhs.max(rhs) } fn is_accelerated() -> bool { ::is_min_max_accelerated::() } } impl ScalarSimdBinop for VecClamp { type Rhs = (T, T); type RhsVec = (Vector, Vector); fn splat((min, max): Self::Rhs) -> Self::RhsVec { (min.splat(), max.splat()) } fn apply_vec(lhs: Vector, (min, max): Self::RhsVec) -> Vector { lhs.min(max).max(min) } fn apply(lhs: T, (min, max): Self::Rhs) -> T { lhs.min(max).max(min) } fn is_accelerated() -> bool { ::is_min_max_accelerated::() } } impl ScalarSimdBinop for VecBitAnd { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs & rhs } fn apply(lhs: T, rhs: Self::Rhs) -> T { lhs & rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecBitOr { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs | rhs } fn apply(lhs: T, rhs: Self::Rhs) -> T { lhs | rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } impl ScalarSimdBinop for VecBitXor { type Rhs = T; type RhsVec = Vector; fn splat(rhs: Self::Rhs) -> Self::RhsVec { rhs.splat() } fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { lhs ^ rhs } fn apply(lhs: T, rhs: Self::Rhs) -> T { lhs ^ rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } #[macerator::with_simd] fn is_accelerated>( _x: PhantomData<(T, Out, Op)>, ) -> bool { Op::is_accelerated::() } pub fn try_binary_scalar_simd< E: NdArrayElement, EOut: NdArrayElement, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: ScalarSimdBinop, >( input: SharedArray, elem: Op::Rhs, ) -> Result, SharedArray> { if !should_use_simd(input.len()) || input.as_slice_memory_order().is_none() || !is_accelerated::(PhantomData) { return Err(input); } // Used to assert traits based on the dynamic `DType`. let input = unsafe { core::mem::transmute::, SharedArray>(input) }; let out = if size_of::() == size_of::() && align_of::() >= align_of::() && input.is_unique() { unsafe { binary_scalar_simd_inplace::(input, elem) } } else { binary_scalar_simd_owned::(input, elem) }; // Used to assert traits based on the dynamic `DType`. let out = unsafe { core::mem::transmute::, SharedArray>(out) }; Ok(out) } /// Execute operation in place on an owned tensor /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. unsafe fn binary_scalar_simd_inplace< T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: ScalarSimdBinop, >( input: SharedArray, elem: Op::Rhs, ) -> SharedArray { let mut buffer = input.into_owned(); let slice = buffer.as_slice_memory_order_mut().unwrap(); unsafe { binary_scalar_slice_inplace::(slice, elem, PhantomData) }; // Buffer has the same elem size and is filled with the operation output, so this is safe let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; out.into_shared() } /// Create a new copy of the tensor as the output fn binary_scalar_simd_owned< T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: ScalarSimdBinop, >( input: SharedArray, elem: Op::Rhs, ) -> SharedArray { let mut out = uninit_array_like(&input); let input = input.as_slice_memory_order().unwrap(); let out_slice = out.as_slice_memory_order_mut().unwrap(); binary_scalar_slice::(input, out_slice, elem, PhantomData); out.into_shared() } #[inline(always)] #[allow(clippy::erasing_op, clippy::identity_op)] #[macerator::with_simd] fn binary_scalar_slice< 'a, S: Simd, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: ScalarSimdBinop, >( input: &'a [T], out: &'a mut [Out], rhs: Op::Rhs, _op: PhantomData, ) where 'a: 'a, { let lanes = T::lanes::(); let mut chunks_input = input.chunks_exact(8 * lanes); let mut chunks_out = out.chunks_exact_mut(8 * lanes); let rhs_vec = Op::splat::(rhs); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { seq!(N in 0..8 { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; let s~N = Op::apply_vec(s~N, rhs_vec); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; }); } let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let s0 = unsafe { vload_unaligned(input.as_ptr()) }; let s0 = Op::apply_vec(s0, rhs_vec); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == lanes` unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; } for (input, out) in chunks_input .remainder() .iter() .zip(chunks_out.into_remainder()) { *out = Op::apply(*input, rhs) } } /// Execute operation in line. /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. #[inline(always)] #[macerator::with_simd] unsafe fn binary_scalar_slice_inplace< 'a, S: Simd, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: ScalarSimdBinop, >( buf: &'a mut [T], rhs: Op::Rhs, _op: PhantomData<(Out, Op)>, ) where 'a: 'a, { let (head, main, tail) = unsafe { buf.align_to_mut::>() }; for elem in head.iter_mut().chain(tail) { *elem = cast(Op::apply(*elem, rhs)); } let mut chunks = main.chunks_exact_mut(8); let rhs = Op::splat::(rhs); for elem in chunks.by_ref() { seq!(N in 0..8 { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; let s~N = Op::apply_vec(s~N, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) }; }); } for elem in chunks.into_remainder() { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s0 = unsafe { vload(elem as *const _ as *const T) }; let s0 = Op::apply_vec(s0, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { vstore(elem as *mut _ as *mut Out, s0) }; } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/cmp.rs ================================================ use core::{marker::PhantomData, slice}; use burn_backend::Element; use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned}; use ndarray::ArrayD; use seq_macro::seq; use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; use super::should_use_simd; pub trait SimdCmpOp { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask; fn apply(lhs: T, rhs: T) -> bool; fn is_accelerated() -> bool; } pub struct VecEquals; impl SimdCmpOp for VecEquals { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { lhs.eq(rhs) } fn apply(lhs: T, rhs: T) -> bool { lhs == rhs } fn is_accelerated() -> bool { ::is_accelerated::() } } pub struct VecGreater; impl SimdCmpOp for VecGreater { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { lhs.gt(rhs) } fn apply(lhs: T, rhs: T) -> bool { lhs > rhs } fn is_accelerated() -> bool { ::is_cmp_accelerated::() } } pub struct VecGreaterEq; impl SimdCmpOp for VecGreaterEq { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { lhs.ge(rhs) } fn apply(lhs: T, rhs: T) -> bool { lhs >= rhs } fn is_accelerated() -> bool { ::is_cmp_accelerated::() } } pub struct VecLowerEq; impl SimdCmpOp for VecLowerEq { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { lhs.le(rhs) } fn apply(lhs: T, rhs: T) -> bool { lhs <= rhs } fn is_accelerated() -> bool { ::is_cmp_accelerated::() } } pub struct VecLower; impl SimdCmpOp for VecLower { fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { lhs.lt(rhs) } fn apply(lhs: T, rhs: T) -> bool { lhs < rhs } fn is_accelerated() -> bool { ::is_cmp_accelerated::() } } #[macerator::with_simd] fn is_accelerated>(_x: PhantomData<(T, Op)>) -> bool { Op::is_accelerated::() } #[allow(clippy::result_large_err)] pub fn try_cmp_simd>( lhs: SharedArray, rhs: SharedArray, ) -> Result, (SharedArray, SharedArray)> { let lhs_len = lhs.len(); let rhs_len = rhs.len(); if !should_use_simd(lhs_len.max(rhs_len)) || !lhs.is_standard_layout() || !rhs.is_standard_layout() || lhs.shape() != rhs.shape() || !is_accelerated::(PhantomData) { return Err((lhs, rhs)); } // Used to assert traits based on the dynamic `DType`. let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; let out = cmp_simd_same::(lhs, rhs); Ok(out) } fn cmp_simd_same>( lhs: SharedArray, rhs: SharedArray, ) -> SharedArray { let out = if lhs.is_unique() && size_of::() == size_of::() { let mut buf = lhs.into_owned(); let lhs = buf.as_slice_mut().unwrap(); let rhs = rhs.as_slice().unwrap(); let out = unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) }; cmp(lhs, rhs, out, PhantomData::); unsafe { core::mem::transmute::, ArrayD>(buf) } } else if rhs.is_unique() && size_of::() == size_of::() { let mut buf = rhs.into_owned(); let lhs = lhs.as_slice().unwrap(); let rhs = buf.as_slice_mut().unwrap(); let out = unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) }; cmp(lhs, rhs, out, PhantomData::); unsafe { core::mem::transmute::, ArrayD>(buf) } } else { let mut out = uninit_array_like(&lhs); let lhs = lhs.as_slice().unwrap(); let rhs = rhs.as_slice().unwrap(); let out_slice = out.as_slice_mut().unwrap(); cmp(lhs, rhs, out_slice, PhantomData::); out }; out.into_shared() } #[allow(clippy::erasing_op, clippy::identity_op)] #[macerator::with_simd] fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( lhs: &'a [T], rhs: &'a [T], out: &'a mut [bool], _op: PhantomData, ) where 'a: 'a, { let lanes = T::lanes::(); let mut chunks_lhs = lhs.chunks_exact(8 * lanes); let mut chunks_rhs = rhs.chunks_exact(8 * lanes); let mut chunks_out = out.chunks_exact_mut(8 * lanes); while let Some(((lhs, rhs), out)) = chunks_lhs .next() .zip(chunks_rhs.next()) .zip(chunks_out.next()) { seq!(N in 0..8 { // Load one full vector from `lhs`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; // Load one full vector from `rhs`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; let s~N = Op::apply_vec(lhs~N, rhs~N); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; }); } let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); while let Some(((lhs, rhs), out)) = chunks_lhs .next() .zip(chunks_rhs.next()) .zip(chunks_out.next()) { // Load one full vector from `lhs`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; // Load one full vector from `rhs`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; let s0 = Op::apply_vec(lhs0, rhs0); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == lanes` unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; } for ((lhs, rhs), out) in chunks_lhs .remainder() .iter() .zip(chunks_rhs.remainder()) .zip(chunks_out.into_remainder()) { *out = Op::apply(*lhs, *rhs) } } /// Unsafely alias a slice to use as an inline argument fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { let ptr = slice.as_mut_ptr(); let len = slice.len(); unsafe { slice::from_raw_parts_mut(ptr, len) } } pub use elemwise::try_cmp_scalar_simd; mod elemwise { use bytemuck::cast; use macerator::vload; use super::*; pub fn try_cmp_scalar_simd>( input: SharedArray, elem: T, ) -> Result, SharedArray> { if !should_use_simd(input.len()) || input.as_slice_memory_order().is_none() || !is_accelerated::(PhantomData) { return Err(input); } // Used to assert traits based on the dynamic `DType`. let input = unsafe { core::mem::transmute::, SharedArray>(input) }; let out = if size_of::() == size_of::() && align_of::() >= align_of::() && input.is_unique() { unsafe { cmp_scalar_simd_inplace::(input, elem) } } else { cmp_scalar_simd_owned::(input, elem) }; Ok(out) } /// Execute operation in place on an owned tensor /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. unsafe fn cmp_scalar_simd_inplace>( input: SharedArray, elem: T, ) -> SharedArray { let mut buffer = input.into_owned(); let slice = buffer.as_slice_memory_order_mut().unwrap(); unsafe { cmp_scalar_slice_inplace::(slice, elem, PhantomData) }; // Buffer has the same elem size and is filled with the operation output, so this is safe let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; out.into_shared() } /// Create a new copy of the tensor as the output fn cmp_scalar_simd_owned>( input: SharedArray, elem: T, ) -> SharedArray { let mut out = uninit_array_like(&input); let input = input.as_slice_memory_order().unwrap(); let out_slice = out.as_slice_memory_order_mut().unwrap(); cmp_scalar_slice::(input, out_slice, elem, PhantomData); out.into_shared() } #[inline(always)] #[allow(clippy::erasing_op, clippy::identity_op)] #[macerator::with_simd] fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( input: &'a [T], out: &'a mut [bool], rhs: T, _op: PhantomData, ) where 'a: 'a, { let lanes = T::lanes::(); let mut chunks_input = input.chunks_exact(8 * lanes); let mut chunks_out = out.chunks_exact_mut(8 * lanes); let rhs_vec = rhs.splat::(); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { seq!(N in 0..8 { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; let s~N = Op::apply_vec(s~N, rhs_vec); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; }); } let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let s0 = unsafe { vload_unaligned(input.as_ptr()) }; let s0 = Op::apply_vec(s0, rhs_vec); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == lanes` unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; } for (input, out) in chunks_input .remainder() .iter() .zip(chunks_out.into_remainder()) { *out = Op::apply(*input, rhs) } } /// Execute operation in line. /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. #[inline(always)] #[macerator::with_simd] unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( buf: &'a mut [T], rhs: T, _op: PhantomData, ) where 'a: 'a, { let (head, main, tail) = unsafe { buf.align_to_mut::>() }; for elem in head.iter_mut().chain(tail) { *elem = cast(Op::apply(*elem, rhs)); } let mut chunks = main.chunks_exact_mut(8); let rhs = rhs.splat::(); for elem in chunks.by_ref() { seq!(N in 0..8 { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; let s~N = Op::apply_vec(s~N, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) }; }); } for elem in chunks.into_remainder() { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s0 = unsafe { vload(elem as *const _ as *const T) }; let s0 = Op::apply_vec(s0, rhs); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) }; } } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/conv.rs ================================================ use core::{marker::PhantomData, mem::transmute}; use burn_backend::{ DType, Element, ops::{ConvOptions, conv::calculate_conv_output_size}, }; use bytemuck::Zeroable; use macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned}; use ndarray::{ ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s, }; use seq_macro::seq; use crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; type Args = (SharedArray, SharedArray, Option>); #[allow(clippy::result_large_err)] pub fn try_conv2d_simd( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvOptions<2>, ) -> Result, Args> { match E::dtype() { DType::F64 => conv2d::(x, weight, bias, options, PhantomData), DType::F32 => conv2d::(x, weight, bias, options, PhantomData), DType::I64 => conv2d::(x, weight, bias, options, PhantomData), DType::I32 => conv2d::(x, weight, bias, options, PhantomData), DType::I16 => conv2d::(x, weight, bias, options, PhantomData), DType::U64 => conv2d::(x, weight, bias, options, PhantomData), DType::U32 => conv2d::(x, weight, bias, options, PhantomData), DType::U16 => conv2d::(x, weight, bias, options, PhantomData), _ => Err((x, weight, bias)), } } fn cast(tensor: SharedArray) -> SharedArray { unsafe { transmute::, SharedArray>(tensor) } } /// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on /// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018). /// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures. /// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. . #[allow(clippy::result_large_err)] fn conv2d( x: SharedArray, weight: SharedArray, bias: Option>, options: ConvOptions<2>, _ty: PhantomData, ) -> Result, Args> { let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap(); let channels_per_group = out_channels / options.groups; #[macerator::with_simd] fn precheck(_ty: PhantomData) -> (usize, bool) { (E::lanes::(), E::is_accelerated::()) } let (lanes, accelerated) = precheck::(PhantomData); if !accelerated || !channels_per_group.is_multiple_of(lanes) { return Err((x, weight, bias)); } let x = cast::<_, E>(x); let weight = cast::<_, E>(weight); let bias = bias.map(|bias| cast::<_, E>(bias)); let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); let [dilate_h, dilate_w] = options.dilation; let [stride_h, stride_w] = options.stride; let [pad_h, pad_w] = options.padding; let padded = options.padding != [0, 0]; let strided = options.stride != [1, 1] || options.dilation != [1, 1]; let grouped = options.groups != 1; let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height); let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width); let x = x.into_dimensionality::().unwrap(); let weights = weight.into_dimensionality::().unwrap(); let weights = weights.permuted_axes([1, 2, 3, 0]); let weights = weights.as_standard_layout(); let bias = bias.map(|bias| bias.into_dimensionality::().unwrap()); // floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`. let oc_blocks = out_channels / lanes; let mut out = unsafe { Array4::::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init() }; let unsafe_shared_out = UnsafeSharedRef::new(&mut out); run_par!(|| { // SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference // is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function. iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe { let b = k / oc_blocks; let ob = k % oc_blocks; let x = x.slice(s![b, .., .., ..]); let out = unsafe_shared_out.get(); let mut out = out.slice_mut(s![b, .., .., ..]); let w = weights.view(); match (padded, strided, grouped) { (true, true, true) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (true, false, true) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (false, true, true) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (false, false, true) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (true, true, false) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (true, false, false) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (false, true, false) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } (false, false, false) => { conv2d_launch::(x, w, &bias, &mut out, &options, ob) } } }); }); let output = out.permuted_axes([0, 3, 1, 2]); Ok(cast(output.into_dyn().into_shared())) } /// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support /// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might /// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16). /// This should always be conservative, since oversizing it will cause register spills and that's /// **much** worse than the performance lost with lower values. const REGISTER_BLOCK: usize = 8; inner_with_register_blocking_size!(8); /// Run a loop of conv2d. /// # SAFETY /// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`. /// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and /// `out` must have unit stride for the out channels. #[inline(always)] #[macerator::with_simd] unsafe fn conv2d_launch< 'a, S: Simd, E: VMulAdd, const PAD: bool, const STRIDE: bool, const GROUPS: bool, >( x: ArrayView3<'a, E>, weights: ArrayView4<'a, E>, bias: &'a Option>, out: &'a mut ArrayViewMut3<'a, E>, options: &'a ConvOptions<2>, ob: usize, ) where 'a: 'a, { let (in_channels, k_height, k_width, out_channels) = weights.dim(); let (out_height, out_width, _) = out.dim(); let channels_per_group = out_channels / options.groups; let lanes = E::lanes::(); let [mut pad_h, mut pad_w] = options.padding; let [stride_h, stride_w] = options.stride; let [dilate_h, dilate_w] = options.dilation; // Trick compiler into inlining 0 to padding if !PAD { pad_h = 0; pad_w = 0; } let oc_b = channels_per_group.min(lanes); let ow_b = REGISTER_BLOCK; let ow_start = pad_w; let ow_width = out_width.saturating_sub(2 * pad_w); let oh_start = pad_h; let oh_end = out_height.saturating_sub(pad_h); let ow_blocks = ow_width / ow_b; let oc = ob * oc_b; let group = oc / channels_per_group; let mut ic_off = group * in_channels; if !GROUPS { ic_off = 0; } unsafe { let bias = if let Some(bias) = &bias { vload_unaligned::(&bias[oc]) } else { Zeroable::zeroed() }; for oh in oh_start..oh_end { let mut out = out.slice_mut(s![oh, .., ..]); for ow_block in 0..ow_blocks { let ow = ow_block * ow_b + ow_start; #[allow(clippy::if_same_then_else)] if STRIDE { conv2d_inner_nopad( &x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w, dilate_h, dilate_w, k_height, k_width, pad_h, pad_w, ); } else { conv2d_inner_nopad_nostride( &x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h, pad_w, ); } } } conv2d_remainder( x, weights, out, bias, oc, ic_off, ow_blocks * ow_b, stride_h, stride_w, dilate_h, dilate_w, pad_h, pad_w, k_height, k_width, ); } } /// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is /// much slower, so we want to minimize the amount of pixels that need to be processed by this /// /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). #[allow(clippy::too_many_arguments)] #[inline(always)] unsafe fn conv2d_remainder( x: ArrayView3, weights: ArrayView4, out: &mut ArrayViewMut3, bias: Vector, oc: usize, ic_off: usize, owb_end: usize, stride_h: usize, stride_w: usize, dilate_h: usize, dilate_w: usize, pad_h: usize, pad_w: usize, k_height: usize, k_width: usize, ) { let in_channels = weights.shape()[0]; let (_, in_height, in_width) = x.dim(); let (out_height, out_width, _) = out.dim(); let oh_start = pad_h; let oh_end = out_height.saturating_sub(pad_h); let ow_start = pad_w; let height1 = in_height + pad_h; let width1 = in_width + pad_w; for oh in (0..oh_start).chain(oh_end..out_height) { for ow in 0..out_width { let mut acc = bias; for ic in 0..in_channels { for kh in 0..k_height { let ih = oh * stride_h + kh * dilate_h; if (ih < pad_h) | (ih >= height1) { continue; } let ih = ih - pad_h; for kw in 0..k_width { let iw = ow * stride_w + kw * dilate_w; if (iw < pad_w) | (iw >= width1) { continue; } let iw = iw - pad_w; // Load a full vector from the weights. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and out channels are last. // We need to ensure the weights are reshaped appropriately. let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the // compiler can't prove this. We can't use `as_slice` with fixed bounds // because we want to support arbitrary input layouts. So an unchecked load // is used. let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::(); acc = i0.mul_add(f0, acc); } } } // Store a full vector from the output. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with // channels last, so this always holds. unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; } } for ow in (0..ow_start).chain(owb_end..out_width) { for oh in 0..out_height { let mut acc = bias; for ic in 0..in_channels { for kh in 0..k_height { let ih = oh * stride_h + kh * dilate_h; if (ih < pad_h) | (ih >= height1) { continue; } let ih = ih - pad_h; for kw in 0..k_width { let iw = ow * stride_w + kw * dilate_w; if (iw < pad_w) | (iw >= width1) { continue; } let iw = iw - pad_w; // Load a full vector from the weights. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and out channels are last. // We need to ensure the weights are reshaped appropriately. let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the // compiler can't prove this. We can't use `as_slice` with fixed bounds // because we want to support arbitrary input layouts. So an unchecked load // is used. let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::(); acc = i0.mul_add(f0, acc); } } } // Store a full vector from the output. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with // channels last, so this always holds. unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; } } } macro_rules! inner_with_register_blocking_size { ($rb: literal) => { /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is /// guaranteed to always be in bounds (because of the way out size is calculated). /// /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] #[inline(always)] unsafe fn conv2d_inner_nopad( x: &ArrayView3, weights: &ArrayView4, out: &mut ArrayViewMut2, bias: Vector, oh: usize, ow: usize, oc: usize, ic_off: usize, stride_h: usize, stride_w: usize, dilate_h: usize, dilate_w: usize, k_height: usize, k_width: usize, pad_h: usize, pad_w: usize, ) { let in_channels = weights.shape()[0]; seq!(N in 0..$rb { let mut acc~N = bias; }); for ic in 0..in_channels { for kh in 0..k_height { let ih = oh * stride_h + kh * dilate_h - pad_h; for kw in 0..k_width { // Load a full vector from the weights. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and out channels are last. // We need to ensure the weights are reshaped appropriately. let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; let iw = ow * stride_w + kw * dilate_w - pad_w; seq!(N in 0..$rb { // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the // compiler can't prove this. We can't use `as_slice` with fixed bounds // because we want to support arbitrary input layouts. So an unchecked load // is used. let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::(); }); seq!(N in 0..$rb { acc~N = i~N.mul_add(f0, acc~N); }); } } } seq!(N in 0..$rb { // Store a full vector from the output. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with // channels last, so this always holds. unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; }); } /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is /// guaranteed to always be in bounds (because of the way out size is calculated). /// /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] #[inline(always)] unsafe fn conv2d_inner_nopad_nostride( x: &ArrayView3, weights: &ArrayView4, out: &mut ArrayViewMut2, bias: Vector, oh: usize, ow: usize, oc: usize, ic_off: usize, k_height: usize, k_width: usize, pad_h: usize, pad_w: usize, ) { let in_channels = weights.shape()[0]; seq!(N in 0..$rb { let mut acc~N = bias; }); for ic in 0..in_channels { for kh in 0..k_height { let ih = oh + kh - pad_h; for kw in 0..k_width { // Load a full vector from the weights. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and out channels are last. // We need to ensure the weights are reshaped appropriately. let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; let iw = ow + kw - pad_w; seq!(N in 0..$rb { // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the // compiler can't prove this. We can't use `as_slice` with fixed bounds // because we want to support arbitrary input layouts. So an unchecked load // is used. let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::(); }); seq!(N in 0..$rb { acc~N = i~N.mul_add(f0, acc~N); }); } } } seq!(N in 0..$rb { // Store a full vector from the output. This is guaranteed to be in bounds // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with // channels last, so this always holds. unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; }); } }; } pub(crate) use inner_with_register_blocking_size; ================================================ FILE: crates/burn-ndarray/src/ops/simd/maxpool.rs ================================================ use core::{marker::PhantomData, mem::transmute}; use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; use burn_backend::{BoolStore, DType, Element, quantization::QuantValue}; use macerator::{Simd, VOrd}; use ndarray::{Array4, s}; use nhwc::max_pool2d_nhwc; use super::{MinMax, should_use_simd}; #[macerator::with_simd] fn is_accelerated_impl(_x: PhantomData) -> bool { ::is_min_max_accelerated::() } fn is_accelerated() -> bool { is_accelerated_impl::(PhantomData) } macro_rules! launch_kernel { ($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => { match <$ty as Element>::dtype() { DType::F64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::F32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::I8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::U64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::U32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::U16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::U8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::Bool(BoolStore::Native) if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::QFloat(scheme) => match scheme.value { QuantValue::Q8F | QuantValue::Q8S if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), _ => Err($x) }, _ => Err($x), } }; } pub(crate) fn try_max_pool2d_simd( x: SharedArray, ksize: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ) -> Result, SharedArray> { let [_, c, _, _] = x.shape().try_into().unwrap(); if !should_use_simd(c) || x.strides()[1] != 1 { return Err(x); } launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation) } fn cast(tensor: SharedArray) -> SharedArray { unsafe { transmute::, SharedArray>(tensor) } } mod nhwc { use itertools::Itertools; use macerator::{Simd, vload_unaligned, vstore_unaligned}; use ndarray::{ArrayView3, ArrayViewMut3, Ix4}; use seq_macro::seq; use crate::ops::simd::lanes; use super::*; // Until you can use associated constants as array size, we need to hardcode this. // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. const BLOCK_REGISTERS: usize = 8; pub(crate) fn max_pool2d_nhwc( x: SharedArray, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ) -> SharedArray { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); let lanes = lanes::(); let ch_block = lanes * BLOCK_REGISTERS; let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1) / stride_height) + 1; let out_width = ((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; let mut output = unsafe { Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() }; let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let x = x.into_dimensionality::().unwrap(); let x = x.view(); let x = x.permuted_axes([0, 2, 3, 1]); // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. // An exclusive loop will always have `lanes * blocking factor` elements in bounds. let blocks = channels / ch_block; let blocks_end = blocks * ch_block; // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An // exclusive loop will always have `lanes` elements in bounds. let simd_end = channels / lanes * lanes; let simd_unblocked = (simd_end - blocks_end) / lanes; let remainder = channels - simd_end; run_par!(|| { // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { let block = k % blocks; let b = k / blocks; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_blocked(x, out, kernel_size, stride, padding, dilation, block); }); // SAFETY: See `loop_unblocked` iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe { let ch = (k % simd_unblocked) * lanes + blocks_end; let b = k / simd_unblocked; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch); }); // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { let ch = (k % remainder) + simd_end; let b = k / remainder; let output = unsafe_shared_out.get(); let x = x.slice(s![b, .., .., ..]); let out = output.slice_mut(s![b, .., .., ..]); loop_scalar(x, out, kernel_size, stride, padding, dilation, ch); }); }); output = output.permuted_axes([0, 3, 1, 2]); output.into_dyn().into_shared() } /// Execute the blocked (unrolled) portion of the pool. #[allow( clippy::too_many_arguments, clippy::erasing_op, clippy::identity_op, unused_mut )] #[inline(always)] #[macerator::with_simd] fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>( x: ArrayView3<'a, E>, mut out: ArrayViewMut3<'a, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], block: usize, ) where 'a: 'a, { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); let lanes = E::lanes::(); let ch_block = lanes * BLOCK_REGISTERS; let min = E::MIN.splat::(); // If outside padding area, kernels are guaranteed to be in bounds for oh in pad_h..out_height.saturating_sub(pad_h) { for ow in pad_w..out_width.saturating_sub(pad_w) { seq!(N in 0..8 { let mut acc~N = min; }); let ch = block * ch_block; let ch_end = ch + ch_block; let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); for kh in 0..kernel_height { let ih = oh * stride_height + kh * dilation_height - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw * dilation_width - pad_w; let x = x.slice(s![ih, iw, ch..ch_end]); seq!(N in 0..8 { // SAFETY: // Load a full vector from x[N * lanes]. This is bounds checked by the // slice above. acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); }); } } seq!(N in 0..8 { // SAFETY: // Store a full vector to out[N * lanes]. This is bounds checked by the // slice above. unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; }); } } // Border pixels need bounds checks if (pad_h, pad_w) != (0, 0) { let v_borders = (0..pad_h) .chain(out_height.saturating_sub(pad_h)..out_height) .cartesian_product(0..out_width); let h_borders = (0..out_height) .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); for (oh, ow) in v_borders.chain(h_borders) { seq!(N in 0..8 { let mut acc~N = min; }); let ch = block * ch_block; let ch_end = ch + ch_block; let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); for kh in 0..kernel_height { let ih = oh * stride_height + kh * dilation_height; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw * dilation_width; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; let x = x.slice(s![ih, iw, ch..ch_end]); seq!(N in 0..8 { // SAFETY: // Load a full vector from x[N * lanes]. This is bounds checked by the // slice above. acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); }); } } seq!(N in 0..8 { // SAFETY: // Store a full vector to out[N * lanes]. This is bounds checked by the // slice above. unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; }); } } } /// Execute the unblocked (not unrolled) portion of the pool. /// /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. #[allow(clippy::too_many_arguments, unused_mut)] #[inline(always)] #[macerator::with_simd] unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>( x: ArrayView3<'a, E>, mut out: ArrayViewMut3<'a, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ch: usize, ) where 'a: 'a, { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); for oh in pad_h..out_height.saturating_sub(pad_h) { for ow in pad_w..out_width.saturating_sub(pad_w) { let mut acc = E::MIN.splat::(); let out = &mut out[[oh, ow, ch]]; for kh in 0..kernel_height { let ih = oh * stride_height + kh * dilation_height - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw * dilation_width - pad_w; // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); } } // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. unsafe { vstore_unaligned(out, acc) }; } } // Border pixels need bounds checks if (pad_h, pad_w) != (0, 0) { let v_borders = (0..pad_h) .chain(out_height.saturating_sub(pad_h)..out_height) .cartesian_product(0..out_width); let h_borders = (0..out_height) .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); for (oh, ow) in v_borders.chain(h_borders) { let mut acc = E::MIN.splat::(); let out = &mut out[[oh, ow, ch]]; for kh in 0..kernel_height { let ih = oh * stride_height + kh * dilation_height; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw * dilation_width; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); } } // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. unsafe { vstore_unaligned(out, acc) }; } } } fn loop_scalar( x: ArrayView3<'_, E>, mut out: ArrayViewMut3<'_, E>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ch: usize, ) { let [kernel_height, kernel_width] = kernel_size; let [pad_h, pad_w] = padding; let [stride_height, stride_width] = stride; let [dilation_height, dilation_width] = dilation; let (x_height, x_width, _) = x.dim(); let (out_height, out_width, _) = out.dim(); for oh in 0..out_height { for ow in 0..out_width { let mut acc = E::MIN; for kh in 0..kernel_height { let ih = oh * stride_height + kh * dilation_height; if ih < pad_h || ih >= x_height + pad_h { continue; } let ih = ih - pad_h; for kw in 0..kernel_width { let iw = ow * stride_width + kw * dilation_width; if iw < pad_w || iw >= x_width + pad_w { continue; } let iw = iw - pad_w; acc = acc.max(x[[ih, iw, ch]]); } } out[[oh, ow, ch]] = acc; } } } } ================================================ FILE: crates/burn-ndarray/src/ops/simd/mod.rs ================================================ pub(crate) mod avgpool; mod base; pub(crate) mod binary; pub(crate) mod binary_elemwise; pub(crate) mod cmp; pub(crate) mod conv; pub(crate) mod maxpool; pub(crate) mod unary; pub use base::*; ================================================ FILE: crates/burn-ndarray/src/ops/simd/unary.rs ================================================ use core::marker::PhantomData; use bytemuck::cast; use macerator::{ Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned, }; use ndarray::ArrayD; use num_traits::Signed; use seq_macro::seq; use crate::{NdArrayElement, SharedArray}; use super::should_use_simd; pub trait SimdUnop { fn apply_vec(input: Vector) -> Vector; fn apply(input: T) -> Out; fn is_accelerated() -> bool; } pub struct RecipVec; impl SimdUnop for RecipVec { fn apply_vec(input: Vector) -> Vector { input.recip() } fn apply(input: f32) -> f32 { input.recip() } fn is_accelerated() -> bool { ::is_accelerated::() } } pub struct VecAbs; impl SimdUnop for VecAbs { fn apply_vec(input: Vector) -> Vector { input.abs() } fn apply(input: T) -> T { input.abs() } fn is_accelerated() -> bool { ::is_accelerated::() } } pub struct VecBitNot; impl SimdUnop for VecBitNot { fn apply_vec(input: Vector) -> Vector { !input } fn apply(input: T) -> T { input.not() } fn is_accelerated() -> bool { ::is_accelerated::() } } #[macerator::with_simd] fn is_accelerated>( _x: PhantomData<(T, Out, Op)>, ) -> bool { Op::is_accelerated::() } pub fn try_unary_simd< E: NdArrayElement, EOut: NdArrayElement, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdUnop, >( input: SharedArray, ) -> Result, SharedArray> { if !should_use_simd(input.len()) || input.as_slice_memory_order().is_none() || !is_accelerated::(PhantomData) { return Err(input); } // Used to assert traits based on the dynamic `DType`. let input = unsafe { core::mem::transmute::, SharedArray>(input) }; let out = if size_of::() == size_of::() && align_of::() >= align_of::() && input.is_unique() { unsafe { unary_scalar_simd_inplace::(input) } } else { unary_scalar_simd_owned::(input) }; // Used to assert traits based on the dynamic `DType`. let out = unsafe { core::mem::transmute::, SharedArray>(out) }; Ok(out) } /// Execute operation in line. /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. unsafe fn unary_scalar_simd_inplace< T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdUnop, >( input: SharedArray, ) -> SharedArray { let mut buffer = input.into_owned(); let slice = buffer.as_slice_memory_order_mut().unwrap(); // This is only called when in and out have the same size, so it's safe unsafe { unary_slice_inplace::(slice, PhantomData) }; // Buffer has the same elem size and is filled with the operation output, so this is safe let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; out.into_shared() } fn unary_scalar_simd_owned< T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdUnop, >( input: SharedArray, ) -> SharedArray { let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() }; let input = input.as_slice_memory_order().unwrap(); let out_slice = out.as_slice_memory_order_mut().unwrap(); unary_slice::(input, out_slice, PhantomData); out.into_shared() } #[allow(clippy::erasing_op, clippy::identity_op)] #[macerator::with_simd] fn unary_slice< 'a, S: Simd, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdUnop, >( input: &'a [T], out: &'a mut [Out], _op: PhantomData, ) where 'a: 'a, { let lanes = T::lanes::(); let mut chunks_input = input.chunks_exact(8 * lanes); let mut chunks_out = out.chunks_exact_mut(8 * lanes); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { seq!(N in 0..8 { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; let s~N = Op::apply_vec::(s~N); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; }); } let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { // Load one full vector from `input`. // SAFETY: Guaranteed to be in bounds because `len == lanes` let s0 = unsafe { vload_unaligned(input.as_ptr()) }; let s0 = Op::apply_vec::(s0); // Store one full vector to `out`. // SAFETY: Guaranteed to be in bounds because `len == lanes` unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; } for (input, out) in chunks_input .remainder() .iter() .zip(chunks_out.into_remainder()) { *out = Op::apply(*input) } } /// Execute operation in line. /// SAFETY: /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. #[macerator::with_simd] unsafe fn unary_slice_inplace< 'a, S: Simd, T: NdArrayElement + Scalar, Out: NdArrayElement + Scalar, Op: SimdUnop, >( buf: &'a mut [T], _op: PhantomData<(Out, Op)>, ) where 'a: 'a, { let (head, main, tail) = unsafe { buf.align_to_mut::>() }; for elem in head.iter_mut().chain(tail) { *elem = cast(Op::apply(*elem)); } let mut chunks = main.chunks_exact_mut(8); for elem in chunks.by_ref() { seq!(N in 0..8 { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; let s~N = Op::apply_vec::(s~N); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) }; }); } for elem in chunks.into_remainder() { // Load a full vector from the aligned portion of the buffer. // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is // always a full vector in bounds. let s0 = unsafe { vload(elem as *const _ as *const T) }; let s0 = Op::apply_vec::(s0); // Store a full vector at the same position as the input. Cast is safe because `Out` is // size and align compatible unsafe { vstore(elem as *mut _ as *mut Out, s0) }; } } ================================================ FILE: crates/burn-ndarray/src/ops/tensor.rs ================================================ // Language use alloc::vec::Vec; use burn_backend::backend::ExecutionError; use burn_backend::ops::GridSampleOptions; use burn_backend::tensor::FloatTensor; use burn_backend::{TensorMetadata, element::cast::ToElement}; // Current crate use super::{ NdArrayMathOps, NdArrayOps, matmul::{cross, matmul}, }; use crate::{ NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor, }; use crate::{NdArrayDevice, SEED, slice}; use crate::{ SharedArray, element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement}, }; use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d}; // Workspace crates use crate::rand::get_seeded_rng; use burn_backend::{Distribution, FloatDType, Scalar}; use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float; use libm::erf; #[cfg(feature = "std")] #[allow(dead_code)] fn round_ties_even_wrapper(x: f64) -> f64 { x.round_ties_even() } #[cfg(not(feature = "std"))] #[allow(dead_code)] fn round_ties_even_wrapper(x: f64) -> f64 { if (x - x.floor()) == 0.5 { (x * 0.5).round() * 2.0 } else { x.round() } } impl FloatTensorOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor { NdArrayTensor::from_data(data) } fn float_random( shape: Shape, distribution: Distribution, device: &NdArrayDevice, ) -> FloatTensor { let mut seed = SEED.lock().unwrap(); let mut rng = seed.take().unwrap_or_else(get_seeded_rng); let tensor = Self::float_from_data( TensorData::random::(shape, distribution, &mut rng), device, ); *seed = Some(rng); tensor } async fn float_into_data(tensor: FloatTensor) -> Result { Ok(tensor.into_data()) } fn float_device(_tensor: &FloatTensor) -> NdArrayDevice { NdArrayDevice::Cpu } fn float_to_device(tensor: FloatTensor, _device: &NdArrayDevice) -> FloatTensor { tensor } fn float_empty( shape: Shape, device: & as Backend>::Device, dtype: FloatDType, ) -> FloatTensor { Self::float_zeros(shape, device, dtype) } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add) } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::add_scalar(array, rhs.elem()) }) } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub) } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::sub_scalar(array, rhs.elem()) }) } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul) } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::mul_scalar(array, rhs.elem()) }) } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div) } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::div_scalar(array, rhs.elem()) }) } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder) } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::remainder_scalar(array, rhs.elem()) }) } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), matmul) } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim)) } fn float_recip(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::recip(array) }) } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::swap_dims(array, dim1, dim2) }) } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::reshape(array, shape) }) } fn float_gather( dim: usize, tensor: FloatTensor, indices: NdArrayTensor, ) -> FloatTensor { execute_with_int_dtype!( indices, IntElem, |idx_array: SharedArray| -> NdArrayTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::gather(dim, array, idx_array) }) } ) } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: NdArrayTensor, value: FloatTensor, ) -> FloatTensor { execute_with_int_dtype!( indices, IntElem, |idx_array: SharedArray| -> NdArrayTensor { execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter( dim, tensor, idx_array, value )) } ) } fn float_select( tensor: FloatTensor, dim: usize, indices: NdArrayTensor, ) -> FloatTensor { execute_with_int_dtype!( indices, IntElem, |idx_array: SharedArray| -> NdArrayTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::select(array, dim, idx_array) }) } ) } fn float_select_add( tensor: FloatTensor, dim: usize, indices: NdArrayTensor, value: FloatTensor, ) -> FloatTensor { execute_with_int_dtype!( indices, IntElem, |idx_array: SharedArray| -> NdArrayTensor { execute_with_float_dtype!((tensor, value), |tensor, value| { NdArrayMathOps::select_assign(tensor, dim, idx_array, value) }) } ) } fn float_slice(tensor: FloatTensor, slices: &[burn_backend::Slice]) -> FloatTensor { slice!(tensor, slices) } fn float_slice_assign( tensor: FloatTensor, slices: &[burn_backend::Slice], value: FloatTensor, ) -> FloatTensor { execute_with_float_dtype!((tensor, value), |tensor, value| { NdArrayOps::slice_assign(tensor, slices, value) }) } fn float_mask_where( tensor: FloatTensor, mask: NdArrayTensor, value: FloatTensor, ) -> FloatTensor { execute_with_float_dtype!((tensor, value), |tensor, value| { NdArrayOps::mask_where(tensor, mask.bool(), value) }) } fn float_mask_fill( tensor: FloatTensor, mask: NdArrayTensor, value: Scalar, ) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::mask_fill(array, mask.bool(), value.elem()) }) } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) }) } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> NdArrayTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::equal_elem(array, rhs.elem()) }) } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) }) } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> NdArrayTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::greater_elem(array, rhs.elem()) }) } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater_equal(lhs, rhs) }) } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> NdArrayTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::greater_equal_elem(array, rhs.elem()) }) } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) }) } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> NdArrayTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::lower_elem(array, rhs.elem()) }) } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower_equal(lhs, rhs) }) } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> NdArrayTensor { execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { NdArrayMathOps::lower_equal_elem(array, rhs.elem()) }) } fn float_detach(tensor: FloatTensor) -> FloatTensor { tensor } fn float_mean(tensor: FloatTensor) -> FloatTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::mean_view(array.view()) }) } fn float_sum(tensor: FloatTensor) -> FloatTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::sum_view(array.view()) }) } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::mean_dim(array, dim) }) } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::cumsum(array, dim) }) } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::cumprod(array, dim) }) } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::cummin(array, dim) }) } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::cummax(array, dim) }) } fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::sum_dim(array, dim) }) } fn float_argmax(tensor: FloatTensor, dim: usize) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::argmax_view::(array.view(), dim) }) } fn float_argmin(tensor: FloatTensor, dim: usize) -> NdArrayTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::argmin_view::(array.view(), dim) }) } fn float_exp(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared() }) } fn float_log(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array.mapv_into(|a: FloatElem| a.log_elem()).into_shared() }) } fn float_prod(tensor: FloatTensor) -> FloatTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::prod_view(array.view()) }) } fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::prod_dim(array, dim) }) } fn float_max(tensor: FloatTensor) -> FloatTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::max_view(array.view()) }) } fn float_min(tensor: FloatTensor) -> FloatTensor { // Use view() for zero-copy on borrowed storage execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::min_view(array.view()) }) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared() }) } fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| a.powf_elem(value.elem())) .into_shared() }) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared() }) } fn float_abs(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::abs(array) }) } fn float_cos(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem()) .into_shared() }) } fn float_cosh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem()) .into_shared() }) } fn float_sin(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem()) .into_shared() }) } fn float_sinh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem()) .into_shared() }) } fn float_tan(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem()) .into_shared() }) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem()) .into_shared() }) } fn float_acos(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem()) .into_shared() }) } fn float_acosh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem()) .into_shared() }) } fn float_asin(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem()) .into_shared() }) } fn float_asinh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem()) .into_shared() }) } fn float_atan(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem()) .into_shared() }) } fn float_atanh(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem()) .into_shared() }) } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| { NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b)) }) } fn float_round(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem()) .into_shared() }) } fn float_floor(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem()) .into_shared() }) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem()) .into_shared() }) } fn float_trunc(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem()) .into_shared() }) } fn float_erf(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array .mapv_into(|a: FloatElem| erf(a.to_f64()).elem()) .into_shared() }) } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { cat_with_dtype!(tensors, dim, [F64, F32]) } fn float_clamp_min(tensor: FloatTensor, min: Scalar) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::clamp_min(array, min.elem()) }) } fn float_clamp_max(tensor: FloatTensor, max: Scalar) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::clamp_max(array, max.elem()) }) } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::clamp(array, min.elem(), max.elem()) }) } fn float_into_int(tensor: FloatTensor) -> NdArrayTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { array.mapv(|a: FloatElem| a.elem::()).into_shared() }) } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| { NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b)) }) } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::permute(array, axes) }) } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::flip(array, axes) }) } fn float_sign(tensor: FloatTensor) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayMathOps::sign_op(array) }) } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::expand(array, shape) }) } fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { cast_to_dtype(array, dtype.into()) }) } fn float_grid_sample_2d( tensor: FloatTensor, grid: FloatTensor, options: GridSampleOptions, ) -> FloatTensor { execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d( tensor, grid, options )) } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { NdArrayOps::unfold(array, dim, size, step) }) } } ================================================ FILE: crates/burn-ndarray/src/ops/transaction.rs ================================================ use crate::{ FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray, element::{IntNdArrayElement, QuantElement}, }; use burn_backend::ops::TransactionOps; impl TransactionOps for NdArray where NdArrayTensor: From>, NdArrayTensor: From>, { } ================================================ FILE: crates/burn-ndarray/src/parallel.rs ================================================ /// Macro for running a function in parallel. #[cfg(feature = "multi-threads")] #[macro_export(local_inner_macros)] macro_rules! run_par { ( $func:expr ) => {{ use rayon::prelude::*; #[allow(clippy::redundant_closure_call)] rayon::scope(|_| $func()) }}; } /// Macro for running a function in parallel. #[cfg(not(feature = "multi-threads"))] #[macro_export(local_inner_macros)] macro_rules! run_par { ( $func:expr ) => {{ $func() }}; } /// Macro for iterating in parallel. #[cfg(not(feature = "multi-threads"))] #[macro_export(local_inner_macros)] macro_rules! iter_par { ( $iter:expr ) => {{ $iter }}; } /// Macro for iterating in parallel. #[cfg(feature = "multi-threads")] #[macro_export(local_inner_macros)] macro_rules! iter_par { ( $iter:expr ) => {{ $iter.into_par_iter() }}; } /// Macro for iterating in parallel. #[cfg(feature = "multi-threads")] #[macro_export(local_inner_macros)] macro_rules! iter_slice_par { ( $slice:expr ) => {{ $slice.into_par_iter() }}; } /// Macro for iterating in parallel. #[cfg(not(feature = "multi-threads"))] #[macro_export(local_inner_macros)] macro_rules! iter_slice_par { ( $slice:expr ) => {{ $slice.iter() }}; } /// Macro for iterating over a range in parallel. #[cfg(feature = "multi-threads")] #[macro_export(local_inner_macros)] macro_rules! iter_range_par { ( $start:expr, $end:expr ) => {{ ($start..$end).into_par_iter() }}; } /// Macro for iterating over a range in parallel. #[cfg(not(feature = "multi-threads"))] #[macro_export(local_inner_macros)] macro_rules! iter_range_par { ( $start:expr, $end:expr ) => {{ ($start..$end) }}; } ================================================ FILE: crates/burn-ndarray/src/rand.rs ================================================ //! Random number generation utilities for burn-ndarray #[cfg(not(feature = "std"))] use rand::rngs::SmallRng; #[cfg(feature = "std")] use rand::rngs::StdRng; /// Type alias for the RNG used by burn-ndarray #[cfg(feature = "std")] pub type NdArrayRng = StdRng; #[cfg(not(feature = "std"))] pub type NdArrayRng = SmallRng; #[cfg(not(feature = "std"))] use rand::SeedableRng; /// Get a seeded random number generator /// /// For std builds, uses OS entropy. /// For no_std builds, uses a compile-time random seed. #[cfg(feature = "std")] pub fn get_seeded_rng() -> NdArrayRng { // Use the standard implementation from burn-std burn_std::rand::get_seeded_rng() } /// Get a seeded random number generator /// /// For std builds, uses OS entropy. /// For no_std builds, uses a compile-time random seed. #[cfg(not(feature = "std"))] pub fn get_seeded_rng() -> NdArrayRng { // Use compile-time random seed for no_std const SEED: u64 = const_random::const_random!(u64); SmallRng::seed_from_u64(SEED) } ================================================ FILE: crates/burn-ndarray/src/sharing.rs ================================================ use core::cell::UnsafeCell; /// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439). pub(crate) struct UnsafeSharedRef<'a, T> { cell: UnsafeCell<&'a mut T>, } unsafe impl Sync for UnsafeSharedRef<'_, T> {} impl<'a, T> UnsafeSharedRef<'a, T> { pub fn new(data: &'a mut T) -> Self { Self { cell: UnsafeCell::new(data), } } pub unsafe fn get(&self) -> &'a mut T { unsafe { core::ptr::read(self.cell.get()) } } } ================================================ FILE: crates/burn-ndarray/src/storage.rs ================================================ //! Copy-on-write storage for zero-copy tensor loading. //! //! This module provides `NdArrayStorage`, which enables true zero-copy loading //! from burnpack files. When data is borrowed from external memory (like mmap'd files //! or static data), it remains zero-copy until a mutating operation is performed, //! at which point it's copied (copy-on-write semantics). //! //! This integrates with ndarray's existing COW patterns - operations that check //! `is_unique()` will see borrowed data as non-unique, triggering the allocation path. use burn_backend::Element; use burn_std::{Bytes, Shape}; use core::mem; use ndarray::{ArcArray, ArrayView, IxDyn}; /// Storage that supports both owned data and borrowed (zero-copy) data. /// /// # Copy-on-Write Semantics /// /// - **Borrowed**: Data from external source (burnpack, mmap, static). /// Reports `is_unique() == false` to trigger copy on mutation. /// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount. /// /// # Example /// /// ```ignore /// // Zero-copy load /// let storage = NdArrayStorage::from_borrowed(bytes, shape); /// storage.is_unique(); // false - will copy on mutation /// /// // Read operations use view() - zero-copy /// let view = storage.view(); /// /// // Mutation converts to owned /// let owned = storage.into_owned(); // Copies here /// ``` #[derive(Debug)] pub enum NdArrayStorage { /// Borrowed from external source (e.g., burnpack zero-copy load). /// Keeps `Bytes` alive to ensure the referenced memory is valid. Borrowed { /// Source bytes - keeps external memory alive via reference counting bytes: Bytes, /// Shape of the tensor shape: Shape, }, /// Standard owned storage with ArcArray COW semantics. Owned(ArcArray), } impl Clone for NdArrayStorage { fn clone(&self) -> Self { match self { // For borrowed data, clone the Bytes (cheap Arc clone) and shape Self::Borrowed { bytes, shape } => Self::Borrowed { bytes: bytes.clone(), shape: shape.clone(), }, // For owned data, clone the ArcArray (cheap Arc clone) Self::Owned(arr) => Self::Owned(arr.clone()), } } } impl NdArrayStorage { /// Create borrowed storage from external bytes. /// /// Returns the bytes and shape back on failure (misaligned or too small), /// enabling zero-copy even for native allocations by avoiding defensive cloning. /// /// # Requirements /// /// The caller must ensure that: /// - The `Bytes` contain valid data for the element type `E` /// - The data is contiguous in row-major (C) order matching the provided shape /// /// These requirements are upheld when loading from `TensorData` (burnpack, etc.) /// which always stores data contiguously in row-major order. pub fn from_borrowed(bytes: Bytes, shape: impl Into) -> Result { let shape = shape.into(); // Validate alignment let ptr = bytes.as_ptr(); if !(ptr as usize).is_multiple_of(mem::align_of::()) { return Err((bytes, shape)); } // Validate size (using checked arithmetic to prevent overflow) let num_elements = match shape .iter() .try_fold(1usize, |acc, &dim| acc.checked_mul(dim)) { Some(n) => n, None => return Err((bytes, shape)), }; let expected_size = match num_elements.checked_mul(mem::size_of::()) { Some(s) => s, None => return Err((bytes, shape)), }; if bytes.len() < expected_size { return Err((bytes, shape)); } Ok(Self::Borrowed { bytes, shape }) } /// Create owned storage from an ArcArray. #[inline] pub fn from_owned(array: ArcArray) -> Self { Self::Owned(array) } /// Returns whether this storage is uniquely owned and can be mutated in-place. /// /// - **Borrowed**: Always returns `false` to trigger copy-on-write. /// - **Owned**: Delegates to `ArcArray::is_unique()`. /// /// This integrates with existing SIMD code patterns like: /// ```ignore /// if tensor.is_unique() { /// // mutate in place /// } else { /// // allocate new /// } /// ``` #[inline] pub fn is_unique(&self) -> bool { match self { Self::Borrowed { .. } => false, // Force copy path Self::Owned(arr) => arr.is_unique(), } } /// Get a read-only view of the data. /// /// This is zero-copy for both borrowed and owned variants. #[inline] pub fn view(&self) -> ArrayView<'_, E, IxDyn> { match self { Self::Borrowed { bytes, shape } => { let ptr = bytes.as_ptr() as *const E; let dim = IxDyn(shape); // SAFETY: // - `bytes` is kept alive for the lifetime of `self` // - Alignment was validated in `from_borrowed` // - Size was validated in `from_borrowed` unsafe { ArrayView::from_shape_ptr(dim, ptr) } } Self::Owned(arr) => arr.view(), } } /// Convert to owned ArcArray. /// /// - **Borrowed**: Copies the data into a new ArcArray. /// - **Owned + unique**: Returns the array without copying. /// - **Owned + shared**: Clones the data. pub fn into_owned(self) -> ArcArray { match self { Self::Borrowed { bytes, shape } => { let ptr = bytes.as_ptr() as *const E; let dim = IxDyn(&shape); // SAFETY: Same as view() - bytes is valid for this scope let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; view.to_owned().into_shared() } Self::Owned(arr) => arr, } } /// Convert to shared ArcArray, suitable for returning from operations. /// /// This is equivalent to `into_owned()` but named for clarity. #[inline] pub fn into_shared(self) -> ArcArray { self.into_owned() } /// Get the shape of the tensor. pub fn shape(&self) -> &[usize] { match self { Self::Borrowed { shape, .. } => shape, Self::Owned(arr) => arr.shape(), } } /// Get the number of dimensions. #[inline] pub fn ndim(&self) -> usize { self.shape().len() } /// Get the total number of elements. #[inline] pub fn len(&self) -> usize { self.shape().iter().product() } /// Check if the tensor is empty. #[inline] pub fn is_empty(&self) -> bool { self.len() == 0 } /// Returns `true` if this is borrowed (zero-copy) storage. #[inline] pub fn is_borrowed(&self) -> bool { matches!(self, Self::Borrowed { .. }) } /// Returns `true` if this is owned storage. #[inline] pub fn is_owned(&self) -> bool { matches!(self, Self::Owned(_)) } /// Ensure owned and return mutable reference to the ArcArray. /// /// Converts borrowed to owned if necessary. pub fn ensure_owned(&mut self) -> &mut ArcArray { if let Self::Borrowed { bytes, shape } = self { let ptr = bytes.as_ptr() as *const E; let dim = IxDyn(shape); // SAFETY: Same as view() let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; *self = Self::Owned(view.to_owned().into_shared()); } match self { Self::Owned(arr) => arr, Self::Borrowed { .. } => unreachable!(), } } } /// Convert from ArcArray to NdArrayStorage. impl From> for NdArrayStorage { fn from(array: ArcArray) -> Self { Self::Owned(array) } } #[cfg(test)] mod tests { use super::*; use alloc::{vec, vec::Vec}; use burn_std::Bytes; #[test] fn test_borrowed_is_not_unique() { let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); assert!(!storage.is_unique()); assert!(storage.is_borrowed()); } #[test] fn test_owned_unique_when_single_ref() { let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); let storage = NdArrayStorage::from_owned(array); assert!(storage.is_unique()); assert!(storage.is_owned()); } #[test] fn test_owned_not_unique_when_cloned() { let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); let storage = NdArrayStorage::from_owned(array); let _clone = storage.clone(); assert!(!storage.is_unique()); } #[test] fn test_view_zero_copy() { let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); let view = storage.view(); assert_eq!(view[[0, 0]], 1.0); assert_eq!(view[[1, 1]], 4.0); } #[test] fn test_into_owned_copies_borrowed() { let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); let owned = storage.into_owned(); assert_eq!(owned[[0, 0]], 1.0); assert_eq!(owned[[1, 1]], 4.0); } #[test] fn test_from_borrowed_validates_alignment() { use burn_std::AllocationProperty; // Test 1: Properly aligned data should succeed let aligned_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let aligned_bytes = Bytes::from_elems(aligned_data); // Verify test setup - should be 4-byte aligned for f32 assert_eq!( (aligned_bytes.as_ptr() as usize) % core::mem::align_of::(), 0, "Test setup: f32 data should be properly aligned" ); let result = NdArrayStorage::::from_borrowed(aligned_bytes, [2, 2]); assert!( result.is_ok(), "from_borrowed should succeed for properly aligned data" ); // Test 2: Misaligned data should fail // Create a buffer large enough to find a misaligned offset // (static data placement varies by platform, so we find an offset dynamically) let buffer: &[u8] = &[0u8; 32]; let shared = bytes::Bytes::from_static(buffer); let base = shared.as_ptr() as usize; let align = core::mem::align_of::(); // Find an offset in 1..align that produces misalignment (at least one must exist) let misalign_offset = (1..align) .find(|&off| !(base + off).is_multiple_of(align)) .expect("Should find a misaligned offset"); let sliced = shared.slice(misalign_offset..(misalign_offset + 16)); let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other); // Verify test setup - should NOT be 4-byte aligned assert_ne!( (misaligned_bytes.as_ptr() as usize) % align, 0, "Test setup: sliced data should be misaligned for f32" ); let result = NdArrayStorage::::from_borrowed(misaligned_bytes, [4]); assert!( result.is_err(), "from_borrowed should return Err for misaligned data" ); } #[test] fn test_insufficient_size_returns_err() { // Create bytes that are too small for the requested shape let data: Vec = vec![1.0, 2.0]; // 8 bytes let bytes = Bytes::from_elems(data); // Try to create storage for 4 elements (needs 16 bytes) let result = NdArrayStorage::::from_borrowed(bytes, [4]); assert!( result.is_err(), "from_borrowed should return Err when bytes are too small" ); } // ========================================================================== // Zero-copy hardening tests // These tests verify the zero-copy guarantee is maintained. If any of these // fail, it indicates a regression in zero-copy functionality. // ========================================================================== #[test] fn test_zero_copy_native_allocation() { // CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy // on initial load. The view() must return a pointer to the SAME memory. // // Note: Native allocations copy on clone (this is expected), but the initial // load is still zero-copy, avoiding an extra copy in the common case where // the tensor is used without cloning. let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let original_ptr = bytes.as_ptr(); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); // Initial load must be zero-copy let view = storage.view(); let view_ptr = view.as_ptr() as *const u8; assert_eq!( original_ptr, view_ptr, "ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes" ); // Verify data integrity assert_eq!(view[[0, 0]], 1.0); assert_eq!(view[[0, 1]], 2.0); assert_eq!(view[[1, 0]], 3.0); assert_eq!(view[[1, 1]], 4.0); } #[test] fn test_zero_copy_shared_bytes_pointer_identity() { // CRITICAL: Test with SharedBytesAllocationController for true zero-copy. // This simulates the actual burnpack/mmap loading path. use burn_std::AllocationProperty; // Create static-like data using bytes::Bytes let data: &[u8] = &[ 0, 0, 128, 63, // 1.0f32 in little-endian 0, 0, 0, 64, // 2.0f32 0, 0, 64, 64, // 3.0f32 0, 0, 128, 64, // 4.0f32 ]; let shared = bytes::Bytes::from_static(data); let original_ptr = shared.as_ptr(); // Create Bytes with SharedBytesAllocationController let bytes = Bytes::from_shared(shared, AllocationProperty::Other); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); // Verify pointer identity let view_ptr = storage.view().as_ptr() as *const u8; assert_eq!( original_ptr, view_ptr, "ZERO-COPY REGRESSION: SharedBytes view must point to original static data" ); // Clone should also share the same memory let cloned = storage.clone(); let cloned_ptr = cloned.view().as_ptr() as *const u8; assert_eq!( original_ptr, cloned_ptr, "ZERO-COPY REGRESSION: SharedBytes clone must share memory" ); } #[test] fn test_clone_borrowed_stays_borrowed() { // Verify that cloning borrowed storage produces another borrowed storage. // Note: The underlying Bytes may or may not share memory depending on // the allocation controller (native allocations copy, file-backed may share). let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); let cloned = storage.clone(); // Both should still be borrowed (the storage type is preserved) assert!( storage.is_borrowed(), "ZERO-COPY REGRESSION: original should remain borrowed after clone" ); assert!( cloned.is_borrowed(), "ZERO-COPY REGRESSION: clone should be borrowed type" ); // Both should report not unique (important for COW behavior) assert!( !storage.is_unique(), "ZERO-COPY REGRESSION: original should not be unique after clone" ); assert!( !cloned.is_unique(), "ZERO-COPY REGRESSION: clone should not be unique" ); // Data should be identical assert_eq!(storage.view(), cloned.view(), "Clone should have same data"); } #[test] fn test_zero_copy_triggers_copy_on_mutation() { // Verify that into_owned() on borrowed data creates a NEW allocation // (this is the "copy" in copy-on-write) let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let original_ptr = bytes.as_ptr(); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); assert!(storage.is_borrowed(), "should start as borrowed"); let owned = storage.into_owned(); let owned_ptr = owned.as_ptr() as *const u8; assert_ne!( original_ptr, owned_ptr, "into_owned() on borrowed data MUST allocate new memory (copy-on-write)" ); } #[test] fn test_borrowed_reports_not_unique() { // CRITICAL: Borrowed storage must report is_unique() == false // This is what triggers copy-on-write in mutation operations let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); assert!( !storage.is_unique(), "ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \ to trigger copy-on-write. If this is true, mutations will corrupt shared data!" ); } } ================================================ FILE: crates/burn-ndarray/src/tensor.rs ================================================ use burn_backend::{ DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata, quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue}, }; use burn_std::BoolStore; use crate::NdArrayStorage; use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization}; use alloc::vec::Vec; use ndarray::{ArcArray, ArrayD, IxDyn}; /// Concrete storage type for ndarray (owned with COW semantics via Arc) pub type SharedArray = ArcArray; /// Tensor primitive used by the [ndarray backend](crate::NdArray). /// /// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`. /// When data is borrowed from external sources (like burnpack files), /// it remains zero-copy until a mutating operation is performed. #[derive(Debug, Clone)] #[allow(missing_docs)] pub enum NdArrayTensor { F64(NdArrayStorage), F32(NdArrayStorage), I64(NdArrayStorage), I32(NdArrayStorage), I16(NdArrayStorage), I8(NdArrayStorage), U64(NdArrayStorage), U32(NdArrayStorage), U16(NdArrayStorage), U8(NdArrayStorage), Bool(NdArrayStorage), } impl NdArrayTensor { /// Extract bool array, converting to owned if necessary. pub(crate) fn bool(self) -> SharedArray { match self { NdArrayTensor::Bool(storage) => storage.into_shared(), _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()), } } /// Returns true if this tensor uses borrowed (zero-copy) storage. #[inline] pub fn is_borrowed(&self) -> bool { macro_rules! check { ($($variant:ident),*) => { match self { $(NdArrayTensor::$variant(s) => s.is_borrowed(),)* } }; } check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) } } pub(crate) fn cast_to_dtype(array: SharedArray, dtype: DType) -> NdArrayTensor where NdArrayTensor: From>, { fn cast(array: SharedArray) -> SharedArray { array.mapv(|a| a.elem()).into_shared() } if E1::dtype() == dtype { return array.into(); } match dtype { DType::F64 => cast::(array).into(), DType::F32 => cast::(array).into(), DType::Flex32 => cast::(array).into(), DType::I64 => cast::(array).into(), DType::I32 => cast::(array).into(), DType::I16 => cast::(array).into(), DType::I8 => cast::(array).into(), DType::U64 => cast::(array).into(), DType::U32 => cast::(array).into(), DType::U16 => cast::(array).into(), DType::U8 => cast::(array).into(), DType::Bool(BoolStore::Native) => cast::(array).into(), dtype => panic!("Unsupported dtype: {dtype:?}"), } } macro_rules! impl_from { ($($ty: ty => $dtype: ident),*) => { // From SharedArray (owned) -> NdArrayTensor $(impl From> for NdArrayTensor { fn from(value: SharedArray<$ty>) -> NdArrayTensor { NdArrayTensor::$dtype(NdArrayStorage::from_owned(value)) } })* // From NdArrayStorage -> NdArrayTensor $(impl From> for NdArrayTensor { fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor { NdArrayTensor::$dtype(value) } })* }; } impl_from!( f64 => F64, f32 => F32, i64 => I64, i32 => I32, i16 => I16, i8 => I8, u64 => U64, u32 => U32, u16 => U16, u8 => U8, bool => Bool ); /// Macro to execute an operation on a given element type. /// /// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation. /// /// # Panics /// Since there is no automatic type cast at this time, binary operations for different /// floating point precision data types will panic with a data type mismatch. #[macro_export] macro_rules! execute_with_dtype { (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs); let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs); match ($lhs, $rhs) { $( ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => { #[allow(unused)] type $element = $ty; // Convert storage to SharedArray for compatibility with existing operations $op(lhs.into_shared(), rhs.into_shared()).into() } )* _ => panic!( "Data type mismatch (lhs: {:?}, rhs: {:?})", lhs_dtype, rhs_dtype ), } }}; // Binary op: type automatically inferred by the compiler (($lhs:expr, $rhs:expr), $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), E, $op) }}; // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ F64 => f64, F32 => f32, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8, Bool => bool ]) }}; ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ match $tensor { $( $crate::NdArrayTensor::$dtype(storage) => { #[allow(unused)] type $element = $ty; // Convert to SharedArray for compatibility with most operations $op(storage.into_shared()).into() } )* #[allow(unreachable_patterns)] other => unimplemented!("unsupported dtype: {:?}", other.dtype()) } }}; // Unary op: type automatically inferred by the compiler ($tensor:expr, $op:expr) => {{ $crate::execute_with_dtype!($tensor, E, $op) }}; // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ F64 => f64, F32 => f32, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8, Bool => bool ]) }}; } /// Macro to execute an operation a given element type. /// Only handles float types. /// /// # Panics /// Since there is no automatic type cast at this time, binary operations for different /// floating point precision data types will panic with a data type mismatch. #[macro_export] macro_rules! execute_with_float_dtype { // Binary op: type automatically inferred by the compiler (($lhs:expr, $rhs:expr), $op:expr) => {{ $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op) }}; // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ F64 => f64, F32 => f32 ]) }}; // Unary op: type automatically inferred by the compiler ($tensor:expr, $op:expr) => {{ $crate::execute_with_float_dtype!($tensor, E, $op) }}; // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ F64 => f64, F32 => f32 ]) }}; } /// Macro to execute an operation a given element type. /// Only handles int types. /// /// # Panics /// Since there is no automatic type cast at this time, binary operations for different /// floating point precision data types will panic with a data type mismatch. #[macro_export] macro_rules! execute_with_int_dtype { // Binary op: type automatically inferred by the compiler (($lhs:expr, $rhs:expr), $op:expr) => {{ $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op) }}; // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) }}; // Unary op: type automatically inferred by the compiler ($tensor:expr, $op:expr) => {{ $crate::execute_with_int_dtype!($tensor, E, $op) }}; // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) }}; } /// Macro to execute an operation a given element type. /// Only handles numeric types /// /// # Panics /// Since there is no automatic type cast at this time, binary operations for different /// floating point precision data types will panic with a data type mismatch. #[macro_export] macro_rules! execute_with_numeric_dtype { // Binary op: type automatically inferred by the compiler (($lhs:expr, $rhs:expr), $op:expr) => {{ $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op) }}; // Binary op: generic type cannot be inferred for an operation (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ F64 => f64, F32 => f32, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) }}; // Unary op: type automatically inferred by the compiler ($tensor:expr, $op:expr) => {{ $crate::execute_with_numeric_dtype!($tensor, E, $op) }}; // Unary op: generic type cannot be inferred for an operation ($tensor:expr, $element:ident, $op:expr) => {{ $crate::execute_with_dtype!($tensor, $element, $op, [ F64 => f64, F32 => f32, I64 => i64, I32 => i32, I16 => i16, I8 => i8, U64 => u64, U32 => u32, U16 => u16, U8 => u8 ]) }}; } /// Macro to execute a cat operation on a given set of element types. /// /// Uses zero-copy views from storage for concatenation. /// /// # Panics /// Since there is no automatic type cast at this time, binary operations for different /// floating point precision data types will panic with a data type mismatch. #[macro_export] macro_rules! cat_with_dtype { ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => { match &$tensors[0] { $(NdArrayTensor::$dtype(_) => { let tensors = $tensors .iter() .map(|t| { if let NdArrayTensor::$dtype(storage) = t { // Use storage.view() for zero-copy access storage.view() } else { panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype()) } }) .collect::>(); NdArrayOps::concatenate(&tensors, $dim).into() })* _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype()) } }; } impl TensorMetadata for NdArrayTensor { fn dtype(&self) -> DType { match self { NdArrayTensor::F64(_) => DType::F64, NdArrayTensor::F32(_) => DType::F32, NdArrayTensor::I64(_) => DType::I64, NdArrayTensor::I32(_) => DType::I32, NdArrayTensor::I16(_) => DType::I16, NdArrayTensor::I8(_) => DType::I8, NdArrayTensor::U64(_) => DType::U64, NdArrayTensor::U32(_) => DType::U32, NdArrayTensor::U16(_) => DType::U16, NdArrayTensor::U8(_) => DType::U8, NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native), } } fn shape(&self) -> Shape { // Use storage's shape method (works for both borrowed and owned) macro_rules! get_shape { ($($variant:ident),*) => { match self { $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)* } }; } get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) } fn rank(&self) -> usize { self.shape().num_dims() } } pub(crate) trait ShapeOps { fn num_dims(self) -> usize; fn num_elements(self) -> usize; fn dims(self) -> [usize; N]; fn into_shape(self) -> Shape; } impl ShapeOps for &[usize] { fn num_dims(self) -> usize { self.len() } fn num_elements(self) -> usize { self.iter().product() } fn dims(self) -> [usize; N] { self.try_into().unwrap() } fn into_shape(self) -> Shape { Shape::from(self) } } mod utils { use burn_std::tensor::is_contiguous; use super::*; impl NdArrayTensor { pub(crate) fn into_data(self) -> TensorData { let shape = self.shape(); let contiguous = self.is_contiguous(); fn inner( shape: Shape, is_contiguous: bool, array: ArcArray, ) -> TensorData { let vec = if is_contiguous { match array.try_into_owned_nocopy() { Ok(owned) => { let (mut vec, offset) = owned.into_raw_vec_and_offset(); if let Some(offset) = offset { vec.drain(..offset); } if vec.len() > shape.num_elements() { vec.drain(shape.num_elements()..vec.len()); } vec } Err(array) => array.into_iter().collect(), } } else { array.into_iter().collect() }; TensorData::new(vec, shape) } // Convert storage to owned array before extracting data execute_with_dtype!(self, |arr| inner(shape, contiguous, arr)) } pub(crate) fn is_contiguous(&self) -> bool { // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous) // For owned data, we check the strides macro_rules! check_contiguous { ($($variant:ident),*) => { match self { $(NdArrayTensor::$variant(storage) => { match storage { NdArrayStorage::Borrowed { .. } => { // Borrowed storage requires contiguous row-major data // (see NdArrayStorage::from_borrowed documentation) true } NdArrayStorage::Owned(array) => { let shape = array.shape(); let mut strides = Vec::with_capacity(array.strides().len()); for &stride in array.strides() { if stride <= 0 { return false; } strides.push(stride as usize); } is_contiguous(shape, &strides) } } })* } }; } check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) } } } /// Converts a slice of usize to a typed dimension. #[macro_export(local_inner_macros)] macro_rules! to_typed_dims { ( $n:expr, $dims:expr, justdim ) => {{ let mut dims = [0; $n]; for i in 0..$n { dims[i] = $dims[i]; } let dim: Dim<[usize; $n]> = Dim(dims); dim }}; } /// Reshapes an array into a tensor. #[macro_export(local_inner_macros)] macro_rules! reshape { ( ty $ty:ty, n $n:expr, shape $shape:expr, array $array:expr ) => {{ let dim = $crate::to_typed_dims!($n, $shape, justdim); let array = match $array.is_standard_layout() { true => { match $array.to_shape(dim) { Ok(val) => val.into_shared(), Err(err) => { core::panic!("Shape should be compatible shape={dim:?}: {err:?}"); } } }, false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(), }; array.into_dyn() }}; ( ty $ty:ty, shape $shape:expr, array $array:expr, d $D:expr ) => {{ match $D { 1 => reshape!(ty $ty, n 1, shape $shape, array $array), 2 => reshape!(ty $ty, n 2, shape $shape, array $array), 3 => reshape!(ty $ty, n 3, shape $shape, array $array), 4 => reshape!(ty $ty, n 4, shape $shape, array $array), 5 => reshape!(ty $ty, n 5, shape $shape, array $array), 6 => reshape!(ty $ty, n 6, shape $shape, array $array), _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D), } }}; } /// Slice a tensor #[macro_export] macro_rules! slice { ($tensor:expr, $slices:expr) => { slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) }; ($tensor:expr, $slices:expr, $($variant:ident),*) => { match $tensor { $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })* } }; } impl NdArrayTensor { /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData). /// /// This method attempts zero-copy loading when possible. If the data has properly /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise, /// it falls back to copying the data. /// /// Zero-copy loading works when: /// - The data's bytes are properly aligned for the element type /// - The bytes can be borrowed (e.g., from mmap'd file or static data) pub fn from_data(data: TensorData) -> NdArrayTensor { // Only use Borrowed storage for non-native allocations (e.g., burnpack mmap/file). // For native Rust heap allocations (the common case), go directly to owned storage: // `from_data_owned` reclaims the Vec zero-copy via `into_vec`, while // Borrowed storage would trigger a full memcopy on every single operation // (because `is_unique()` always returns false for Borrowed). use burn_backend::AllocationProperty; if data.bytes.property() != AllocationProperty::Native { match Self::try_from_data_borrowed(data) { Ok(tensor) => return tensor, Err(data) => return Self::from_data_owned(data), } } Self::from_data_owned(data) } /// Try to create a tensor with borrowed storage (zero-copy). /// /// Takes ownership of TensorData and returns it back on failure. /// No cloning occurs - bytes are moved into storage or returned on failure. /// /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data). fn try_from_data_borrowed(data: TensorData) -> Result { let TensorData { bytes, shape, dtype, } = data; macro_rules! try_borrow { ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => { match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) { Ok(storage) => return Ok(NdArrayTensor::$variant(storage)), Err((bytes, shape)) => (bytes, shape), } }; } // Try to create borrowed storage; get bytes back on failure let (bytes, shape) = match dtype { DType::F64 => try_borrow!(f64, F64, bytes, shape), DType::F32 => try_borrow!(f32, F32, bytes, shape), DType::I64 => try_borrow!(i64, I64, bytes, shape), DType::I32 => try_borrow!(i32, I32, bytes, shape), DType::I16 => try_borrow!(i16, I16, bytes, shape), DType::I8 => try_borrow!(i8, I8, bytes, shape), DType::U64 => try_borrow!(u64, U64, bytes, shape), DType::U32 => try_borrow!(u32, U32, bytes, shape), DType::U16 => try_borrow!(u16, U16, bytes, shape), DType::U8 => try_borrow!(u8, U8, bytes, shape), DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape), _ => (bytes, shape), // QFloat not supported for zero-copy }; Err(TensorData { bytes, shape, dtype, }) } /// Create a tensor with owned storage. /// /// This may or may not copy data depending on whether the underlying bytes /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned, /// no copy occurs; otherwise data is copied to a new allocation. fn from_data_owned(data: TensorData) -> NdArrayTensor { let shape = data.shape.to_vec(); // TODO: into_vec macro_rules! execute { ($data: expr, [$($dtype: pat => $ty: ty),*]) => { match $data.dtype { $( $dtype => { match data.into_vec::<$ty>() { Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(), Err(err) => panic!("Data should have the same element type as the tensor {err:?}"), }.into() }, )* other => unimplemented!("Unsupported dtype {other:?}"), } }; } execute!(data, [ DType::F64 => f64, DType::F32 => f32, DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8, DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8, DType::Bool(BoolStore::Native) => bool ]) } } /// A quantized tensor for the ndarray backend. #[derive(Clone, Debug)] pub struct NdArrayQTensor { /// The quantized tensor. pub qtensor: NdArrayTensor, /// The quantization scheme. pub scheme: QuantScheme, /// The quantization parameters. pub qparams: Vec>, } impl NdArrayQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { QuantScheme { level: QuantLevel::Tensor, mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 | QuantValue::Q2F | QuantValue::Q2S, .. } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( self.qparams[0].scales, self.scheme.value, )), QuantScheme { level: QuantLevel::Block(block_size), mode: QuantMode::Symmetric, value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 | QuantValue::Q2F | QuantValue::Q2S, .. } => QuantizationStrategy::PerBlockSymmetric( self.qparams .iter() .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value)) .collect(), block_size, ), } } } impl QTensorPrimitive for NdArrayQTensor { fn scheme(&self) -> &QuantScheme { &self.scheme } fn default_scheme() -> QuantScheme { QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native) } } impl TensorMetadata for NdArrayQTensor { fn dtype(&self) -> DType { DType::QFloat(self.scheme) } fn shape(&self) -> Shape { self.qtensor.shape() } fn rank(&self) -> usize { self.shape().num_dims() } } #[cfg(test)] mod tests { use crate::NdArray; use alloc::vec; use super::*; use burn_backend::{ Distribution, ops::{FloatTensorOps, QTensorOps}, quantization::{QuantStore, QuantizationParametersPrimitive}, }; use burn_std::rand::get_seeded_rng; #[test] fn should_support_into_and_from_data_1d() { let data_expected = TensorData::random::( Shape::new([3]), Distribution::Default, &mut get_seeded_rng(), ); let tensor = NdArrayTensor::from_data(data_expected.clone()); let data_actual = tensor.into_data(); assert_eq!(data_expected, data_actual); } #[test] fn should_support_into_and_from_data_2d() { let data_expected = TensorData::random::( Shape::new([2, 3]), Distribution::Default, &mut get_seeded_rng(), ); let tensor = NdArrayTensor::from_data(data_expected.clone()); let data_actual = tensor.into_data(); assert_eq!(data_expected, data_actual); } #[test] fn should_support_into_and_from_data_3d() { let data_expected = TensorData::random::( Shape::new([2, 3, 4]), Distribution::Default, &mut get_seeded_rng(), ); let tensor = NdArrayTensor::from_data(data_expected.clone()); let data_actual = tensor.into_data(); assert_eq!(data_expected, data_actual); } #[test] fn should_support_into_and_from_data_4d() { let data_expected = TensorData::random::( Shape::new([2, 3, 4, 2]), Distribution::Default, &mut get_seeded_rng(), ); let tensor = NdArrayTensor::from_data(data_expected.clone()); let data_actual = tensor.into_data(); assert_eq!(data_expected, data_actual); } #[test] fn should_support_qtensor_strategy() { type B = NdArray; let scale: f32 = 0.009_019_608; let device = Default::default(); let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device); let scheme = QuantScheme::default() .with_value(QuantValue::Q8S) .with_store(QuantStore::Native); let qparams = QuantizationParametersPrimitive { scales: B::float_from_data(TensorData::from([scale]), &device), }; let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams); assert_eq!(qtensor.scheme(), &scheme); assert_eq!( qtensor.strategy(), QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( scale, QuantValue::Q8S )) ); } // ========================================================================== // Zero-copy integration tests // These tests verify end-to-end zero-copy behavior through NdArrayTensor. // ========================================================================== #[test] fn zero_copy_creates_borrowed_storage_for_non_native() { // Verify that from_data creates borrowed storage for non-native allocations // (e.g. burnpack mmap/file data tagged with AllocationProperty::Other or File). // Native heap allocations intentionally use Owned storage for performance. use burn_backend::AllocationProperty; use burn_std::Bytes; let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); // Tag as Other to simulate burnpack / mmap data (non-native backing storage) let non_native_bytes = Bytes::from_shared( bytes::Bytes::copy_from_slice(&*bytes), AllocationProperty::Other, ); let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32); let tensor = NdArrayTensor::from_data(tensor_data); match &tensor { NdArrayTensor::F32(storage) => { assert!( storage.is_borrowed(), "ZERO-COPY REGRESSION: from_data should create borrowed storage \ for non-native (e.g. burnpack) TensorData" ); assert!( !storage.is_unique(), "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false" ); } _ => panic!("Expected F32 tensor"), } } #[test] fn native_alloc_creates_owned_storage() { // Native heap allocations must use Owned storage so that is_unique() // returns true and ndarray can perform in-place mutations without copying. use burn_std::Bytes; let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); // AllocationProperty::Native let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); let tensor = NdArrayTensor::from_data(tensor_data); match &tensor { NdArrayTensor::F32(storage) => { assert!( !storage.is_borrowed(), "PERF REGRESSION: from_data must NOT create borrowed storage \ for native heap allocations (is_unique() would always be false)" ); } _ => panic!("Expected F32 tensor"), } } #[test] fn zero_copy_data_integrity() { // Verify data is correctly accessible through borrowed storage use burn_std::Bytes; let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; let bytes = Bytes::from_elems(data); let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); let tensor = NdArrayTensor::from_data(tensor_data); match &tensor { NdArrayTensor::F32(storage) => { let view = storage.view(); assert_eq!(view[[0, 0]], 1.0); assert_eq!(view[[0, 1]], 2.0); assert_eq!(view[[1, 0]], 3.0); assert_eq!(view[[1, 1]], 4.0); } _ => panic!("Expected F32 tensor"), } } #[test] fn zero_copy_fallback_when_bytes_owned() { // When TensorData owns bytes exclusively, it may use the copy path // This is expected behavior - verify it still works correctly let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]); let tensor = NdArrayTensor::from_data(data.clone()); let result = tensor.into_data(); assert_eq!(data, result, "Data should round-trip correctly"); } } ================================================ FILE: crates/burn-nn/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Neural network building blocks for the Burn deep learning framework" documentation = "https://docs.rs/burn-nn" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-nn" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-nn" version.workspace = true [lints] workspace = true [features] default = [ "std", "burn-core/default", ] doc = [ "std", # Doc features "burn-core/doc", "pretrained", ] pretrained = ["std", "burn-store/pytorch", "burn-std/network", "dirs"] # Added for some test cases that should only be run locally # (e.g., test cases with pretrained weights for gram matrix loss) test-local = [] std = [ "burn-core/std", "num-traits/std", "burn-store?/std", "burn-std?/std", ] tracing = [ "burn-core/tracing", "burn-cuda?/tracing", "burn-rocm?/tracing", "burn-tch?/tracing", "burn-wgpu?/tracing", "burn-fusion?/tracing", ] test-cuda = [ "burn-cuda/default", ] # To use cuda during testing, default uses ndarray. test-rocm = [ "burn-rocm/default", ] # To use hip during testing, default uses ndarray. test-tch = [ "burn-tch/default", ] # To use tch during testing, default uses ndarray. test-wgpu = [ "burn-wgpu/default", ] # To use wgpu during testing, default uses ndarray. test-vulkan = [ "test-wgpu", "burn-wgpu/vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. test-metal = [ "test-wgpu", "burn-wgpu/metal", ] # To use wgpu-spirv during testing, default uses ndarray. # Memory checks are disabled by default test-memory-checks = ["burn-fusion/memory-checks"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false } num-traits = { workspace = true } # FOR TESTING burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } # For loss functions requiring pretrained models (e.g., Gram Matrix Loss) burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, default-features = false } dirs = { workspace = true, optional = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } rstest = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-nn/README.md ================================================ # Burn Neural Networks Core building blocks for Burn neural networks. ================================================ FILE: crates/burn-nn/src/activation/activation_wrapper.rs ================================================ use burn_core as burn; use crate::activation::{ Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid, HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu, Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig, Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig, }; use burn::config::Config; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// [`Activation`] Configuration. #[derive(Config, Debug)] #[non_exhaustive] pub enum ActivationConfig { /// [`Gelu`] activation layer. Gelu, /// [`Gelu`] activation layer with tanh approximation. GeluApproximate, /// [`PRelu`] activation layer. PRelu(PReluConfig), /// [`Relu`] activation layer. Relu, /// [`LeakyRelu`] activation layer. LeakyRelu(LeakyReluConfig), /// [`SwiGlu`] activation layer. SwiGlu(SwiGluConfig), /// [`Selu`] activation layer. Selu, /// [`Sigmoid`] activation layer. Sigmoid, /// [`Tanh`] activation layer. Tanh, /// [`HardSigmoid`] activation layer. HardSigmoid(HardSigmoidConfig), /// [`HardSwish`] activation layer. HardSwish, /// [`Softplus`] activation layer. Softplus(SoftplusConfig), /// [`Softsign`] activation layer. Softsign, /// [`Elu`] activation layer. Elu(EluConfig), /// [`Celu`] activation layer. Celu(CeluConfig), /// [`ThresholdedRelu`] activation layer. ThresholdedRelu(ThresholdedReluConfig), /// [`HardShrink`] activation layer. HardShrink(HardShrinkConfig), /// [`SoftShrink`] activation layer. SoftShrink(SoftShrinkConfig), /// [`Shrink`] activation layer. Shrink(ShrinkConfig), } impl From for ActivationConfig { fn from(config: PReluConfig) -> Self { Self::PRelu(config) } } impl From for ActivationConfig { fn from(config: LeakyReluConfig) -> Self { Self::LeakyRelu(config) } } impl From for ActivationConfig { fn from(config: SwiGluConfig) -> Self { Self::SwiGlu(config) } } impl From for ActivationConfig { fn from(config: HardSigmoidConfig) -> Self { Self::HardSigmoid(config) } } impl From for ActivationConfig { fn from(config: SoftplusConfig) -> Self { Self::Softplus(config) } } impl From for ActivationConfig { fn from(config: EluConfig) -> Self { Self::Elu(config) } } impl From for ActivationConfig { fn from(config: CeluConfig) -> Self { Self::Celu(config) } } impl From for ActivationConfig { fn from(config: ThresholdedReluConfig) -> Self { Self::ThresholdedRelu(config) } } impl From for ActivationConfig { fn from(config: HardShrinkConfig) -> Self { Self::HardShrink(config) } } impl From for ActivationConfig { fn from(config: SoftShrinkConfig) -> Self { Self::SoftShrink(config) } } impl From for ActivationConfig { fn from(config: ShrinkConfig) -> Self { Self::Shrink(config) } } impl ActivationConfig { /// Initialize a wrapped activation layer. pub fn init(&self, device: &B::Device) -> Activation { match self { ActivationConfig::Relu => Relu.into(), ActivationConfig::LeakyRelu(conf) => conf.init().into(), ActivationConfig::Gelu => Gelu::new().into(), ActivationConfig::GeluApproximate => Gelu::new_approximate().into(), ActivationConfig::PRelu(conf) => conf.init(device).into(), ActivationConfig::SwiGlu(conf) => conf.init(device).into(), ActivationConfig::HardSigmoid(conf) => conf.init().into(), ActivationConfig::HardSwish => HardSwish.into(), ActivationConfig::Softplus(conf) => conf.init().into(), ActivationConfig::Selu => Selu.into(), ActivationConfig::Sigmoid => Sigmoid.into(), ActivationConfig::Tanh => Tanh.into(), ActivationConfig::Softsign => Softsign.into(), ActivationConfig::Elu(conf) => conf.init().into(), ActivationConfig::Celu(conf) => conf.init().into(), ActivationConfig::HardShrink(conf) => conf.init().into(), ActivationConfig::SoftShrink(conf) => conf.init().into(), ActivationConfig::Shrink(conf) => conf.init().into(), ActivationConfig::ThresholdedRelu(conf) => conf.init().into(), } } } /// Activation Layer Wrapper. /// /// Provides support for many in-built `burn::nn` activations. #[derive(Module, Debug)] #[non_exhaustive] #[allow(clippy::large_enum_variant)] pub enum Activation { /// [`Gelu`] activation layer. Gelu(Gelu), /// [`PRelu`] activation layer. PRelu(PRelu), /// [`Relu`] activation layer. Relu(Relu), /// [`LeakyRelu`] activation layer. LeakyRelu(LeakyRelu), /// [`SwiGlu`] activation layer. SwiGlu(SwiGlu), /// [`Selu`] activation layer. Selu(Selu), /// [`Sigmoid`] activation layer. Sigmoid(Sigmoid), /// [`Tanh`] activation layer. Tanh(Tanh), /// [`HardSigmoid`] activation layer. HardSigmoid(HardSigmoid), /// [`HardSwish`] activation layer. HardSwish(HardSwish), /// [`Softplus`] activation layer. Softplus(Softplus), /// [`Softsign`] activation layer. Softsign(Softsign), /// [`Elu`] activation layer. Elu(Elu), /// [`Celu`] activation layer. Celu(Celu), /// [`ThresholdedRelu`] activation layer. ThresholdedRelu(ThresholdedRelu), /// [`HardShrink`] activation layer. HardShrink(HardShrink), /// [`SoftShrink`] activation layer. SoftShrink(SoftShrink), /// [`Shrink`] activation layer. Shrink(Shrink), } impl From for Activation { fn from(layer: Gelu) -> Self { Self::Gelu(layer) } } impl From> for Activation { fn from(layer: PRelu) -> Self { Self::PRelu(layer) } } impl From for Activation { fn from(layer: Relu) -> Self { Self::Relu(layer) } } impl From for Activation { fn from(layer: LeakyRelu) -> Self { Self::LeakyRelu(layer) } } impl From> for Activation { fn from(layer: SwiGlu) -> Self { Self::SwiGlu(layer) } } impl From for Activation { fn from(layer: Selu) -> Self { Self::Selu(layer) } } impl From for Activation { fn from(layer: Sigmoid) -> Self { Self::Sigmoid(layer) } } impl From for Activation { fn from(layer: Tanh) -> Self { Self::Tanh(layer) } } impl From for Activation { fn from(layer: HardSigmoid) -> Self { Self::HardSigmoid(layer) } } impl From for Activation { fn from(layer: HardSwish) -> Self { Self::HardSwish(layer) } } impl From for Activation { fn from(layer: Softplus) -> Self { Self::Softplus(layer) } } impl From for Activation { fn from(layer: Softsign) -> Self { Self::Softsign(layer) } } impl From for Activation { fn from(layer: Elu) -> Self { Self::Elu(layer) } } impl From for Activation { fn from(layer: Celu) -> Self { Self::Celu(layer) } } impl From for Activation { fn from(layer: ThresholdedRelu) -> Self { Self::ThresholdedRelu(layer) } } impl From for Activation { fn from(layer: HardShrink) -> Self { Self::HardShrink(layer) } } impl From for Activation { fn from(layer: SoftShrink) -> Self { Self::SoftShrink(layer) } } impl From for Activation { fn from(layer: Shrink) -> Self { Self::Shrink(layer) } } impl Activation { /// Forward pass. pub fn forward(&self, input: Tensor) -> Tensor { match self { Activation::Relu(layer) => layer.forward(input), Activation::LeakyRelu(layer) => layer.forward(input), Activation::Gelu(layer) => layer.forward(input), Activation::PRelu(layer) => layer.forward(input), Activation::SwiGlu(layer) => layer.forward(input), Activation::HardSigmoid(layer) => layer.forward(input), Activation::HardSwish(layer) => layer.forward(input), Activation::Softplus(layer) => layer.forward(input), Activation::Selu(layer) => layer.forward(input), Activation::Sigmoid(layer) => layer.forward(input), Activation::Tanh(layer) => layer.forward(input), Activation::Softsign(layer) => layer.forward(input), Activation::Elu(layer) => layer.forward(input), Activation::Celu(layer) => layer.forward(input), Activation::ThresholdedRelu(layer) => layer.forward(input), Activation::HardShrink(layer) => layer.forward(input), Activation::SoftShrink(layer) => layer.forward(input), Activation::Shrink(layer) => layer.forward(input), } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::module::Module; fn make_input(device: &B::Device) -> Tensor { Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device) } fn expect_tensor(actual: Tensor, expected: Tensor) { actual.to_data().assert_eq(&expected.to_data(), true); } fn check_stateless_config_output( config: ActivationConfig, input: Tensor, expected: Tensor, device: &B::Device, ) { let act = config.init(device); let output = act.forward(input); expect_tensor(output, expected); } #[test] fn test_gelu() { let device = Default::default(); let input = make_input::(&device); let expected = Gelu::new().forward(input.clone()); check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device) } #[test] fn test_gelu_approximate() { let device = Default::default(); let input = make_input::(&device); let expected = Gelu::new_approximate().forward(input.clone()); check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device) } #[test] fn test_prelu() { let device = Default::default(); let input = make_input::(&device); let inner_config = PReluConfig::new(); let expected = inner_config.init(&device).forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_relu() { let device = Default::default(); let input = make_input::(&device); let expected = Relu.forward(input.clone()); check_stateless_config_output(ActivationConfig::Relu, input, expected, &device) } #[test] fn test_leaky_relu() { let device = Default::default(); let input = make_input::(&device); let inner_config = LeakyReluConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_swi_glu() { let device = Default::default(); let input = make_input::(&device); let d_input = input.shape()[1]; let d_output = 2 * d_input; let inner_config = SwiGluConfig::new(d_input, d_output); let mut reference: SwiGlu = inner_config.init(&device); let config: ActivationConfig = inner_config.into(); let layer = config.init(&device); match &layer { Activation::SwiGlu(inner) => { // Clone the initialized weights. let state = inner.clone().into_record(); reference = reference.load_record(state); } _ => unreachable!(), }; expect_tensor( layer.forward(input.clone()), reference.forward(input.clone()), ) } #[test] fn test_selu() { let device = Default::default(); let input = make_input::(&device); let expected = Selu.forward(input.clone()); check_stateless_config_output(ActivationConfig::Selu, input, expected, &device) } #[test] fn test_sigmoid() { let device = Default::default(); let input = make_input::(&device); let expected = Sigmoid.forward(input.clone()); check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device) } #[test] fn test_tanh() { let device = Default::default(); let input = make_input::(&device); let expected = Tanh.forward(input.clone()); check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device) } #[test] fn test_hard_sigmoid() { let device = Default::default(); let input = make_input::(&device); let inner_config = HardSigmoidConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_softsign() { let device = Default::default(); let input = make_input::(&device); let expected = Softsign.forward(input.clone()); check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device) } #[test] fn test_elu() { let device = Default::default(); let input = make_input::(&device); let inner_config = EluConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_softplus() { let device = Default::default(); let input = make_input::(&device); let inner_config = SoftplusConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_celu() { let device = Default::default(); let input = make_input::(&device); let inner_config = CeluConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_thresholded_relu() { let device = Default::default(); let input = make_input::(&device); let inner_config = ThresholdedReluConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_hard_shrink() { let device = Default::default(); let input = make_input::(&device); let inner_config = HardShrinkConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_soft_shrink() { let device = Default::default(); let input = make_input::(&device); let inner_config = SoftShrinkConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } #[test] fn test_shrink() { let device = Default::default(); let input = make_input::(&device); let inner_config = ShrinkConfig::new(); let expected = inner_config.init().forward(input.clone()); check_stateless_config_output(inner_config.into(), input, expected, &device) } } ================================================ FILE: crates/burn-nn/src/activation/celu.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::celu; use burn::tensor::backend::Backend; /// CELU (Continuously Differentiable Exponential Linear Unit) layer. /// /// Applies the CELU function element-wise: /// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))` /// /// Should be created with [CeluConfig](CeluConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Celu { /// The alpha value for the CELU formulation. pub alpha: f64, } /// Configuration to create a [Celu](Celu) layer using the [init function](CeluConfig::init). #[derive(Config, Debug)] pub struct CeluConfig { /// The alpha value for the CELU formulation. Default is 1.0 #[config(default = "1.0")] pub alpha: f64, } impl CeluConfig { /// Initialize a new [Celu](Celu) Layer pub fn init(&self) -> Celu { Celu { alpha: self.alpha } } } impl ModuleDisplay for Celu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("alpha", &self.alpha).optional() } } impl Celu { /// Forward pass for the Celu layer. /// /// See [celu](burn::tensor::activation::celu) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { celu(input, self.alpha) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_celu_forward() { let device = ::Device::default(); let model: Celu = CeluConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device); let out = model.forward(input); // celu(0.5, 1) = 0.5 // celu(-0.5, 1) = 1 * (exp(-0.5) - 1) = -0.393469 // celu(-1.0, 1) = 1 * (exp(-1) - 1) = -0.632121 let expected = TensorData::from([[0.5, -0.393469, -0.632121]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_celu_with_alpha() { let device = ::Device::default(); let model: Celu = CeluConfig::new().with_alpha(2.0).init(); let input = Tensor::::from_data(TensorData::from([[0.0, -2.0]]), &device); let out = model.forward(input); // celu(0, 2) = 0 // celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241 let expected = TensorData::from([[0.0, -1.264241]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = CeluConfig::new().init(); assert_eq!(alloc::format!("{config}"), "Celu {alpha: 1}"); } } ================================================ FILE: crates/burn-nn/src/activation/elu.rs ================================================ use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_core as burn; use burn::tensor::activation::elu; /// ELU (Exponential Linear Unit) layer. /// /// Should be created with [EluConfig](EluConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Elu { /// The alpha value. pub alpha: f64, } /// Configuration to create an [Elu](Elu) layer using the [init function](EluConfig::init). #[derive(Config, Debug)] pub struct EluConfig { /// The alpha value. Default is 1.0 #[config(default = "1.0")] pub alpha: f64, } impl EluConfig { /// Initialize a new [Elu](Elu) Layer pub fn init(&self) -> Elu { Elu { alpha: self.alpha } } } impl ModuleDisplay for Elu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("alpha", &self.alpha).optional() } } impl Elu { /// Forward pass for the ELU layer. /// /// See [elu](burn::tensor::activation::elu) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { elu(input, self.alpha) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_elu_forward() { let device = ::Device::default(); let model: Elu = EluConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.4410, -0.2507]]), &device); let out = model.forward(input); // elu(0.4410, 1.0) = 0.4410 // elu(-0.2507, 1.0) = 1.0 * (exp(-0.2507) - 1) = -0.22186 let expected = TensorData::from([[0.4410, -0.22186]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = EluConfig::new().init(); assert_eq!(alloc::format!("{config}"), "Elu {alpha: 1}"); } } ================================================ FILE: crates/burn-nn/src/activation/gelu.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the Gaussian Error Linear Units function element-wise. /// /// See also [gelu](burn::tensor::activation::gelu) /// /// When `approximate` is true, uses the tanh approximation: /// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))` #[derive(Module, Clone, Debug, Default)] pub struct Gelu { /// Whether to use tanh approximation. pub approximate: bool, } impl Gelu { /// Create the module with exact GELU. pub fn new() -> Self { Self::default() } /// Create the module with tanh approximation. pub fn new_approximate() -> Self { Self { approximate: true } } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { if self.approximate { burn::tensor::activation::gelu_approximate(input) } else { burn::tensor::activation::gelu(input) } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::Tolerance; use burn::tensor::ops::FloatElem; type FT = FloatElem; #[test] fn display() { let layer = Gelu::new(); assert_eq!(alloc::format!("{layer}"), "Gelu {\n approximate: false\n}"); } #[test] fn forward_approximate() { let device = Default::default(); let input = Tensor::::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device); let output = Gelu::new_approximate().forward(input); // PyTorch: torch.nn.functional.gelu(x, approximate="tanh") let expected = Tensor::::from_data( [ [-0.1588079929, 0.0000000000, 0.8411920071], [0.3457140028, -0.1542859972, 1.9545977116], ], &device, ); output .into_data() .assert_approx_eq::(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5)); } } ================================================ FILE: crates/burn-nn/src/activation/glu.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the gated linear unit function. /// /// See also [glu](burn::tensor::activation::glu) #[derive(Module, Clone, Debug, Default)] pub struct GLU { dim: usize, } impl GLU { /// Create the module. /// /// # Arguments /// * `dim` - The dimension on which to split the input. pub fn new(dim: usize) -> Self { Self { dim } } /// Applies the gated linear unit function. /// /// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half. /// /// **Note**: /// * The size of the input tensor along `dim` must be divisible by 2. /// /// ### Arguments /// * `tensor` - The input tensor. /// /// ### Returns /// * A tensor with the same shape as the input, except the size along `dim` is halved. pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::glu(input, self.dim) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = GLU::new(1); assert_eq!(alloc::format!("{layer}"), "GLU {\n dim: 1\n}"); } } ================================================ FILE: crates/burn-nn/src/activation/hard_shrink.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::hard_shrink; use burn::tensor::backend::Backend; /// Hard Shrink layer. /// /// Applies the Hard Shrink function element-wise: /// `hard_shrink(x) = x if |x| > lambda else 0` /// /// Should be created with [HardShrinkConfig](HardShrinkConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct HardShrink { /// The lambda value for the Hard Shrink formulation. pub lambda: f64, } /// Configuration to create a [HardShrink](HardShrink) layer using the [init function](HardShrinkConfig::init). #[derive(Config, Debug)] pub struct HardShrinkConfig { /// The lambda value for the Hard Shrink formulation. Default is 0.5 #[config(default = "0.5")] pub lambda: f64, } impl HardShrinkConfig { /// Initialize a new [HardShrink](HardShrink) Layer pub fn init(&self) -> HardShrink { HardShrink { lambda: self.lambda, } } } impl ModuleDisplay for HardShrink { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("lambda", &self.lambda).optional() } } impl HardShrink { /// Forward pass for the Hard Shrink layer. /// /// See [hard_shrink](burn::tensor::activation::hard_shrink) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { hard_shrink(input, self.lambda) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn test_hard_shrink_forward() { let device = ::Device::default(); let model: HardShrink = HardShrinkConfig::new().init(); let input = Tensor::::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -1.0], [8.0, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn test_hard_shrink_with_lambda() { let device = ::Device::default(); let model: HardShrink = HardShrinkConfig::new().with_lambda(0.2).init(); let input = Tensor::::from_data([[0.1, -0.1, -0.3], [0.5, 0.1, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -0.3], [0.5, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn display() { let config = HardShrinkConfig::new().init(); assert_eq!(alloc::format!("{config}"), "HardShrink {lambda: 0.5}"); } } ================================================ FILE: crates/burn-nn/src/activation/hard_sigmoid.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::hard_sigmoid; use burn::tensor::backend::Backend; /// Hard Sigmoid layer. /// /// Should be created with [HardSigmoidConfig](HardSigmoidConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct HardSigmoid { /// The alpha value. pub alpha: f64, /// The beta value. pub beta: f64, } /// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init). #[derive(Config, Debug)] pub struct HardSigmoidConfig { /// The alpha value. Default is 0.2 #[config(default = "0.2")] pub alpha: f64, /// The beta value. Default is 0.5 #[config(default = "0.5")] pub beta: f64, } impl HardSigmoidConfig { /// Initialize a new [Hard Sigmoid](HardSigmoid) Layer pub fn init(&self) -> HardSigmoid { HardSigmoid { alpha: self.alpha, beta: self.beta, } } } impl ModuleDisplay for HardSigmoid { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("alpha", &self.alpha) .add("beta", &self.beta) .optional() } } impl HardSigmoid { /// Forward pass for the Hard Sigmoid layer. /// /// See [hard_sigmoid](burn::tensor::activation::hard_sigmoid) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { hard_sigmoid(input, self.alpha, self.beta) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_hard_sigmoid_forward() { let device = ::Device::default(); let model: HardSigmoid = HardSigmoidConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.4410, -0.2507]]), &device); let out = model.forward(input); let expected = TensorData::from([[0.5882, 0.44986]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = HardSigmoidConfig::new().init(); assert_eq!( alloc::format!("{config}"), "HardSigmoid {alpha: 0.2, beta: 0.5}" ); } } ================================================ FILE: crates/burn-nn/src/activation/hard_swish.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::activation::hard_swish; use burn::tensor::backend::Backend; /// Hard Swish layer. #[derive(Module, Clone, Debug, Default)] pub struct HardSwish; impl HardSwish { /// Create the module. pub fn new() -> Self { Self } /// Forward pass for the Hard Swish layer. /// /// See [hard_swish](burn::tensor::activation::hard_swish) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { hard_swish(input) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_hard_swish_forward() { let device = ::Device::default(); let model = HardSwish::new(); let input = Tensor::::from_data( TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]), &device, ); let out = model.forward(input); let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let layer = HardSwish::new(); assert_eq!(alloc::format!("{layer}"), "HardSwish"); } } ================================================ FILE: crates/burn-nn/src/activation/leaky_relu.rs ================================================ use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_core as burn; use burn::tensor::activation::leaky_relu; /// Leaky ReLu layer. /// /// Should be created with [LeakyReluConfig](LeakyReluConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct LeakyRelu { /// The negative slope. pub negative_slope: f64, } /// Configuration to create a [Leaky Relu](LeakyRelu) layer using the [init function](LeakyReluConfig::init). #[derive(Config, Debug)] pub struct LeakyReluConfig { /// The negative slope. Default is 0.01 #[config(default = "0.01")] pub negative_slope: f64, } impl LeakyReluConfig { /// Initialize a new [Leaky Relu](LeakyRelu) Layer pub fn init(&self) -> LeakyRelu { LeakyRelu { negative_slope: self.negative_slope, } } } impl ModuleDisplay for LeakyRelu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("negative_slope", &self.negative_slope) .optional() } } impl LeakyRelu { /// Forward pass for the Leaky ReLu layer. /// /// See [leaky_relu](burn::tensor::activation::leaky_relu) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { leaky_relu(input, self.negative_slope) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_leaky_relu_forward() { let device = ::Device::default(); let model: LeakyRelu = LeakyReluConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.4410, -0.2507]]), &device); let out = model.forward(input); let expected = TensorData::from([[0.4410, -0.002507]]); out.to_data().assert_eq(&expected, false); } #[test] fn test_leaky_relu_forward_multi_dim() { let input = [ [ [-1.0222, 1.5810, 0.3457, -1.3530], [0.0231, 0.8681, 0.2473, -0.0377], [0.3520, -1.1199, 1.2219, 0.2804], ], [ [1.0002, 0.7259, 0.8779, 0.2084], [1.5615, -0.1057, -0.4886, -1.5184], [-0.5523, -0.2741, -0.0210, -1.1352], ], ]; let expected = TensorData::from([ [ [-1.0222e-02, 1.5810e+00, 3.457e-01, -1.3530e-02], [2.31e-02, 8.681e-01, 2.473e-01, -3.77e-04], [3.52e-01, -1.1199e-02, 1.2219e+00, 2.804e-01], ], [ [1.0002e+00, 7.259e-01, 8.779e-01, 2.084e-01], [1.5615e+00, -1.057e-03, -4.886e-03, -1.5184e-02], [-5.523e-03, -2.741e-03, -2.1e-04, -1.1352e-02], ], ]); let device = ::Device::default(); let model: LeakyRelu = LeakyReluConfig::new().init(); let input_data = Tensor::::from_data(TensorData::from(input), &device); let actual_output = model.forward(input_data); actual_output .to_data() .assert_approx_eq::(&expected, Tolerance::default()) } #[test] fn display() { let config = LeakyReluConfig::new().init(); assert_eq!( alloc::format!("{config}"), "LeakyRelu {negative_slope: 0.01}" ); } } ================================================ FILE: crates/burn-nn/src/activation/mod.rs ================================================ //! # Activation Layers //! //! Users who desire a selectable activation function should //! consider [`Activation`], which provides an abstraction over: //! * [`Relu`] - the default, //! * ['PRelu'] //! * [`Gelu`] //! * [`LeakyRelu`] //! * [`SwiGlu`] //! * [`Selu`] //! * [`Sigmoid`] //! * [`HardSigmoid`] //! * [`HardSwish`] //! * [`Softplus`] //! * [`Softsign`] //! * [`Tanh`] //! * [`Elu`] //! * [`Celu`] //! * [`ThresholdedRelu`] //! //! The activation layer [`GLU`] has shape-changing behaviors //! not compatible with the common API, and is not included //! in the abstraction wrappers. mod activation_wrapper; // These are pub(crate) for dual-export in `nn` without re-exporting // all of `nn.activation`, or manually listing each symbol. pub(crate) mod celu; pub(crate) mod elu; pub(crate) mod gelu; pub(crate) mod glu; pub(crate) mod hard_shrink; pub(crate) mod hard_sigmoid; pub(crate) mod hard_swish; pub(crate) mod leaky_relu; pub(crate) mod prelu; pub(crate) mod relu; pub(crate) mod selu; pub(crate) mod shrink; pub(crate) mod sigmoid; pub(crate) mod soft_shrink; pub(crate) mod softplus; pub(crate) mod softsign; pub(crate) mod swiglu; pub(crate) mod tanh; pub(crate) mod thresholded_relu; pub use activation_wrapper::*; pub use celu::*; pub use elu::*; pub use gelu::*; pub use glu::*; pub use hard_shrink::*; pub use hard_sigmoid::*; pub use hard_swish::*; pub use leaky_relu::*; pub use prelu::*; pub use relu::*; pub use selu::*; pub use shrink::*; pub use sigmoid::*; pub use soft_shrink::*; pub use softplus::*; pub use softsign::*; pub use swiglu::*; pub use tanh::*; pub use thresholded_relu::*; ================================================ FILE: crates/burn-nn/src/activation/prelu.rs ================================================ use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_core as burn; /// Parametric Relu layer. /// /// Should be created using [PReluConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct PRelu { /// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must /// be the same as number of channels in the input tensor pub alpha: Param>, /// Alpha value for the PRelu layer pub alpha_value: f64, } impl ModuleDisplay for PRelu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [num_parameters] = self.alpha.shape().dims(); content .add("num_parameters", &num_parameters) .add("alpha_value", &self.alpha_value) .optional() } } /// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init). #[derive(Config, Debug)] pub struct PReluConfig { /// The number of parameters. #[config(default = "1")] pub num_parameters: usize, /// The learnable weight alpha. Default is 0.25 #[config(default = "0.25")] pub alpha: f64, } impl PReluConfig { /// Initialize a new [Parametric Relu](PRelu) Layer pub fn init(&self, device: &B::Device) -> PRelu { PRelu { // alpha is a tensor of length num_parameters alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device), alpha_value: self.alpha, } } } impl PRelu { /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` /// /// See also [prelu](burn::tensor::activation::prelu) for more information. pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::prelu(input, self.alpha.val()) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn display() { let layer = PReluConfig::new().init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}" ); } } ================================================ FILE: crates/burn-nn/src/activation/relu.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the rectified linear unit function element-wise /// See also [relu](burn::tensor::activation::relu) /// #[derive(Module, Clone, Debug, Default)] pub struct Relu; impl Relu { /// Create the module. pub fn new() -> Self { Self {} } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::relu(input) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = Relu::new(); assert_eq!(alloc::format!("{layer}"), "Relu"); } } ================================================ FILE: crates/burn-nn/src/activation/selu.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the Scaled Exponential Linear Unit function element-wise. /// See also [selu](burn::tensor::activation::selu) #[derive(Module, Clone, Debug, Default)] pub struct Selu; impl Selu { /// Create the module. pub fn new() -> Self { Self {} } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::selu(input) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = Selu::new(); assert_eq!(alloc::format!("{layer}"), "Selu"); } } ================================================ FILE: crates/burn-nn/src/activation/shrink.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::shrink; use burn::tensor::backend::Backend; /// Shrink layer. /// /// Applies the Shrink function element-wise: /// `shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise` /// /// Should be created with [ShrinkConfig](ShrinkConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Shrink { /// The lambda value for the Shrink formulation. pub lambda: f64, /// The bias value for the Shrink formulation. // Usually bias = lambda, but need this to handle onnx spec https://onnx.ai/onnx/operators/onnx__Shrink.html pub bias: f64, } /// Configuration to create a [Shrink](Shrink) layer using the [init function](ShrinkConfig::init). #[derive(Config, Debug)] pub struct ShrinkConfig { /// The lambda value for the Shrink formulation. Default is 0.5 #[config(default = "0.5")] pub lambda: f64, /// The bias value for the Shrink formulation. Default is 0.5. #[config(default = "0.5")] pub bias: f64, } impl ShrinkConfig { /// Initialize a new [Shrink](Shrink) Layer pub fn init(&self) -> Shrink { Shrink { lambda: self.lambda, bias: self.bias, } } } impl ModuleDisplay for Shrink { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("lambda", &self.lambda) .add("bias", &self.bias) .optional() } } impl Shrink { /// Forward pass for the Shrink layer. /// /// See [shrink](burn::tensor::activation::shrink) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { shrink(input, self.lambda, self.bias) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn test_shrink_forward() { let device = ::Device::default(); let model: Shrink = ShrinkConfig::new().init(); let input = Tensor::::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn test_shrink_with_lambda_and_bias() { let device = ::Device::default(); let model: Shrink = ShrinkConfig::new() .with_lambda(0.25) .with_bias(0.125) .init(); let input = Tensor::::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn display() { let config = ShrinkConfig::new().init(); assert_eq!( alloc::format!("{config}"), "Shrink {lambda: 0.5, bias: 0.5}" ); } } ================================================ FILE: crates/burn-nn/src/activation/sigmoid.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the sigmoid function element-wise /// See also [sigmoid](burn::tensor::activation::sigmoid) #[derive(Module, Clone, Debug, Default)] pub struct Sigmoid; impl Sigmoid { /// Create the module. pub fn new() -> Self { Self {} } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::sigmoid(input) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = Sigmoid::new(); assert_eq!(alloc::format!("{layer}"), "Sigmoid"); } } ================================================ FILE: crates/burn-nn/src/activation/soft_shrink.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::soft_shrink; use burn::tensor::backend::Backend; /// Soft Shrink layer. /// /// Applies the Soft Shrink function element-wise: /// `soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise` /// /// Should be created with [SoftShrinkConfig](SoftShrinkConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct SoftShrink { /// The lambda value for the Soft Shrink formulation. pub lambda: f64, } /// Configuration to create a [SoftShrink](SoftShrink) layer using the [init function](SoftShrinkConfig::init). #[derive(Config, Debug)] pub struct SoftShrinkConfig { /// The lambda value for the Soft Shrink formulation. Default is 0.5 #[config(default = "0.5")] pub lambda: f64, } impl SoftShrinkConfig { /// Initialize a new [SoftShrink](SoftShrink) Layer pub fn init(&self) -> SoftShrink { SoftShrink { lambda: self.lambda, } } } impl ModuleDisplay for SoftShrink { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("lambda", &self.lambda).optional() } } impl SoftShrink { /// Forward pass for the Soft Shrink layer. /// /// See [soft_shrink](burn::tensor::activation::soft_shrink) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { soft_shrink(input, self.lambda) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn test_soft_shrink_forward() { let device = ::Device::default(); let model: SoftShrink = SoftShrinkConfig::new().init(); let input = Tensor::::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn test_soft_shrink_with_lambda() { let device = ::Device::default(); let model: SoftShrink = SoftShrinkConfig::new().with_lambda(0.25).init(); let input = Tensor::::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device); let out = model.forward(input); let expected = TensorData::from([[0.0_f32, 0.0, -0.25], [0.5, 0.0, 0.0]]); assert_eq!(out.into_data(), expected); } #[test] fn display() { let config = SoftShrinkConfig::new().init(); assert_eq!(alloc::format!("{config}"), "SoftShrink {lambda: 0.5}"); } } ================================================ FILE: crates/burn-nn/src/activation/softplus.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::activation::softplus; use burn::tensor::backend::Backend; /// Softplus layer. /// /// Applies the softplus function element-wise: /// `softplus(x) = (1/beta) * log(1 + exp(beta * x))` /// /// Should be created with [SoftplusConfig](SoftplusConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Softplus { /// The beta value. pub beta: f64, } /// Configuration to create a [Softplus](Softplus) layer using the [init function](SoftplusConfig::init). #[derive(Config, Debug)] pub struct SoftplusConfig { /// The beta value. Default is 1.0 #[config(default = "1.0")] pub beta: f64, } impl SoftplusConfig { /// Initialize a new [Softplus](Softplus) Layer pub fn init(&self) -> Softplus { Softplus { beta: self.beta } } } impl ModuleDisplay for Softplus { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("beta", &self.beta).optional() } } impl Softplus { /// Forward pass for the Softplus layer. /// /// See [softplus](burn::tensor::activation::softplus) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { softplus(input, self.beta) } } #[cfg(test)] #[allow(clippy::approx_constant)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_softplus_forward() { let device = ::Device::default(); let model: Softplus = SoftplusConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device); let out = model.forward(input); // softplus(0) = log(2) ≈ 0.6931 // softplus(1) = log(1 + e) ≈ 1.3133 // softplus(-1) = log(1 + e^-1) ≈ 0.3133 let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_softplus_with_beta() { let device = ::Device::default(); let model: Softplus = SoftplusConfig::new().with_beta(2.0).init(); let input = Tensor::::from_data(TensorData::from([[0.0, 1.0]]), &device); let out = model.forward(input); // softplus(0, beta=2) = (1/2) * log(1 + exp(0)) = 0.5 * log(2) ≈ 0.3466 // softplus(1, beta=2) = (1/2) * log(1 + exp(2)) = 0.5 * log(8.389) ≈ 1.0635 let expected = TensorData::from([[0.3466, 1.0635]]); out.to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = SoftplusConfig::new().init(); assert_eq!(alloc::format!("{config}"), "Softplus {beta: 1}"); } } ================================================ FILE: crates/burn-nn/src/activation/softsign.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the softsign function element-wise /// See also [softsign](burn::tensor::activation::softsign) #[derive(Module, Clone, Debug, Default)] pub struct Softsign; impl Softsign { /// Create the module. pub fn new() -> Self { Self {} } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::softsign(input) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = Softsign::new(); assert_eq!(alloc::format!("{layer}"), "Softsign"); } } ================================================ FILE: crates/burn-nn/src/activation/swiglu.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::activation::silu; use burn::tensor::{Tensor, backend::Backend}; use crate::{Linear, LinearConfig, LinearLayout}; /// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init). #[derive(Config, Debug)] pub struct SwiGluConfig { /// The size of the input features. pub d_input: usize, /// The size of the output features. pub d_output: usize, /// If a bias should be applied during the linear transformation. Default behaviour is False /// for SwiGLU activation implementations. #[config(default = false)] pub bias: bool, /// The type of function used to initialize the linear layer parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, /// The layout in which the linear parameters are stored. #[config(default = "LinearLayout::Row")] pub layout: LinearLayout, } /// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor. /// The SwiGLU activation function is defined as: /// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)` /// /// Should be created with [SwiGluConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct SwiGlu { /// The inner linear layer for Swish activation function /// with `d_input` input features and `d_output` output features. pub linear_inner: Linear, /// The outer linear layer for element wise multiplication /// with `d_input` input features and `d_output` output features. pub linear_outer: Linear, } impl ModuleDisplay for SwiGlu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, d_output] = self.linear_inner.weight.shape().dims(); content .add("d_input", &d_input) .add("d_output", &d_output) .add("bias", &self.linear_inner.bias.is_some()) .optional() } } impl SwiGluConfig { /// Initialize a new [SwiGLU](SwiGlu) activation layer. pub fn init(&self, device: &B::Device) -> SwiGlu { SwiGlu { linear_inner: LinearConfig::new(self.d_input, self.d_output) .with_bias(self.bias) .with_initializer(self.initializer.clone()) .with_layout(self.layout) .init(device), linear_outer: LinearConfig::new(self.d_input, self.d_output) .with_bias(self.bias) .with_initializer(self.initializer.clone()) .with_layout(self.layout) .init(device), } } } impl SwiGlu { /// Applies the Swish Gated Linear Unit to the input tensor. /// /// # Shapes /// /// - input: `[batch_size, seq_length, d_input]` /// - output: `[batch_size, seq_length, d_output]` pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input.clone()); let x = silu(x); x.mul(self.linear_outer.forward(input)) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_swiglu_forward_no_bias() { let device = Default::default(); TestBackend::seed(&device, 0); let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 }); let swiglu = config.init(&device); let input = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let output = swiglu.forward(input); let expected_output = Tensor::::from_data( [[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]], &device, ); output .to_data() .assert_approx_eq::(&expected_output.to_data(), Tolerance::default()); } #[test] fn test_swiglu_forward_with_bias() { let device = Default::default(); TestBackend::seed(&device, 0); let config = SwiGluConfig::new(3, 3) .with_bias(true) .with_initializer(Initializer::Constant { value: 0.5 }); let swiglu = config.init(&device); let input = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); let output = swiglu.forward(input); let expected_output = Tensor::::from_data( [[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]], &device, ); output .to_data() .assert_approx_eq::(&expected_output.to_data(), Tolerance::default()); } #[test] fn display() { let config = SwiGluConfig::new(3, 5); let swiglu = config.init::(&Default::default()); assert_eq!( alloc::format!("{swiglu}"), "SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}" ); } } ================================================ FILE: crates/burn-nn/src/activation/tanh.rs ================================================ use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Applies the tanh activation function element-wise /// See also [tanh](burn::tensor::activation::tanh) #[derive(Module, Clone, Debug, Default)] pub struct Tanh; impl Tanh { /// Create the module. pub fn new() -> Self { Self {} } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { burn::tensor::activation::tanh(input) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let layer = Tanh::new(); assert_eq!(alloc::format!("{layer}"), "Tanh"); } } ================================================ FILE: crates/burn-nn/src/activation/thresholded_relu.rs ================================================ use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_core as burn; use burn::tensor::activation::thresholded_relu; /// Thresholded ReLU layer. /// /// Should be created with [ThresholdedReluConfig](ThresholdedReluConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct ThresholdedRelu { /// The alpha threshold. pub alpha: f64, } /// Configuration to create a [ThresholdedRelu](ThresholdedRelu) layer using the [init function](ThresholdedReluConfig::init). #[derive(Config, Debug)] pub struct ThresholdedReluConfig { /// The alpha threshold. Default is 1.0 #[config(default = "1.0")] pub alpha: f64, } impl ThresholdedReluConfig { /// Initialize a new [ThresholdedRelu](ThresholdedRelu) layer. pub fn init(&self) -> ThresholdedRelu { ThresholdedRelu { alpha: self.alpha } } } impl ModuleDisplay for ThresholdedRelu { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("alpha", &self.alpha).optional() } } impl ThresholdedRelu { /// Forward pass for the Thresholded ReLU layer. /// /// See [thresholded_relu](burn::tensor::activation::thresholded_relu) for more information. /// /// # Shapes /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { thresholded_relu(input, self.alpha) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn test_thresholded_relu_forward() { let device = ::Device::default(); let model: ThresholdedRelu = ThresholdedReluConfig::new().init(); let input = Tensor::::from_data(TensorData::from([[0.5, 1.5, -0.2]]), &device); let out = model.forward(input); let expected = TensorData::from([[0.0, 1.5, 0.0]]); out.to_data().assert_eq(&expected, false); } #[test] fn display() { let config = ThresholdedReluConfig::new().init(); assert_eq!(alloc::format!("{config}"), "ThresholdedRelu {alpha: 1}"); } } ================================================ FILE: crates/burn-nn/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![recursion_limit = "256"] //! Burn neural network module. /// Loss module pub mod loss; /// Neural network modules implementations. pub mod modules; pub use modules::*; pub mod activation; pub use activation::{ celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*, relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*, tanh::*, thresholded_relu::*, }; mod padding; pub use padding::*; // For backward compat, `burn::nn::Initializer` pub use burn_core::module::Initializer; extern crate alloc; /// Backend for test cases #[cfg(all( test, not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda"), not(feature = "test-rocm") ))] pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] /// Backend for test cases pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] /// Backend for test cases pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] /// Backend for test cases pub type TestBackend = burn_cuda::Cuda; #[cfg(all(test, feature = "test-rocm"))] /// Backend for test cases pub type TestBackend = burn_rocm::Rocm; /// Backend for autodiff test cases #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[cfg(all(test, feature = "test-memory-checks"))] mod tests { burn_fusion::memory_checks!(); } ================================================ FILE: crates/burn-nn/src/loss/binary_cross_entropy.rs ================================================ use burn_core as burn; use alloc::vec::Vec; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::activation::log_sigmoid; use burn::tensor::{Int, Tensor, backend::Backend}; use burn::{config::Config, module::Module}; /// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init). #[derive(Config, Debug)] pub struct BinaryCrossEntropyLossConfig { /// Create weighted binary cross-entropy with a weight for each class. /// /// The loss of a specific sample will simply be multiplied by its label weight. pub weights: Option>, /// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629). /// /// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`. /// Alpha = 0 would be the same as default. pub smoothing: Option, /// Treat the inputs as logits, applying a sigmoid activation when computing the loss. #[config(default = false)] pub logits: bool, } impl BinaryCrossEntropyLossConfig { /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss). pub fn init(&self, device: &B::Device) -> BinaryCrossEntropyLoss { self.assertions(); BinaryCrossEntropyLoss { weights: self .weights .as_ref() .map(|e| Tensor::::from_floats(e.as_slice(), device)), smoothing: self.smoothing, logits: self.logits, } } fn assertions(&self) { if let Some(alpha) = self.smoothing { assert!( (0.0..=1.).contains(&alpha), "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}" ); }; if let Some(weights) = self.weights.as_ref() { assert!( weights.iter().all(|e| e > &0.), "Weights of cross-entropy have to be positive." ); } } } /// Calculate the binary cross entropy loss from the input logits and the targets. /// /// Should be created using [BinaryCrossEntropyLossConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct BinaryCrossEntropyLoss { /// Weights for cross-entropy. pub weights: Option>, /// Label smoothing alpha. pub smoothing: Option, /// Treat the inputs as logits pub logits: bool, } impl ModuleDisplay for BinaryCrossEntropyLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("weights", &self.weights) .add("smoothing", &self.smoothing) .add("logits", &self.logits) .optional() } } impl BinaryCrossEntropyLoss { /// Compute the criterion on the input tensor. /// /// # Shapes /// /// Binary: /// - logits: `[batch_size]` /// - targets: `[batch_size]` /// /// Multi-label: /// - logits: `[batch_size, num_classes]` /// - targets: `[batch_size, num_classes]` pub fn forward( &self, logits: Tensor, targets: Tensor, ) -> Tensor { self.assertions(&logits, &targets); let mut targets_float = targets.clone().float(); let shape = targets.dims(); if let Some(alpha) = self.smoothing { let num_classes = if D > 1 { shape[D - 1] } else { 2 }; targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32; } let mut loss = if self.logits { // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)` (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits) } else { // - (target * log(input) + (1 - target) * log(1 - input)) // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0) - targets_float * logits.log().clamp_min(-100.0) }; if let Some(weights) = &self.weights { let weights = if D > 1 { weights.clone().expand(shape) } else { // Flatten targets and expand resulting weights to make it compatible with // Tensor for binary 1-D case weights .clone() .gather(0, targets.flatten(0, 0)) .expand(shape) }; loss = loss * weights; } loss.mean() } fn assertions(&self, logits: &Tensor, targets: &Tensor) { let logits_dims = logits.dims(); let targets_dims = targets.dims(); assert!( logits_dims == targets_dims, "Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?})." ); if let Some(weights) = &self.weights && D > 1 { let targets_classes = targets_dims[D - 1]; let weights_classes = weights.dims()[0]; assert!( weights_classes == targets_classes, "The number of classes ({weights_classes}) does not match the weights provided ({targets_classes})." ); } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{TensorData, activation::sigmoid}; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_binary_cross_entropy_preds_all_correct() { let device = Default::default(); let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device); let targets = Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); let loss_actual = BinaryCrossEntropyLossConfig::new() .init(&device) .forward(preds, targets) .into_data(); let loss_expected = TensorData::from([0.000]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::default()); } #[test] fn test_binary_cross_entropy_preds_all_incorrect() { let device = Default::default(); let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device); let targets = Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); let loss_actual = BinaryCrossEntropyLossConfig::new() .init(&device) .forward(preds, targets) .into_data(); let loss_expected = TensorData::from([100.000]); // clamped value loss_actual.assert_approx_eq::(&loss_expected, Tolerance::default()); } #[test] fn test_binary_cross_entropy() { // import torch // from torch import nn // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355]) // target = torch.tensor([0., 1., 0., 1.]) // loss = nn.BCELoss() // sigmoid = nn.Sigmoid() // out = loss(sigmoid(input), target) # tensor(0.7491) let device = Default::default(); let logits = Tensor::::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device); let targets = Tensor::::from_data(TensorData::from([0, 1, 0, 1]), &device); let loss_actual = BinaryCrossEntropyLossConfig::new() .init(&device) .forward(sigmoid(logits), targets) .into_data(); let loss_expected = TensorData::from([0.7491]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::relative(1e-4)); } #[test] fn test_binary_cross_entropy_with_logits() { let device = Default::default(); let logits = Tensor::::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device); let targets = Tensor::::from_data(TensorData::from([0, 1, 0, 1]), &device); let loss_actual = BinaryCrossEntropyLossConfig::new() .with_logits(true) .init(&device) .forward(logits, targets) .into_data(); let loss_expected = TensorData::from([0.7491]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::relative(1e-4)); } #[test] fn test_binary_cross_entropy_with_weights() { // import torch // from torch import nn // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355]) // target = torch.tensor([0, 1, 0, 1]) // weights = torch.tensor([3., 7.]).gather(0, target) // loss = nn.BCELoss(weights) // sigmoid = nn.Sigmoid() // out = loss(sigmoid(input), target.float()) # tensor(3.1531) let device = Default::default(); let logits = Tensor::::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device); let targets = Tensor::::from_data(TensorData::from([0, 1, 0, 1]), &device); let weights = [3., 7.]; let loss_actual = BinaryCrossEntropyLossConfig::new() .with_weights(Some(weights.to_vec())) .init(&device) .forward(sigmoid(logits), targets) .into_data(); let loss_expected = TensorData::from([3.1531]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::relative(1e-4)); } #[test] fn test_binary_cross_entropy_with_smoothing() { // import torch // from torch import nn // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355]) // target = torch.tensor([0., 1., 0., 1.]) // target_smooth = target * (1 - 0.1) + (0.1 / 2) // loss = nn.BCELoss() // sigmoid = nn.Sigmoid() // out = loss(sigmoid(input), target_smooth) # tensor(0.7490) let device = Default::default(); let logits = Tensor::::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device); let targets = Tensor::::from_data(TensorData::from([0, 1, 0, 1]), &device); let loss_actual = BinaryCrossEntropyLossConfig::new() .with_smoothing(Some(0.1)) .init(&device) .forward(sigmoid(logits), targets) .into_data(); let loss_expected = TensorData::from([0.7490]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::relative(1e-4)); } #[test] fn test_binary_cross_entropy_multilabel() { // import torch // from torch import nn // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]]) // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]]) // weights = torch.tensor([3., 7., 0.9]) // loss = nn.BCEWithLogitsLoss() // out = loss(input, target) # tensor(0.7112) let device = Default::default(); let logits = Tensor::::from_floats( [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]], &device, ); let targets = Tensor::::from_data( TensorData::from([[1, 0, 1], [1, 0, 0]]), &device, ); let loss_actual = BinaryCrossEntropyLossConfig::new() .with_logits(true) .init(&device) .forward(logits, targets) .into_data(); let loss_expected = TensorData::from([0.7112]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::relative(1e-4)); } #[test] fn test_binary_cross_entropy_multilabel_with_weights() { // import torch // from torch import nn // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]]) // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]]) // loss = nn.BCEWithLogitsLoss() // out = loss(input, target) # tensor(3.1708) let device = Default::default(); let logits = Tensor::::from_floats( [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]], &device, ); let targets = Tensor::::from_data( TensorData::from([[1, 0, 1], [1, 0, 0]]), &device, ); let weights = [3., 7., 0.9]; let loss_actual = BinaryCrossEntropyLossConfig::new() .with_logits(true) .with_weights(Some(weights.to_vec())) .init(&device) .forward(logits, targets) .into_data(); let loss_expected = TensorData::from([3.1708]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::default()); } #[test] fn test_binary_cross_entropy_multilabel_with_smoothing() { // import torch // from torch import nn // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]]) // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]]) // target_smooth = target * (1 - 0.1) + (0.1 / 3) // loss = nn.BCELoss() // sigmoid = nn.Sigmoid() // out = loss(sigmoid(input), target_smooth) # tensor(0.7228) let device = Default::default(); let logits = Tensor::::from_floats( [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]], &device, ); let targets = Tensor::::from_data( TensorData::from([[1, 0, 1], [1, 0, 0]]), &device, ); let loss_actual = BinaryCrossEntropyLossConfig::new() .with_smoothing(Some(0.1)) .init(&device) .forward(sigmoid(logits), targets) .into_data(); let loss_expected = TensorData::from([0.7228]); loss_actual.assert_approx_eq::(&loss_expected, Tolerance::default()); } #[test] #[should_panic = "The number of classes"] fn multilabel_weights_should_match_target() { // import torch // from torch import nn // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]]) // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]]) // loss = nn.BCEWithLogitsLoss() // out = loss(input, target) # tensor(3.1708) let device = Default::default(); let logits = Tensor::::from_floats( [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]], &device, ); let targets = Tensor::::from_data( TensorData::from([[1, 0, 1], [1, 0, 0]]), &device, ); let weights = [3., 7.]; let _loss = BinaryCrossEntropyLossConfig::new() .with_logits(true) .with_weights(Some(weights.to_vec())) .init(&device) .forward(logits, targets); } #[test] fn display() { let config = BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9])); let loss = config.init::(&Default::default()); assert_eq!( alloc::format!("{loss}"), "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}" ); } } ================================================ FILE: crates/burn-nn/src/loss/cosine_embedding.rs ================================================ use alloc::format; use burn::tensor::linalg::cosine_similarity; use burn_core as burn; use crate::loss::reduction::Reduction; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::{Int, Tensor, activation::relu, backend::Backend}; /// Configuration for CosineEmbeddingLoss. #[derive(Config, Debug)] pub struct CosineEmbeddingLossConfig { /// Margin for negative samples. #[config(default = 0.0)] pub margin: f32, /// Specifies the reduction to apply to the output. #[config(default = "Reduction::Mean")] pub reduction: Reduction, } impl CosineEmbeddingLossConfig { /// Initialize CosineEmbeddingLoss. pub fn init(&self) -> CosineEmbeddingLoss { CosineEmbeddingLoss { margin: self.margin, reduction: self.reduction.clone(), } } } /// Cosine embedding loss between two tensors. /// /// Measures cosine distance between tensors. /// Used for learning embeddings or similarity. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct CosineEmbeddingLoss { /// Margin value. Default: 0.0 pub margin: f32, /// Reduction method pub reduction: Reduction, } impl Default for CosineEmbeddingLoss { fn default() -> Self { CosineEmbeddingLossConfig::new().init() } } impl ModuleDisplay for CosineEmbeddingLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("margin", &self.margin) .add("reduction", format!("{:?}", &self.reduction).as_str()) .optional() } } impl CosineEmbeddingLoss { /// Creates a new instance pub fn new() -> Self { CosineEmbeddingLossConfig::new().init() } /// Compute loss with reduction. /// /// # Shapes /// /// - input1: ``[batch_size, embedding_dim]`` /// - input2: ``[batch_size, embedding_dim]`` /// - target: ``[batch_size]`` with values 1 or -1 /// /// # Returns /// /// Loss tensor of shape ``[1]`` pub fn forward( &self, input1: Tensor, input2: Tensor, target: Tensor, ) -> Tensor { let tensor = self.forward_no_reduction(input1, input2, target); match &self.reduction { Reduction::Mean | Reduction::Auto => tensor.mean(), Reduction::Sum => tensor.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Compute loss without applying reduction. /// /// # Arguments /// /// * `input1` - First input tensor of shape ``[batch_size, embedding_dim]`` /// * `input2` - Second input tensor of shape ``[batch_size, embedding_dim]`` /// * `target` - Target tensor of shape ``[batch_size]`` with values 1 or -1 /// /// # Returns /// /// Tensor of per-element losses with shape ``[batch_size]`` pub fn forward_no_reduction( &self, input1: Tensor, input2: Tensor, target: Tensor, ) -> Tensor { self.assertions(&input1, &input2, &target); // cos_sim shape: [batch_size, 1] let cos_sim = cosine_similarity(input1, input2, 1, None); // cos_sim shape: [batch_size] let cos_sim: Tensor = cos_sim.squeeze_dim(1); let mut loss = cos_sim.zeros_like(); // Similar pairs (target == 1) - Formula: L = 1 - cos_sim let similar_mask = target.clone().equal_elem(1); let similar_loss = cos_sim.clone().neg().add_scalar(1); loss = loss.mask_where(similar_mask, similar_loss); // Dissimilar pairs (target == -1) - Formula: L = max(0, cos_sim - margin) let dissimilar_mask = target.equal_elem(-1); let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin)); loss = loss.mask_where(dissimilar_mask, dissimilar_loss); // return loss shape: [batch_size] loss } fn assertions( &self, input1: &Tensor, input2: &Tensor, target: &Tensor, ) { let [batch_size1, dim1] = input1.dims(); let [batch_size2, dim2] = input2.dims(); let [batch_size_target] = target.dims(); assert_eq!( batch_size1, batch_size2, "Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})" ); assert_eq!( dim1, dim2, "Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})" ); assert_eq!( batch_size1, batch_size_target, "Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})" ); } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn cosine_embedding_loss_positive_target() { let device = Default::default(); // Two identical vectors should have cosine similarity of 1 let input1 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); let input2 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); // Target 1 means that inputs should be similar let target = Tensor::::from_data(TensorData::from([1, 1]), &device); let loss = CosineEmbeddingLossConfig::new().init(); let loss_no_reduction = loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone()); let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone()); let loss_sum = loss.forward(input1, input2, target); // For identical vectors, 1 - cos_sim = 1 - 1 = 0 let expected_no_reduction = TensorData::from([0.0, 0.0]); loss_no_reduction .into_data() .assert_approx_eq::(&expected_no_reduction, Tolerance::default()); let expected_mean = TensorData::from([0.0]); loss_mean .into_data() .assert_approx_eq::(&expected_mean, Tolerance::default()); let expected_sum = TensorData::from([0.0]); loss_sum .into_data() .assert_approx_eq::(&expected_sum, Tolerance::default()); } #[test] fn cosine_embedding_loss_negative_target() { let device = Default::default(); // Two identical vectors should have cosine similarity of 1 let input1 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); let input2 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); // Target -1 means that inputs should be dissimilar let target = Tensor::::from_data(TensorData::from([-1, -1]), &device); // With margin 0.0, max(0, cos_sim - margin) = max(0, 1 - 0) = 1 let loss = CosineEmbeddingLossConfig::new().init(); let loss_no_reduction = loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone()); let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone()); // Create a loss with Sum reduction for testing let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum); let loss_sum = loss_sum_config .init() .forward(input1.clone(), input2.clone(), target.clone()); let expected_no_reduction = TensorData::from([1.0, 1.0]); loss_no_reduction .into_data() .assert_approx_eq::(&expected_no_reduction, Tolerance::default()); let expected_mean = TensorData::from([1.0]); loss_mean .into_data() .assert_approx_eq::(&expected_mean, Tolerance::default()); let expected_sum = TensorData::from([2.0]); loss_sum .into_data() .assert_approx_eq::(&expected_sum, Tolerance::default()); // With margin 0.5, max(0, cos_sim - margin) = max(0, 1 - 0.5) = 0.5 let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init(); let loss_with_margin = loss_with_margin.forward(input1, input2, target); let expected = TensorData::from([0.5]); loss_with_margin .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn cosine_embedding_loss_mixed_targets() { let device = Default::default(); let input1 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); let input2 = Tensor::::from_data( TensorData::from([[1.0, 0.0], [0.0, 1.0]]), &device, ); // Mixed targets let target = Tensor::::from_data(TensorData::from([1, -1]), &device); let loss = CosineEmbeddingLossConfig::new().init(); let loss_no_reduction = loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone()); let loss_mean = loss.forward(input1, input2, target); let expected_no_reduction = TensorData::from([0.0, 1.0]); loss_no_reduction .into_data() .assert_approx_eq::(&expected_no_reduction, Tolerance::default()); let expected_mean = TensorData::from([0.5]); loss_mean .into_data() .assert_approx_eq::(&expected_mean, Tolerance::default()); } #[test] fn display() { let config = CosineEmbeddingLossConfig::new().with_margin(0.5); let loss = config.init(); assert_eq!( alloc::format!("{loss}"), "CosineEmbeddingLoss {margin: 0.5, reduction: Mean}" ); } } ================================================ FILE: crates/burn-nn/src/loss/cross_entropy.rs ================================================ use burn_core as burn; use burn_core::tensor::IndexingUpdateOp; use alloc::string::ToString; use alloc::vec; use alloc::vec::Vec; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::activation::log_softmax; use burn::tensor::{Bool, Int, Tensor, backend::Backend}; use burn::{config::Config, module::Module}; /// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init). #[derive(Config, Debug)] pub struct CrossEntropyLossConfig { /// Create padded cross entropy. /// /// Prevents pad tokens from impacting loss calculation. pub pad_tokens: Option>, /// Create weighted cross-entropy. /// /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1, /// /// # Pre-conditions /// - The order of the weight vector should correspond to the label integer assignment. /// - Targets assigned negative Int's will not be allowed. pub weights: Option>, /// Create cross-entropy with label smoothing. /// /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes. /// Alpha = 0 would be the same as default. pub smoothing: Option, /// Create cross-entropy with probabilities as input instead of logits. /// #[config(default = true)] pub logits: bool, } impl CrossEntropyLossConfig { /// Initialize [Cross-entropy loss](CrossEntropyLoss). pub fn init(&self, device: &B::Device) -> CrossEntropyLoss { self.assertions(); CrossEntropyLoss { pad_tokens: self.pad_tokens.clone(), weights: self .weights .as_ref() .map(|e| Tensor::::from_floats(e.as_slice(), device)), smoothing: self.smoothing, logits: self.logits, } } fn assertions(&self) { if let Some(alpha) = self.smoothing { assert!( (0.0..=1.).contains(&alpha), "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}" ); }; if let Some(weights) = self.weights.as_ref() { assert!( weights.iter().all(|e| e > &0.), "Weights of cross-entropy have to be positive." ); } } } /// Calculate the cross entropy loss from the input logits and the targets. /// /// Should be created using [CrossEntropyLossConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct CrossEntropyLoss { /// Pad tokens to ignore in the loss calculation. pub pad_tokens: Option>, /// Weights for cross-entropy. pub weights: Option>, /// Label smoothing factor. pub smoothing: Option, /// Use logits as input. pub logits: bool, } impl ModuleDisplay for CrossEntropyLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens { alloc::format!("Vec<0..{}>", pad_tokens.len()) } else { "None".to_string() }; content .add("pad_tokens", &pad_tokens) .add("weights", &self.weights) .add("smoothing", &self.smoothing) .add("logits", &self.logits) .optional() } } impl CrossEntropyLoss { /// For backward compatibility. pub fn new(pad_index: Option, device: &B::Device) -> Self { CrossEntropyLossConfig::new() .with_pad_tokens(pad_index.map(|e| vec![e])) .init(device) } /// Compute the criterion on the input tensor. /// /// # Shapes /// /// - logits: `[batch_size, num_targets]` /// - targets: `[batch_size]` pub fn forward(&self, logits: Tensor, targets: Tensor) -> Tensor { Self::assertions(logits.clone(), targets.clone()); match self.smoothing { Some(alpha) => self.forward_smoothed(logits, targets, alpha), _ => self.forward_default(logits, targets), } } fn forward_smoothed( &self, logits: Tensor, targets: Tensor, alpha: f32, ) -> Tensor { let mask = self.padding_mask(&targets); let tensor = if self.logits { log_softmax(logits, 1) } else { logits.log() }; let [batch_size, nr_classes] = tensor.dims(); let tensor = tensor * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha); match &self.weights { Some(weights) => { let tensor = tensor * weights .clone() .reshape([1, nr_classes]) .repeat_dim(0, batch_size); let weights = weights.clone().gather(0, targets); let tensor = Self::apply_mask_2d(tensor, mask); tensor.sum().neg() / weights.sum() } None => { let tensor = Self::apply_mask_2d(tensor, mask); tensor.sum_dim(1).mean().neg() } } } fn forward_default(&self, logits: Tensor, targets: Tensor) -> Tensor { let [batch_size] = targets.dims(); let mask = self.padding_mask(&targets); let tensor = log_softmax(logits, 1); let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1])); match &self.weights { Some(weights) => { let weights = weights.clone().gather(0, targets); let tensor = tensor.reshape([batch_size]) * weights.clone(); let tensor = Self::apply_mask_1d(tensor, mask); tensor.sum().neg() / weights.sum() } None => { let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask); tensor.mean().neg() } } } fn compute_smoothed_targets( shape: [usize; 2], targets: Tensor, alpha: f32, ) -> Tensor { let [batch_size, nr_classes] = shape; let device = &targets.device(); let targets_matrix = Tensor::::zeros(shape, device).scatter( 1, targets.reshape([batch_size, 1]), Tensor::ones([batch_size, 1], device), IndexingUpdateOp::Add, ); targets_matrix * (1. - alpha) + alpha / nr_classes as f32 } fn padding_mask(&self, targets: &Tensor) -> Option> { let mut mask = None; if let Some(pad_tokens) = &self.pad_tokens { let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int(); for x in pad_tokens { res = res + targets.clone().equal_elem(*x as i64).int(); } mask = Some(res.greater_elem(0)); } mask } fn apply_mask_1d(mut tensor: Tensor, mask: Option>) -> Tensor { if let Some(mask) = mask { tensor = tensor.mask_fill(mask, 0); } tensor } fn apply_mask_2d(mut tensor: Tensor, mask: Option>) -> Tensor { if let Some(mask) = mask { let [batch_size, nr_classes] = tensor.dims(); tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0); } tensor } fn assertions(logits: Tensor, targets: Tensor) { let [logits_height, _] = logits.dims(); let [targets_height] = targets.dims(); assert!( logits_height == targets_height, "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})." ); } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem}; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; macro_rules! setup { () => {{ let [batch_size, num_targets] = [4, 5]; let device = Default::default(); let logits = Tensor::::random( [batch_size, num_targets], Distribution::Normal(0., 1.0), &device, ); let targets = Tensor::::from_data(TensorData::from([2, 0, 4, 1]), &device); let targets_logits = Tensor::::from_data( TensorData::from([ [0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0, 0.0], ]), &device, ); (logits, targets, targets_logits) }}; } macro_rules! setup_padded { () => {{ let [batch_size, num_targets, pad_index] = [4, 5, 1]; let device = Default::default(); let logits = Tensor::::random( [batch_size, num_targets], Distribution::Normal(0., 1.0), &device, ); let targets = Tensor::::from_data( TensorData::from([2, 0, 4, pad_index as i64]).convert::>(), &device, ); let targets_logits = Tensor::::from_data( TensorData::from([ [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0], ]), &device, ); (logits, targets, targets_logits) }}; } #[test] fn test_cross_entropy_loss_with_weights() { let (logits, targets, targets_logits) = setup!(); let weights = vec![1.0, 2., 3., 4., 5.]; let device = Default::default(); let loss_1 = CrossEntropyLossConfig::new() .with_weights(Some(weights.clone())) .init(&device) .forward(logits.clone(), targets); let tensor = log_softmax(logits, 1); let loss_2 = tensor * targets_logits * Tensor::::from_floats(weights.as_slice(), &device) .unsqueeze() .repeat_dim(0, 4); let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_label_smoothing_with_weights_and_alpha_zero() { let (logits, targets, _) = setup!(); let device = Default::default(); let weights = vec![1.0, 2., 3., 4., 5.]; let loss_1 = CrossEntropyLossConfig::new() .with_weights(Some(weights.clone())) .init(&device) .forward(logits.clone(), targets.clone()); let loss_2 = CrossEntropyLossConfig::new() .with_weights(Some(weights.clone())) .with_smoothing(Some(0.)) .init(&device) .forward(logits.clone(), targets); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_cross_entropy_loss() { let (logits, targets, targets_logits) = setup!(); let device = Default::default(); let loss_1 = CrossEntropyLossConfig::new() .init(&device) .forward(logits.clone(), targets); let loss_2 = cross_entropy_with_logits(logits, targets_logits); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_label_smoothing_alpha_equal_zero() { let (logits, targets, _) = setup!(); let device = Default::default(); let loss_1 = CrossEntropyLossConfig::new() .init(&device) .forward(logits.clone(), targets.clone()); let loss_2 = CrossEntropyLossConfig::new() .with_smoothing(Some(0.)) .init(&device) .forward(logits, targets); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_cross_entropy_loss_with_pad_token() { let (logits, targets, targets_logits) = setup_padded!(); let pad_index = 1; let loss_1 = CrossEntropyLossConfig::new() .with_pad_tokens(Some(vec![pad_index, 2])) .init(&logits.device()) .forward(logits.clone(), targets); let loss_2 = cross_entropy_with_logits(logits, targets_logits); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_label_smoothing_with_zero_alpha_and_pad_token() { let (logits, targets, _) = setup_padded!(); let pad_index = 1; let loss_1 = CrossEntropyLossConfig::new() .with_pad_tokens(Some(vec![pad_index, 2])) .init(&logits.device()) .forward(logits.clone(), targets.clone()); let loss_2 = CrossEntropyLossConfig::new() .with_pad_tokens(Some(vec![pad_index, 2])) .with_smoothing(Some(0.)) .init(&logits.device()) .forward(logits.clone(), targets); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn test_label_smoothing_target_conversion() { let (logits, targets, _) = setup!(); let smoothed_targets = CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05); let targets_logits = Tensor::::from_data( TensorData::from([ [0.01, 0.01, 0.96, 0.01, 0.01], [0.96, 0.01, 0.01, 0.01, 0.01], [0.01, 0.01, 0.01, 0.01, 0.96], [0.01, 0.96, 0.01, 0.01, 0.01], ]), &Default::default(), ); smoothed_targets .into_data() .assert_approx_eq::(&targets_logits.into_data(), Tolerance::default()); } #[test] fn test_label_smoothing() { let (logits, targets, _) = setup!(); let device = Default::default(); let loss_1 = CrossEntropyLossConfig::new() .with_smoothing(Some(0.05)) .init(&device) .forward(logits.clone(), targets); let targets_logits = Tensor::::from_data( TensorData::from([ [0.01, 0.01, 0.96, 0.01, 0.01], [0.96, 0.01, 0.01, 0.01, 0.01], [0.01, 0.01, 0.01, 0.01, 0.96], [0.01, 0.96, 0.01, 0.01, 0.01], ]), &device, ); let x = log_softmax(logits, 1); let loss_2 = (x * targets_logits).sum_dim(1).mean().neg(); loss_1 .into_data() .assert_approx_eq::(&loss_2.into_data(), Tolerance::default()); } #[test] fn display() { let config = CrossEntropyLossConfig::new() .with_weights(Some(alloc::vec![3., 7., 0.9])) .with_smoothing(Some(0.5)); let loss = config.init::(&Default::default()); assert_eq!( alloc::format!("{loss}"), "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}" ); } } ================================================ FILE: crates/burn-nn/src/loss/ctc.rs ================================================ #![allow(clippy::excessive_precision)] use super::Reduction; use alloc::vec; use burn::config::Config; use burn::module::Module; use burn::tensor::{Bool, Element, Int, Tensor, backend::Backend, s}; use burn_core as burn; use burn_core::tensor::Numeric; use core::f32; /// Configuration for the [CTC Loss](CTCLoss) module. #[derive(Config, Debug)] pub struct CTCLossConfig { /// The index number used to represent the blank label. Default value is `0`. #[config(default = 0)] pub blank: usize, /// Whether to zero infinite losses and the associated gradients. Default value is `false`. #[config(default = false)] pub zero_infinity: bool, } impl CTCLossConfig { /// Initialize a new [CTC Loss](CTCLoss) module pub fn init(&self) -> CTCLoss { CTCLoss { blank: self.blank, zero_infinity: self.zero_infinity, } } } /// Computes the Connectionist Temporal Classification (CTC) loss. /// /// Calculates the loss between a continuous (unsegmented) time series and a target sequence. /// CTC sums over the probability of all possible alignments of the input to the target, /// producing a loss value that is differentiable with respect to each input node. /// /// The input to this loss is expected to be **log-probabilities** (e.g,, via `log_softmax`), /// not raw logits. /// /// # References /// /// - [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](https://www.cs.toronto.edu/~graves/icml_2006.pdf) /// /// # Example /// /// ```rust,ignore /// use burn::tensor::{Tensor, Int}; /// use burn::tensor::activation::log_softmax; /// use burn::nn::loss::{CTCLossConfig, CTCLoss}; /// /// let device = Default::default(); /// /// // Initialize CTC Loss with default configuration /// let ctc_loss = CTCLossConfig::new().init(); /// /// // Initialize CTC Loss with custom configuration /// let ctc_loss = CTCLossConfig::new() /// .with_blank(1) /// .with_zero_infinity(true) /// .init(); /// /// // Prepare inputs (Logits shape: [Time, Batch, Class]) /// // In your actual code, the logits would be the output of your model /// let logits = Tensor::::ones([10, 2, 5], &device); /// let log_probs = log_softmax(logits, 2); /// /// // Targets shape: [Batch, Max_Target_Len] /// // Note: Targets should not contain the blank index (1). /// let targets = Tensor::::from_data([[0, 2], [3, 4]], &device); /// /// // Lengths shape: [Batch] /// let input_lengths = Tensor::::from_data([10, 8], &device); /// let target_lengths = Tensor::::from_data([2, 2], &device); /// /// // Compute loss /// let loss = ctc_loss.forward(log_probs, targets, input_lengths, target_lengths); /// ``` #[derive(Module, Clone, Debug)] pub struct CTCLoss { blank: usize, zero_infinity: bool, } impl CTCLoss { /// Computes the CTC loss for the input log-probabilities and targets with no reduction applied. /// /// # Arguments /// /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`). /// - `targets`: A 2D tensor containing the target class indices. These indices should not /// include the blank index used in CTC loss. The targets are padded to the length of the longest sequence. /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This /// allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains /// sequences of varying lengths. /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target /// sequence in `targets`. /// /// # Returns /// /// - A 1D tensor of shape `[batch_size]` containing the loss for each sample. /// /// # Shapes /// /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank. /// - `targets`: `[batch_size, max_target_length]` /// - `input_lengths`: `[batch_size]` /// - `target_lengths`: `[batch_size]` pub fn forward( &self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, ) -> Tensor { let device = log_probs.device(); let [max_input_length, batch_size, num_classes] = log_probs.dims(); // [T, N, C] let max_target_len = targets.dims()[1]; let input_lengths_len = input_lengths.dims()[0]; let target_lengths_len = target_lengths.dims()[0]; self.assertions( batch_size, num_classes, targets.clone(), input_lengths_len, target_lengths_len, ); // Build the modified label sequence l' by inserting blanks around every label let blank_inserted_targets = self.insert_blanks::(&targets, batch_size, max_target_len, &device); // Initialize the forward variable alpha let max_l_prime_len = 2 * max_target_len + 1; let mut log_alpha_t_s = Tensor::::full([batch_size, max_l_prime_len], f32::NEG_INFINITY, &device); log_alpha_t_s = self.initialize_log_alpha( log_probs.clone(), blank_inserted_targets.clone(), log_alpha_t_s, ); let l_prime_combined_mask = self.create_l_prime_mask( blank_inserted_targets.clone(), batch_size, max_l_prime_len, &device, ); let s_mask = self.create_s_mask(max_l_prime_len, batch_size, target_lengths.clone(), &device); // Loop over time steps since an arbitrary time step t depends on t - 1 for t in 1..max_input_length { let combined_s_t_mask = self.create_combined_s_t_mask( input_lengths.clone(), t, batch_size, max_l_prime_len, s_mask.clone(), ); log_alpha_t_s = self.compute_log_alpha_t_s( t, combined_s_t_mask, log_alpha_t_s, l_prime_combined_mask.clone(), log_probs.clone(), blank_inserted_targets.clone(), ); } let last_blank_indices = target_lengths.mul_scalar(2).reshape([batch_size, 1]); let last_label_indices = last_blank_indices.clone().sub_scalar(1); let log_alpha_last_blank = log_alpha_t_s .clone() .gather(1, last_blank_indices) .squeeze_dim::<1>(1); let log_alpha_last_label = log_alpha_t_s .clone() .gather(1, last_label_indices) .squeeze_dim::<1>(1); let log_likelihood = self.log_sum_exp(log_alpha_last_blank, log_alpha_last_label, &device); let mut ctc_loss_tensor = log_likelihood.neg(); if self.zero_infinity { let inf_mask = ctc_loss_tensor.clone().is_inf(); ctc_loss_tensor = ctc_loss_tensor .clone() .mask_where(inf_mask, ctc_loss_tensor.clone().zeros_like()); } ctc_loss_tensor } /// Computes the CTC loss for the input log-probabilities and targets with reduction. /// /// # Arguments /// /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`). /// - `targets`: A 2D tensor containing the target class indices. These indices should not /// include the blank index used in CTC loss. The targets are padded to the length of the longest sequence. /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This /// allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains /// sequences of varying lengths. /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target /// sequence in `targets`. /// - `reduction`: The reduction stratey to apply to the loss tensor containing the CTC loss values for /// each sample (e.g., mean, sum). For the mean reduction strategy, the output losses will be divided /// by the target lengths and then the mean over the batch is taken. This follows PyTorch's behavior. /// /// # Returns /// /// - A 1D tensor of shape `[1]` containing the reduced loss value. /// /// # Shapes /// /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank. /// - `targets`: `[batch_size, max_target_length]` /// - `input_lengths`: `[batch_size]` /// - `target_lengths`: `[batch_size]` /// /// # Panics /// - If `reduction` is not one of `Reduction::Auto`, `Reduction::Mean`, and `Reduction::Sum`. /// - If `blank` index is greater than or equal to `num_classes`. /// - If the batch dimension of `log_probs`, `targets`, `input_lengths`, and `target_lengths` do not match. pub fn forward_with_reduction( &self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, reduction: Reduction, ) -> Tensor { let ctc_loss_tensor = self.forward(log_probs, targets, input_lengths, target_lengths.clone()); match reduction { Reduction::Auto | Reduction::Mean => { // Following PyTorch's behavior where the output losses are divided // by the target lengths and then the mean over the batch is taken let target_lengths_float = target_lengths.float(); ctc_loss_tensor.div(target_lengths_float).mean() } Reduction::Sum => ctc_loss_tensor.sum(), other => panic!("{other:?} reduction is not supported"), } } fn assertions( &self, batch_size: usize, num_classes: usize, targets: Tensor, input_lengths_len: usize, target_lengths_len: usize, ) { assert!( self.blank < num_classes, "blank index {} must be less than num_classes {}", self.blank, num_classes ); assert_eq!( targets.dims()[0], batch_size, "targets batch dimension {} must equal batch_size {}", targets.dims()[0], batch_size ); assert_eq!( input_lengths_len, batch_size, "input_lengths length {} must equal batch_size {}", input_lengths_len, batch_size ); assert_eq!( target_lengths_len, batch_size, "target_lengths length {} must equal batch_size {}", target_lengths_len, batch_size ); } fn insert_blanks( &self, targets: &Tensor, batch_size: usize, max_target_len: usize, device: &B::Device, ) -> Tensor { // The modified label sequences have (max_target_len + 1) blank labels let blank_tensor = Tensor::::full( [batch_size, 2 * max_target_len + 1], self.blank as i64, device, ); blank_tensor.slice_assign(s![.., 1..;2], targets.clone()) } fn initialize_log_alpha( &self, log_probs: Tensor, blank_inserted_targets: Tensor, log_alpha_t_s: Tensor, ) -> Tensor { // Given alpha_t(s), we have: // alpha_1(1) = (y_blank)^1 => log_alpha_1(1) = ln(y_blank)^1 // alpha_1(2) = (y_l1)^1 => log_alpha_1(2) = ln(y_l1)^1 // alpha_1(s) = 0 (for every s > 2) => log_alpha_1(s) = -neg_inf let log_probs_t0 = log_probs .clone() .slice(s![0..1, .., ..]) .squeeze_dim::<2>(0); // shape: [N, C] // log_alpha shape: [N, 2*S+1] // log_probs shape: [T, N, C] // log_alpha[:, 0] = log_probs[0, :, blank] let first_blank = blank_inserted_targets.clone().slice(s![.., 0..1]); // [N, 1] // log_probs_t0 have C columns where each represents a unique class (includes blank) let log_prob_blank = log_probs_t0.clone().gather(1, first_blank); // [N, 1] let temp_log_alpha_t_s = log_alpha_t_s.slice_assign(s![.., 0..1], log_prob_blank); // log_alpha[:, 1] = log_probs[0, :, targets[:, 0]] let first_label = blank_inserted_targets.clone().slice(s![.., 1..2]); // [N, 1] let log_prob_first_label = log_probs_t0.gather(1, first_label); // [N, 1] temp_log_alpha_t_s.slice_assign(s![.., 1..2], log_prob_first_label) } fn right_shift_2d_tensor( &self, org_2d_tensor: Tensor, shift_by: usize, device: &B::Device, ) -> Tensor where K: Numeric, K::Elem: Element, { assert!( shift_by == 1 || shift_by == 2, "The parameter shift_by must 1 or 2" ); let [rows, cols] = org_2d_tensor.dims(); let padding_shape = [rows, shift_by]; let padding_tensor = if org_2d_tensor.dtype().is_float() { Tensor::::full(padding_shape, f32::NEG_INFINITY, device) } else { Tensor::::full(padding_shape, 0, device) }; let org_tensor_shortened = org_2d_tensor.slice(s![.., ..cols - shift_by]); Tensor::cat(vec![padding_tensor, org_tensor_shortened], 1) } fn create_l_prime_mask( &self, blank_inserted_targets: Tensor, batch_size: usize, max_l_prime_len: usize, device: &B::Device, ) -> Tensor { let l_prime_s = blank_inserted_targets.clone(); let l_prime_s_minus_2 = self.right_shift_2d_tensor(blank_inserted_targets, 2, device); // Create a single mask that is true for entries where alpha_{t-1}(s - 2) should also // be added to compute alpha_{t}(s) let s_is_not_blank_mask = l_prime_s.clone().not_equal_elem(self.blank as i64); let s_not_equal_s_minus_2_mask = l_prime_s.not_equal(l_prime_s_minus_2); // The 2 leftmost columns of the returned mask should only contain false. // These are invalid positions since s - 2 is a valid index only when s >= 2. let col_indices = Tensor::::arange(0..(max_l_prime_len as i64), device) .reshape([1, max_l_prime_len]) .expand([batch_size, max_l_prime_len]); let s_greater_than_1_mask = col_indices.greater_equal_elem(2); s_is_not_blank_mask .bool_and(s_not_equal_s_minus_2_mask) .bool_and(s_greater_than_1_mask) } fn create_s_mask( &self, max_l_prime_len: usize, batch_size: usize, target_lengths: Tensor, device: &B::Device, ) -> Tensor { let col_indices = Tensor::::arange(0..max_l_prime_len as i64, device) .reshape([1, max_l_prime_len]); let col_indices_expanded = col_indices.expand([batch_size, max_l_prime_len]); let blank_inserted_target_lengths = target_lengths .mul_scalar(2) .add_scalar(1) .reshape([batch_size, 1]); let target_lengths_expanded = blank_inserted_target_lengths.expand([batch_size, max_l_prime_len]); col_indices_expanded.lower(target_lengths_expanded) } fn log_sum_exp( &self, log_tensor1: Tensor, log_tensor2: Tensor, device: &B::Device, ) -> Tensor { let shape = log_tensor1.dims(); let ones_tensor = Tensor::::ones(shape, device); // Let A and B represent parameters tensor1 and tensor2 respectively. // Let C be the tensor this method returns. // If an entry in both A and B are neg_inf, then the same entry // in C should also contain neg_inf. // If an entry in only one of A or B is neg_inf, then the same entry in // C should contain the value of the other tensor entry which is not neg_inf. let tensor1_is_neg_inf = log_tensor1.clone().equal_elem(f32::NEG_INFINITY); let tensor2_is_neg_inf = log_tensor2.clone().equal_elem(f32::NEG_INFINITY); let temp_tensor1 = ones_tensor .clone() .mask_where(tensor1_is_neg_inf.clone(), log_tensor2.clone()); let neg_inf_lse_tensor = temp_tensor1.mask_where(tensor2_is_neg_inf.clone(), log_tensor1.clone()); // Create sanitized tensors for math operations to prevent NaN. Replace neg_inf // with 0.0. The tensor neg_inf_lse_tensor contains correct values for entries // where at least one of the corresponding entries in log_tensor1 or log_tensor2 // is neg_inf. Hence, the math operations below is computing the values for entries // that are not already filled with their actual/correct values. Thus, result for // these positions (where we sanitize) are not used anyway since the // unfilled_entries_mask is applied at the end. let tensor1_safe = log_tensor1 .clone() .mask_fill(tensor1_is_neg_inf.clone(), 0.0); let tensor2_safe = log_tensor2 .clone() .mask_fill(tensor2_is_neg_inf.clone(), 0.0); // Create a mask which contains true for entries whose values were not // set by operations above let filled_entries_mask = tensor1_is_neg_inf.bool_or(tensor2_is_neg_inf); let unfilled_entries_mask = filled_entries_mask.bool_not(); let max_tensor = tensor1_safe.clone().max_pair(tensor2_safe.clone()); let diff_tensor = tensor1_safe.sub(tensor2_safe); let exp_tensor = diff_tensor.abs().neg().exp(); let ln_tensor = ones_tensor.add(exp_tensor).log(); let lse_tensor = max_tensor.add(ln_tensor); neg_inf_lse_tensor.mask_where(unfilled_entries_mask, lse_tensor) } fn create_combined_s_t_mask( &self, input_lengths: Tensor, t: usize, batch_size: usize, max_l_prime_len: usize, s_mask: Tensor, ) -> Tensor { // Create masks for valid t and s let t_mask_1d = input_lengths .clone() .greater_elem(t as i64) .reshape([batch_size, 1]); let t_mask = t_mask_1d.expand([batch_size, max_l_prime_len]); t_mask.bool_and(s_mask.clone()) } fn compute_log_alpha_t_s( &self, t: usize, combined_s_t_mask: Tensor, log_alpha_t_s: Tensor, l_prime_combined_mask: Tensor, log_probs: Tensor, blank_inserted_targets: Tensor, ) -> Tensor { let device = log_probs.device(); let log_alpha_t_minus_1 = log_alpha_t_s.clone(); // No move from last time step: alpha_{t-1}(s) let log_alpha_s = log_alpha_t_minus_1.clone(); // Single move from last time step: alpha_{t-1}(s - 1) let log_alpha_s_minus_1 = self.right_shift_2d_tensor(log_alpha_t_minus_1.clone(), 1, &device); // A skip move (moving 2 positions) from last time step: alpha_{t-1}(s - 2) let log_alpha_s_minus_2 = self.right_shift_2d_tensor(log_alpha_t_minus_1.clone(), 2, &device); // Compute alpha_{t}(s) using recursion, corresponding to equation 6 of the paper. let log_alpha_bar = self.log_sum_exp(log_alpha_s, log_alpha_s_minus_1, &device); let log_alpha_bar_plus_log_alpha_s_minus_2 = self.log_sum_exp(log_alpha_bar.clone(), log_alpha_s_minus_2, &device); let log_alpha_s_to_s_minus_2 = log_alpha_bar.mask_where( l_prime_combined_mask.clone(), log_alpha_bar_plus_log_alpha_s_minus_2, ); // [N, 2 * U + 1] let log_probs_t = log_probs.clone().slice(s![t, .., ..]).squeeze_dim::<2>(0); // [N, C] let log_probs_l_prime_s = log_probs_t.gather(1, blank_inserted_targets.clone()); let temp_log_alpha_t_s = log_alpha_s_to_s_minus_2.add(log_probs_l_prime_s); log_alpha_t_s.mask_where(combined_s_t_mask, temp_log_alpha_t_s) } } #[cfg(test)] mod tests { use super::*; use burn_ndarray::{NdArray, NdArrayDevice}; type TestBackend = NdArray; fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) { assert_eq!( actual.len(), expected.len(), "Length mismatch: actual {} vs expected {}", actual.len(), expected.len() ); for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() { assert!( (a - e).abs() < tol, "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})", i, e, a, (a - e).abs() ); } } // --------------------------------------------------------------- // insert_blanks tests // --------------------------------------------------------------- #[test] fn test_insert_blanks_single_sample() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); let targets = Tensor::::from_data([[1_i64, 2, 3]], &device); let result = ctc.insert_blanks::(&targets, 1, 3, &device); let result_data = result.into_data().to_vec::().unwrap(); assert_eq!(result_data, vec![0, 1, 0, 2, 0, 3, 0]); } #[test] fn test_insert_blanks_batch() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); let targets = Tensor::::from_data([[1_i64, 2], [3, 4]], &device); let result = ctc.insert_blanks::(&targets, 2, 2, &device); let result_data = result.into_data().to_vec::().unwrap(); assert_eq!(result_data, vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0]); } #[test] fn test_insert_blanks_custom_blank() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_blank(2).init(); let targets = Tensor::::from_data([[0_i64, 1]], &device); let result = ctc.insert_blanks::(&targets, 1, 2, &device); let result_data = result.into_data().to_vec::().unwrap(); // l' = [blank=2, 0, blank=2, 1, blank=2] assert_eq!(result_data, vec![2, 0, 2, 1, 2]); } // --------------------------------------------------------------- // Assertions // --------------------------------------------------------------- #[test] #[should_panic(expected = "blank index")] fn test_ctc_loss_panics_invalid_blank_index() { let device = NdArrayDevice::Cpu; // blank=5 is out of bounds for num_classes=3 let ctc = CTCLossConfig::new().with_blank(5).init(); let log_probs = Tensor::::zeros([2, 1, 3], &device); let targets = Tensor::::from_data([[1]], &device); let input_lengths = Tensor::::from_data([2], &device); let target_lengths = Tensor::::from_data([1], &device); ctc.forward(log_probs, targets, input_lengths, target_lengths); } #[test] #[should_panic(expected = "must equal batch_size")] fn test_ctc_loss_panics_mismatched_batch_size() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); // Logits batch size = 2 let log_probs = Tensor::::zeros([2, 2, 3], &device); // Targets batch size = 1 (Mismatch) let targets = Tensor::::from_data([[1]], &device); let input_lengths = Tensor::::from_data([2, 2], &device); let target_lengths = Tensor::::from_data([1, 1], &device); ctc.forward(log_probs, targets, input_lengths, target_lengths); } #[test] #[should_panic(expected = "input_lengths length")] fn test_ctc_loss_panics_input_lengths_mismatch() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); // Logits batch size = 2 let log_probs = Tensor::::zeros([2, 2, 3], &device); let targets = Tensor::::from_data([[1], [2]], &device); // Input lengths size = 1 (Mismatch) let input_lengths = Tensor::::from_data([2], &device); let target_lengths = Tensor::::from_data([1, 1], &device); ctc.forward(log_probs, targets, input_lengths, target_lengths); } #[test] #[should_panic(expected = "target_lengths length")] fn test_ctc_loss_panics_target_lengths_mismatch() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); // Logits batch size = 2 let log_probs = Tensor::::zeros([2, 2, 3], &device); let targets = Tensor::::from_data([[1], [2]], &device); let input_lengths = Tensor::::from_data([2, 2], &device); // Target lengths size = 1 (Mismatch) let target_lengths = Tensor::::from_data([1], &device); ctc.forward(log_probs, targets, input_lengths, target_lengths); } // --------------------------------------------------------------- // Edge Case & Config Tests // --------------------------------------------------------------- #[test] fn test_ctc_loss_repeated_labels_minimum_input_length() { // T=3, N=1, C=2, blank=0, target=[1, 1], uniform P = 1/2. // // The minimum T for target [1, 1] is 3: the only valid path is (1, 0, 1). // prob = (1/2)^3 = 1/8 // Loss = -ln(1/8) = 3 * ln(2) let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); let log_probs = Tensor::::full([3, 1, 2], 0.5_f32.ln(), &device); let targets = Tensor::::from_data([[1_i64, 1]], &device); let input_lengths = Tensor::::from_data([3_i64], &device); let target_lengths = Tensor::::from_data([2_i64], &device); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.into_data().to_vec::().unwrap(); let expected = 3.0 * 2.0_f32.ln(); assert_approx_equal(&loss_data, &[expected], 1e-3); } #[test] fn test_ctc_loss_custom_blank_uniform() { // T=3, N=1, C=3, blank=2, target=[0, 1], uniform P = 1/3. // // Two distinct labels, 3 classes, 3 time steps, just with // blank=2 instead of 0. // 5 valid paths → total = 5/27 // Loss = -ln(5/27) let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_blank(2).init(); let log_probs = Tensor::::full([3, 1, 3], (1.0_f32 / 3.0).ln(), &device); let targets = Tensor::::from_data([[0_i64, 1]], &device); let input_lengths = Tensor::::from_data([3_i64], &device); let target_lengths = Tensor::::from_data([2_i64], &device); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.into_data().to_vec::().unwrap(); let expected = -(5.0_f32 / 27.0).ln(); assert_approx_equal(&loss_data, &[expected], 1e-3); } // --------------------------------------------------------------- // zero_infinity tests // --------------------------------------------------------------- #[test] fn test_ctc_loss_zero_infinity_produces_inf_when_disabled() { // T=2, N=1, C=3, blank=0, target=[1, 1], input_length=2 // Target [1, 1] requires at least 3 time steps → no valid paths → loss = +inf let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_zero_infinity(false).init(); let log_probs = Tensor::::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device); let targets = Tensor::::from_data([[1_i64, 1]], &device); let input_lengths = Tensor::::from_data([2_i64], &device); let target_lengths = Tensor::::from_data([2_i64], &device); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.into_data().to_vec::().unwrap(); assert!( loss_data[0].is_infinite() && loss_data[0] > 0.0, "Expected +inf, got {}", loss_data[0] ); } #[test] fn test_ctc_loss_zero_infinity_masks_inf_when_enabled() { // Same inputs as above, but zero_infinity=true → loss should be 0.0 let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_zero_infinity(true).init(); let log_probs = Tensor::::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device); let targets = Tensor::::from_data([[1_i64, 1]], &device); let input_lengths = Tensor::::from_data([2_i64], &device); let target_lengths = Tensor::::from_data([2_i64], &device); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.into_data().to_vec::().unwrap(); assert_approx_equal(&loss_data, &[0.0], 1e-6); } #[test] fn test_ctc_loss_zero_infinity_does_not_affect_finite_loss() { // Verify that zero_infinity=true does not change a finite loss value. let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_zero_infinity(true).init(); let log_probs = Tensor::::full([2, 1, 2], 0.5_f32.ln(), &device); let targets = Tensor::::from_data([[1_i64]], &device); let input_lengths = Tensor::::from_data([2_i64], &device); let target_lengths = Tensor::::from_data([1_i64], &device); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.into_data().to_vec::().unwrap(); let expected = -(0.75_f32).ln(); assert_approx_equal(&loss_data, &[expected], 1e-3); } } #[cfg(test)] mod pytorch_comparison_tests { use super::*; use burn::tensor::activation::log_softmax; use burn_autodiff::Autodiff; use burn_core::tensor::TensorData; use burn_ndarray::{NdArray, NdArrayDevice}; type InnerBackend = NdArray; type TestBackend = Autodiff; fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) { assert_eq!( actual.len(), expected.len(), "Length mismatch: actual {} vs expected {}", actual.len(), expected.len() ); for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() { assert!( (a - e).abs() < tol, "Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})", i, e, a, (a - e).abs() ); } } /// Deterministic logits: sin((t*7 + n*13 + c*3) * 0.1). fn generate_logits( t_size: usize, n_size: usize, c_size: usize, device: &NdArrayDevice, ) -> Tensor { let mut data = Vec::with_capacity(t_size * n_size * c_size); for t in 0..t_size { for n in 0..n_size { for c in 0..c_size { data.push(((t * 7 + n * 13 + c * 3) as f32 * 0.1).sin()); } } } Tensor::::from_data(TensorData::new(data, [t_size, n_size, c_size]), device) } /// Runs a CTC forward + backward test and asserts against expected values from PyTorch. /// /// This helper performs the following steps: /// 1. Generates deterministic logits using a sine-wave formula. /// 2. Computes the CTC loss (forward pass). /// 3. Asserts the computed loss matches `expected_losses`. /// 4. Backpropagates the sum of the loss. /// 5. Asserts the resulting gradients w.r.t. logits match `expected_grad_flat`. /// /// # Arguments /// /// - `expected_losses`: per-sample loss values from PyTorch (reduction='none'). /// - `expected_grad_flat`: flattened gradient of sum(loss) w.r.t. logits. #[allow(clippy::too_many_arguments)] fn run_comparison( label: &str, t_size: usize, n_size: usize, c_size: usize, targets_flat: Vec, target_shape: [usize; 2], input_lengths: Vec, target_lengths: Vec, blank: usize, expected_losses: &[f32], expected_grad_flat: &[f32], loss_tol: f32, grad_tol: f32, ) { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().with_blank(blank).init(); let logits = generate_logits(t_size, n_size, c_size, &device).require_grad(); let log_probs = log_softmax(logits.clone(), 2); let targets = Tensor::::from_data( TensorData::new(targets_flat, target_shape), &device, ); let input_lengths = Tensor::::from_data( TensorData::new(input_lengths, [n_size]), &device, ); let target_lengths = Tensor::::from_data( TensorData::new(target_lengths, [n_size]), &device, ); let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths); let loss_data = loss.clone().into_data().to_vec::().unwrap(); println!("=== {} ===", label); println!(" Loss: {:?}", loss_data); assert_approx_equal(&loss_data, expected_losses, loss_tol); let loss_sum = loss.sum(); let grads = loss_sum.backward(); let logits_grad = logits.grad(&grads).unwrap(); let grad_data = logits_grad.into_data().to_vec::().unwrap(); assert_approx_equal(&grad_data, expected_grad_flat, grad_tol); } #[test] fn test_ctc_loss_uniform_input_lengths() { // T=5, N=3, C=4, all input_lengths = 5 // Expected losses and gradient from PyTorch let expected_losses = [3.5236570835113525_f32, 3.495313882827759, 4.262677192687988]; let expected_grad_flat = [ -0.1679008007_f32, -0.4595540464, 0.2795598209, 0.3478950262, -0.3913056254, -0.0832268298, 0.2535884976, 0.2209439576, -0.0502742566, 0.2766197622, 0.2054125518, -0.4317580462, -0.0544800088, -0.3144550920, 0.0847885981, 0.2841464877, -0.1844545156, -0.2063435912, 0.2222184092, 0.1685796976, 0.0278018005, 0.2657383382, -0.0336986706, -0.2598414719, -0.0482986756, -0.0098767160, -0.1533526182, 0.2115280181, -0.1380317956, -0.2198686600, 0.2042596638, 0.1536407918, 0.0534787849, 0.1819230020, -0.2805589139, 0.0451571345, -0.0895631388, 0.1996460557, -0.2741115987, 0.1640286744, -0.2200077325, -0.1693530381, 0.2101601064, 0.1792006642, 0.0398471877, -0.1131042913, -0.2363226712, 0.3095797896, -0.2163617164, 0.2740726173, -0.2124865055, 0.1547756046, -0.4312027395, -0.0446923785, 0.2330704331, 0.2428246588, -0.0050083841, -0.6256869435, 0.2689785957, 0.3617166877, ]; run_comparison( "T=5, N=3, C=4 (uniform input lengths)", 5, 3, 4, vec![1, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3], vec![5, 5, 5], vec![2, 1, 3], 0, &expected_losses, &expected_grad_flat, 1e-3, 1e-3, ); } #[test] fn test_ctc_loss_repeated_labels() { // T=8, N=4, C=6, includes consecutive repeated label [1,1,2] // Expected losses and gradient from PyTorch let expected_losses = [ 8.84203052520752_f32, 9.023029327392578, 9.398024559020996, 9.008068084716797, ]; let expected_grad_flat = [ -0.2766432464, -0.5202965736, 0.1523768753, 0.1896236390, 0.2200277001, 0.2349116206, -0.1854365915, 0.2031330466, -0.4260218740, 0.1678018719, 0.1360142529, 0.1045092493, -0.6603536606, 0.2278252542, 0.1691786796, 0.1262856424, 0.0972681716, 0.0397959016, -0.0894432291, -0.5457318425, 0.1490373611, 0.1462858170, 0.1569476575, 0.1829041988, -0.2842915654, -0.4220107496, 0.1822281033, 0.1889107376, 0.1791101843, 0.1560532600, -0.1155678406, 0.2295538932, -0.2645366490, -0.0288553704, 0.1027252972, 0.0766806602, -0.5448347330, 0.2031028718, 0.1589304954, 0.1322451383, 0.1189499870, -0.0683937520, -0.0873993114, -0.3051757514, -0.2355299890, 0.1586059481, 0.2018169016, 0.2676822543, -0.3225219846, -0.2611543834, 0.1922984123, 0.1632783115, 0.1297036558, 0.0983960181, -0.1507159024, 0.2256962359, -0.1040333956, -0.1514528394, 0.0985243544, 0.0819815546, -0.2940836251, 0.1586865336, 0.1468491107, 0.1485087872, 0.1639631987, -0.3239239752, -0.0767390430, -0.0434846729, -0.4023587406, -0.0052628326, 0.2273432612, 0.3005020022, -0.2598774135, -0.2188862711, 0.1678501070, 0.1352078766, 0.1002781317, 0.0754275694, -0.1502914876, 0.1930875033, -0.0709601715, -0.2219523191, 0.1243555173, 0.1257609427, -0.0574148744, 0.1152269915, 0.1307857931, 0.1599020809, 0.2068412602, -0.5553412437, -0.0536844917, 0.0758557543, -0.2106334567, -0.2509877980, 0.1757438034, 0.2637061775, -0.1759711355, -0.2431350052, 0.1071053818, 0.1259848624, 0.1004033238, 0.0856125653, -0.1173698306, 0.1213828772, -0.1768893301, -0.2070008069, 0.1709136516, 0.2089634240, 0.0153109450, 0.0967332721, 0.1268781722, 0.1706230640, 0.2291058898, -0.6386513710, -0.0536664203, 0.1378114969, 0.0360041447, -0.2989685237, -0.0084722806, 0.1872915775, -0.1523490399, -0.2111770809, -0.0390694551, 0.1366800815, 0.1302325875, 0.1356829405, -0.0982905105, -0.0127884001, -0.3586881459, -0.0259541404, 0.2114149332, 0.2843062580, -0.0324133746, 0.1084750593, 0.1447229236, 0.1862253845, 0.2259712219, -0.6329812407, -0.1173689738, 0.1914442331, 0.1654772907, -0.1376858056, -0.2194855511, 0.1176188141, -0.1529908478, -0.0606661662, -0.3384291232, 0.1524862647, 0.1777049750, 0.2218948901, -0.0923086405, -0.2855934799, -0.3215619624, 0.1726681292, 0.2303666323, 0.2964293361, -0.2508065701, 0.1479703039, 0.1753441393, 0.1917535067, 0.1919818372, -0.4562432170, -0.2350299209, 0.2257601619, 0.1863904297, 0.0388212129, -0.2966264784, 0.0806845874, -0.1992894858, 0.1068909168, -0.5761897564, 0.1624972969, 0.2155302167, 0.2905607820, -0.1168124676, -0.6870660186, 0.1488010883, 0.1881926507, 0.2230074406, 0.2438773215, -0.5771554708, 0.1980127096, 0.1924194694, 0.1714663208, 0.1415647417, -0.1263078004, -0.3408652246, 0.2292248607, 0.1707807332, 0.1269564927, -0.2634142637, 0.0773174241, ]; run_comparison( "T=8, N=4, C=6 (repeated labels)", 8, 4, 6, vec![1, 1, 2, 0, 2, 3, 2, 1, 5, 0, 0, 0, 1, 2, 3, 4], [4, 4], vec![8, 8, 8, 8], vec![3, 4, 1, 4], 0, &expected_losses, &expected_grad_flat, 1e-3, 1e-3, ); } #[test] fn test_ctc_loss_long_sequence() { // T=10, N=2, C=8 // Expected losses and gradient from PyTorch let expected_losses = [12.629399299621582, 12.298524856567383]; let expected_grad_flat = [ -0.2570972741, -0.6013792753, 0.1061997041, 0.1321590245, 0.1533492655, 0.1637226790, 0.1598964781, 0.1431493312, -0.2540431321, 0.1788398325, -0.4038805366, 0.1477340311, 0.1197479516, 0.0920107216, 0.0686140805, 0.0509770736, -0.1364373565, -0.3724762201, 0.1489177048, -0.0966964588, 0.1463697106, 0.1275274903, 0.1033692732, 0.0794258416, -0.1771971881, 0.2073454857, -0.3109439015, 0.1249521226, -0.0101635465, 0.0692621097, 0.0533472970, 0.0433975980, -0.1398337185, -0.0874802172, 0.1705365479, -0.2174201906, 0.1150254831, 0.0460043959, 0.0647982135, 0.0483694859, -0.2332949787, 0.1969220787, -0.1270586401, 0.1098557115, -0.1364655048, 0.0715296715, 0.0553609394, 0.0631506816, -0.2169117928, 0.0929956511, 0.1624538749, -0.2009791434, 0.0904926360, -0.0248185843, 0.0532633252, 0.0435040221, -0.2313277274, 0.1497355998, -0.0024202778, 0.1029939279, -0.2776987851, 0.0963881761, 0.0351882279, 0.1271408647, -0.2590557337, 0.1577988416, 0.1429322213, -0.1401246637, 0.0866033062, -0.1151762009, 0.0683368817, 0.0586853735, -0.1322475076, 0.0806737095, 0.0528722852, 0.0920089707, -0.3037962914, 0.1280544847, -0.1391123086, 0.2215466499, -0.1918463260, 0.1376975775, 0.1160097718, -0.0549413785, 0.0970225409, -0.2708687484, 0.1147320047, 0.0521945432, -0.0504456684, -0.0012221609, 0.0644332916, 0.0818370953, -0.1036835983, 0.1512031406, -0.4072600305, 0.2651379406, -0.0681083873, 0.0860663429, 0.0810486302, 0.0434282124, 0.1056238264, -0.2994530201, 0.1729898751, -0.1215954795, -0.0481944978, -0.1697723418, 0.0725984722, 0.0692019314, 0.0859903544, 0.1680216491, -0.4071443677, 0.2292988002, -0.0205532499, 0.0566616580, 0.0326749459, 0.0861379728, 0.1142501161, -0.0448331088, 0.2054910213, -0.4298293889, -0.0647637174, -0.4240962267, 0.1013666242, -0.0110451467, 0.1519176364, 0.1661346704, -0.0719586164, 0.1524447650, -0.0496110357, 0.0562372655, -0.1889088154, 0.1013496071, 0.1339637935, 0.1694275290, 0.2007708699, -0.4232292175, -0.0401752405, -0.2951072752, 0.1443216652, -0.2857291698, 0.1489982456, 0.1327733696, 0.1096193567, 0.0852990299, -0.0413062274, 0.0820900649, -0.7903561592, 0.1329460591, 0.1535883099, 0.1631743014, 0.1585651338, 0.1412984729, -0.1033771932, 0.1799504310, 0.1697744429, -0.5749052763, 0.1189445183, 0.0911802500, 0.0679325759, 0.0505003072, ]; run_comparison( "T=10, N=2, C=8", 10, 2, 8, vec![1, 3, 5, 7, 2, 2, 4, 6, 1, 3], [2, 5], vec![10, 10], vec![5, 5], 0, &expected_losses, &expected_grad_flat, 1e-3, 1e-3, ); } #[test] fn test_ctc_loss_mixed_input_lengths() { // T=12, N=3, C=5, input_lengths=[12, 7, 10] // Expected losses and gradient from PyTorch let expected_losses = [10.595505714416504, 6.8078508377075195, 7.705057144165039]; let expected_grad_flat = [ -0.4790987670, -0.2554937005, 0.1991624236, 0.2478453964, 0.2875846624, -0.3495813310, 0.2268397957, 0.2150714993, -0.2442178279, 0.1518878639, -0.2764556706, 0.2474014312, -0.2137086987, 0.1371368915, 0.1056260392, -0.2729502618, -0.3609606028, 0.2159237266, 0.2238420397, 0.1941450834, -0.2953839302, 0.1920599341, 0.1974952668, -0.2054278404, 0.1112565696, -0.1719199270, 0.2299505472, -0.2864859998, 0.1497263014, 0.0787290633, -0.2035763413, -0.3042884767, 0.2126964629, 0.1810975969, 0.1140707731, -0.2759391963, 0.0975771844, 0.1823379993, -0.1112988219, 0.1073228419, -0.1336459517, 0.1869296581, -0.1996247321, 0.1846873760, -0.0383463502, -0.2254105806, -0.1834360659, 0.1925925612, 0.1462381780, 0.0700158924, -0.2259973884, -0.0393539183, 0.1802661419, -0.0571591072, 0.1422442794, -0.0609069727, 0.1089282706, -0.0313654318, 0.2186669111, -0.2353227735, -0.2840364873, -0.0632198900, 0.1755636632, 0.1377806067, 0.0339120962, -0.1904856712, -0.2139032930, 0.1827126741, 0.0056131603, 0.2160631120, -0.0243270602, -0.0070458520, 0.1070247591, 0.2239368409, -0.2995886803, -0.2955487072, 0.0309870224, 0.1654911339, 0.1581364125, -0.0590658709, -0.2191396207, -0.3791662455, 0.1803640425, 0.1225430891, 0.2953987718, -0.0436352938, -0.1575258970, 0.1785279512, 0.1756918877, -0.1530586481, -0.1834939867, 0.0909025446, 0.1423641294, 0.1959712654, -0.2457439601, -0.3619639874, -0.3929221630, 0.1820438206, 0.2454170734, 0.3274252713, -0.0628800318, -0.2567180395, 0.2112283260, 0.0507859327, 0.0575838275, -0.0587697029, 0.1174769849, 0.0783569664, 0.2290501744, -0.3661144078, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, -0.0725664943, -0.1532069892, 0.2162397504, -0.1248963475, 0.1344300956, -0.0362483934, 0.1295878887, -0.0502482466, 0.2470482886, -0.2901395261, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, -0.1349253207, 0.0867646411, 0.1998746395, -0.2658679783, 0.1141540110, -0.0705668628, 0.1519546807, -0.2509805560, 0.2475892603, -0.0779965296, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, -0.2338010073, 0.2471641302, 0.1834627241, -0.3026831448, 0.1058573127, -0.1155209392, 0.1921830922, -0.4129956067, 0.2229512781, 0.1133821756, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, -0.2636392713, 0.2323469073, -0.2913427949, 0.1800564528, 0.1425786912, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, ]; run_comparison( "T=12, N=3, C=5 (mixed input lengths)", 12, 3, 5, vec![1, 4, 2, 0, 3, 1, 0, 0, 2, 4, 1, 3], [3, 4], vec![12, 7, 10], vec![3, 2, 4], 0, &expected_losses, &expected_grad_flat, 1e-3, 1e-3, ); } #[test] fn test_ctc_loss_sum_reduction() { // Same inputs as comparison_uniform_input_lengths, sum reduction let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); let logits = generate_logits(5, 3, 4, &device).require_grad(); let log_probs = log_softmax(logits.clone(), 2); let targets = Tensor::::from_data( TensorData::new(vec![1_i64, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]), &device, ); let il = Tensor::::from_data([5_i64, 5, 5], &device); let tl = Tensor::::from_data([2_i64, 1, 3], &device); let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Sum); let loss_data = loss.clone().into_data().to_vec::().unwrap(); let expected_sum = 11.2816486359_f32; // Expected value from PyTorch assert_approx_equal(&loss_data, &[expected_sum], 1e-3); let grads = loss.backward(); let logits_grad = logits.grad(&grads).unwrap(); let grad_data = logits_grad.into_data().to_vec::().unwrap(); // Expected gradient from PyTorch let expected_grad = [ -0.1679008007_f32, -0.4595540464, 0.2795598209, 0.3478950262, -0.3913056254, -0.0832268298, 0.2535884976, 0.2209439576, -0.0502742566, 0.2766197622, 0.2054125518, -0.4317580462, -0.0544800088, -0.3144550920, 0.0847885981, 0.2841464877, -0.1844545156, -0.2063435912, 0.2222184092, 0.1685796976, 0.0278018005, 0.2657383382, -0.0336986706, -0.2598414719, -0.0482986756, -0.0098767160, -0.1533526182, 0.2115280181, -0.1380317956, -0.2198686600, 0.2042596638, 0.1536407918, 0.0534787849, 0.1819230020, -0.2805589139, 0.0451571345, -0.0895631388, 0.1996460557, -0.2741115987, 0.1640286744, -0.2200077325, -0.1693530381, 0.2101601064, 0.1792006642, 0.0398471877, -0.1131042913, -0.2363226712, 0.3095797896, -0.2163617164, 0.2740726173, -0.2124865055, 0.1547756046, -0.4312027395, -0.0446923785, 0.2330704331, 0.2428246588, -0.0050083841, -0.6256869435, 0.2689785957, 0.3617166877, ]; assert_approx_equal(&grad_data, &expected_grad, 1e-3); } #[test] fn test_ctc_loss_mean_reduction() { let device = NdArrayDevice::Cpu; let ctc = CTCLossConfig::new().init(); let logits = generate_logits(5, 3, 4, &device).require_grad(); let log_probs = log_softmax(logits.clone(), 2); let targets = Tensor::::from_data( TensorData::new(vec![1_i64, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]), &device, ); let il = Tensor::::from_data([5_i64, 5, 5], &device); let tl = Tensor::::from_data([2_i64, 1, 3], &device); let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Mean); let loss_data = loss.clone().into_data().to_vec::().unwrap(); let expected_mean = 2.2260115147_f32; // Expected value from PyTorch assert_approx_equal(&loss_data, &[expected_mean], 1e-3); let grads = loss.backward(); let logits_grad = logits.grad(&grads).unwrap(); let grad_data = logits_grad.into_data().to_vec::().unwrap(); // Expected gradient from PyTorch let expected_grad = [ -0.0279834662_f32, -0.0765923411, 0.0465933047, 0.0579825081, -0.1304352134, -0.0277422778, 0.0845294967, 0.0736479908, -0.0055860290, 0.0307355281, 0.0228236169, -0.0479731150, -0.0090800021, -0.0524091832, 0.0141314333, 0.0473577492, -0.0614848398, -0.0687812045, 0.0740728080, 0.0561932363, 0.0030890885, 0.0295264814, -0.0037442972, -0.0288712755, -0.0080497796, -0.0016461194, -0.0255587716, 0.0352546684, -0.0460105985, -0.0732895583, 0.0680865571, 0.0512135960, 0.0059420872, 0.0202136654, -0.0311732125, 0.0050174589, -0.0149271907, 0.0332743451, -0.0456852652, 0.0273381118, -0.0733359158, -0.0564510152, 0.0700533763, 0.0597335547, 0.0044274656, -0.0125671430, -0.0262580756, 0.0343977548, -0.0360602848, 0.0456787720, -0.0354144201, 0.0257959347, -0.1437342465, -0.0148974592, 0.0776901469, 0.0809415579, -0.0005564869, -0.0695207715, 0.0298865121, 0.0401907414, ]; assert_approx_equal(&grad_data, &expected_grad, 1e-3); } } ================================================ FILE: crates/burn-nn/src/loss/huber.rs ================================================ use burn_core as burn; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::{config::Config, module::Module}; use super::Reduction; /// Configuration to create a [Huber loss](HuberLoss). #[derive(Config, Debug)] pub struct HuberLossConfig { /// The bound where the Huber loss function changes from quadratic to linear behaviour. pub delta: f32, } impl HuberLossConfig { /// Initialize [Huber loss](HuberLoss). pub fn init(&self) -> HuberLoss { self.assertions(); HuberLoss { delta: self.delta, lin_bias: self.delta * self.delta * 0.5, } } fn assertions(&self) { assert!( self.delta >= 0., // This also tests for normality "Delta for Huber loss must be a non-negative number." ); } } /// Calculate the Huber loss between the inputs and the target. /// /// The loss for each element of the residuals `r = targets - predictions` is given by /// /// ```text /// L(r) = 0.5 * r^2 if |r| <= d /// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d /// ``` /// /// where `d` is the configured `delta`. In particular, this is equal to the /// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`, /// but behaves linearly instead of quadratically for large residuals. /// /// This loss function is less sensitive to outliers than the mean squared error loss. /// /// See also: #[derive(Module, Debug, Clone)] #[module(custom_display)] pub struct HuberLoss { /// The bound where the Huber loss function changes from quadratic to linear behaviour. pub delta: f32, /// Precomputed value for the linear bias. pub lin_bias: f32, // delta * delta * 0.5 precomputed } impl ModuleDisplay for HuberLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("delta", &self.delta) .add("lin_bias", &self.lin_bias) .optional() } } impl HuberLoss { /// Compute the loss element-wise for the predictions and targets, then reduce /// to a single loss value. /// /// `Reduction::Auto` behaves as `Reduction::Mean`. /// /// # Shapes /// /// - predictions: \[...dims\] /// - targets: \[...dims\] /// - output: \[1\] pub fn forward( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let loss = self.forward_no_reduction(predictions, targets); match reduction { Reduction::Mean | Reduction::Auto => loss.mean(), Reduction::Sum => loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Compute the loss element-wise for the predictions and targets. /// /// # Shapes /// /// - predictions: [...dims] /// - targets: [...dims] /// - output: [...dims] pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { let residuals = targets - predictions; self.forward_residuals(residuals) } /// Compute the loss element-wise for the given residuals. /// /// # Shapes /// /// - residuals: [...dims] /// - output: [...dims] pub fn forward_residuals( &self, residuals: Tensor, ) -> Tensor { let is_large = residuals.clone().abs().greater_elem(self.delta); // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the // `sign()` function, in general, suffers from a jump at 0. // Instead the following tensor implements `delta * sign(r)` for values outside // the bound: let softsign = residuals.clone().clamp(-self.delta, self.delta); // 0.5 * d^2 + d * (|r| - d) = // d * |r| - 0.5 * d^2 // Moreover |r| = sign(r) * r let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias); let inside = residuals.square().mul_scalar(0.5); inside.mask_where(is_large, outside) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; type TestTensor = Tensor; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_huber_loss() { let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]); let targets = TensorData::from([0., 0., 0., 0., 0.]); let device = Default::default(); let predict = TestTensor::<1>::from_data(predict, &device); let targets = TestTensor::<1>::from_data(targets, &device); let huber = HuberLossConfig::new(0.5).init(); let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum); let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto); let loss_no_reduction = huber.forward_no_reduction(predict, targets); let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]); loss_no_reduction .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([0.284]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([1.42]); loss_sum .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[cfg(feature = "std")] #[test] fn test_huber_ad_loss() { type TestAutodiffTensor = Tensor; let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]); let targets = TensorData::from([0., 0., 0., 0., 0.]); let device = Default::default(); let predict = TestAutodiffTensor::from_data(predict, &device).require_grad(); let targets = TestAutodiffTensor::from_data(targets, &device); let loss = HuberLossConfig::new(0.5).init(); let loss = loss.forward_no_reduction(predict.clone(), targets); let grads = loss.backward(); let grads_predict = predict.grad(&grads).unwrap(); let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]); grads_predict .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = HuberLossConfig::new(0.5); let loss = config.init(); assert_eq!( alloc::format!("{loss}"), "HuberLoss {delta: 0.5, lin_bias: 0.125}" ); } } ================================================ FILE: crates/burn-nn/src/loss/kldiv.rs ================================================ use burn_core as burn; use super::Reduction; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::{config::Config, module::Module}; /// Configuration to create a [KLDiv loss](KLDivLoss). #[derive(Config, Debug)] pub struct KLDivLossConfig { /// Specifies whether target is the log space. Default: False. #[config(default = false)] pub log_target: bool, } impl KLDivLossConfig { /// Initialize [KLDiv Loss](KLDivLoss). pub fn init(&self) -> KLDivLoss { KLDivLoss { log_target: self.log_target, } } } /// Kullback-Leibler Divergence Loss /// /// KL Divergence shows the difference between two probability distributions by measuring information loss /// /// KLDivLoss = /// ```tex /// y_{true} \cdot (\log{y_{true}} - \log{y_{pred}}) /// ``` /// By default, the loss expects the input in the log-space. /// The targets may also be provided in the log-space if `log_target` is true. /// /// See /// - [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence) #[derive(Module, Debug, Clone)] #[module(custom_display)] pub struct KLDivLoss { /// Specifies whether target is the log space. Default: False. pub log_target: bool, } impl ModuleDisplay for KLDivLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("log_target", &self.log_target).optional() } } impl KLDivLoss { /// Compute the criterion on the input tensor. /// /// `Reduction::Auto` behaves as `Reduction::BatchMean`,`Reduction::Mean` dose not align with the math definition. /// /// # Shapes /// /// - predictions: \[batch_size,num_targets\] /// - targets: \[batch_size,num_targets\] /// - output: \[1\] pub fn forward( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let loss = self.forward_no_reduction(predictions, targets); match reduction { Reduction::BatchMean | Reduction::Auto => { let batch_size = loss.dims()[0] as f32; loss.sum().div_scalar(batch_size) } Reduction::Mean => loss.mean(), Reduction::Sum => loss.sum(), } } /// Compute the criterion on the input tensor without reducing. pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { match self.log_target { true => targets.clone().exp().mul(targets.sub(predictions)), false => { let epsilon = 1e-8; let log_target = targets.clone().clamp(epsilon, 1.0).log(); targets.mul(log_target.sub(predictions)) } } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; type TestTensor = Tensor; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_kl_div_loss() { let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]); let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]); let device = Default::default(); let predict = TestTensor::<2>::from_data(predict, &device); let targets = TestTensor::<2>::from_data(targets, &device); let kl_loss = KLDivLossConfig { log_target: false }.init(); let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum); let loss_batch_mean = kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean); let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets); let expected_no_reduction = TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]); loss_no_reduction .into_data() .assert_approx_eq::(&expected_no_reduction, Tolerance::absolute(1e-5)); let expected_sum = TensorData::from([0.08191]); loss_sum .into_data() .assert_approx_eq::(&expected_sum, Tolerance::absolute(1e-5)); let expected_batch_mean = TensorData::from([0.04095]); loss_batch_mean .into_data() .assert_approx_eq::(&expected_batch_mean, Tolerance::absolute(1e-5)); } #[test] fn test_kl_div_loss_log_target() { let device = Default::default(); let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device); let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device); let kl_loss = KLDivLossConfig { log_target: true }.init(); let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone()); let expected_none = TensorData::from([0.3032653299, 0.1115650801]); loss_no_reduction .into_data() .assert_approx_eq::(&expected_none, Tolerance::absolute(1e-5)); let loss_batch_mean = kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean); let expected_bm = TensorData::from([0.207415204965]); loss_batch_mean .into_data() .assert_approx_eq::(&expected_bm, Tolerance::absolute(1e-5)); let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum); let expected_sum = TensorData::from([0.414830409931]); loss_sum .into_data() .assert_approx_eq::(&expected_sum, Tolerance::absolute(1e-5)); } #[cfg(feature = "std")] #[test] fn test_kl_div_ad_loss() { type TestAutodiffTensor = Tensor; let device = Default::default(); let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad(); let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device); let kl_loss = KLDivLossConfig { log_target: false }.init(); let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum); let grads = loss.backward(); let grads_predict = predict.grad(&grads).unwrap(); // d/d_pred [target * (log_target - pred)] = -target let expected = TensorData::from([[-0.4, -0.6]]); grads_predict .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = KLDivLossConfig { log_target: true }; let loss = config.init(); assert_eq!(alloc::format!("{loss}"), "KLDivLoss {log_target: true}"); } } ================================================ FILE: crates/burn-nn/src/loss/lp_loss.rs ================================================ use super::Reduction; use burn::config::Config; use burn::module::Module; use burn::tensor::{Tensor, backend::Backend}; use burn_core as burn; /// Configuration for the [Lp Loss](LpLoss) module. /// /// # Example /// /// ```ignore /// use burn_nn::loss::{LpLossConfig, Reduction}; /// /// // Create L1 loss (MAE when using mean reduction) /// let l1_loss = LpLossConfig::l1(); /// /// // Create L2 loss (MSE when using mean reduction) /// let l2_loss = LpLossConfig::l2(); /// /// // Create custom Lp loss with p=3 /// let l3_loss = LpLossConfig::new(3.0).init(); /// ``` #[derive(Config, Debug)] pub struct LpLossConfig { /// The exponent `p` determining the type of error measurement. /// /// Common values: /// - `p = 1.0`: L1 loss (MAE with mean reduction) - robust to outliers /// - `p = 2.0`: L2 loss (MSE with mean reduction) - standard choice, differentiable everywhere /// - `p > 2.0`: Increasingly sensitive to large errors (outliers) /// - `0 < p < 1`: More robust to outliers than L1 (quasi-norm) pub p: f64, } impl LpLossConfig { /// Initializes a [Lp Loss](LpLoss) module. /// /// # Panics /// /// Panics if `p <= 0`. pub fn init(&self) -> LpLoss { self.assertions(); LpLoss { p: self.p } } /// Creates L1 loss (p=1). /// /// When used with `Reduction::Mean`, this computes Mean Absolute Error (MAE). /// When used with `Reduction::Sum`, this computes Sum of Absolute Errors (SAE). pub fn l1() -> LpLoss { LpLoss { p: 1.0 } } /// Creates L2 loss (p=2). /// /// When used with `Reduction::Mean`, this computes Mean Squared Error (MSE). /// When used with `Reduction::Sum`, this computes Sum of Squared Errors (SSE). pub fn l2() -> LpLoss { LpLoss { p: 2.0 } } fn assertions(&self) { assert!(self.p > 0.0, "The order of the norm p must be positive.") } } /// Computes the Lp Loss between predictions and targets. /// /// This loss function computes the element-wise p-th power of absolute errors, /// then reduces them via mean or sum. /// /// # Mathematical Definition /// /// For predictions `ŷ` and targets `y`, the element-wise loss is: /// /// ```text /// Lᵢ = |ŷᵢ - yᵢ|ᵖ /// ``` /// /// With mean reduction (default), the final loss is: /// /// ```text /// L = (1/n) × Σᵢ |ŷᵢ - yᵢ|ᵖ /// ``` /// /// # Notes /// /// - This implementation computes `|error|^p`, **not** the Lp norm `(Σ|error|^p)^(1/p)`. /// - The `p = 1` case uses an optimized `abs()` operation. /// - The `p = 2` case uses an optimized computation `error * error` instead of `powf`. /// /// # Example /// /// ```ignore /// use burn_nn::loss::{LpLossConfig, Reduction}; /// use burn::tensor::Tensor; /// /// // Create L2 loss /// let l2_loss = LpLossConfig::l2(); /// /// let predictions: Tensor = /* model output */; /// let targets: Tensor = /* ground truth */; /// /// // Compute loss with mean reduction (MSE) /// let mse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Mean); /// /// // Compute loss with sum reduction (SSE) /// let sse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Sum); /// /// // Compute loss with no reduction /// let unreduced_l2_loss = l2_loss.forward_no_reduction(predictions, targets); /// ``` #[derive(Module, Clone, Debug)] pub struct LpLoss { /// The order of the norm (e.g., 1 for L1, 2 for L2). /// Equivalently, the exponent `p` for computing `|error|^p`. pub p: f64, } impl LpLoss { /// Computes the element-wise loss `|error|^p` with reduction. /// /// # Arguments /// /// * `predictions` - The model's predicted values. /// * `targets` - The ground truth target values. /// * `reduction` - Specifies how to reduce the element-wise losses: /// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses. /// - `Reduction::Sum`: Returns the sum of all element-wise losses. /// /// # Returns /// /// A scalar tensor containing the reduced loss value. /// /// # Shapes /// /// - predictions: `[...dims]` - Any shape /// - targets: `[...dims]` - Must match predictions shape /// - output: `[1]` - Scalar loss value pub fn forward( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let unreduced_loss = self.forward_no_reduction(predictions, targets); match reduction { Reduction::Mean | Reduction::Auto => unreduced_loss.mean(), Reduction::Sum => unreduced_loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Computes the element-wise loss `|error|^p` without reduction. /// /// # Arguments /// /// * `predictions` - The model's predicted values. /// * `targets` - The ground truth target values. /// /// # Returns /// /// A tensor of the same shape as the inputs, containing `|prediction - target|^p` /// for each element. /// /// # Shapes /// /// - predictions: `[...dims]` - Any shape /// - targets: `[...dims]` - Must match predictions shape /// - output: `[...dims]` - Same shape as inputs pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { let error = predictions.sub(targets); // Use simplified/optimized expressions for common cases (p = 1, p = 2) if self.p == 1.0 { // L1 loss error.abs() } else if self.p == 2.0 { // L2 loss error.clone().mul(error) } else { error.abs().powf_scalar(self.p) } } /// Computes the element-wise loss `|error|^p` with reduction over specified dimensions. /// /// Calculates element-wise `|predictions - targets|^p`, then takes the mean /// over the specified dimensions. Useful for per-sample or per-channel losses (e.g., when /// working with images). /// /// Dimensions can be provided in any order. They are sorted internally and /// reduced from highest to lowest to ensure indices remain valid. /// /// # Arguments /// /// * `predictions` - The model's predicted values. /// * `targets` - The ground truth target values. /// * `dims` - Dimensions to reduce over. /// /// # Returns /// /// A tensor with the specified dimensions reduced to size 1. /// /// # Example /// /// ```ignore /// // Image tensor: [batch, C, H, W] /// let l2_loss = LpLossConfig::l2(); /// /// // Per-image MSE for PSNR: reduce over C, H, W → [batch, 1, 1, 1] /// let mse_per_image = l2_loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]); /// ``` pub fn forward_reduce_dims( &self, predictions: Tensor, targets: Tensor, dims: &[usize], ) -> Tensor { let error = self.forward_no_reduction(predictions, targets); // Sort the dimensions to ascending order let mut sorted_dims = dims.to_vec(); sorted_dims.sort(); // Reduce over specified dimensions error.mean_dims(sorted_dims.as_slice()) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_lp_loss_l1_constructor() { let loss_func_l1 = LpLossConfig::l1(); let loss_func_p1 = LpLossConfig::new(1.0).init(); assert_eq!(loss_func_l1.p, 1.0); assert_eq!(loss_func_l1.p, loss_func_p1.p); } #[test] fn test_lp_loss_l2_constructor() { let loss_func_l2 = LpLossConfig::l2(); let loss_func_p2 = LpLossConfig::new(2.0).init(); assert_eq!(loss_func_l2.p, 2.0); assert_eq!(loss_func_l2.p, loss_func_p2.p); } #[test] fn test_lp_loss_l1() { let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 2.0]]), &device, ); let loss_func = LpLossConfig::l1(); let loss_no_reduction = loss_func.forward_no_reduction(predictions.clone(), targets.clone()); let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum); let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]); loss_no_reduction.into_data().assert_eq(&expected, false); let expected = TensorData::from([1.0]); loss_auto.into_data().assert_eq(&expected, false); let expected = TensorData::from([4.0]); loss_sum.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_l2() { let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 2.0]]), &device, ); let loss_func = LpLossConfig::l2(); let loss_no_reduction = loss_func.forward_no_reduction(predictions.clone(), targets.clone()); let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum); let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]); loss_no_reduction.into_data().assert_eq(&expected, false); let expected = TensorData::from([1.5]); loss_auto.into_data().assert_eq(&expected, false); let expected = TensorData::from([6.0]); loss_sum.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_p_half() { // L0.5 quasi-norm: more robust to outliers than L1 let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 0.0]]), &device, ); let loss_func = LpLossConfig::new(0.5).init(); let loss_no_reduction = loss_func.forward_no_reduction(predictions.clone(), targets.clone()); let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum); // |1-2|^0.5 = 1, |2-1|^0.5 = 1, |3-3|^0.5 = 0, |4-0|^0.5 = 2 let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]); loss_no_reduction.into_data().assert_eq(&expected, false); let expected = TensorData::from([1.0]); loss_auto.into_data().assert_eq(&expected, false); let expected = TensorData::from([4.0]); loss_sum.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_p3() { // L3 norm: more sensitive to outliers than L2 let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 2.0]]), &device, ); let loss_func = LpLossConfig::new(3.0).init(); let loss_no_reduction = loss_func.forward_no_reduction(predictions.clone(), targets.clone()); let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum); // |1-2|^3 = 1, |2-1|^3 = 1, |3-3|^3 = 0, |4-2|^3 = 8 let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]); loss_no_reduction.into_data().assert_eq(&expected, false); let expected = TensorData::from([2.5]); loss_auto.into_data().assert_eq(&expected, false); let expected = TensorData::from([10.0]); loss_sum.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_zero_error() { // Test when predictions exactly match targets let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = predictions.clone(); let loss_func_l1 = LpLossConfig::l1(); let loss_func_l2 = LpLossConfig::l2(); let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto); let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto); let expected = TensorData::from([0.0]); l1_loss.into_data().assert_eq(&expected, false); l2_loss.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_negative_errors() { // Test that negative errors are handled correctly (absolute value) let device = Default::default(); let predictions = Tensor::::from_data(TensorData::from([1.0, 2.0, 3.0]), &device); let targets = Tensor::::from_data(TensorData::from([3.0, 4.0, 5.0]), &device); let loss_func_l1 = LpLossConfig::l1(); let loss_func_p1 = LpLossConfig::new(1.0).init(); let loss_no_reduction_l1 = loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone()); let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets); // All errors are negative: 1-3=-2, 2-4=-2, 3-5=-2, but |error| = 2 let expected = TensorData::from([2.0, 2.0, 2.0]); loss_no_reduction_l1.into_data().assert_eq(&expected, false); loss_no_reduction_p1.into_data().assert_eq(&expected, false); } #[test] fn test_lp_loss_3d_tensor() { let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]), &device, ); let loss_func_l2 = LpLossConfig::l2(); let loss_func_p2 = LpLossConfig::new(2.0).init(); let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto); // Errors: 1, 0, 0, -1, 1, 0, 0, -2 // Squared: 1, 0, 0, 1, 1, 0, 0, 4 // Mean: 7/8 = 0.875 let expected = TensorData::from([0.875]); loss_l2.into_data().assert_eq(&expected, false); loss_p2.into_data().assert_eq(&expected, false); } #[test] #[should_panic(expected = "The order of the norm p must be positive.")] fn test_lp_loss_negative_p_panics() { let _ = LpLossConfig::new(-1.0).init(); } #[test] #[should_panic(expected = "The order of the norm p must be positive.")] fn test_lp_loss_zero_p_panics() { let _ = LpLossConfig::new(0.0).init(); } #[test] fn test_lp_loss_fractional_p() { // Test p = 1.5 let device = Default::default(); let predictions = Tensor::::from_data(TensorData::from([0.0, 4.0]), &device); let targets = Tensor::::from_data(TensorData::from([1.0, 0.0]), &device); let loss_func = LpLossConfig::new(1.5).init(); let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets); // |0-1|^1.5 = 1, |4-0|^1.5 = 8 let expected = TensorData::from([1.0, 8.0]); loss_no_reduction.into_data().assert_eq(&expected, false); } #[test] fn test_forward_reduce_dims_single_dim() { let device = Default::default(); // Shape: [2, 3] let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]), &device, ); let loss_func_l2 = LpLossConfig::l2(); let loss_func_p2 = LpLossConfig::new(2.0).init(); // Reduce over dim 1 -> should give [2, 1] shape let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]); let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]); // Errors row 0: [1, 0, -3] -> squared: [1, 0, 9] -> mean: 10/3 // Errors row 1: [3, 0, 0] -> squared: [9, 0, 0] -> mean: 3 let expected = TensorData::from([[10.0 / 3.0], [3.0]]); loss_l2 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); loss_p2 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_first_dim() { let device = Default::default(); // Shape: [2, 3] let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]), &device, ); let loss_func = LpLossConfig::l2(); // Reduce over dim 0 -> should give [1, 3] shape let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]); // Squared errors: [[1, 0, 9], [9, 0, 0]] // Mean over dim 0: [5, 0, 4.5] let expected = TensorData::from([[5.0, 0.0, 4.5]]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_multiple_dims() { let device = Default::default(); // Shape: [2, 2, 2] let predictions = Tensor::::from_data( TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]), &device, ); let loss_func = LpLossConfig::l2(); // Reduce over dims 1 and 2 -> should give [2, 1, 1] shape let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]); // Batch 0 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25 // Batch 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25 let expected = TensorData::from([[[1.25]], [[1.25]]]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_all_dims() { let device = Default::default(); // Shape: [2, 2] let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 2.0]]), &device, ); let loss_func = LpLossConfig::l2(); // Reduce over all dims -> should give [1, 1] shape let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]); // Errors: [[-1, 1], [0, 2]] -> squared: [[1, 1], [0, 4]] -> mean: 1.5 let expected = TensorData::from([[1.5]]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_image_batch() { // Simulate per-image loss for [batch, C, H, W] tensor (common use case for PSNR) let device = Default::default(); // Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2) let predictions = Tensor::::from_data( TensorData::from([ [[[1.0, 2.0], [3.0, 4.0]]], // Image 1 [[[5.0, 6.0], [7.0, 8.0]]], // Image 2 ]), &device, ); let targets = Tensor::::from_data( TensorData::from([ [[[0.0, 2.0], [3.0, 6.0]]], // Target 1 [[[5.0, 5.0], [7.0, 7.0]]], // Target 2 ]), &device, ); let loss_func = LpLossConfig::l2(); // Reduce over C, H, W (dims 1, 2, 3) to get per-image MSE let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]); // Image 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 1.25 // Image 2 errors: [[0, 1], [0, 1]] -> squared: [[0, 1], [0, 1]] -> mean: 0.5 let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_with_p1() { let device = Default::default(); // Shape: [2, 3] let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]), &device, ); let loss_func = LpLossConfig::l1(); // Reduce over dim 1 -> should give [2, 1] shape let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]); // Abs errors row 0: [1, 3, 0] -> mean: 4/3 // Abs errors row 1: [3, 0, 3] -> mean: 2 let expected = TensorData::from([[4.0 / 3.0], [2.0]]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_forward_reduce_dims_empty_dims() { // Reducing over no dimensions should return the unreduced loss let device = Default::default(); let predictions = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[0.0, 2.0], [3.0, 6.0]]), &device, ); let loss_func = LpLossConfig::l2(); let loss_reduce_dims = loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]); let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets); // Should be equivalent loss_reduce_dims .into_data() .assert_eq(&loss_no_reduction.into_data(), true); } #[test] fn test_forward_reduce_dims_zero_error() { let device = Default::default(); // Shape: [2, 2, 2] let predictions = Tensor::::from_data( TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), &device, ); let targets = predictions.clone(); let loss_func = LpLossConfig::l2(); let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]); // All zeros, reduced to shape: [2, 1, 1] let expected = TensorData::from([[[0.0]], [[0.0]]]); loss.into_data().assert_eq(&expected, false); } } ================================================ FILE: crates/burn-nn/src/loss/mod.rs ================================================ #[cfg(feature = "pretrained")] #[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))] mod pretrained; #[cfg(feature = "pretrained")] #[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))] pub use pretrained::*; mod binary_cross_entropy; mod cosine_embedding; mod cross_entropy; mod ctc; mod huber; mod kldiv; mod lp_loss; mod mse; mod poisson; mod reduction; mod rnnt; mod smooth_l1; pub use binary_cross_entropy::*; pub use cosine_embedding::*; pub use cross_entropy::*; pub use ctc::*; pub use huber::*; pub use kldiv::*; pub use lp_loss::*; pub use mse::*; pub use poisson::*; pub use reduction::*; pub use rnnt::*; pub use smooth_l1::*; ================================================ FILE: crates/burn-nn/src/loss/mse.rs ================================================ use burn_core as burn; use crate::loss::reduction::Reduction; use burn::module::Module; use burn::tensor::{Tensor, backend::Backend}; /// Calculate the mean squared error loss from the input logits and the targets. #[derive(Module, Clone, Debug)] pub struct MseLoss; impl Default for MseLoss { fn default() -> Self { Self::new() } } impl MseLoss { /// Create the criterion. pub fn new() -> Self { Self } /// Compute the criterion on the input tensor. /// /// # Shapes /// /// - logits: [batch_size, num_targets] /// - targets: [batch_size, num_targets] pub fn forward( &self, logits: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let tensor = self.forward_no_reduction(logits, targets); match reduction { Reduction::Mean | Reduction::Auto => tensor.mean(), Reduction::Sum => tensor.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Compute the criterion on the input tensor without reducing. pub fn forward_no_reduction( &self, logits: Tensor, targets: Tensor, ) -> Tensor { logits.sub(targets).square() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn test_mse_loss() { let device = Default::default(); let logits = Tensor::::from_data( TensorData::from([[1.0, 2.0], [3.0, 4.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[2.0, 1.0], [3.0, 2.0]]), &device, ); let mse = MseLoss::new(); let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone()); let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto); let loss_sum = mse.forward(logits, targets, Reduction::Sum); let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]); loss_no_reduction.into_data().assert_eq(&expected, false); let expected = TensorData::from([1.5]); loss.into_data().assert_eq(&expected, false); let expected = TensorData::from([6.0]); loss_sum.into_data().assert_eq(&expected, false); } #[test] fn display() { let loss = MseLoss::new(); assert_eq!(alloc::format!("{loss}"), "MseLoss"); } } ================================================ FILE: crates/burn-nn/src/loss/poisson.rs ================================================ use burn_core as burn; use core::f32::consts::PI; use burn::tensor::cast::ToElement; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::{config::Config, module::Module}; use super::Reduction; /// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. /// /// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss /// behavior, such as whether the input is in log-space, whether to include the Stirling /// approximation term, and a small epsilon value to avoid numerical instability. #[derive(Config, Debug)] pub struct PoissonNllLossConfig { /// If `true`, the predictions are expected to be in log-space. /// /// When `log_input` is `true`, the loss is computed as: /// ```text /// L(predictions, target) = exp(predictions) - target * predictions /// ``` /// When `log_input` is `false`, the loss is computed as: /// ```text /// L(predictions, target) = predictions - target * log(predictions + eps) /// ``` #[config(default = true)] pub log_input: bool, /// Whether to compute the full loss, including the Stirling approximation term. /// /// When `full` is `true`, the Stirling approximation term is added to the loss: /// ```text /// target * log(target) - target + 0.5 * log(2 * PI * target) /// ``` #[config(default = false)] pub full: bool, /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. /// /// This epsilon value is added to the predictions to ensure numerical stability /// when computing the logarithm. #[config(default = 1e-8)] pub eps: f64, } impl PoissonNllLossConfig { /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. /// /// # Panics /// - Panics if `eps` is not a positive number. pub fn init(&self) -> PoissonNllLoss { self.assertions(); PoissonNllLoss { log_input: self.log_input, full: self.full, eps: self.eps, } } /// Validates the configuration parameters. /// /// # Panics /// - Panics if `eps` is not a positive number. fn assertions(&self) { assert!( self.eps > 0., "eps for PoissonNllLoss must be a positive number." ); } } /// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target. /// /// This loss function is used when the target values are assumed to follow a Poisson distribution. /// The loss is defined as: /// ```text /// target ~ Poisson(input) /// L(predictions, target) = predictions - target * log(predictions) + log(target!) /// ``` /// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula. /// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss. /// /// For more details, see: /// #[derive(Module, Debug, Clone)] #[module(custom_display)] pub struct PoissonNllLoss { /// If `true`, the predictions are expected to be in log-space. pub log_input: bool, /// Whether to compute the full loss, including the Stirling approximation term. pub full: bool, /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. pub eps: f64, } impl ModuleDisplay for PoissonNllLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("log_input", &self.log_input) .add("full", &self.full) .add("eps", &self.eps) .optional() } } impl PoissonNllLoss { /// Computes the loss element-wise for the given predictions and targets, then reduces /// the result to a single loss value. /// /// # Arguments /// - `predictions`: The predicted values. /// - `targets`: The target values. /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`. /// /// # Shapes /// - `predictions`: `[...dims]` /// - `targets`: `[...dims]` /// - `output`: `[1]` /// /// # Panics /// - Panics if the shapes of `predictions` and `targets` do not match. /// - Panics if any target value is negative. /// - Panics if `log_input` is `false` and any prediction value is negative. pub fn forward( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let loss = self.forward_no_reduction(predictions, targets); match reduction { Reduction::Mean | Reduction::Auto => loss.mean(), Reduction::Sum => loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Computes the loss element-wise for the given predictions and targets without reduction. /// /// # Arguments /// - `predictions`: The predicted values. /// - `targets`: The target values. /// /// # Shapes /// - `predictions`: `[...dims]` /// - `targets`: `[...dims]` /// - `output`: `[...dims]` /// /// # Panics /// - Panics if the shapes of `predictions` and `targets` do not match. /// - Panics if any target value is negative. /// - Panics if `log_input` is `false` and any prediction value is negative. pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { self.assertions(&predictions, &targets); let mut loss; if self.log_input { loss = predictions.clone().exp() - targets.clone() * predictions; } else { loss = predictions.clone() - targets.clone() * (predictions + self.eps).log(); } if self.full { let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone() + (targets.clone() * 2. * PI).log() * 0.5; loss = loss + log_stirling_term .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like()); } loss } /// Validates the input tensors for the loss computation. /// /// # Panics /// - Panics if the shapes of `predictions` and `targets` do not match. /// - Panics if any target value is negative. /// - Panics if `log_input` is `false` and any prediction value is negative. fn assertions( &self, predictions: &Tensor, targets: &Tensor, ) { let predictions_dims = predictions.dims(); let targets_dims = targets.dims(); assert!( predictions_dims == targets_dims, "Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?})." ); assert!( targets .clone() .greater_equal_elem(0.) .all() .into_scalar() .to_bool(), "All the values of `targets` must be non-negative." ); if !self.log_input { assert!( predictions .clone() .greater_equal_elem(0.) .all() .into_scalar() .to_bool(), "When `log_input` is `false`, all the values of `predictions` must be non-negative." ); } } } #[cfg(test)] mod tests { #![allow(clippy::approx_constant)] use super::*; use crate::TestBackend; use burn::tensor::TensorData; type TestTensor = Tensor; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_poisson_nll_loss() { let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().init(); let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]); loss_no_reduction .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([21.0321]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([126.1929]); loss_sum .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_poisson_nll_loss_no_log_input() { let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]); let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]); loss_no_reduction .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_poisson_nll_loss_full() { let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().with_full(true).init(); let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]); loss_no_reduction .into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([21.9920]); loss.into_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([131.9518]); loss_sum .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[cfg(feature = "std")] #[test] fn test_poisson_nll_loss_gradients() { type TestAutodiffTensor = Tensor; let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); let device = Default::default(); let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad(); let predictions2 = predictions1.clone(); let targets = TestAutodiffTensor::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().with_full(false).init(); let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); let loss_full_sum = poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum); let grads = loss_sum.backward(); let grads_full = loss_full_sum.backward(); let grads_predictions1 = predictions1.grad(&grads).unwrap(); let grads_predictions2 = predictions2.grad(&grads_full).unwrap(); let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]); grads_predictions1 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); grads_predictions2 .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] #[should_panic = "eps for PoissonNllLoss must be a positive number."] fn test_negative_eps() { let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); } #[test] #[should_panic = "All the values of `targets` must be non-negative."] fn test_targets_with_negative_values() { let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]); let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().init(); let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); } #[test] #[should_panic = "Shape of targets"] fn test_shape_tensors() { let predictions = TensorData::from([0., 1., 2.]); let targets = TensorData::from([0., 1.]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().init(); let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); } #[test] #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."] fn test_exp_predictions_non_negative() { let predictions = TensorData::from([0.3, -0.1, 0.4]); let targets = TensorData::from([0., 1., 0.]); let device = Default::default(); let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); } #[test] fn display() { let config = PoissonNllLossConfig::new(); let loss = config.init(); assert_eq!( alloc::format!("{loss}"), "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" ); } } ================================================ FILE: crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs ================================================ use burn_core as burn; use super::vgg19::Vgg19; use super::weights::load_vgg19_weights; use crate::loss::Reduction; use burn::config::Config; use burn::module::Module; use burn::tensor::{Tensor, backend::Backend}; /// Configuration for the [Gram Matrix Loss](GramMatrixLoss) module. /// /// Gram Matrix Loss (often used in Neural Style Transfer) measures the difference in /// texture or style between two images. It does this by comparing the spatial correlations /// of their feature maps extracted from a pretrained VGG19 network. /// /// # Example /// /// ```rust,ignore /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig; /// /// // Create Gram Matrix Loss with equal weights for all 5 layers /// let device = Default::default(); /// let gram_loss = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]) /// .with_use_avg_pool(true) /// .init::(&device); /// ``` /// /// # Reference /// [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) #[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))] #[derive(Config, Debug)] pub struct GramMatrixLossConfig { /// The weights of the layer contributing to the total loss. /// Should have a length of 5 since Gram Matrix Loss uses 5 specific VGG19 layers. pub layer_weights: Vec, /// If true, uses average pooling in the VGG19 feature extractor. /// If false, uses the max pooling. #[config(default = "false")] pub use_avg_pool: bool, } impl GramMatrixLossConfig { /// Initializes a [Gram Matrix Loss](GramMatrixLoss) module. /// /// This will automatically download and load the pretrained VGG19 weights /// if they are not already cached locally. /// /// # Panics /// /// - If `layer_weights` does not contain exactly 5 elements. /// - If any of the weights in `layer_weights` is not non-negative. /// /// # Example /// /// ```rust,ignore /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig; /// /// // Create Gram Matrix Loss with equal weights for all 5 layers /// let device = Default::default(); /// let gram_loss = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]) /// .init::(&device); /// ``` pub fn init(&self, device: &B::Device) -> GramMatrixLoss { self.assertions(); let vgg19 = Vgg19::new(self.use_avg_pool, device); let pretrained_vgg19 = load_vgg19_weights(vgg19).no_grad(); GramMatrixLoss { layer_weights: self.layer_weights.clone(), feat_extractor: pretrained_vgg19, } } fn assertions(&self) { assert!( self.layer_weights.len() == 5, "The layer_weights vector must contain exactly 5 elements" ); assert!( self.layer_weights.iter().all(|&w| w >= 0.0), "All layer weights must be non-negative" ); } } /// Computes the Gram Matrix Loss between predictions and targets. /// /// This loss function extracts features from 5 specific layers of a pretrained VGG19 network /// (`conv1_1`, `conv2_1`, `conv3_1`, `conv4_1`, `conv5_1`). It computes the Gram matrix for each /// layer's feature map, which captures the style/texture information, and calculates the /// Mean Squared Error between the Gram matrices of the predictions and targets. /// /// # Note /// /// The Gram Matrix Loss assumes the input tensors are already in the \[0.0, 1.0\] range. /// /// # Example /// /// ```rust,ignore /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig; /// /// // Initialize the loss function via its config /// let device = Default::default(); /// // Uses max pool by default /// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::(&device); /// ``` /// /// # Reference /// [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) #[cfg_attr(docsrs, doc(cfg(feature = "pretrained")))] #[derive(Module, Debug)] pub struct GramMatrixLoss { /// The weights of the layer contributing to the total loss. /// Should have a length of 5 since Gram Matrix Loss uses 5 layers. pub layer_weights: Vec, /// Pretrained VGG19 feature extractor pub feat_extractor: Vgg19, } impl GramMatrixLoss { /// Computes the Gram Matrix Loss with reduction. /// /// # Arguments /// /// - `predictions` - The model's predicted images. The pixels should be in the \[0.0, 1.0\] range. /// - `targets` - The ground truth target images. The pixels should be in the \[0.0, 1.0\] range. /// - `reduction` - Specifies how to reduce the batch losses. /// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of batch losses. /// - `Reduction::Sum`: Returns the sum of batch losses. /// /// # Returns /// /// A scalar tensor containing the reduced loss value. /// /// # Shapes /// /// - predictions: `[batch_size, 3, height, width]` /// - targets: `[batch_size, 3, height, width]` /// - output: `[1]` /// /// # Panics /// /// - If the `reduction` type is not supported. /// - If the input tensors do not have exactly 3 channels. /// /// # Example /// /// ```ignore /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig; /// use burn::loss::Reduction; /// /// let device = Default::default(); /// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::(&device); /// /// let predictions = /* [N, 3, H, W] */; /// let targets = /* [N, 3, H, W] */; /// /// # Returns a tensor with shape [1] containing a single loss value /// let loss = loss_fn.forward(predictions, targets, Reduction::Mean); /// ``` pub fn forward( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let unreduced_loss = self.forward_no_reduction(predictions, targets); match reduction { Reduction::Mean | Reduction::Auto => unreduced_loss.mean(), Reduction::Sum => unreduced_loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Computes the unreduced Gram Matrix Loss per sample in the batch. /// /// # Arguments /// /// - `predictions` - The model's predicted images. The pixels should be in the \[0.0, 1.0\] range. /// - `targets` - The ground truth target images. The pixels should be in the \[0.0, 1.0\] range. /// /// # Returns /// /// A 1D tensor containing the total weighted loss for each sample in the batch. /// /// # Shapes /// /// - predictions: `[batch_size, 3, height, width]` /// - targets: `[batch_size, 3, height, width]` /// - output: `[batch_size]` /// /// # Panics /// /// - If the input tensors do not have exactly 3 channels. /// /// # Example /// /// ```rust,ignore /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig; /// /// let device = Default::default(); /// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::(&device); /// /// let predictions = /* [N, 3, H, W] */; /// let targets = /* [N, 3, H, W] */; /// /// // Returns a tensor of shape [N] containing the loss for each sample /// let unreduced_loss = loss_fn.forward_no_reduction(predictions, targets); /// ``` pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { let pred_processed = self.preprocess_input(predictions); let target_processed = self.preprocess_input(targets); // Both vectors contain 5 entries since there are 5 layers // Both feature map tensors already have the shape [N, C, H * W] let pred_features = self.feat_extractor.forward(pred_processed); let mut pred_normalization_factors = Vec::with_capacity(5); for feature_tensor in &pred_features { let [_, c, h_times_w] = feature_tensor.dims(); let (c_f, hw_f) = (c as f32, h_times_w as f32); pred_normalization_factors.push(4.0 * c_f * c_f * hw_f * hw_f); } let target_features = self.feat_extractor.forward(target_processed); // Create vector which will hold loss tensors for each layer let mut loss_tensors = Vec::with_capacity(pred_features.len()); // Compute and add the weighted loss for each layer to the final loss tensor. // Note that the loss tensor for each layer and the final loss tensors // contains a loss value for each sample in the batch. for (pred_f, target_f) in pred_features.into_iter().zip(target_features) { // Compute Gram matrix as G = F(F^T) // [N, C, H*W] times [N, H*W, C] equals [N, C, C] let pred_gram_matrices = pred_f.clone().matmul(pred_f.clone().transpose()); let target_gram_matrices = target_f.clone().matmul(target_f.clone().transpose()); let gram_matrices_diff = pred_gram_matrices - target_gram_matrices; let gram_matrices_diff_squared = gram_matrices_diff.powi_scalar(2); // For each sample, sum over all the entries of the gram matrix. // Equivalently, sum over the last two dimensions (the two C dimensions). let loss = gram_matrices_diff_squared .sum_dims(&[1, 2]) .squeeze_dims::<1>(&[1, 2]); loss_tensors.push(loss); } // Sum each layer's loss in the vector of loss tensors let scaled_loss_tensors: Vec> = loss_tensors .into_iter() .zip(pred_normalization_factors) .zip(self.layer_weights.clone()) .map(|((loss_tensor, norm_factor), weight)| { loss_tensor.div_scalar(norm_factor).mul_scalar(weight) }) .collect(); let stacked_loss_tensors = Tensor::stack::<2>(scaled_loss_tensors, 1); stacked_loss_tensors.sum_dim(1).squeeze_dim(1) } /// Applies standard ImageNet normalization to the input tensor for the VGG19 network. /// /// # Note /// /// This method assumes the input tensor is already in the \[0.0, 1.0\] range. /// /// # Panics /// /// - If the input tensor does not have exactly 3 channels. fn preprocess_input(&self, tensor: Tensor) -> Tensor { let device = &tensor.device(); let channels = tensor.dims()[1]; assert!( channels == 3, "Expected input tensor to have exactly 3 channels, but got {}", channels ); // ImageNet normalization constants let mean = Tensor::::from_floats([0.485, 0.456, 0.406], device).reshape([1, 3, 1, 1]); let std = Tensor::::from_floats([0.229, 0.224, 0.225], device).reshape([1, 3, 1, 1]); (tensor - mean) / std } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::Distribution; #[test] #[should_panic(expected = "The layer_weights vector must contain exactly 5 elements")] fn test_gram_matrix_loss_config_invalid_length() { let device = Default::default(); GramMatrixLossConfig::new(vec![1.0, 1.0]).init::(&device); } #[test] #[should_panic(expected = "All layer weights must be non-negative")] fn test_gram_matrix_loss_config_negative_weights() { let device = Default::default(); GramMatrixLossConfig::new(vec![1.0, -1.0, 1.0, 1.0, 1.0]).init::(&device); } #[test] fn test_gram_matrix_loss_config_valid_weights() { let device = Default::default(); let layer_weights = vec![0.0, 0.2, 0.2, 0.25, 0.4]; let loss_fn = GramMatrixLossConfig::new(layer_weights.clone()).init::(&device); assert_eq!( loss_fn.layer_weights, layer_weights, "Expected layer weights vector {:?}, got {:?}", loss_fn.layer_weights, layer_weights ); } #[test] #[should_panic(expected = "Expected input tensor to have exactly 3 channels, but got 1")] fn test_gram_matrix_loss_1_channel_panic() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; // 1 channel (Grayscale) should panic let tensor1: Tensor = Tensor::random([2, 1, 16, 16], Distribution::Default, &device); let tensor2 = tensor1.clone(); let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean); } #[test] #[should_panic(expected = "Expected input tensor to have exactly 3 channels, but got 4")] fn test_gram_matrix_loss_4_channel_panic() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; // 4 channels (e.g., RGBA) should panic let tensor1: Tensor = Tensor::random([2, 4, 16, 16], Distribution::Default, &device); let tensor2 = tensor1.clone(); let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean); } #[test] fn test_gram_matrix_loss_zero_for_identical_inputs() { let device = Default::default(); // Instantiate using Vgg19::new() to use random weights let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; let tensor1: Tensor = Tensor::random([2, 3, 16, 16], Distribution::Default, &device); let tensor2 = tensor1.clone(); let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean); let loss_val: f32 = loss.into_scalar(); // Loss should be exactly 0 (or extremely close due to floating point) when inputs are identical assert!( loss_val.abs() < 1e-4, "Loss should be zero for identical inputs" ); } #[test] fn test_gram_matrix_loss_greater_than_zero_for_different_inputs() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; let tensor1: Tensor = Tensor::ones([2, 3, 16, 16], &device); let tensor2: Tensor = Tensor::zeros([2, 3, 16, 16], &device); let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean); let loss_val: f32 = loss.into_scalar(); assert!( loss_val > 0.0, "Loss should be positive for different inputs" ); } #[test] fn test_gram_matrix_loss_forward_no_reduction_shape() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; let batch_size = 4; let tensor1: Tensor = Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device); let tensor2: Tensor = Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device); let unreduced_loss = loss_fn.forward_no_reduction(tensor1, tensor2); // Unreduced loss should return a 1D tensor with shape [batch_size] assert_eq!(unreduced_loss.dims(), [batch_size]); } #[test] fn test_gram_matrix_loss_reduction_sum_vs_mean() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::new(false, &device), }; let batch_size = 4; let tensor1: Tensor = Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device); let tensor2: Tensor = Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device); let loss_mean: f32 = loss_fn .forward(tensor1.clone(), tensor2.clone(), Reduction::Mean) .into_scalar(); let loss_sum: f32 = loss_fn .forward(tensor1, tensor2, Reduction::Sum) .into_scalar(); let expected_sum = loss_mean * (batch_size as f32); let diff = (loss_sum - expected_sum).abs(); // The sum reduction should be equal to the mean reduction multiplied by the batch size assert!( diff < 1e-4, "Sum reduction should equal batch_size * Mean reduction" ); } #[test] fn test_gram_matrix_loss_with_avg_pool() { let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], // Initialize with use_avg_pool = true feat_extractor: Vgg19::new(true, &device), }; let batch_size = 4; let tensor1: Tensor = Tensor::ones([batch_size, 3, 16, 16], &device); let tensor2: Tensor = Tensor::zeros([batch_size, 3, 16, 16], &device); let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean); let loss_val: f32 = loss.into_scalar(); assert!( loss_val > 0.0, "Loss should be positive for different inputs using avg pooling" ); } #[test] fn test_gram_matrix_loss_autodiff() { use crate::TestAutodiffBackend; let device = Default::default(); let loss_fn = GramMatrixLoss { layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0], feat_extractor: Vgg19::::new(false, &device).no_grad(), }; // The prediction tensor requires gradients let predictions: Tensor = Tensor::ones([2, 3, 16, 16], &device).require_grad(); // The target tensor does not require gradients let targets: Tensor = Tensor::zeros([2, 3, 16, 16], &device); let loss = loss_fn.forward(predictions.clone(), targets, Reduction::Mean); let grads = loss.backward(); // Verify that gradients were successfully computed for the predictions tensor let pred_grad = predictions.grad(&grads); assert!( pred_grad.is_some(), "Gradients should be computed for the predictions tensor" ); // Verify that VGG19 parameters do not have gradients let conv1_1_weight_grad = loss_fn.feat_extractor.conv1_1.weight.val().grad(&grads); assert!( conv1_1_weight_grad.is_none(), "Gradients should not be computed for VGG19 parameters" ); } #[test] #[cfg(feature = "test-local")] fn test_gram_matrix_loss_pretrained_weights_identical_inputs() { let device = Default::default(); let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::(&device); let tensor1: Tensor = Tensor::random([2, 3, 16, 16], Distribution::Default, &device); let tensor2 = tensor1.clone(); let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean); let loss_val: f32 = loss.into_scalar(); // Loss should be exactly 0 (or extremely close due to floating point) when inputs are identical assert!( loss_val.abs() < 1e-4, "Loss should be zero for identical inputs" ); } #[test] #[cfg(feature = "test-local")] fn test_gram_matrix_loss_pretrained_weights_different_inputs() { let device = Default::default(); let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::(&device); let tensor1: Tensor = Tensor::ones([2, 3, 16, 16], &device); let tensor2: Tensor = Tensor::zeros([2, 3, 16, 16], &device); let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean); let loss_val: f32 = loss.into_scalar(); assert!( loss_val > 0.0, "Loss should be positive for different inputs" ); } } ================================================ FILE: crates/burn-nn/src/loss/pretrained/gram_matrix/mod.rs ================================================ mod gram_matrix_loss; mod vgg19; mod weights; pub use gram_matrix_loss::*; ================================================ FILE: crates/burn-nn/src/loss/pretrained/gram_matrix/vgg19.rs ================================================ use burn_core as burn; use crate::PaddingConfig2d; use crate::conv::{Conv2d, Conv2dConfig}; use burn::module::Module; use burn::tensor::{ Tensor, activation::relu, backend::Backend, module::{avg_pool2d, max_pool2d}, }; /// VGG19 feature extractor for the Gram Matrix Loss. /// /// This module implements the VGG19 architecture up to the 5th convolutional block. /// It is specifically tailored for Neural Style Transfer and Gram Matrix Loss, /// extracting and flattening features from the following 5 layers: /// - `conv1_1` /// - `conv2_1` /// - `conv3_1` /// - `conv4_1` /// - `conv5_1` #[derive(Module, Debug)] pub struct Vgg19 { use_avg_pool: bool, // Block 1 // Field is made public for testing whether the weights are frozen or not pub conv1_1: Conv2d, conv1_2: Conv2d, // Block 2 conv2_1: Conv2d, conv2_2: Conv2d, // Block 3 conv3_1: Conv2d, conv3_2: Conv2d, conv3_3: Conv2d, conv3_4: Conv2d, // Block 4 conv4_1: Conv2d, conv4_2: Conv2d, conv4_3: Conv2d, conv4_4: Conv2d, // Block 5 conv5_1: Conv2d, } impl Vgg19 { /// Creates a new VGG19 feature extractor. /// /// The network is initialized with standard VGG19 configurations (3x3 kernels, /// stride 1, padding 1). Note that the weights are randomly initialized here so /// they should be overwritten by `load_vgg19_weights` before use. pub fn new(use_avg_pool: bool, device: &B::Device) -> Self { // All convolutions use a kernel size of 3 by 3, stride of 1, and // padding of 1. // This combination of kernel size and padding preserves input // dimensions. Thus, `PaddingConfig2d::Same` can be used instead. let conv_config = |in_ch, out_ch| { Conv2dConfig::new([in_ch, out_ch], [3, 3]) .with_stride([1, 1]) .with_padding(PaddingConfig2d::Same) .init(device) }; Self { use_avg_pool, // Block 1 conv1_1: conv_config(3, 64), conv1_2: conv_config(64, 64), // Block 2 conv2_1: conv_config(64, 128), conv2_2: conv_config(128, 128), // Block 3 conv3_1: conv_config(128, 256), conv3_2: conv_config(256, 256), conv3_3: conv_config(256, 256), conv3_4: conv_config(256, 256), // Block 4 conv4_1: conv_config(256, 512), conv4_2: conv_config(512, 512), conv4_3: conv_config(512, 512), conv4_4: conv_config(512, 512), // Block 5 conv5_1: conv_config(512, 512), } } /// Performs a forward pass to extract features for the Gram Matrix Loss. /// /// # Arguments /// /// - `x` - Input image tensor of shape `[batch_size, 3, height, width]`. /// /// # Returns /// /// A tuple containing: /// - `features`: A `Vec` of 5 tensors, each representing the flattened feature map /// from one of the target layers. Shape of each tensor: `[batch_size, channels, height * width]`. /// - `normalization_factors`: A `Vec` of 5 `f32` values, representing the normalization /// factor `4 * N^2 * M^2` for each layer, used to scale the Gram matrix loss. pub fn forward(&self, x: Tensor) -> Vec> { let pool_2d = |x| { if self.use_avg_pool { avg_pool2d(x, [2, 2], [2, 2], [0, 0], false, false) } else { max_pool2d(x, [2, 2], [2, 2], [0, 0], [1, 1], false) } }; let mut features = Vec::with_capacity(5); // Block 1 let x1_1 = relu(self.conv1_1.forward(x)); let flattened_x1_1 = x1_1.clone().flatten(2, 3); features.push(flattened_x1_1); let x1_2 = relu(self.conv1_2.forward(x1_1)); let x1 = pool_2d(x1_2); // Block 2 let x2_1 = relu(self.conv2_1.forward(x1)); let flattened_x2_1 = x2_1.clone().flatten(2, 3); features.push(flattened_x2_1); let x2_2 = relu(self.conv2_2.forward(x2_1)); let x2 = pool_2d(x2_2); // Block 3 let x3_1 = relu(self.conv3_1.forward(x2)); let flattened_x3_1 = x3_1.clone().flatten(2, 3); features.push(flattened_x3_1); let x3_2 = relu(self.conv3_2.forward(x3_1)); let x3_3 = relu(self.conv3_3.forward(x3_2)); let x3_4 = relu(self.conv3_4.forward(x3_3)); let x3 = pool_2d(x3_4); // Block 4 let x4_1 = relu(self.conv4_1.forward(x3)); let flattened_x4_1 = x4_1.clone().flatten(2, 3); features.push(flattened_x4_1); let x4_2 = relu(self.conv4_2.forward(x4_1)); let x4_3 = relu(self.conv4_3.forward(x4_2)); let x4_4 = relu(self.conv4_4.forward(x4_3)); let x4 = pool_2d(x4_4); // Block 5 let x5_1 = relu(self.conv5_1.forward(x4)); let flattened_x5_1 = x5_1.flatten(2, 3); features.push(flattened_x5_1); features } } ================================================ FILE: crates/burn-nn/src/loss/pretrained/gram_matrix/weights.rs ================================================ use burn_core as burn; use super::vgg19::Vgg19; use burn::tensor::backend::Backend; use burn_std::network::downloader::download_file_as_bytes; use burn_store::{ModuleSnapshot, PytorchStore}; use std::fs::{File, create_dir_all, rename}; use std::io::Write; use std::path::PathBuf; const VGG19_URL: &str = "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth"; /// Resolves and returns the local cache directory for the VGG19 weights. /// /// Creates the directory `~/.cache/burn-pretrained-models/loss/vgg19/` /// (or OS equivalent) if it does not already exist. fn get_cache_dir() -> PathBuf { let cache_dir = dirs::cache_dir() .expect("Failed to get cache directory for Gram Matrix Loss") .join("burn-pretrained-models") .join("loss") .join("vgg19"); if !cache_dir.exists() { create_dir_all(&cache_dir).expect("Failed to create cache directory for Gram Matrix Loss"); } cache_dir } /// Downloads the pretrained weights to the `cache_path` if they don't exist already. /// /// Requires an active internet connection on the first run. Subsequent runs will /// use the locally cached `.pth` file. fn download_weights_if_not_saved(cache_path: &PathBuf) { if !cache_path.exists() { let bytes = download_file_as_bytes( VGG19_URL, "Downloading VGG19 ImageNet weights for Gram Matrix Loss...", ); // Write to a temporary file. If writing gets completed, then rename to the actual/correct name. // If writing is not completed, the file with the correct name (i.e. `cache_path`) will not exist // so this code block can run again which is the desired behavior. let temp_path = cache_path.with_extension("pth.tmp"); let mut file = File::create(&temp_path) .expect("Failed to create VGG19 cache file for Gram Matrix Loss"); file.write_all(&bytes) .expect("Failed to write VGG19 weights to the cache file for Gram Matrix Loss"); rename(temp_path, cache_path) .expect("Failed to rename temporary file to the actual VGG19 cache file name for Gram Matrix Loss"); } } /// Loads ImageNet pretrained weights into the provided VGG19 feature extractor. /// /// This function downloads the official PyTorch VGG19 weights, remaps the keys /// from PyTorch's `features.X` format to Burn's `convX_Y` format, and loads /// them into the module. /// /// # Arguments /// /// - `vgg19` - An initialized VGG19 module with random weights. /// /// # Returns /// /// The VGG19 module with pretrained ImageNet weights loaded. pub fn load_vgg19_weights(mut vgg19: Vgg19) -> Vgg19 { let cache_dir = get_cache_dir(); let cache_path = cache_dir.join("vgg19.pth"); download_weights_if_not_saved(&cache_path); // Download the pretrained weights from PyTorch let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) // Block 1 .with_key_remapping(r"^features\.0\.", "conv1_1.") .with_key_remapping(r"^features\.2\.", "conv1_2.") // Block 2 .with_key_remapping(r"^features\.5\.", "conv2_1.") .with_key_remapping(r"^features\.7\.", "conv2_2.") // Block 3 .with_key_remapping(r"^features\.10\.", "conv3_1.") .with_key_remapping(r"^features\.12\.", "conv3_2.") .with_key_remapping(r"^features\.14\.", "conv3_3.") .with_key_remapping(r"^features\.16\.", "conv3_4.") // Block 4 .with_key_remapping(r"^features\.19\.", "conv4_1.") .with_key_remapping(r"^features\.21\.", "conv4_2.") .with_key_remapping(r"^features\.23\.", "conv4_3.") .with_key_remapping(r"^features\.25\.", "conv4_4.") // Block 5 .with_key_remapping(r"^features\.28\.", "conv5_1."); let result = vgg19.load_from(&mut store); if let Err(e) = result { eprintln!("Warning: Some VGG19 weights could not be loaded: {:?}", e); } vgg19 } ================================================ FILE: crates/burn-nn/src/loss/pretrained/mod.rs ================================================ mod gram_matrix; pub use gram_matrix::*; ================================================ FILE: crates/burn-nn/src/loss/reduction.rs ================================================ use burn_core as burn; use burn::config::Config; /// The reduction type for the loss. #[derive(Config, Debug)] pub enum Reduction { /// The mean of the losses will be returned. Mean, /// The sum of the losses will be returned. Sum, /// The sum of the losses divided by the batch_size will be returned. BatchMean, /// The mean of the losses will be returned. Auto, } ================================================ FILE: crates/burn-nn/src/loss/rnnt.rs ================================================ use super::Reduction; use burn::config::Config; use burn::module::Module; use burn::tensor::{Bool, Int, Tensor, backend::Backend, s}; use burn_core as burn; use core::f32; /// Configuration for [RNNTLoss](RNNTLoss). #[derive(Config, Debug)] pub struct RNNTLossConfig { /// Index of the blank label in the vocabulary. Default: `0`. #[config(default = 0)] pub blank: usize, /// Treat the inputs as logits, applying a log-softmax on the last dimension internally. /// If `false`, the input must already be log-probabilities. Default: `true`. #[config(default = true)] pub logits: bool, } impl RNNTLossConfig { /// Initializes a [RNNTLoss](RNNTLoss) module. pub fn init(&self) -> RNNTLoss { RNNTLoss { blank: self.blank, logits: self.logits, } } } /// RNN Transducer (RNNT) loss, as described in /// [Sequence Transduction with Recurrent Neural Networks](https://arxiv.org/abs/1211.3711). /// /// Computes the negative log-likelihood over a 2D lattice of encoder time steps (T) /// and output labels (U), marginalizing over all valid alignments. /// /// # Example /// /// ```rust,ignore /// let rnnt = RNNTLossConfig::new().init(); /// /// // logits: [B, T, U+1, V] from the joiner network /// let loss = rnnt.forward(logits, targets, logit_lengths, target_lengths); /// ``` #[derive(Module, Clone, Debug)] pub struct RNNTLoss { blank: usize, logits: bool, } impl RNNTLoss { /// Computes per-sample RNNT loss (no reduction). Returns shape `[B]`. /// /// - `logits`: `[B, T, U+1, V]` — joiner output (raw logits or log-probs) /// - `targets`: `[B, U]` — target label indices (must not contain blank) /// - `logit_lengths`: `[B]` — actual encoder lengths per sample /// - `target_lengths`: `[B]` — actual target lengths per sample pub fn forward( &self, logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor, ) -> Tensor { let device = logits.device(); let [b, max_t, max_up1, v] = logits.dims(); let max_u = max_up1 - 1; self.check_inputs(b, v, &targets, &logit_lengths, &target_lengths, max_u); let log_probs = if self.logits { let vocab_dim = 3; // last dim of [B, T, U+1, V] burn::tensor::activation::log_softmax(logits, vocab_dim) } else { logits }; let (lpb, lpl) = self.extract_log_probs(log_probs, targets); let u_mask = self.create_u_mask(&target_lengths, b, max_up1, &device); let neg_inf = Tensor::::full([b, max_up1], f32::NEG_INFINITY, &device); // Forward pass: compute log_alpha across the (T, U) lattice let mut alpha = self.init_alpha(&lpl, b, max_up1, &device); alpha = neg_inf.clone().mask_where(u_mask.clone(), alpha); let logit_lengths_exp = logit_lengths.clone().reshape([b, 1]).expand([b, max_up1]); for t in 1..max_t { let new = self.step_alpha(&alpha, &lpb, &lpl, t); let new = neg_inf.clone().mask_where(u_mask.clone(), new); // Only update alpha for samples where t < logit_lengths[b] let valid = logit_lengths_exp.clone().greater_elem(t as i64); alpha = alpha.mask_where(valid, new); } self.gather_loss(alpha, &lpb, logit_lengths, target_lengths, b, max_up1) } /// Computes RNNT loss with the given reduction. Returns shape `[1]`. pub fn forward_with_reduction( &self, logits: Tensor, targets: Tensor, logit_lengths: Tensor, target_lengths: Tensor, reduction: Reduction, ) -> Tensor { let loss = self.forward(logits, targets, logit_lengths, target_lengths); match reduction { Reduction::Auto | Reduction::Mean => loss.mean(), Reduction::Sum => loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Gathers `log_prob_blank[B, T, U+1]` and `log_prob_label[B, T, U]` from the full /// log-probability tensor by indexing into the vocab dimension. fn extract_log_probs( &self, log_probs: Tensor, targets: Tensor, ) -> (Tensor, Tensor) { let [b, max_t, max_up1, v] = log_probs.dims(); let max_u = max_up1 - 1; let device = log_probs.device(); let vocab_dim = 3; // Blank probabilities: gather blank index across vocab dim let blank_idx = Tensor::::full([b, max_t, max_up1, 1], self.blank as i64, &device); let lpb = log_probs .clone() .gather(vocab_dim, blank_idx) .squeeze_dim::<3>(vocab_dim); // Label probabilities: gather target labels across vocab dim (only first U positions) let tgt = targets .reshape([b, 1, max_u, 1]) .expand([b, max_t, max_u, 1]); let lpl = log_probs .slice(s![.., .., 0..max_u, 0..v]) .gather(vocab_dim, tgt) .squeeze_dim::<3>(vocab_dim); (lpb, lpl) } /// Sets up log_alpha at t=0: `alpha(0,0) = 0`, then cumsum of label probs along u. fn init_alpha( &self, lpl: &Tensor, b: usize, max_up1: usize, device: &B::Device, ) -> Tensor { let mut alpha = Tensor::::full([b, max_up1], f32::NEG_INFINITY, device); alpha = alpha.slice_assign(s![.., 0..1], Tensor::zeros([b, 1], device)); // Label probs at t=0 let lpl_0 = lpl.clone().slice(s![.., 0..1, ..]).squeeze_dim::<2>(1); for u in 1..max_up1 { let prev = alpha.clone().slice(s![.., (u - 1)..u]); let lp = lpl_0.clone().slice(s![.., (u - 1)..u]); alpha = alpha.slice_assign(s![.., u..(u + 1)], prev.add(lp)); } alpha } /// Boolean mask `[B, U+1]` that is true where `u <= target_lengths[b]`. fn create_u_mask( &self, target_lengths: &Tensor, b: usize, max_up1: usize, device: &B::Device, ) -> Tensor { let indices = Tensor::::arange(0..max_up1 as i64, device) .reshape([1, max_up1]) .expand([b, max_up1]); let lengths = target_lengths.clone().reshape([b, 1]).expand([b, max_up1]); indices.lower_equal(lengths) } /// One time step of the forward recurrence: /// /// alpha(t, u) = logaddexp( /// alpha(t-1, u) + blank(t-1, u), /// alpha(t, u-1) + label(t, u-1), /// ) fn step_alpha( &self, alpha: &Tensor, lpb: &Tensor, lpl: &Tensor, t: usize, ) -> Tensor { let [b, max_up1] = alpha.dims(); let device = alpha.device(); // Blank transition: alpha(t-1, :) + blank_prob(t-1, :) let blank_prob = lpb .clone() .slice(s![.., (t - 1)..t, ..]) .squeeze_dim::<2>(1); let from_blank = alpha.clone().add(blank_prob); let mut new = Tensor::::full([b, max_up1], f32::NEG_INFINITY, &device); new = new.slice_assign(s![.., 0..1], from_blank.clone().slice(s![.., 0..1])); // Label probs at time t let label_prob = lpl .clone() .slice(s![.., t..(t + 1), ..]) .squeeze_dim::<2>(1); for u in 1..max_up1 { let via_blank = from_blank.clone().slice(s![.., u..(u + 1)]); let via_label = new .clone() .slice(s![.., (u - 1)..u]) .add(label_prob.clone().slice(s![.., (u - 1)..u])); new = new.slice_assign(s![.., u..(u + 1)], self.log_sum_exp(via_blank, via_label)); } new } /// Extracts `-(alpha(T_b, U_b) + blank(T_b, U_b))` for each sample in the batch. fn gather_loss( &self, alpha: Tensor, lpb: &Tensor, logit_lengths: Tensor, target_lengths: Tensor, b: usize, max_up1: usize, ) -> Tensor { let t_idx = logit_lengths.sub_scalar(1); let u_idx = target_lengths; let alpha_tu = alpha .gather(1, u_idx.clone().reshape([b, 1])) .squeeze_dim::<1>(1); // Gather blank prob at (T_b, U_b) let t_exp = t_idx.reshape([b, 1, 1]).expand([b, 1, max_up1]); let lpb_t = lpb.clone().gather(1, t_exp).squeeze_dim::<2>(1); let lpb_tu = lpb_t.gather(1, u_idx.reshape([b, 1])).squeeze_dim::<1>(1); alpha_tu.add(lpb_tu).neg() } fn check_inputs( &self, b: usize, v: usize, targets: &Tensor, logit_lengths: &Tensor, target_lengths: &Tensor, max_u: usize, ) { assert!( self.blank < v, "blank index {} must be less than vocab_size {}", self.blank, v ); assert_eq!( targets.dims()[0], b, "targets batch dimension {} must equal batch_size {}", targets.dims()[0], b ); assert_eq!( targets.dims()[1], max_u, "targets length dimension {} must equal max_target_len (max_u) {}", targets.dims()[1], max_u ); assert_eq!( logit_lengths.dims()[0], b, "logit_lengths length {} must equal batch_size {}", logit_lengths.dims()[0], b ); assert_eq!( target_lengths.dims()[0], b, "target_lengths length {} must equal batch_size {}", target_lengths.dims()[0], b ); } /// Numerically stable `log(exp(a) + exp(b))`, handling `-inf` inputs. fn log_sum_exp( &self, a: Tensor, b: Tensor, ) -> Tensor { let a_inf = a.clone().equal_elem(f32::NEG_INFINITY); let b_inf = b.clone().equal_elem(f32::NEG_INFINITY); // Replace -inf with 0 to prevent NaN in the subtraction (masked out below) let a_safe = a.clone().mask_fill(a_inf.clone(), 0.0); let b_safe = b.clone().mask_fill(b_inf.clone(), 0.0); // log(exp(a) + exp(b)) = max(a,b) + log(1 + exp(-|a-b|)) let max = a_safe.clone().max_pair(b_safe.clone()); let result = max.add(a_safe.sub(b_safe).abs().neg().exp().add_scalar(1.0).log()); // If a=-inf, result is b; if b=-inf, result is a; if both -inf, stays -inf let result = result.mask_where(a_inf, b); result.mask_where(b_inf, a) } } #[cfg(test)] mod tests { use super::*; use burn::tensor::{TensorData, Tolerance}; use burn_ndarray::{NdArray, NdArrayDevice}; type B = NdArray; const NUM_LABELS: usize = 2; // vocab size for simple unit tests #[test] fn config_defaults() { let cfg = RNNTLossConfig::new(); assert_eq!(cfg.blank, 0); assert!(cfg.logits); } #[test] #[should_panic(expected = "blank index")] fn panics_on_invalid_blank() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().with_blank(5).init(); rnnt.forward( Tensor::::zeros([1, 2, 2, 3], &dev), Tensor::::from_data([[1_i64]], &dev), Tensor::::from_data([2], &dev), Tensor::::from_data([1], &dev), ); } #[test] #[should_panic(expected = "must equal batch_size")] fn panics_on_batch_mismatch() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); rnnt.forward( Tensor::::zeros([2, 3, 2, 3], &dev), Tensor::::from_data([[1_i64]], &dev), Tensor::::from_data([3, 3], &dev), Tensor::::from_data([1, 1], &dev), ); } #[test] #[should_panic(expected = "logit_lengths length")] fn panics_on_logit_lengths_mismatch() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); rnnt.forward( Tensor::::zeros([2, 3, 2, 3], &dev), Tensor::::from_data([[1_i64], [2]], &dev), Tensor::::from_data([3], &dev), Tensor::::from_data([1, 1], &dev), ); } #[test] #[should_panic(expected = "target_lengths length")] fn panics_on_target_lengths_mismatch() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); rnnt.forward( Tensor::::zeros([2, 3, 2, 3], &dev), Tensor::::from_data([[1_i64], [2]], &dev), Tensor::::from_data([3, 3], &dev), Tensor::::from_data([1], &dev), ); } #[test] fn single_token_uniform_probs() { // B=1, T=2, U=1, V=2, uniform probs: P(blank) = P(label) = 1/V // // Two alignment paths (label emitted at t=0 or t=1), each with T+U emissions: // total_prob = T * (1/V)^(T+1) = 2 * (1/2)^3 = 1/4 // loss = -ln(1/4) = 2*ln(2) let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().with_logits(false).init(); let time_steps = 2; let target_len = 1; let v = NUM_LABELS as f32; let log_uniform = (1.0 / v).ln(); let loss = rnnt.forward( Tensor::::full( [1, time_steps, target_len + 1, NUM_LABELS], log_uniform, &dev, ), Tensor::::from_data([[1_i64]], &dev), Tensor::::from_data([time_steps as i64], &dev), Tensor::::from_data([target_len as i64], &dev), ); // Each path: T-1 blanks + U labels + 1 final blank = T + U emissions let num_paths = time_steps as f32; let emissions_per_path = (time_steps + target_len) as f32; let total_prob = num_paths * v.powf(-emissions_per_path); let expected_loss = -total_prob.ln(); loss.into_data().assert_approx_eq::( &TensorData::from([expected_loss]), Tolerance::absolute(1e-4), ); } #[test] fn empty_target() { // B=1, T=3, U=0, V=2, uniform probs: only the all-blanks path exists. // // Single path with T emissions (T-1 blanks + 1 final blank, all at u=0): // total_prob = (1/V)^T = (1/2)^3 = 1/8 // loss = T*ln(V) = 3*ln(2) let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().with_logits(false).init(); let time_steps = 3; let target_len = 0; let v = NUM_LABELS as f32; let log_uniform = (1.0 / v).ln(); let loss = rnnt.forward( Tensor::::full([1, time_steps, 2, NUM_LABELS], log_uniform, &dev), Tensor::::from_data([[1_i64]], &dev), Tensor::::from_data([time_steps as i64], &dev), Tensor::::from_data([target_len as i64], &dev), ); // T + U = T emissions total for U=0 let expected_loss = -v.powf(-((time_steps + target_len) as f32)).ln(); loss.into_data().assert_approx_eq::( &TensorData::from([expected_loss]), Tolerance::absolute(1e-4), ); } #[test] fn logits_equivalence() { // Verify that logits=true (internal log_softmax on raw logits) // gives the same loss as logits=false with external log_softmax. let dev = NdArrayDevice::Cpu; let [bs, time_steps, up1, vocab] = [1, 2, 3, 4]; let num_elements = bs * time_steps * up1 * vocab; let target_len = up1 - 1; let data: Vec = (0..num_elements).map(|i| (i as f32 * 0.3).sin()).collect(); let logits = Tensor::::from_data( burn_core::tensor::TensorData::new(data, [bs, time_steps, up1, vocab]), &dev, ); let targets = Tensor::::from_data([[1_i64, 2]], &dev); let logit_lengths = Tensor::::from_data([time_steps as i64], &dev); let target_lengths = Tensor::::from_data([target_len as i64], &dev); let vocab_dim = 3; let fused = RNNTLossConfig::new().with_logits(true).init().forward( logits.clone(), targets.clone(), logit_lengths.clone(), target_lengths.clone(), ); let log_probs = burn::tensor::activation::log_softmax(logits, vocab_dim); let manual = RNNTLossConfig::new().with_logits(false).init().forward( log_probs, targets, logit_lengths, target_lengths, ); fused .into_data() .assert_approx_eq::(&manual.into_data(), Tolerance::absolute(1e-4)); } } /// Tests comparing forward loss and backward gradients against torchaudio.functional.rnnt_loss. /// /// Logits are generated deterministically via sin((b*11+t*7+u*13+v*3)*0.1) so the same /// values can be reproduced in a Python script for cross-checking. #[cfg(test)] #[allow(clippy::identity_op, clippy::too_many_arguments)] mod pytorch_comparison_tests { use super::*; use burn::tensor::{TensorData, Tolerance}; use burn_autodiff::Autodiff; use burn_ndarray::{NdArray, NdArrayDevice}; type B = Autodiff>; fn tol() -> Tolerance { Tolerance::absolute(1e-3) } /// Deterministic logits matching the Python reference generator. /// Uses coprime coefficients to avoid repeating patterns across dimensions. fn make_logits(bs: usize, t: usize, u: usize, v: usize, dev: &NdArrayDevice) -> Tensor { let mut data = Vec::with_capacity(bs * t * u * v); for bi in 0..bs { for ti in 0..t { for ui in 0..u { for vi in 0..v { let idx = bi * 11 + ti * 7 + ui * 13 + vi * 3; data.push((idx as f32 * 0.1).sin()); } } } } Tensor::from_data(TensorData::new(data, [bs, t, u, v]), dev) } /// Checks that gradients along the vocab dim sum to ~0 at every (b, t, u) position. /// This must hold because log_softmax is applied on the last dim, /// and the Jacobian of log_softmax has the property that each row sums to zero. fn check_vocab_grad_sums(grad: &[f32], bs: usize, t: usize, up1: usize, v: usize) { for bi in 0..bs { for ti in 0..t { for ui in 0..up1 { let base = ((bi * t + ti) * up1 + ui) * v; let sum: f32 = (0..v).map(|vi| grad[base + vi]).sum(); TensorData::from([sum]) .assert_approx_eq::(&TensorData::from([0.0f32]), tol()); } } } } /// Returns the V-sized gradient slice at position (b, t, u) in a flattened [B, T, U+1, V] grad. fn grad_at( grad: &[f32], b: usize, t: usize, u: usize, max_t: usize, up1: usize, v: usize, ) -> &[f32] { let base = ((b * max_t + t) * up1 + u) * v; &grad[base..base + v] } /// Asserts that a gradient slice at position (b, t, u) matches expected values. fn assert_grad( grad: &[f32], b: usize, t: usize, u: usize, max_t: usize, up1: usize, v: usize, expected: &[f32], ) { TensorData::from(grad_at(grad, b, t, u, max_t, up1, v)) .assert_approx_eq::(&TensorData::from(expected), tol()); } #[test] fn basic_b1() { // B=1, T=4, U+1=3, V=3, targets=[1,2] let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); let logits = make_logits(1, 4, 3, 3, &dev).require_grad(); let loss = rnnt.forward( logits.clone(), Tensor::::from_data([[1_i64, 2]], &dev), Tensor::::from_data([4_i64], &dev), Tensor::::from_data([2_i64], &dev), ); loss.clone() .into_data() .assert_approx_eq::(&TensorData::from([4.4491f32]), tol()); let grads = loss.sum().backward(); let grad = logits .grad(&grads) .unwrap() .into_data() .to_vec::() .unwrap(); // Spot-check first, middle, and last (t, u) positions against torchaudio assert_grad(&grad, 0, 0, 0, 4, 3, 3, &[-0.2041, -0.2246, 0.4287]); assert_grad(&grad, 0, 2, 0, 4, 3, 3, &[0.0079, -0.0640, 0.0561]); assert_grad(&grad, 0, 3, 2, 4, 3, 3, &[-0.6899, 0.3231, 0.3667]); check_vocab_grad_sums(&grad, 1, 4, 3, 3); } #[test] fn batched_b2() { // B=2, T=5, U+1=4, V=4, targets=[[1,2,3],[2,1,3]] let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); let logits = make_logits(2, 5, 4, 4, &dev).require_grad(); let loss = rnnt.forward( logits.clone(), Tensor::::from_data( TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]), &dev, ), Tensor::::from_data([5_i64, 5], &dev), Tensor::::from_data([3_i64, 3], &dev), ); loss.clone() .into_data() .assert_approx_eq::(&TensorData::from([7.9356f32, 7.2033]), tol()); let grads = loss.sum().backward(); let grad = logits .grad(&grads) .unwrap() .into_data() .to_vec::() .unwrap(); // Spot-check: first position of each sample, and last position assert_grad(&grad, 0, 0, 0, 5, 4, 4, &[-0.3161, -0.3113, 0.2796, 0.3479]); assert_grad(&grad, 1, 0, 0, 5, 4, 4, &[-0.2766, 0.2602, -0.2248, 0.2411]); assert_grad(&grad, 0, 4, 3, 5, 4, 4, &[-0.8216, 0.2296, 0.2786, 0.3133]); assert_grad(&grad, 1, 4, 3, 5, 4, 4, &[-0.7185, 0.2735, 0.2437, 0.2012]); check_vocab_grad_sums(&grad, 2, 5, 4, 4); } #[test] fn variable_lengths_b3() { // B=3, T=6, U+1=4, V=5 // logit_lengths=[6,4,5], target_lengths=[3,2,1] // Tests that masking works correctly for variable-length sequences. let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); let logits = make_logits(3, 6, 4, 5, &dev).require_grad(); let loss = rnnt.forward( logits.clone(), Tensor::::from_data( TensorData::new(vec![1_i64, 2, 3, 4, 1, 0, 2, 0, 0], [3, 3]), &dev, ), Tensor::::from_data([6_i64, 4, 5], &dev), Tensor::::from_data([3_i64, 2, 1], &dev), ); loss.clone() .into_data() .assert_approx_eq::(&TensorData::from([10.7458f32, 8.0196, 8.3316]), tol()); let grads = loss.sum().backward(); let grad = logits .grad(&grads) .unwrap() .into_data() .to_vec::() .unwrap(); let stride = 4 * 5; // U+1 * V per time step let zeros = vec![0.0f32; 5]; // Sample 0 (full length=6): spot-check first and last active positions assert_grad( &grad, 0, 0, 0, 6, 4, 5, &[-0.4232, -0.3114, 0.1992, 0.2478, 0.2876], ); assert_grad( &grad, 0, 5, 3, 6, 4, 5, &[-0.8016, 0.2170, 0.2172, 0.1991, 0.1683], ); // Sample 1 (logit_length=4): gradients beyond t=3 should be zero assert_grad( &grad, 1, 0, 0, 6, 4, 5, &[-0.2502, 0.2160, 0.2173, 0.2002, -0.3833], ); let sample1_t4_start = 1 * 6 * stride + 4 * stride; for i in 0..(2 * stride) { // t=4 and t=5 should all be zero assert!( grad[sample1_t4_start + i].abs() < 1e-3, "sample 1, t>=4: grad[{}] = {} (expected 0)", i, grad[sample1_t4_start + i] ); } // Sample 1 (target_length=2): u=3 positions should be zero within active time steps for ti in 0..4 { assert_grad(&grad, 1, ti, 3, 6, 4, 5, &zeros); } // Sample 2 (logit_length=5): t=5 should be zero let sample2_t5_start = 2 * 6 * stride + 5 * stride; for i in 0..stride { assert!( grad[sample2_t5_start + i].abs() < 1e-3, "sample 2, t=5: grad[{}] = {} (expected 0)", i, grad[sample2_t5_start + i] ); } check_vocab_grad_sums(&grad, 3, 6, 4, 5); } #[test] fn sum_reduction() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); let logits = make_logits(2, 5, 4, 4, &dev).require_grad(); let tgt = Tensor::::from_data( TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]), &dev, ); let il = Tensor::::from_data([5_i64, 5], &dev); let tl = Tensor::::from_data([3_i64, 3], &dev); let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Sum); // 7.9356 + 7.2033 = 15.1389 loss.clone() .into_data() .assert_approx_eq::(&TensorData::from([15.1389f32]), tol()); let grads = loss.backward(); let g = logits .grad(&grads) .unwrap() .into_data() .to_vec::() .unwrap(); TensorData::from(&g[..4]).assert_approx_eq::( &TensorData::from([-0.3161f32, -0.3113, 0.2796, 0.3479]), tol(), ); } #[test] fn mean_reduction() { let dev = NdArrayDevice::Cpu; let rnnt = RNNTLossConfig::new().init(); let logits = make_logits(2, 5, 4, 4, &dev).require_grad(); let tgt = Tensor::::from_data( TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]), &dev, ); let il = Tensor::::from_data([5_i64, 5], &dev); let tl = Tensor::::from_data([3_i64, 3], &dev); let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Mean); // 15.1389 / 2 = 7.5694 loss.clone() .into_data() .assert_approx_eq::(&TensorData::from([7.5694f32]), tol()); // Gradients should be half the sum-reduction gradients (mean over batch of 2) let grads = loss.backward(); let g = logits .grad(&grads) .unwrap() .into_data() .to_vec::() .unwrap(); TensorData::from(&g[..4]).assert_approx_eq::( &TensorData::from([-0.1581f32, -0.1557, 0.1398, 0.1739]), tol(), ); } } ================================================ FILE: crates/burn-nn/src/loss/smooth_l1.rs ================================================ use super::Reduction; use burn::config::Config; use burn::module::Module; use burn::tensor::{Tensor, backend::Backend}; use burn_core as burn; /// Configuration for the [SmoothL1Loss](SmoothL1Loss) module. /// /// Smooth L1 loss combines L1 and L2 loss, using L2 loss for small errors (below beta) /// and L1 loss for large errors (above beta). This makes it less sensitive to outliers /// than MSE while maintaining smooth gradients near zero. /// /// # Example /// /// ```ignore /// use burn_nn::loss::{SmoothL1LossConfig, Reduction}; /// /// // Create Smooth L1 loss with default beta=1.0 /// let smooth_l1 = SmoothL1LossConfig::new().init(); /// /// // Create with custom beta /// let smooth_l1_custom = SmoothL1LossConfig::new().with_beta(0.5).init(); /// ``` #[derive(Config, Debug)] pub struct SmoothL1LossConfig { /// Specifies the threshold at which to change between L1 and L2 loss. /// The value must be positive. Default: 1.0 #[config(default = 1.0)] pub beta: f32, } impl SmoothL1LossConfig { /// Initializes a [Smooth L1 Loss](SmoothL1Loss) module. /// /// # Panics /// /// Panics if `beta <= 0`. pub fn init(&self) -> SmoothL1Loss { self.assertions(); SmoothL1Loss { beta: self.beta } } fn assertions(&self) { assert!(self.beta > 0.0, "The parameter beta must be positive.") } } /// Computes the Smooth L1 Loss between predictions and targets. /// /// This loss function uses L2 loss for small errors (below beta) and L1 loss for /// large errors (above beta), providing robustness to outliers while maintaining /// smooth gradients near |x - y| = 0. /// /// # Mathematical Definition /// /// For predictions `x` and targets `y`, the element-wise loss is: /// /// - L_i = 0.5 * (x_i - y_i)² / beta , if |x_i - y_i| < beta /// - L_i = |x_i - y_i| - 0.5 * beta , otherwise /// /// # Notes /// /// Smooth L1 loss is closely related to HuberLoss since it is equivalent to HuberLoss /// scaled by `1/beta`: /// `SmoothL1(x, y, beta) = Huber(x, y, beta) / beta` /// /// This leads to the following differences: /// /// - As beta approaches 0, Smooth L1 loss converges to L1Loss, while HuberLoss converges to 0. /// When beta = 0, Smooth L1 loss is equivalent to L1 loss. Thus, the `beta` /// parameter in Burn must be positive. L1Loss should be used for beta = 0. /// - As beta approaches positive infinity, Smooth L1 loss converges to a constant 0 loss, while /// HuberLoss converges to L2Loss. /// /// # Example /// /// ```rust,ignore /// use burn_nn::loss::{SmoothL1LossConfig, Reduction}; /// use burn::tensor::Tensor; /// /// // Create Smooth L1 loss with the default beta=1.0 /// let smooth_l1 = SmoothL1LossConfig::new().init(); /// /// let predictions: Tensor = /* model output */; /// let targets: Tensor = /* ground truth */; /// /// // Compute element-wise loss without reduction /// let element_wise = smooth_l1.forward(predictions.clone(), targets.clone()); /// /// // Compute loss with mean reduction /// let loss = smooth_l1.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean); /// /// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1] /// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]); /// ``` #[derive(Module, Clone, Debug)] pub struct SmoothL1Loss { /// Specifies the threshold at which to change between L1 and L2 loss. /// The value must be positive. Default: 1.0 pub beta: f32, } impl SmoothL1Loss { /// Computes the element-wise smooth L1 loss without reduction. /// /// # Arguments /// /// - `predictions` - The model's predicted values. /// - `targets` - The ground truth target values. /// /// # Returns /// /// A tensor of the same shape as the inputs, containing the smooth L1 loss /// for each element. /// /// # Shapes /// /// - predictions: `[...dims]` - Any shape /// - targets: `[...dims]` - Must match predictions shape /// - output: `[...dims]` - Same shape as inputs pub fn forward( &self, predictions: Tensor, targets: Tensor, ) -> Tensor { let error = predictions.sub(targets); let abs_error = error.clone().abs(); // The L1 case: |error| - 0.5 * beta (when |error| >= beta) let l1_loss = abs_error.clone().sub_scalar(0.5 * self.beta); // The L2 case: 0.5 * (error)^2 / beta (when |error| < beta) let l2_loss = error.square().mul_scalar(0.5).div_scalar(self.beta); let l2_mask = abs_error.lower_elem(self.beta); l1_loss.mask_where(l2_mask, l2_loss) } /// Computes the smooth L1 loss with reduction. /// /// # Arguments /// /// - `predictions` - The model's predicted values. /// - `targets` - The ground truth target values. /// - `reduction` - Specifies how to reduce the element-wise losses: /// - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses. /// - `Reduction::Sum`: Returns the sum of all element-wise losses. /// /// # Returns /// /// A scalar tensor containing the reduced loss value. /// /// # Shapes /// /// - predictions: `[...dims]` - Any shape /// - targets: `[...dims]` - Must match predictions shape /// - output: `[1]` - Scalar loss value pub fn forward_with_reduction( &self, predictions: Tensor, targets: Tensor, reduction: Reduction, ) -> Tensor { let unreduced_loss = self.forward(predictions, targets); match reduction { Reduction::Mean | Reduction::Auto => unreduced_loss.mean(), Reduction::Sum => unreduced_loss.sum(), other => panic!("{other:?} reduction is not supported"), } } /// Computes the smooth L1 loss with reduction over specified dimensions. /// /// Calculates element-wise smooth L1 loss, then takes the mean /// over the specified dimensions. Useful for per-sample or per-channel losses. /// /// Dimensions can be provided in any order. They are sorted internally and /// reduced from highest to lowest to ensure indices remain valid. /// /// # Arguments /// /// - `predictions` - The model's predicted values. /// - `targets` - The ground truth target values. /// - `dims` - Dimensions to reduce over. /// /// # Returns /// /// A tensor with the specified dimensions reduced to size 1. /// /// # Example /// /// ```ignore /// // Consider image tensor with shape [batch, C, H, W] /// let smooth_l1 = SmoothL1LossConfig::new().init(); /// /// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1] /// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]); /// ``` pub fn forward_reduce_dims( &self, predictions: Tensor, targets: Tensor, dims: &[usize], ) -> Tensor { let error = self.forward(predictions, targets); // Sort the dimensions to ascending order let mut sorted_dims = dims.to_vec(); sorted_dims.sort(); // Reduce over specified dimensions error.mean_dims(sorted_dims.as_slice()) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; // ========================================================================= // Configuration Tests // ========================================================================= #[test] fn test_smooth_l1_config_default_beta() { let loss = SmoothL1LossConfig::new().init(); assert_eq!(loss.beta, 1.0); } #[test] fn test_smooth_l1_config_custom_beta() { let loss = SmoothL1LossConfig::new().with_beta(2.5).init(); assert_eq!(loss.beta, 2.5); } #[test] #[should_panic(expected = "The parameter beta must be positive")] fn test_smooth_l1_config_beta_zero_panics() { SmoothL1LossConfig::new().with_beta(0.0).init(); } #[test] #[should_panic(expected = "The parameter beta must be positive")] fn test_smooth_l1_config_beta_negative_panics() { SmoothL1LossConfig::new().with_beta(-1.0).init(); } // ========================================================================= // Forward Pass (Element-wise) Tests // ========================================================================= #[test] fn test_smooth_l1_forward_l2_region() { // Beta = 1.0, errors = 0.0 and 0.5 (both < beta, use L2 formula) // L2 formula: 0.5 * error^2 / beta // error = 0.0 -> loss = 0.5 * 0.0 / 1.0 = 0.0 // error = 0.5 -> loss = 0.5 * 0.25 / 1.0 = 0.125 let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([[0.0_f32, 0.5]]), &device); let targets = Tensor::::from_data(TensorData::from([[0.0_f32, 0.0]]), &device); let output = loss.forward(predictions, targets); let expected = TensorData::from([[0.0_f32, 0.125]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_forward_l1_region() { // Beta = 1.0, errors = 0.0 and 2.0 (2.0 >= beta, use L1 formula) // L1 formula: |error| - 0.5 * beta // L2 formula: 0.5 * (error)^2 / beta // error = 0.0 -> loss = 0.0 // error = 2.0 -> loss = 2.0 - 0.5 = 1.5 let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([[0.0_f32, 2.0]]), &device); let targets = Tensor::::from_data(TensorData::from([[0.0_f32, 0.0]]), &device); let output = loss.forward(predictions, targets); let expected = TensorData::from([[0.0_f32, 1.5]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_forward_zero_error() { let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([[1.0_f32, 2.0, 3.0]]), &device); let targets = predictions.clone(); let output = loss.forward(predictions, targets); let expected = TensorData::from([[0.0_f32, 0.0, 0.0]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_forward_negative_errors() { // Ensure absolute value is used correctly // L1 formula: |error| - 0.5 * beta // L2 formula: 0.5 * (error)^2 / beta // Beta = 1.0, error = -3.0 (L1: 3.0 - 0.5 = 2.5) let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([-3.0_f32]), &device); let targets = Tensor::::zeros([1], &device); let output = loss.forward(predictions, targets); let expected = TensorData::from([2.5_f32]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_forward_mixed_regions() { // Test with errors in both L1 and L2 regions // Beta = 1.0 // L1 formula: |error| - 0.5 * beta // L2 formula: 0.5 * (error)^2 / beta // error = 0.5 -> L2: 0.5 * 0.25 / 1 = 0.125 // error = 1.5 -> L1: 1.5 - 0.5 = 1.0 // error = 3.0 -> L1: 3.0 - 0.5 = 2.5 let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([0.5_f32, 1.5, 3.0]), &device); let targets = Tensor::::zeros([3], &device); let output = loss.forward(predictions, targets); let expected = TensorData::from([0.125_f32, 1.0, 2.5]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_custom_beta_values() { // Test with beta = 0.5 // error = 0.25 (< beta): L2 = 0.5 * 0.0625 / 0.5 = 0.0625 // error = 1.0 (>= beta): L1 = 1.0 - 0.25 = 0.75 let device = Default::default(); let loss = SmoothL1LossConfig::new().with_beta(0.5).init(); let predictions = Tensor::::from_data(TensorData::from([0.25_f32, 1.0]), &device); let targets = Tensor::::zeros([2], &device); let output = loss.forward(predictions, targets); let expected = TensorData::from([0.0625_f32, 0.75]); output.into_data().assert_eq(&expected, false); } // ========================================================================= // forward_with_reduction Tests // ========================================================================= #[test] fn test_smooth_l1_reduction_mean() { // Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5) // Mean: (0.125 + 1.5) / 2 = 0.8125 let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([[0.5_f32, 2.0]]), &device); let targets = Tensor::::from_data(TensorData::from([[0.0_f32, 0.0]]), &device); let output = loss.forward_with_reduction(predictions, targets, Reduction::Mean); let expected = TensorData::from([0.8125_f32]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_reduction_sum() { // Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5) // Sum: 1.625 let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([[0.5_f32, 2.0]]), &device); let targets = Tensor::::from_data(TensorData::from([[0.0_f32, 0.0]]), &device); let output = loss.forward_with_reduction(predictions, targets, Reduction::Sum); let expected = TensorData::from([1.625_f32]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_reduction_auto_equals_mean() { let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data(TensorData::from([2.0_f32]), &device); let targets = Tensor::::zeros([1], &device); let mean_out = loss.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean); let auto_out = loss.forward_with_reduction(predictions, targets, Reduction::Auto); mean_out.into_data().assert_eq(&auto_out.into_data(), false); } // ========================================================================= // Dimension Reduction Tests // ========================================================================= #[test] fn test_smooth_l1_forward_reduce_dims_single_dim() { // Beta = 2.0 // L1 formula: |error| - 0.5 * beta // L2 formula: 0.5 * (error)^2 / beta // Row 0: errors [0.0, 1.0, 4.0] // error = 0.0 -> L2: 0.0 // error = 1.0 -> L2: 0.5 * 1.0 / 2.0 = 0.25 // error = 4.0 -> L1: 4.0 - 1.0 = 3.0 // Mean = 3.25 / 3 = 1.083333... // Row 1: errors [0.0, 0.0, 0.0] -> Mean = 0.0 let device = Default::default(); let loss = SmoothL1LossConfig::new().with_beta(2.0).init(); let predictions = Tensor::::from_data( TensorData::from([[0.0_f32, 1.0, 4.0], [5.0_f32, 5.0, 5.0]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[0.0_f32, 0.0, 0.0], [5.0_f32, 5.0, 5.0]]), &device, ); let output = loss.forward_reduce_dims(predictions, targets, &[1]); let expected = TensorData::from([[3.25_f32 / 3.0], [0.0]]); // 3.25/3 = 1.0833... output .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_smooth_l1_forward_reduce_dims_image_batch() { // Simulate per-image Smooth L1 loss for [batch, C, H, W] tensor // (common in object detection like Fast R-CNN) let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); // beta = 1.0 // Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2) let predictions = Tensor::::from_data( TensorData::from([ [[[0.5_f32, 2.0], [0.0, 3.0]]], // Image 1 [[[1.0_f32, 0.0], [0.5, 1.5]]], // Image 2 ]), &device, ); let targets = Tensor::::zeros([2, 1, 2, 2], &device); // Reduce over C, H, W (dims 1, 2, 3) to get per-image loss let output = loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]); // Image 1: losses [[0.125, 1.5], [0.0, 2.5]] -> mean: 4.125 / 4 = 1.03125 // Image 2: losses [[0.5, 0.0], [0.125, 1.0]] -> mean: 1.625 / 4 = 0.40625 let expected = TensorData::from([[[[1.03125_f32]]], [[[0.40625_f32]]]]); output.into_data().assert_eq(&expected, false); } #[test] fn test_smooth_l1_forward_reduce_dims_unsorted() { // Test that unsorted dimensions are handled correctly (sorted internally) let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data( TensorData::from([[[1.0_f32, 2.0], [3.0, 4.0]], [[5.0_f32, 6.0], [7.0, 8.0]]]), &device, ); let targets = Tensor::::zeros([2, 2, 2], &device); // Pass dims in reverse order let output = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[2, 1]); let expected_output = loss.forward_reduce_dims(predictions, targets, &[1, 2]); output .into_data() .assert_eq(&expected_output.into_data(), false); } #[test] fn test_smooth_l1_forward_reduce_dims_empty_dims() { // Reducing over no dimensions should return the unreduced loss let device = Default::default(); let loss = SmoothL1LossConfig::new().init(); let predictions = Tensor::::from_data( TensorData::from([[0.5_f32, 2.0], [0.0, 3.0]]), &device, ); let targets = Tensor::::zeros([2, 2], &device); let loss_reduce_dims = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[]); let loss_no_reduction = loss.forward(predictions, targets); loss_reduce_dims .into_data() .assert_eq(&loss_no_reduction.into_data(), false); } } ================================================ FILE: crates/burn-nn/src/modules/attention/cross_attention.rs ================================================ //! Cross-Attention Module for Burn //! //! Features: //! - Asymmetric Input Shapes (Query vs Context) //! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support //! - Quantization-Safe Masking (min_float) //! - Sparse-Ready (quiet_softmax) //! - KV Caching for Streaming Inference use crate::cache::TensorCache; use crate::modules::{Linear, LinearConfig}; use crate::{Dropout, DropoutConfig}; use burn_core as burn; use burn::{ config::Config, module::{Initializer, Module}, tensor::{ Bool, Tensor, activation::{quiet_softmax, softmax}, backend::Backend, }, }; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; #[derive(Config, Debug)] /// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init). pub struct CrossAttentionConfig { /// Dimension of the Query (e.g., Decoder state). pub d_model: usize, /// Dimension of the Context (e.g., Encoder audio embeddings). pub d_context: usize, /// Number of heads for the Query. pub n_heads: usize, /// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA). pub n_heads_kv: usize, /// Dimension of a single head. pub d_head: usize, /// Dropout rate. #[config(default = 0.1)] pub dropout: f64, /// Masking value. Use -1.0e4 for f16/bf16 safety. #[config(default = -1.0e4)] pub min_float: f64, /// Use quiet_softmax to allow zero-attention (good for sparse/quantized models). #[config(default = false)] pub quiet_softmax: bool, } #[derive(Module, Debug)] /// The Cross attention module /// /// # Params /// /// - `query`: [`Linear`] layer with `d_model` input and output features. /// - `key`: [`Linear`] layer with `d_model` input and output features. /// - `value`: [`Linear`] layer with `d_model` input and output features. /// - `output`: [`Linear`] layer with `d_model` input and output features. /// /// Should be created with [CrossAttentionConfig]. pub struct CrossAttention { query: Linear, key: Linear, value: Linear, output: Linear, dropout: Dropout, n_heads: usize, n_heads_kv: usize, d_head: usize, scale: f64, min_float: f64, quiet_softmax: bool, } /// Cache for the [Cross Attention](CrossAttention) layer. /// /// To be used during inference when context is constant. pub struct CrossAttentionCache { /// Cached key tensor. pub k: TensorCache, /// Cached value tensor. pub v: TensorCache, } impl CrossAttentionCache { /// Create a new empty cache. pub fn new() -> Self { Self { k: TensorCache::empty(), v: TensorCache::empty(), } } } impl Default for CrossAttentionCache { fn default() -> Self { Self::new() } } impl CrossAttentionConfig { /// Initializes a new cross-attention module. /// /// # Arguments /// /// * `device` - The device on which to initialize the module. /// /// # Returns /// /// A new [CrossAttention] module. pub fn init(&self, device: &B::Device) -> CrossAttention { // Safety Rail for GQA assert_eq!( self.n_heads % self.n_heads_kv, 0, "Query heads must be divisible by KV heads" ); let init_linear = |in_dim, out_dim| { LinearConfig::new(in_dim, out_dim) .with_initializer(Initializer::KaimingUniform { gain: 1.0 / (self.d_head as f64).sqrt(), fan_out_only: false, }) .init(device) }; CrossAttention { // ADVICE: Asymmetric Projections query: init_linear(self.d_model, self.n_heads * self.d_head), key: init_linear(self.d_context, self.n_heads_kv * self.d_head), value: init_linear(self.d_context, self.n_heads_kv * self.d_head), output: init_linear(self.n_heads * self.d_head, self.d_model), dropout: DropoutConfig::new(self.dropout).init(), n_heads: self.n_heads, n_heads_kv: self.n_heads_kv, d_head: self.d_head, scale: (self.d_head as f64).sqrt().recip(), min_float: self.min_float, quiet_softmax: self.quiet_softmax, } } } impl CrossAttention { /// Applies cross-attention to query using context as key and value. /// /// # Arguments /// /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`. /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`. /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask. /// /// # Returns /// /// Output tensor of shape `[batch, seq_len_query, d_model]`. pub fn forward( &self, query: Tensor, context: Tensor, mask: Option>, ) -> Tensor { let [batch, l_q, _] = query.dims(); let [_, l_k, _] = context.dims(); // 1. Projections let q = self.query.forward(query); let k = self.key.forward(context.clone()); let v = self.value.forward(context); // 2. Reshape Heads // Q: [Batch, Heads, L_q, D_head] let q = q .reshape([batch, l_q, self.n_heads, self.d_head]) .swap_dims(1, 2); // K, V: [Batch, Heads_KV, L_k, D_head] let k = k .reshape([batch, l_k, self.n_heads_kv, self.d_head]) .swap_dims(1, 2); let v = v .reshape([batch, l_k, self.n_heads_kv, self.d_head]) .swap_dims(1, 2); // 3. GQA Expansion // ADVICE: Handle GQA by repeating KV heads to match Query heads let (k, v) = if self.n_heads != self.n_heads_kv { let n_rep = self.n_heads / self.n_heads_kv; (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep)) } else { (k, v) }; // 4. Score Calculation let scores = q.matmul(k.transpose()) * self.scale; // 5. Masking // ADVICE: Use min_float for F16/FP8 safety let scores = if let Some(mask) = mask { let mask = mask.reshape([batch, 1, 1, l_k]); scores.mask_fill(mask, self.min_float) } else { scores }; // 6. Softmax // ADVICE: Optional Quiet Softmax for sparse networks let weights = if self.quiet_softmax { quiet_softmax(scores, 3) } else { softmax(scores, 3) }; let weights = self.dropout.forward(weights); // 7. Aggregate & Output let output = weights.matmul(v); let output = output .swap_dims(1, 2) .reshape([batch, l_q, self.n_heads * self.d_head]); self.output.forward(output) } /// Applies cross-attention to query using context as key and value. /// /// This method uses a cache to avoid recomputing key and value tensors when the context is the same. /// /// # Arguments /// /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`. /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`. /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask. /// * `cache` - The cache to use. /// /// # Returns /// /// Output tensor of shape `[batch, seq_len_query, d_model]`. pub fn forward_cache( &self, query: Tensor, context: Tensor, mask: Option>, cache: &mut CrossAttentionCache, ) -> Tensor { let [batch, l_q, _] = query.dims(); // 1. Projections let q = self.query.forward(query); let k_compute = |context: Tensor| { let [batch, l_k, _] = context.dims(); self.key .forward(context) .reshape([batch, l_k, self.n_heads_kv, self.d_head]) .swap_dims(1, 2) }; let v_compute = |context: Tensor| { let [batch, l_k, _] = context.dims(); self.value .forward(context) .reshape([batch, l_k, self.n_heads_kv, self.d_head]) .swap_dims(1, 2) }; let k = cache.k.forward_full(context.clone(), k_compute); let v = cache.v.forward_full(context, v_compute); let [_, _, l_k, _] = k.dims(); // 2. Reshape Heads // Q: [Batch, Heads, L_q, D_head] let q = q .reshape([batch, l_q, self.n_heads, self.d_head]) .swap_dims(1, 2); // K, V are already in their correct shape from k_compute and v_compute // 3. GQA Expansion // ADVICE: Handle GQA by repeating KV heads to match Query heads let (k, v) = if self.n_heads != self.n_heads_kv { let n_rep = self.n_heads / self.n_heads_kv; (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep)) } else { (k, v) }; // 4. Score Calculation let scores = q.matmul(k.transpose()) * self.scale; // 5. Masking // ADVICE: Use min_float for F16/FP8 safety let scores = if let Some(mask) = mask { let mask = mask.reshape([batch, 1, 1, l_k]); scores.mask_fill(mask, self.min_float) } else { scores }; // 6. Softmax // ADVICE: Optional Quiet Softmax for sparse networks let weights = if self.quiet_softmax { quiet_softmax(scores, 3) } else { softmax(scores, 3) }; let weights = self.dropout.forward(weights); // 7. Aggregate & Output let output = weights.matmul(v); let output = output .swap_dims(1, 2) .reshape([batch, l_q, self.n_heads * self.d_head]); self.output.forward(output) } /// Helper for Grouped Query Attention fn repeat_kv(&self, x: Tensor, n_rep: usize) -> Tensor { let [b, h, l, d] = x.dims(); x.reshape([b, h, 1, l, d]) .expand([b, h, n_rep, l, d]) .reshape([b, h * n_rep, l, d]) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance}; #[test] fn test_cross_attention_mha_shapes() { let [ batch_size, seq_len_query, seq_len_context, d_model, d_context, n_heads, d_head, ] = [7, 13, 15, 32, 40, 4, 8]; let device = Default::default(); let config = CrossAttentionConfig { d_model, d_context, n_heads, n_heads_kv: n_heads, // MHA case d_head, dropout: 0.1, min_float: -1.0e4, quiet_softmax: false, }; let cross_attn = config.init::(&device); let query = Tensor::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let context = Tensor::random( [batch_size, seq_len_context, d_context], Distribution::Default, &device, ); let output = cross_attn.forward(query, context, None); assert_eq!( output.shape(), Shape::new([batch_size, seq_len_query, d_model]), "Output should have the correct shape", ); } #[test] fn test_cross_attention_gqa_shapes() { let [ batch_size, seq_len_query, seq_len_context, d_model, d_context, n_heads, n_heads_kv, d_head, ] = [7, 13, 15, 32, 40, 4, 2, 8]; let device = Default::default(); let config = CrossAttentionConfig { d_model, d_context, n_heads, n_heads_kv, // GQA case d_head, dropout: 0.1, min_float: -1.0e4, quiet_softmax: false, }; let cross_attn = config.init::(&device); let query = Tensor::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let context = Tensor::random( [batch_size, seq_len_context, d_context], Distribution::Default, &device, ); let output = cross_attn.forward(query, context, None); assert_eq!( output.shape(), Shape::new([batch_size, seq_len_query, d_model]), "Output should have the correct shape", ); } #[test] fn test_cross_attention_mqa_shapes() { let [ batch_size, seq_len_query, seq_len_context, d_model, d_context, n_heads, d_head, ] = [7, 13, 15, 32, 40, 4, 8]; let device = Default::default(); let config = CrossAttentionConfig { d_model, d_context, n_heads, n_heads_kv: 1, // MQA case d_head, dropout: 0.1, min_float: -1.0e4, quiet_softmax: false, }; let cross_attn = config.init::(&device); let query = Tensor::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let context = Tensor::random( [batch_size, seq_len_context, d_context], Distribution::Default, &device, ); let output = cross_attn.forward(query, context, None); assert_eq!( output.shape(), Shape::new([batch_size, seq_len_query, d_model]), "Output should have the correct shape", ); } #[test] fn test_cross_attention_mask() { let [ batch_size, seq_len_query, seq_len_context, d_model, d_context, n_heads, d_head, ] = [3, 6, 8, 12, 16, 4, 3]; let num_padded = 2; let device = Default::default(); let config = CrossAttentionConfig { d_model, d_context, n_heads, n_heads_kv: n_heads, d_head, dropout: 0.0, // No dropout for deterministic test min_float: -1.0e4, quiet_softmax: false, }; let cross_attn = config.init::(&device); // Create a padding mask for the context let mut mask: Tensor = Tensor::zeros([batch_size, seq_len_context], &device); mask = mask.slice_assign( [0..batch_size, seq_len_context - num_padded..seq_len_context], Tensor::ones([batch_size, num_padded], &device), ); let mask_bool = mask.equal_elem(1); let query = Tensor::::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let context_1 = Tensor::::random( [batch_size, seq_len_context, d_context], Distribution::Default, &device, ); // Change the padded part of the context tensor let context_2 = context_1.clone().slice_assign( [ 0..batch_size, seq_len_context - num_padded..seq_len_context, 0..d_context, ], Tensor::random( [batch_size, num_padded, d_context], Distribution::Default, &device, ), ); // The outputs should be the same since the changed part is masked. let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone())); let output_2 = cross_attn.forward(query, context_2, Some(mask_bool)); output_1 .into_data() .assert_approx_eq(&output_2.into_data(), Tolerance::::default()); } #[test] #[should_panic] fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() { let device = Default::default(); let config = CrossAttentionConfig { d_model: 32, d_context: 32, n_heads: 5, n_heads_kv: 2, d_head: 8, dropout: 0.1, min_float: -1.0e4, quiet_softmax: false, }; config.init::(&device); } #[test] fn test_cross_attention_cache() { let [ batch_size, seq_len_query, seq_len_context, d_model, d_context, n_heads, d_head, ] = [3, 6, 8, 12, 16, 4, 3]; let device = Default::default(); let config = CrossAttentionConfig { d_model, d_context, n_heads, n_heads_kv: n_heads, d_head, dropout: 0.0, // No dropout for deterministic test min_float: -1.0e4, quiet_softmax: false, }; let cross_attn = config.init::(&device); let query1 = Tensor::::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let context = Tensor::::random( [batch_size, seq_len_context, d_context], Distribution::Default, &device, ); // First forward pass, no cache let output1 = cross_attn.forward(query1.clone(), context.clone(), None); // Second forward pass with cache let mut cache = CrossAttentionCache::new(); let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache); // The two outputs should be identical output1 .into_data() .assert_approx_eq(&output2.into_data(), Tolerance::::default()); // Third forward pass with different query, but same context and cache let query2 = Tensor::::random( [batch_size, seq_len_query, d_model], Distribution::Default, &device, ); let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache); // For control, do a forward pass without cache with query2 let output4 = cross_attn.forward(query2.clone(), context.clone(), None); // output3 and output4 should be identical output3 .into_data() .assert_approx_eq(&output4.into_data(), Tolerance::::default()); } } ================================================ FILE: crates/burn-nn/src/modules/attention/mask.rs ================================================ use burn_core as burn; use burn_core::config::Config; use alloc::vec::Vec; use burn::tensor::ops::IntElem; use burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend}; /// Generate an autoregressive attention mask. /// /// The mask can be used in Transformer modules to train models to generate tensors sequentially. pub fn generate_autoregressive_mask( batch_size: usize, seq_length: usize, device: &B::Device, ) -> Tensor { let mask = Tensor::::tril_mask([seq_length, seq_length], 0, device); mask.expand([batch_size, seq_length, seq_length]) } /// Generate a padding attention mask. pub struct GeneratePaddingMask { /// The generated tensor. pub tensor: Tensor, /// The generated mask. pub mask: Tensor, } /// Defines an enumeration to specify sequence length options for padding #[derive(Config, Debug, Copy)] pub enum SeqLengthOption { /// No maximum length; use the longest sequence NoMax, /// Maximum length specified, truncate if necessary Max(usize), /// Fixed length, pad or truncate to this exact length Fixed(usize), } impl From> for SeqLengthOption { fn from(val: Option) -> Self { match val { Some(max) => SeqLengthOption::Max(max), None => SeqLengthOption::NoMax, } } } /// Generates a padding attention mask for a batch of token sequences. /// /// # Arguments /// /// * `pad_token` - The token ID used for padding /// * `tokens_list` - Vector of token sequences (each sequence is a vector of token IDs) /// * `seq_length` - Sequence length option (NoMax, Max, or Fixed) /// * `device` - The device for tensor operations /// /// # Returns /// /// A `GeneratePaddingMask` containing the padded tensor and corresponding mask pub fn generate_padding_mask( pad_token: usize, tokens_list: Vec>, seq_length: impl Into, device: &B::Device, ) -> GeneratePaddingMask { let tokens_max = || { tokens_list .iter() .map(|tokens| tokens.len()) .max() .unwrap_or(1) }; let size = match seq_length.into() { SeqLengthOption::NoMax => tokens_max(), SeqLengthOption::Max(max) => usize::min(tokens_max(), max), SeqLengthOption::Fixed(limit) => limit, }; let batch_size = tokens_list.len(); let mut tensor = Tensor::zeros([batch_size, size], device); tensor = tensor.add_scalar(pad_token as i64); for (index, tokens) in tokens_list.into_iter().enumerate() { let seq_length = tokens.len().min(size); tensor = tensor.slice_assign( [index..index + 1, 0..seq_length], Tensor::from_data( TensorData::new( tokens .into_iter() .take(size) .map(|e| (e as i64).elem::>()) .collect(), Shape::new([1, seq_length]), ), device, ), ); } let mask = tensor.clone().equal_elem(pad_token as i64); GeneratePaddingMask { tensor, mask } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use alloc::vec; use burn::tensor::TensorData; #[test] fn test_generate_autoregressive_mask() { let device = ::Device::default(); let mask = generate_autoregressive_mask::(2, 3, &device); mask.into_data().assert_eq( &TensorData::from([ [ [false, true, true], [false, false, true], [false, false, false], ], [ [false, true, true], [false, false, true], [false, false, false], ], ]), false, ); } #[test] fn test_generate_padding_mask() { let device = ::Device::default(); let tokens = vec![ vec![3, 3, 3], vec![3, 3, 3], vec![3, 3, 3, 4], vec![3, 3, 3, 4, 10, 15], ]; let mask = generate_padding_mask::(0, tokens, None, &device); mask.mask.into_data().assert_eq( &TensorData::from([ [false, false, false, true, true, true], [false, false, false, true, true, true], [false, false, false, false, true, true], [false, false, false, false, false, false], ]), false, ); } } ================================================ FILE: crates/burn-nn/src/modules/attention/mha.rs ================================================ use burn_core as burn; use crate::activation::Gelu; use crate::cache::TensorCache; use crate::{Dropout, DropoutConfig, Linear, LinearConfig}; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::{Bool, Tensor, backend::Backend}; use burn::tensor::activation::{quiet_softmax, softmax}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init). #[derive(Config, Debug)] pub struct MultiHeadAttentionConfig { /// The size of each linear layer. pub d_model: usize, /// The number of heads. pub n_heads: usize, /// The dropout rate. Default: 0.1 #[config(default = 0.1)] pub dropout: f64, /// The minimum value a float can take. Default: -1.0e4 /// This is used to mask attention scores before calculating attention weights. /// A value too low might result in NaN. #[config(default = -1.0e4)] pub min_float: f64, /// Use "quiet softmax" instead of regular softmax. /// /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. /// /// Reference: #[config(default = false)] pub quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, } /// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). /// /// # Params /// /// - `query`: [`Linear`] layer with `d_model` input and output features. /// - `key`: [`Linear`] layer with `d_model` input and output features. /// - `value`: [`Linear`] layer with `d_model` input and output features. /// - `output`: [`Linear`] layer with `d_model` input and output features. /// /// Should be created with [MultiHeadAttentionConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct MultiHeadAttention { /// Linear layer to transform the input features into the query space. pub query: Linear, /// Linear layer to transform the input features into the key space. pub key: Linear, /// Linear layer to transform the input features into the value space. pub value: Linear, /// Linear layer to transform the output features back to the original space. pub output: Linear, /// Dropout layer. pub dropout: Dropout, /// Activation function. pub activation: Gelu, /// The size of each linear layer. pub d_model: usize, /// The number of heads. pub n_heads: usize, /// Size of the key and query vectors. pub d_k: usize, /// Minimum value a float can take. pub min_float: f64, /// Use "quiet softmax" instead of regular softmax. pub quiet_softmax: bool, } impl ModuleDisplay for MultiHeadAttention { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("d_model", &self.d_model) .add("n_heads", &self.n_heads) .add("d_k", &self.d_k) .add("dropout", &self.dropout.prob) .add("min_float", &self.min_float) .add("quiet_softmax", &self.quiet_softmax) .optional() } } /// [Multihead attention](MultiHeadAttention) forward pass input argument. #[derive(Debug, Clone)] pub struct MhaInput { /// Shape `[batch_size, seq_length_1, d_model]` query: Tensor, /// Shape `[batch_size, seq_length_2, d_model]` key: Tensor, /// Shape `[batch_size, seq_length_2, d_model]` value: Tensor, mask_pad: Option>, mask_attn: Option>, } impl MultiHeadAttentionConfig { /// Initialize a new [multihead attention](MultiHeadAttention) module. pub fn init(&self, device: &B::Device) -> MultiHeadAttention { let linear = |config: &Self| { LinearConfig::new(config.d_model, config.d_model) .with_initializer(self.initializer.clone()) .init(device) }; MultiHeadAttention { query: linear(self), key: linear(self), value: linear(self), output: linear(self), dropout: DropoutConfig::new(self.dropout).init(), activation: Gelu::new(), n_heads: self.n_heads, d_k: self.d_model / self.n_heads, min_float: self.min_float, quiet_softmax: self.quiet_softmax, d_model: self.d_model, } } } impl MhaInput { /// Create a [multihead attention](MultiHeadAttention) input argument /// by setting the query, key and value to the given tensor. /// /// # Shape /// - tensor: `[batch_size, seq_length, d_model]` pub fn self_attn(tensor: Tensor) -> Self { Self { query: tensor.clone(), key: tensor.clone(), value: tensor, mask_pad: None, mask_attn: None, } } /// Create a [multihead attention](MultiHeadAttention) input argument. pub fn new(query: Tensor, key: Tensor, value: Tensor) -> Self { Self { query, key, value, mask_pad: None, mask_attn: None, } } /// Register the padding mask. pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { self.mask_pad = Some(mask_pad); self } /// Register the attention mask. pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { self.mask_attn = Some(mask_attn); self } } /// [Multihead attention](MultiHeadAttention) outputs. #[derive(Debug, Clone)] pub struct MhaOutput { /// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`. pub weights: Tensor, /// The context tensor `[batch_size, seq_length_1, d_model]`. pub context: Tensor, } impl MultiHeadAttention { /// Applies the forward pass on the input tensors. /// /// See [MultiHeadAttention](MultiHeadAttention) for more information. /// /// # Shapes /// /// - query: `[batch_size, seq_length_1, d_model]` /// - key: `[batch_size, seq_length_2, d_model]` /// - value: `[batch_size, seq_length_2, d_model]` /// - output: `[batch_size, seq_length_1, d_model]` pub fn forward(&self, input: MhaInput) -> MhaOutput { let [batch_size, seq_length_1, d_model] = input.query.dims(); let query = self.attention_linear(input.query, &self.query); let key = self.attention_linear(input.key, &self.key); let value = self.attention_linear(input.value, &self.value); let attn_scores = self.attn_scores(query, key); let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); let context = weights.clone().matmul(value); let context = context .swap_dims(1, 2) .reshape([batch_size, seq_length_1, d_model]); let context = self.output.forward(context); MhaOutput { weights, context } } /// Applies the forward pass using a cache. /// /// # Shapes /// /// - query: `[batch_size, seq_length_1, d_model]` /// - key: `[batch_size, seq_length_2, d_model]` /// - value: `[batch_size, seq_length_2, d_model]` /// - output: `[batch_size, seq_length_1, d_model]` pub fn forward_cache(&self, input: MhaInput, cache: &mut MhaCache) -> MhaOutput { let [batch_size, seq_length_1, d_model] = input.query.dims(); let query = cache .query .forward(input.query, |t| self.attention_linear(t, &self.query)); let key = cache .key .forward(input.key, |t| self.attention_linear(t, &self.key)); let value = cache .value .forward(input.value, |t| self.attention_linear(t, &self.value)); let attn_scores = self.attn_scores(query, key); let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); let context = weights.clone().matmul(value); let context = context .swap_dims(1, 2) .reshape([batch_size, seq_length_1, d_model]); let context = cache.output.forward(context, |t| self.output.forward(t)); MhaOutput { weights, context } } fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { let attn_scores = query .matmul(key.transpose()) .div_scalar((self.d_k as f32).sqrt()); self.dropout.forward(attn_scores) } fn attn_weights( &self, mut attn_scores: Tensor, mask_pad: Option>, mask_attn: Option>, ) -> Tensor { if let Some(mask_pad) = mask_pad { let [batch_size, seq_length] = mask_pad.dims(); attn_scores = attn_scores.mask_fill( mask_pad.reshape([batch_size, 1, 1, seq_length]), self.min_float, ); } if let Some(mask_attn) = mask_attn { let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims(); attn_scores = attn_scores.mask_fill( mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]), self.min_float, ); } if self.quiet_softmax { quiet_softmax(attn_scores, 3) } else { softmax(attn_scores, 3) } } fn attention_linear(&self, x: Tensor, linear: &Linear) -> Tensor { let [batch_size, seq_length, _d_model] = x.dims(); linear .forward(x) .reshape([batch_size, seq_length, self.n_heads, self.d_k]) .swap_dims(1, 2) } } /// Cache for the [Multi Head Attention](MultiHeadAttention) layer. /// /// To be used during inference when decoding tokens. pub struct MhaCache { query: MhaLinearCache, key: MhaLinearCache, value: MhaLinearCache, output: MhaLinearCache, } enum MhaLinearCache { Autoregressive(TensorCache, usize), Full(TensorCache), } impl MhaCache { /// Initialize a cache for autoregressive inference. pub fn autoregressive() -> Self { Self { query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), } } /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and /// values (cross-attention). pub fn autoregressive_cross_attention() -> Self { Self { query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2), key: MhaLinearCache::Full(TensorCache::empty()), value: MhaLinearCache::Full(TensorCache::empty()), output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1), } } } impl MhaLinearCache { pub fn forward) -> Tensor>( &mut self, tensor: Tensor, func: F, ) -> Tensor { match self { MhaLinearCache::Autoregressive(cache, dim) => { cache.forward_autoregressive(tensor, *dim, func) } MhaLinearCache::Full(cache) => cache.forward_full(tensor, func), } } } #[cfg(test)] mod tests { use super::*; use crate::{TestBackend, attention::generate_autoregressive_mask}; use alloc::vec::Vec; use burn::tensor::Int; use burn::tensor::Tolerance; use burn::tensor::ops::FloatElem; use burn::tensor::{Distribution, Shape}; #[test] fn test_self_attention_shapes() { let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4]; let device = Default::default(); let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(&device); let input = MhaInput::self_attn(Tensor::random( [batch_size, seq_length, d_model], Distribution::Default, &device, )); let output = mha.forward(input); assert_eq!( output.context.shape(), Shape::new([batch_size, seq_length, d_model]), "Context should have the correct shape", ); assert_eq!( output.weights.shape(), Shape::new([batch_size, n_heads, seq_length, seq_length]), "Weights should have the correct shape", ); } #[test] fn test_generic_mha_shapes() { let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4]; let mha = MultiHeadAttentionConfig::new(d_model, n_heads) .init::(&Default::default()); let device = Default::default(); let input = MhaInput::new( Tensor::random( [batch_size, seq_length_1, d_model], Distribution::Default, &device, ), Tensor::random( [batch_size, seq_length_2, d_model], Distribution::Default, &device, ), Tensor::random( [batch_size, seq_length_2, d_model], Distribution::Default, &device, ), ); let output = mha.forward(input); assert_eq!( output.context.shape(), Shape::new([batch_size, seq_length_1, d_model]), "Context should have the correct shape", ); assert_eq!( output.weights.shape(), Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]), "Weights should have the correct shape", ); } #[test] fn test_self_attention_mask_pad() { let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; let device = Default::default(); let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(&device); // Create a padding mask let mask_pad: Tensor = Tensor::zeros([batch_size, seq_length], &device); let mask_pad = mask_pad.slice_assign( [0..batch_size, seq_length - num_padded..seq_length], Tensor::ones([batch_size, num_padded], &device), ); let mask_pad = mask_pad.equal_elem(1).to_device(&device); let tensor_1 = Tensor::::random( [batch_size, seq_length, d_model], Distribution::Default, &device, ); // Change the end of the tensor let tensor_2 = tensor_1.clone().slice_assign( [ 0..batch_size, seq_length - num_padded..seq_length, 0..d_model, ], Tensor::random( [batch_size, num_padded, d_model], Distribution::Default, &device, ), ); let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone()); let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad); let output_1 = mha.forward(input_1); let output_2 = mha.forward(input_2); // Check that the beginning of each tensor is the same output_1 .context .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) .into_data() .assert_approx_eq( &output_2 .context .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model]) .into_data(), Tolerance::::default(), ); } #[test] fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; let device = Default::default(); let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::(&device); let tensor = Tensor::::random( [batch_size, seq_length, d_model], Distribution::Default, &device, ); let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); let output_1 = mha.forward(input); let mut output_2 = Vec::new(); let mut cache = MhaCache::autoregressive(); for i in 1..seq_length + 1 { let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); let input = MhaInput::self_attn(tensor); let next_tok = mha.forward_cache(input, &mut cache).context.slice([ 0..batch_size, i - 1..i, 0..d_model, ]); output_2.push(next_tok); } let output_2 = Tensor::cat(output_2, 1); output_1 .context .into_data() .assert_approx_eq::>( &output_2.into_data(), Tolerance::default(), ); } #[test] fn display() { let config = MultiHeadAttentionConfig::new(2, 4); let mha = config.init::(&Default::default()); assert_eq!( alloc::format!("{mha}"), "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \ dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}" ); } } ================================================ FILE: crates/burn-nn/src/modules/attention/mod.rs ================================================ mod cross_attention; mod mask; mod mha; pub use cross_attention::*; pub use mask::*; pub use mha::*; ================================================ FILE: crates/burn-nn/src/modules/cache/autoregressive.rs ================================================ use alloc::vec; use burn_core as burn; use super::{CacheState, TensorCache}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; impl TensorCache { pub(crate) fn forward_autoregressive( &mut self, tensor: Tensor, dim_cat: usize, func: F, ) -> Tensor where F: Fn(Tensor) -> Tensor, { let mut tensor_old = CacheState::Empty; core::mem::swap(&mut self.state, &mut tensor_old); let tensor_new = match tensor_old { CacheState::Value(tensor_old) => { let [batch_size, seq_length, d_model] = tensor.dims(); let next_seq_token = tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); let next_seq_token = func(next_seq_token); Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) } _ => func(tensor), }; self.state = CacheState::Value(tensor_new.clone()); tensor_new } pub(crate) fn forward_full(&mut self, tensor: Tensor, func: F) -> Tensor where F: Fn(Tensor) -> Tensor, { let mut tensor_old = CacheState::Empty; core::mem::swap(&mut self.state, &mut tensor_old); let tensor_new = match tensor_old { CacheState::Value(tensor_old) => tensor_old, _ => func(tensor), }; self.state = CacheState::Value(tensor_new.clone()); tensor_new } } ================================================ FILE: crates/burn-nn/src/modules/cache/base.rs ================================================ use burn_core as burn; use burn::tensor::Tensor; use burn::tensor::backend::Backend; pub(crate) enum CacheState { Value(T), Empty, } /// A cache for a tensor. pub struct TensorCache { pub(crate) state: CacheState>, } impl TensorCache { /// Creates a new empty cache. /// /// # Returns /// /// The empty cache. pub fn empty() -> Self { Self { state: CacheState::Empty, } } } ================================================ FILE: crates/burn-nn/src/modules/cache/mod.rs ================================================ mod autoregressive; mod base; pub use base::*; ================================================ FILE: crates/burn-nn/src/modules/conv/checks.rs ================================================ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) { let channels_in_div_by_group = channels_in.is_multiple_of(groups); let channels_out_div_by_group = channels_out.is_multiple_of(groups); if !channels_in_div_by_group || !channels_out_div_by_group { panic!( "Both channels must be divisible by the number of groups. Got \ channels_in={channels_in}, channels_out={channels_out}, groups={groups}" ); } } // https://github.com/tracel-ai/burn/issues/2676 /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel /// size is not supported as it will not produce the same output size. pub(crate) fn check_same_padding_support(kernel_size: &[usize]) { for k in kernel_size.iter() { if k % 2 == 0 { unimplemented!("Same padding with an even kernel size is not supported"); } } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv1d.rs ================================================ use alloc::format; use burn_core as burn; use crate::{PaddingConfig1d, conv::checks}; use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::PaddedConvOptions}; use burn::{ config::Config, module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param}, }; /// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init). #[derive(Config, Debug)] pub struct Conv1dConfig { /// The number of input channels. pub channels_in: usize, /// The number of output channels. pub channels_out: usize, /// The size of the kernel. pub kernel_size: usize, /// The stride of the convolution. #[config(default = "1")] pub stride: usize, /// Spacing between kernel elements. #[config(default = "1")] pub dilation: usize, /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 1D convolution over input tensors. /// /// Should be created with [Conv1dConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Conv1d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: usize, /// Size of the kernel. pub kernel_size: usize, /// Spacing between kernel elements. pub dilation: usize, /// Controls the connections between input and output channels. pub groups: usize, /// Padding configuration. pub padding: PaddingConfig1d, } impl ModuleDisplay for Conv1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { // Format stride/dilation as strings let stride = format!("{:?}", self.stride); let kernel_size = format!("{:?}", self.kernel_size); let dilation = format!("{:?}", self.dilation); // Extract channels in/out from weight dims let [channels_out, group_channels_in, _] = self.weight.dims(); let channels_in = group_channels_in * self.groups; let ch_out = format!("{:?}", channels_out); let ch_in = format!("{:?}", channels_in); content .add("ch_in", &ch_in) .add("ch_out", &ch_out) .add("stride", &stride) .add("kernel_size", &kernel_size) .add("dilation", &dilation) .add("groups", &self.groups) .add_debug_attribute("padding", &self.padding) .optional() } } impl Conv1dConfig { /// Initialize a new [conv1d](Conv1d) module. pub fn init(&self, device: &B::Device) -> Conv1d { checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); let shape = [ self.channels_out, self.channels_in / self.groups, self.kernel_size, ]; let fan_in: usize = self.channels_in / self.groups * self.kernel_size; let weight = self .initializer .init_with(shape, Some(fan_in), None, device); let mut bias = None; if self.bias { bias = Some( self.initializer .init_with([self.channels_out], Some(fan_in), None, device), ); } Conv1d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, padding: self.padding.clone(), dilation: self.dilation, groups: self.groups, } } } impl Conv1d { /// Applies the forward pass on the input tensor. /// /// See [conv1d](burn::tensor::module::conv1d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels_in, length_in]` /// - output: `[batch_size, channels_out, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { let length = input.dims()[2]; // Calculate padding as pair - handles Same, Valid, and Explicit uniformly let (left, right) = self.padding .calculate_padding_1d_pair(length, self.kernel_size, self.stride); let options = PaddedConvOptions::asymmetric( [self.stride], [left], [right], [self.dilation], self.groups, ); conv1d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), options, ) } } #[cfg(test)] mod tests { use burn::tensor::{ElementConversion, ops::FloatElem}; type FT = FloatElem; use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv1dConfig::new(5, 5, 5); let k = (config.channels_in * config.kernel_size) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&device); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros); let conv = config.init::(&Default::default()); assert_eq!(config.initializer, Initializer::Zeros); conv.weight .to_data() .assert_eq(&TensorData::zeros::(conv.weight.shape()), false); } #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = Conv1dConfig::new(4, 4, 2) .with_padding(PaddingConfig1d::Same) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=4, length=5] let input = Tensor::::ones([1, 4, 5], &device); let output = conv.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 4, 5]); } #[test] fn display() { let config = Conv1dConfig::new(5, 5, 5); let conv = config.init::(&Default::default()); assert_eq!( alloc::format!("{conv}"), "Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = Conv1dConfig::new(5, 3, 3); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10], &Default::default()); let _ = conv.forward(input); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create conv with asymmetric padding: left=1, right=2 let config = Conv1dConfig::new(2, 3, 3) .with_padding(PaddingConfig1d::Explicit(1, 2)) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = conv.forward(input); // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 // Output length = (7 - 3) / 1 + 1 = 5 assert_eq!(output.dims(), [1, 3, 5]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create conv with symmetric explicit padding: left=2, right=2 let config = Conv1dConfig::new(2, 3, 3) .with_padding(PaddingConfig1d::Explicit(2, 2)) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = conv.forward(input); // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 // Output length = (8 - 3) / 1 + 1 = 6 assert_eq!(output.dims(), [1, 3, 6]); } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv2d.rs ================================================ use alloc::format; use burn_core as burn; use crate::PaddingConfig2d; use burn::config::Config; use burn::module::Initializer; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv2d; use burn::tensor::ops::PaddedConvOptions; use crate::conv::checks; /// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init). #[derive(Config, Debug)] pub struct Conv2dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 2], /// The stride of the convolution. #[config(default = "[1, 1]")] pub stride: [usize; 2], /// Spacing between kernel elements. #[config(default = "[1, 1]")] pub dilation: [usize; 2], /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 2D convolution over input tensors. /// /// Should be created with [Conv2dConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Conv2d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: [usize; 2], /// Size of the kernel. pub kernel_size: [usize; 2], /// Spacing between kernel elements. pub dilation: [usize; 2], /// Controls the connections between input and output channels. pub groups: usize, /// The padding configuration. pub padding: PaddingConfig2d, } impl Conv2dConfig { /// Initialize a new [conv2d](Conv2d) module. pub fn init(&self, device: &B::Device) -> Conv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); let shape = [ self.channels[1], self.channels[0] / self.groups, self.kernel_size[0], self.kernel_size[1], ]; let k = self.kernel_size.iter().product::(); let fan_in = self.channels[0] / self.groups * k; let fan_out = self.channels[1] / self.groups * k; let weight = self .initializer .init_with(shape, Some(fan_in), Some(fan_out), device); let mut bias = None; if self.bias { bias = Some(self.initializer.init_with( [self.channels[1]], Some(fan_in), Some(fan_out), device, )); } Conv2d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, padding: self.padding.clone(), groups: self.groups, } } } impl ModuleDisplay for Conv2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed. let stride = format!("{:?}", self.stride); let kernel_size = format!("{:?}", self.kernel_size); let dilation = format!("{:?}", self.dilation); let [channels_out, group_channels_in, _, _] = self.weight.dims(); let channels_in = group_channels_in * self.groups; let ch_out = format!("{:?}", channels_out); let ch_in = format!("{:?}", channels_in); content .add("ch_in", &ch_in) .add("ch_out", &ch_out) .add("stride", &stride) .add("kernel_size", &kernel_size) .add("dilation", &dilation) .add("groups", &self.groups) .add_debug_attribute("padding", &self.padding) .optional() } } impl Conv2d { /// Applies the forward pass on the input tensor. /// /// See [conv2d](burn::tensor::module::conv2d) for more information. /// /// # Shapes /// - `input`: `[batch_size, channels_in, height_in, width_in]` /// - `output`: `[batch_size, channels_out, height_out, width_out]` /// /// # Example /// ```rust,ignore /// use burn::nn::conv::Conv2dConfig; /// use burn::tensor::Tensor; /// /// // Assuming backend type alias `B` /// let device = Default::default(); /// let conv = Conv2dConfig::new([3, 8], [3, 3]).init::(&device); /// /// let x = Tensor::::zeros([1, 3, 28, 28], &device); /// let y = conv.forward(x); /// /// println!("{:?}", y.dims()); // [1, 8, 26, 26] /// ``` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels_in, height_in, width_in] = input.dims(); // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs( height_in, width_in, &self.kernel_size, &self.stride, ); let options = PaddedConvOptions::asymmetric( self.stride, [top, left], [bottom, right], self.dilation, self.groups, ); conv2d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), options, ) } } #[cfg(test)] mod tests { use burn::tensor::ops::FloatElem; use burn::tensor::{ElementConversion, Tolerance}; use super::*; use crate::TestBackend; use burn::tensor::TensorData; type FT = FloatElem; // Float test #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv2dConfig::new([5, 1], [5, 5]); let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&device); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); let conv = config.init::(&device); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn initializer_fan_out() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, // test that fan_out is passed to `init_with()` }; let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone()); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] fn initializer_fan_with_groups_is_valid() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, }; let config = Conv2dConfig::new([4, 4], [1, 1]) .with_initializer(init.clone()) .with_groups(4); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] #[should_panic = "Both channels must be divisible by the number of groups."] fn channels_with_groups_is_invalid() { let device = Default::default(); let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4); let _ = config.init::(&device); } #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = Conv2dConfig::new([4, 4], [2, 2]) .with_padding(PaddingConfig2d::Same) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=4, height=5, width=5] let input = Tensor::::ones([1, 4, 5, 5], &device); let output = conv.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 4, 5, 5]); } #[test] fn display() { let config = Conv2dConfig::new([5, 1], [5, 5]); let conv = config.init::(&Default::default()); assert_eq!( alloc::format!("{conv}"), "Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = Conv2dConfig::new([5, 3], [3, 3]); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10, 10], &Default::default()); let _ = conv.forward(input); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4 let config = Conv2dConfig::new([2, 3], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = conv.forward(input); // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 assert_eq!(output.dims(), [1, 3, 6, 9]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2 let config = Conv2dConfig::new([2, 3], [3, 3]) .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)) .with_initializer(Initializer::Constant { value: 1.0 }) .with_bias(false); let conv = config.init::(&device); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = conv.forward(input); // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 assert_eq!(output.dims(), [1, 3, 6, 7]); } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv3d.rs ================================================ use alloc::format; use burn_core as burn; use crate::PaddingConfig3d; use burn::config::Config; use burn::module::Initializer; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv3d; use burn::tensor::ops::ConvOptions; use crate::conv::checks; /// Configuration to create a [3D convolution](Conv3d) layer, using the [init function](Conv3dConfig::init). #[derive(Config, Debug)] pub struct Conv3dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 3], /// The stride of the convolution. #[config(default = "[1, 1, 1]")] pub stride: [usize; 3], /// Spacing between kernel elements. #[config(default = "[1, 1, 1]")] pub dilation: [usize; 3], /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. #[config(default = "PaddingConfig3d::Valid")] pub padding: PaddingConfig3d, /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 3D convolution over input tensors. /// /// Should be created with [Conv3dConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Conv3d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2, kernel_size_3]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: [usize; 3], /// Size of the kernel. pub kernel_size: [usize; 3], /// Spacing between kernel elements. pub dilation: [usize; 3], /// Controls the connections between input and output channels. pub groups: usize, /// The padding configuration. pub padding: PaddingConfig3d, } impl Conv3dConfig { /// Initialize a new [conv3d](Conv3d) module. pub fn init(&self, device: &B::Device) -> Conv3d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); if self.padding == PaddingConfig3d::Same { checks::check_same_padding_support(&self.kernel_size); } let shape = [ self.channels[1], self.channels[0] / self.groups, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], ]; let k = self.kernel_size.iter().product::(); let fan_in = self.channels[0] / self.groups * k; let fan_out = self.channels[1] / self.groups * k; let weight = self .initializer .init_with(shape, Some(fan_in), Some(fan_out), device); let mut bias = None; if self.bias { bias = Some(self.initializer.init_with( [self.channels[1]], Some(fan_in), Some(fan_out), device, )); } Conv3d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, padding: self.padding.clone(), groups: self.groups, } } } impl ModuleDisplay for Conv3d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { // Format arrays as strings (consistent with Conv2d/Conv1d). let stride = format!("{:?}", self.stride); let kernel_size = format!("{:?}", self.kernel_size); let dilation = format!("{:?}", self.dilation); // Weight dims: [channels_out, channels_in/groups, k1, k2, k3] let [channels_out, group_channels_in, _, _, _] = self.weight.dims(); let channels_in = group_channels_in * self.groups; let ch_out = format!("{:?}", channels_out); let ch_in = format!("{:?}", channels_in); content .add("ch_in", &ch_in) .add("ch_out", &ch_out) .add("stride", &stride) .add("kernel_size", &kernel_size) .add("dilation", &dilation) .add("groups", &self.groups) .add_debug_attribute("padding", &self.padding) .optional() } } impl Conv3d { /// Applies the forward pass on the input tensor. /// /// See [conv3d](burn::tensor::module::conv3d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels_in, depth_in, height_in, width_in]` /// - output: `[batch_size, channels_out, depth_out, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims(); let padding = self.padding.calculate_padding_3d( depth_in, height_in, width_in, &self.kernel_size, &self.stride, ); conv3d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), ConvOptions::new(self.stride, padding, self.dilation, self.groups), ) } } #[cfg(test)] mod tests { use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv3dConfig::new([5, 1], [5, 5, 5]); let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1] * config.kernel_size[2]) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&device); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = Conv3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros); let device = Default::default(); let conv = config.init::(&device); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn initializer_fan_out() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, // test that fan_out is passed to `init_with()` }; let config = Conv3dConfig::new([5, 1], [5, 5, 5]).with_initializer(init.clone()); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] fn initializer_fan_with_groups_is_valid() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, }; let config = Conv3dConfig::new([4, 4], [1, 1, 1]) .with_initializer(init.clone()) .with_groups(4); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] #[should_panic = "Same padding with an even kernel size is not supported"] fn same_with_even_kernel_is_invalid() { let device = Default::default(); let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same); let _ = config.init::(&device); } #[test] fn display() { let config = Conv3dConfig::new([5, 1], [5, 5, 5]); let conv = config.init::(&Default::default()); assert_eq!( alloc::format!("{conv}"), "Conv3d {ch_in: 5, ch_out: 1, stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: Valid, params: 626}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = Conv3dConfig::new([5, 3], [3, 3, 3]); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10, 10, 10], &Default::default()); let _ = conv.forward(input); } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv_transpose1d.rs ================================================ use alloc::format; use burn_core as burn; use crate::conv::checks; use burn::config::Config; use burn::module::Content; use burn::module::DisplaySettings; use burn::module::Initializer; use burn::module::Module; use burn::module::ModuleDisplay; use burn::module::Param; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv_transpose1d; use burn::tensor::ops::ConvTransposeOptions; /// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer /// using the [init function](ConvTranspose1dConfig::init). #[derive(Config, Debug)] pub struct ConvTranspose1dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: usize, /// The stride of the convolution. #[config(default = "1")] pub stride: usize, /// Spacing between kernel elements. #[config(default = "1")] pub dilation: usize, /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. #[config(default = "0")] pub padding: usize, /// The padding output configuration. #[config(default = "0")] pub padding_out: usize, /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 1D transposed convolution over input tensors. #[derive(Module, Debug)] #[module(custom_display)] pub struct ConvTranspose1d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: usize, /// Size of the kernel. pub kernel_size: usize, /// Spacing between kernel elements. pub dilation: usize, /// Controls the connections between input and output channels. pub groups: usize, /// The padding configuration. pub padding: usize, /// The padding output configuration. pub padding_out: usize, /// The number of channels. pub channels: [usize; 2], } impl ModuleDisplay for ConvTranspose1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("channels", &format!("{:?}", &self.channels)) .add("stride", &self.stride) .add("kernel_size", &self.kernel_size) .add("dilation", &self.dilation) .add("groups", &self.groups) .add("padding", &self.padding) .add("padding_out", &self.padding_out) .optional() } } impl ConvTranspose1dConfig { /// Initialize a new [conv transpose 1d](ConvTranspose1d) module. pub fn init(&self, device: &B::Device) -> ConvTranspose1d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); let shape = [ self.channels[0], self.channels[1] / self.groups, self.kernel_size, ]; let fan_in = self.channels[1] / self.groups * self.kernel_size; let weight = self .initializer .init_with(shape, Some(fan_in), None, device); let mut bias = None; if self.bias { bias = Some( self.initializer .init_with([self.channels[1]], Some(fan_in), None, device), ); } ConvTranspose1d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, groups: self.groups, padding: self.padding, padding_out: self.padding_out, channels: self.channels, } } } impl ConvTranspose1d { /// Applies the forward pass on the input tensor. /// /// See also [conv_transpose1d](burn::tensor::module::conv_transpose1d). /// /// # Shapes /// /// - input: `[batch_size, channels_in, length_in]` /// - output: `[batch_size, channels_out, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { conv_transpose1d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), ConvTransposeOptions::new( [self.stride], [self.padding], [self.padding_out], [self.dilation], self.groups, ), ) } } #[cfg(test)] mod tests { use burn::tensor::ops::FloatElem; use burn::tensor::{ElementConversion, Tolerance}; use super::*; use crate::TestBackend; use burn::tensor::TensorData; type FT = FloatElem; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose1dConfig::new([5, 1], 5); let k = (config.channels[1] * config.kernel_size) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&Default::default()); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros); let conv = config.init::(&Default::default()); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn display() { let config = ConvTranspose1dConfig::new([5, 2], 5); let conv = config.init::(&Default::default()); assert_eq!( format!("{conv}"), "ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = ConvTranspose1dConfig::new([5, 3], 3); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10], &Default::default()); let _ = conv.forward(input); } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv_transpose2d.rs ================================================ use alloc::format; use burn_core as burn; use crate::conv::checks; use burn::config::Config; use burn::module::Content; use burn::module::DisplaySettings; use burn::module::Initializer; use burn::module::Module; use burn::module::ModuleDisplay; use burn::module::Param; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv_transpose2d; use burn::tensor::ops::ConvTransposeOptions; /// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer /// using the [init function](ConvTranspose2dConfig::init). #[derive(Config, Debug)] pub struct ConvTranspose2dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 2], /// The stride of the convolution. #[config(default = "[1, 1]")] pub stride: [usize; 2], /// Spacing between kernel elements. #[config(default = "[1, 1]")] pub dilation: [usize; 2], /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. #[config(default = "[0, 0]")] pub padding: [usize; 2], /// The padding output configuration. #[config(default = "[0, 0]")] pub padding_out: [usize; 2], /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 2D transposed convolution over input tensors. #[derive(Module, Debug)] #[module(custom_display)] pub struct ConvTranspose2d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: [usize; 2], /// Size of the kernel. pub kernel_size: [usize; 2], /// Spacing between kernel elements. pub dilation: [usize; 2], /// Controls the connections between input and output channels. pub groups: usize, /// Padding configuration. pub padding: [usize; 2], /// Padding output configuration. pub padding_out: [usize; 2], /// Number of channels. pub channels: [usize; 2], } impl ModuleDisplay for ConvTranspose2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("channels", &format!("{:?}", &self.channels)) .add("stride", &format!("{:?}", &self.stride)) .add("kernel_size", &format!("{:?}", &self.kernel_size)) .add("dilation", &format!("{:?}", &self.dilation)) .add("groups", &self.groups) .add("padding", &format!("{:?}", &self.padding)) .add("padding_out", &format!("{:?}", &self.padding_out)) .optional() } } impl ConvTranspose2dConfig { /// Initialize a new [conv transpose 2d](ConvTranspose2d) module. pub fn init(&self, device: &B::Device) -> ConvTranspose2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); let shape = [ self.channels[0], self.channels[1] / self.groups, self.kernel_size[0], self.kernel_size[1], ]; let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); let weight = self .initializer .init_with(shape, Some(fan_in), None, device); let mut bias = None; if self.bias { bias = Some( self.initializer .init_with([self.channels[1]], Some(fan_in), None, device), ); } ConvTranspose2d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, groups: self.groups, padding: self.padding, padding_out: self.padding_out, channels: self.channels, } } } impl ConvTranspose2d { /// Applies the forward pass on the input tensor. /// /// See also [conv_transpose2d](burn::tensor::module::conv_transpose2d). /// /// # Shapes /// /// - input: `[batch_size, channels_in, height_in, width_in]` /// - output: `[batch_size, channels_out, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { conv_transpose2d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), ConvTransposeOptions::new( self.stride, self.padding, self.padding_out, self.dilation, self.groups, ), ) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose2dConfig::new([5, 1], [5, 5]); let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&Default::default()); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); let conv = config.init::(&Default::default()); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn display() { let config = ConvTranspose2dConfig::new([5, 2], [5, 5]); let conv = config.init::(&Default::default()); assert_eq!( format!("{conv}"), "ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = ConvTranspose2dConfig::new([5, 3], [3, 3]); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10, 10], &Default::default()); let _ = conv.forward(input); } } ================================================ FILE: crates/burn-nn/src/modules/conv/conv_transpose3d.rs ================================================ use alloc::format; use burn_core as burn; use crate::conv::checks; use burn::config::Config; use burn::module::Content; use burn::module::DisplaySettings; use burn::module::Initializer; use burn::module::Module; use burn::module::ModuleDisplay; use burn::module::Param; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv_transpose3d; use burn::tensor::ops::ConvTransposeOptions; /// Configuration to create an [3D transposed convolution](ConvTranspose3d) layer /// using the [init function](ConvTranspose3dConfig::init). #[derive(Config, Debug)] pub struct ConvTranspose3dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 3], /// The stride of the convolution. #[config(default = "[1, 1, 1]")] pub stride: [usize; 3], /// Spacing between kernel elements. #[config(default = "[1, 1, 1]")] pub dilation: [usize; 3], /// Controls the connections between input and output channels. #[config(default = "1")] pub groups: usize, /// The padding configuration. #[config(default = "[0, 0, 0]")] pub padding: [usize; 3], /// The padding output configuration. #[config(default = "[0, 0, 0]")] pub padding_out: [usize; 3], /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a 3D transposed convolution over input tensors. #[derive(Module, Debug)] #[module(custom_display)] pub struct ConvTranspose3d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2, kernel_size_3]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: [usize; 3], /// Size of the kernel. pub kernel_size: [usize; 3], /// Spacing between kernel elements. pub dilation: [usize; 3], /// Controls the connections between input and output channels. pub groups: usize, /// Padding configuration. pub padding: [usize; 3], /// Padding output configuration. pub padding_out: [usize; 3], /// Number of channels. pub channels: [usize; 2], } impl ModuleDisplay for ConvTranspose3d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("channels", &format!("{:?}", &self.channels)) .add("stride", &format!("{:?}", &self.stride)) .add("kernel_size", &format!("{:?}", &self.kernel_size)) .add("dilation", &format!("{:?}", &self.dilation)) .add("groups", &self.groups) .add("padding", &format!("{:?}", &self.padding)) .add("padding_out", &format!("{:?}", &self.padding_out)) .optional() } } impl ConvTranspose3dConfig { /// Initialize a new [conv transpose 2d](ConvTranspose3d) module. pub fn init(&self, device: &B::Device) -> ConvTranspose3d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); let shape = [ self.channels[0], self.channels[1] / self.groups, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], ]; let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::(); let weight = self .initializer .init_with(shape, Some(fan_in), None, device); let mut bias = None; if self.bias { bias = Some( self.initializer .init_with([self.channels[1]], Some(fan_in), None, device), ); } ConvTranspose3d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, groups: self.groups, padding: self.padding, padding_out: self.padding_out, channels: self.channels, } } } impl ConvTranspose3d { /// Applies the forward pass on the input tensor. /// /// See also [conv_transpose3d](burn::tensor::module::conv_transpose3d). /// /// # Shapes /// /// - input: `[batch_size, channels_in, depth_in, height_in, width_in]` /// - output: `[batch_size, channels_out, depth_out, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { conv_transpose3d( input, self.weight.val(), self.bias.as_ref().map(|bias| bias.val()), ConvTransposeOptions::new( self.stride, self.padding, self.padding_out, self.dilation, self.groups, ), ) } } #[cfg(test)] mod tests { use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose3dConfig::new([5, 1], [5, 5, 5]); let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1] * config.kernel_size[2]) as f64; let k = (config.groups as f64 / k).sqrt().elem::(); let conv = config.init::(&Default::default()); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = ConvTranspose3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros); let conv = config.init::(&Default::default()); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn display() { let config = ConvTranspose3dConfig::new([5, 2], [5, 5, 5]); let conv = config.init::(&Default::default()); assert_eq!( format!("{conv}"), "ConvTranspose3d {channels: [5, 2], stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: [0, 0, 0], padding_out: [0, 0, 0], params: 1252}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = ConvTranspose3dConfig::new([5, 3], [3, 3, 3]); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10, 10, 10], &Default::default()); let _ = conv.forward(input); } } ================================================ FILE: crates/burn-nn/src/modules/conv/deform_conv2d.rs ================================================ use alloc::format; use burn::tensor::ops::DeformConvOptions; use burn_core as burn; use crate::PaddingConfig2d; use burn::config::Config; use burn::module::Initializer; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::deform_conv2d; use crate::conv::checks; /// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init). #[derive(Config, Debug)] pub struct DeformConv2dConfig { /// The number of channels. pub channels: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 2], /// The stride of the convolution. #[config(default = "[1, 1]")] pub stride: [usize; 2], /// Spacing between kernel elements. #[config(default = "[1, 1]")] pub dilation: [usize; 2], /// Controls the connections between input and output channels. #[config(default = "1")] pub weight_groups: usize, /// Offset groups. #[config(default = "1")] pub offset_groups: usize, /// The padding configuration. /// /// ### Warning /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" )] pub initializer: Initializer, } /// Applies a deformable 2D convolution over input tensors. /// /// Should be created with [DeformConv2dConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct DeformConv2d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, /// Stride of the convolution. pub stride: [usize; 2], /// Size of the kernel. pub kernel_size: [usize; 2], /// Spacing between kernel elements. pub dilation: [usize; 2], /// Controls the connections between input and output channels. pub weight_groups: usize, /// Offset groups. pub offset_groups: usize, /// The padding configuration. pub padding: PaddingConfig2d, } impl DeformConv2dConfig { /// Initialize a new [DeformConv2d](DeformConv2d) module. pub fn init(&self, device: &B::Device) -> DeformConv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups); if self.padding == PaddingConfig2d::Same { checks::check_same_padding_support(&self.kernel_size); } let shape = [ self.channels[1], self.channels[0] / self.weight_groups, self.kernel_size[0], self.kernel_size[1], ]; let k = self.kernel_size.iter().product::(); let fan_in = self.channels[0] / self.weight_groups * k; let fan_out = self.channels[1] / self.weight_groups * k; let weight = self .initializer .init_with(shape, Some(fan_in), Some(fan_out), device); let mut bias = None; if self.bias { bias = Some(self.initializer.init_with( [self.channels[1]], Some(fan_in), Some(fan_out), device, )); } DeformConv2d { weight, bias, stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, padding: self.padding.clone(), weight_groups: self.weight_groups, offset_groups: self.weight_groups, } } } impl ModuleDisplay for DeformConv2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed. let stride = format!("{:?}", self.stride); let kernel_size = format!("{:?}", self.kernel_size); let dilation = format!("{:?}", self.dilation); content .add("stride", &stride) .add("kernel_size", &kernel_size) .add("dilation", &dilation) .add("weight_groups", &self.weight_groups) .add("offset_groups", &self.offset_groups) .add_debug_attribute("padding", &self.padding) .optional() } } impl DeformConv2d { /// Applies the forward pass on the input tensor. /// /// See [deform_conv2d](burn::tensor::module::deform_conv2d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels_in, height_in, width_in]` /// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]` /// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]` /// - output: `[batch_size, channels_out, height_out, width_out]` pub fn forward( &self, input: Tensor, offset: Tensor, mask: Option>, ) -> Tensor { let [_batch_size, _channels_in, height_in, width_in] = input.dims(); let padding = self.padding .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); deform_conv2d( input, offset, self.weight.val(), mask, self.bias.as_ref().map(|bias| bias.val()), DeformConvOptions::new( self.stride, padding, self.dilation, self.weight_groups, self.offset_groups, ), ) } } #[cfg(test)] mod tests { use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; use super::*; use crate::TestBackend; use burn::tensor::TensorData; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = DeformConv2dConfig::new([5, 1], [5, 5]); let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; let k = (config.offset_groups as f64 / k).sqrt().elem::(); let conv = config.init::(&device); conv.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); let conv = config.init::(&device); assert_eq!(config.initializer, Initializer::Zeros); conv.weight.to_data().assert_approx_eq::( &TensorData::zeros::(conv.weight.shape()), Tolerance::default(), ); } #[test] fn initializer_fan_out() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, // test that fan_out is passed to `init_with()` }; let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone()); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] fn initializer_fan_with_groups_is_valid() { let device = Default::default(); TestBackend::seed(&device, 0); let init = Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: true, }; let config = DeformConv2dConfig::new([4, 4], [1, 1]) .with_initializer(init.clone()) .with_weight_groups(4); let _ = config.init::(&device); assert_eq!(config.initializer, init); } #[test] #[should_panic = "Both channels must be divisible by the number of groups."] fn channels_with_groups_is_invalid() { let device = Default::default(); let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4); let _ = config.init::(&device); } #[test] #[should_panic = "Same padding with an even kernel size is not supported"] fn same_with_even_kernel_is_invalid() { let device = Default::default(); let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same); let _ = config.init::(&device); } #[test] fn display() { let config = DeformConv2dConfig::new([5, 1], [5, 5]); let conv = config.init::(&Default::default()); assert_eq!( alloc::format!("{conv}"), "DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}" ); } #[test] #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"] fn input_channels_mismatch() { let config = DeformConv2dConfig::new([5, 3], [3, 3]); let conv = config.init::(&Default::default()); let input = Tensor::::zeros([1, 4, 10, 10], &Default::default()); let offset = Tensor::::zeros([1, 2 * 3 * 3, 10, 10], &Default::default()); let _ = conv.forward(input, offset, None); } } ================================================ FILE: crates/burn-nn/src/modules/conv/mod.rs ================================================ mod conv1d; mod conv2d; mod conv3d; mod conv_transpose1d; mod conv_transpose2d; mod conv_transpose3d; mod deform_conv2d; pub(crate) mod checks; pub use conv_transpose1d::*; pub use conv_transpose2d::*; pub use conv_transpose3d::*; pub use conv1d::*; pub use conv2d::*; pub use conv3d::*; pub use deform_conv2d::*; ================================================ FILE: crates/burn-nn/src/modules/dropout.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::backend::Backend; use burn::tensor::{Distribution, Tensor}; /// Configuration to create a [Dropout](Dropout) layer using the [init function](DropoutConfig::init). #[derive(Config, Debug)] pub struct DropoutConfig { /// The probability of randomly zeroes some elements of the input tensor during training. pub prob: f64, } /// Set at random some elements of the input tensor to zero during training. /// /// This is an effective regularization technique as describe in the paper /// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580). /// /// The input is also scaled during training to `1 / (1 - prob_keep)`. /// /// Should be created with [DropoutConfig]. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Dropout { /// The probability of randomly zeroes some elements of the input tensor during training. pub prob: f64, } impl DropoutConfig { /// Initialize a new [dropout](Dropout) module. pub fn init(&self) -> Dropout { if self.prob < 0.0 || self.prob > 1.0 { panic!( "Dropout probability should be between 0 and 1, but got {}", self.prob ); } Dropout { prob: self.prob } } } impl Dropout { /// Applies the forward pass on the input tensor. /// /// See [Dropout](Dropout) for more information. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { if !B::ad_enabled(&input.device()) || self.prob == 0.0 { return input; } let prob_keep = 1.0 - self.prob; let random = input.random_like(Distribution::Bernoulli(prob_keep)); let x = input * random; x * (1.0 / prob_keep) } } impl ModuleDisplay for Dropout { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("prob", &self.prob).optional() } } #[cfg(test)] mod tests { use super::*; use burn::tensor::Shape; #[cfg(feature = "std")] use crate::{TestAutodiffBackend, TestBackend}; #[cfg(not(feature = "std"))] use crate::TestBackend; #[cfg(feature = "std")] #[test] fn with_ad_backend_should_mark_input() { let tensor = Tensor::::ones(Shape::new([100, 100]), &Default::default()); let dropout = DropoutConfig::new(0.5).init(); let output = dropout.forward(tensor.clone()); assert_ne!(tensor.to_data(), output.to_data()); } #[test] fn without_ad_backend_should_not_change_input() { let tensor = Tensor::::ones(Shape::new([100, 100]), &Default::default()); let dropout = DropoutConfig::new(0.5).init(); let output = dropout.forward(tensor.clone()); assert_eq!(tensor.to_data(), output.to_data()); } #[test] fn display() { let config = DropoutConfig::new(0.5); let layer = config.init(); assert_eq!(alloc::format!("{layer}"), "Dropout {prob: 0.5}"); } #[test] #[should_panic = "Dropout probability should be between 0 and 1,"] fn dropout_prob_invalid() { let config = DropoutConfig::new(-10.); let _layer = config.init(); } } ================================================ FILE: crates/burn-nn/src/modules/embedding.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Initializer; use burn::module::Module; use burn::module::Param; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Int; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::embedding; /// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init). #[derive(Config, Debug)] pub struct EmbeddingConfig { /// The number of embedding vectors. pub n_embedding: usize, /// The size of each vector. pub d_model: usize, /// The type of function used to initialize neural network parameters #[config(default = "Initializer::Normal{mean:0.0, std:1.0}")] pub initializer: Initializer, } /// Lookup table to store a fix number of vectors. /// /// Should be created with [EmbeddingConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Embedding { /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized /// from a normal distribution `N(0, 1)`. pub weight: Param>, } impl ModuleDisplay for Embedding { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [n_embedding, d_model] = self.weight.shape().dims(); content .add("n_embedding", &n_embedding) .add("d_model", &d_model) .optional() } } impl EmbeddingConfig { /// Initialize a new [embedding](Embedding) module. pub fn init(&self, device: &B::Device) -> Embedding { let weight = self .initializer .init([self.n_embedding, self.d_model], device); Embedding { weight } } } impl Embedding { /// Applies the forward pass on the input tensor. /// /// See also [embedding](burn::tensor::module::embedding). /// /// # Shapes /// /// - input: `[batch_size, seq_length]` /// - output: `[batch_size, seq_length, d_model]` pub fn forward(&self, input: Tensor) -> Tensor { embedding(self.weight.val(), input) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros); let embed = config.init::(&Default::default()); assert_eq!(config.initializer, Initializer::Zeros); embed.weight.to_data().assert_approx_eq::( &TensorData::zeros::(embed.weight.shape()), Tolerance::default(), ); } #[test] fn display() { let config = EmbeddingConfig::new(100, 10); let embed = config.init::(&Default::default()); assert_eq!( alloc::format!("{embed}"), "Embedding {n_embedding: 100, d_model: 10, params: 1000}" ); } } ================================================ FILE: crates/burn-nn/src/modules/interpolate/interpolate1d.rs ================================================ use alloc::format; use burn::tensor::module::interpolate; use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::InterpolateOptions; use super::InterpolateMode; /// Configuration for the 1D interpolation module. /// /// This struct defines the configuration options for the 1D interpolation operation. /// It allows specifying the output size, scale factor, and interpolation mode. #[derive(Config, Debug)] pub struct Interpolate1dConfig { /// Output size of the interpolated tensor. /// If specified, this takes precedence over `scale_factor`. #[config(default = "None")] pub output_size: Option, /// Scale factor for resizing the input tensor. /// This is used when `output_size` is not specified. #[config(default = "None")] pub scale_factor: Option, /// Interpolation mode to use for resizing. /// Determines how the output values are calculated. #[config(default = "InterpolateMode::Nearest")] pub mode: InterpolateMode, /// If `true`, the input and output tensors are aligned by their corner pixels. /// If `false`, half-pixel coordinate mapping is used instead. #[config(default = true)] pub align_corners: bool, } /// Interpolate module for resizing 1D tensors with shape [N, C, L]. /// /// This struct represents a 1D interpolation module that can resize tensors /// using various interpolation methods. It provides flexibility in specifying /// either an output size or a scale factor for resizing, along with options /// for the interpolation mode. /// /// The module can be used to upsample or downsample 1D tensors, preserving the /// number of channels and batch size while adjusting the length dimension. /// /// The module can be created using the [Interpolate1dConfig] struct and the /// `init` method, which returns an instance of the [Interpolate1d] struct. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Interpolate1d { /// Output size of the interpolated tensor pub output_size: Option, /// Scale factor for resizing the input tensor pub scale_factor: Option, /// Interpolation mode used for resizing pub mode: InterpolateMode, /// Whether to align corner pixels pub align_corners: bool, } impl Interpolate1dConfig { /// Initialize the interpolation module pub fn init(self) -> Interpolate1d { Interpolate1d { output_size: self.output_size, scale_factor: self.scale_factor, mode: self.mode, align_corners: self.align_corners, } } } impl Interpolate1d { /// Performs the forward pass of the 1D interpolation module /// /// # Arguments /// /// * `input` - Input tensor with shape [N, C, L] /// /// # Returns /// /// Resized tensor with shape [N, C, L'], where L' is determined by /// the output_size or scale_factor specified in the module configuration /// /// # Example /// /// ```ignore /// let input = Tensor::::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device); /// let interpolate = Interpolate1dConfig::new() /// .with_output_size(Some(128)) /// .init(); /// let output = interpolate.forward(input); /// assert_eq!(output.dims(), [1, 3, 128]); /// ``` pub fn forward(&self, input: Tensor) -> Tensor { let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor); // Use the interpolate operation to resize the temporal input tensor // by adding a new dimension for the interpolation axis let input = input.unsqueeze_dim(2); let result = interpolate( input, [1, output_size], InterpolateOptions::new(self.mode.clone().into()) .with_align_corners(self.align_corners), ); result.squeeze_dims(&[2]) } } /// Calculate output size based on input dimensions, output size, and scale factor /// /// # Arguments /// /// * `input_dims` - Input dimensions of the tensor /// * `output_size` - Output size for the interpolated tensor /// * `scale_factor` - Scale factor for resizing the tensor /// /// # Returns /// /// Output size for the interpolated tensor /// /// # Panics /// /// Panics if neither output_size nor scale_factor is provided /// or if the scale factor is too large fn calculate_output_size( input_dims: [usize; 3], output_size: Option, scale_factor: Option, ) -> usize { match (output_size, scale_factor) { (Some(output_size), None) => { // Use provided output_size } (None, Some(scale_factor)) => { // Calculate output size based on scale factor let [_, _, l] = input_dims; let new_dim = (l as f64) * (scale_factor as f64); if new_dim > usize::MAX as f64 { panic!("Scale factor is too large"); } new_dim as usize } _ => panic!("Either output_size or scale_factor must be provided"), } } impl ModuleDisplay for Interpolate1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add_debug_attribute("mode", &self.mode) .add("output_size", &format!("{:?}", self.output_size)) .add("scale_factor", &self.scale_factor) .optional() } } #[cfg(test)] mod tests { use burn::tensor::Distribution; use super::*; use crate::TestBackend; #[test] fn test_calculate_output_size() { let input_dims = [1, 1, 4]; let output_size = calculate_output_size(input_dims, Some(2), None); assert_eq!(output_size, 2); let output_size = calculate_output_size(input_dims, None, Some(2.0)); assert_eq!(output_size, 8); let output_size = calculate_output_size(input_dims, None, Some(0.5)); assert_eq!(output_size, 2); let output_size = calculate_output_size(input_dims, None, Some(1.5)); assert_eq!(output_size, 6); } #[test] #[should_panic(expected = "Either output_size or scale_factor must be provided")] fn test_panic() { let input_dims = [1, 1, 4]; calculate_output_size(input_dims, None, None); } #[test] #[should_panic(expected = "Scale factor is too large")] fn test_large_scale_factor() { let input_dims = [1, 1, usize::MAX - 1]; calculate_output_size(input_dims, None, Some(2.0)); } #[test] fn test_module() { let input = Tensor::::random( [2, 3, 4], Distribution::Uniform(0.0, 1.0), &Default::default(), ); // Test with output_size let config = Interpolate1dConfig::new().with_output_size(Some(8)); let interpolate = config.init(); let output = interpolate.forward(input.clone()); assert_eq!(output.dims(), [2, 3, 8]); // Test with scale_factor let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5)); let interpolate = config.init(); let output = interpolate.forward(input.clone()); assert_eq!(output.dims(), [2, 3, 2]); // Test with different interpolation mode let config = Interpolate1dConfig::new() .with_output_size(Some(6)) .with_mode(InterpolateMode::Linear); let interpolate = config.init(); let output = interpolate.forward(input); assert_eq!(output.dims(), [2, 3, 6]); } #[test] fn display() { let config = Interpolate1dConfig::new().with_output_size(Some(20)); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "Interpolate1d {mode: Nearest, output_size: Some(20), \ scale_factor: None}" ); } } ================================================ FILE: crates/burn-nn/src/modules/interpolate/interpolate2d.rs ================================================ use alloc::format; use burn::tensor::module::interpolate; use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::InterpolateOptions; use super::InterpolateMode; /// Configuration for the 2D interpolation module. /// /// This struct defines the configuration options for the 2D interpolation operation. /// It allows specifying the output size, scale factor, and interpolation mode. #[derive(Config, Debug)] pub struct Interpolate2dConfig { /// Output size of the interpolated tensor. /// If specified, this takes precedence over `scale_factor`. #[config(default = "None")] pub output_size: Option<[usize; 2]>, /// Scale factor for resizing the input tensor. /// This is used when `output_size` is not specified. #[config(default = "None")] pub scale_factor: Option<[f32; 2]>, /// Interpolation mode to use for resizing. /// Determines how the output values are calculated. #[config(default = "InterpolateMode::Nearest")] pub mode: InterpolateMode, /// If `true`, the input and output tensors are aligned by their corner pixels. /// If `false`, half-pixel coordinate mapping is used instead. #[config(default = true)] pub align_corners: bool, } /// Interpolate module for resizing tensors with shape [N, C, H, W]. /// /// This struct represents an interpolation module that can resize tensors /// using various interpolation methods. It provides flexibility in specifying /// either an output size or a scale factor for resizing, along with options /// for the interpolation mode. /// /// The module can be used to upsample or downsample tensors, preserving the /// number of channels and batch size while adjusting the height and width /// dimensions. /// /// The module can be created using the [Interpolate2dConfig] struct and the /// `init` method, which returns an instance of the [Interpolate2d] struct. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Interpolate2d { /// Output size of the interpolated tensor pub output_size: Option<[usize; 2]>, /// Scale factor for resizing the input tensor pub scale_factor: Option<[f32; 2]>, /// Interpolation mode used for resizing pub mode: InterpolateMode, /// Whether to align corner pixels pub align_corners: bool, } impl Interpolate2dConfig { /// Initialize the interpolation module pub fn init(self) -> Interpolate2d { Interpolate2d { output_size: self.output_size, scale_factor: self.scale_factor, mode: self.mode, align_corners: self.align_corners, } } } impl Interpolate2d { /// Performs the forward pass of the interpolation module /// /// # Arguments /// /// * `input` - Input tensor with shape [N, C, H, W] /// /// # Returns /// /// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by /// the output_size or scale_factor specified in the module configuration /// /// # Example /// /// ```ignore /// let input = Tensor::::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device); /// let interpolate = Interpolate2dConfig::new() /// .with_output_size(Some([128, 128])) /// .init(); /// let output = interpolate.forward(input); /// assert_eq!(output.dims(), [1, 3, 128, 128]); /// ``` pub fn forward(&self, input: Tensor) -> Tensor { let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor); interpolate( input, output_size, InterpolateOptions::new(self.mode.clone().into()) .with_align_corners(self.align_corners), ) } } /// Calculates the output size for tensor interpolation. /// /// # Arguments /// /// * `input_dims` - The dimensions of the input tensor [N, C, H, W]. /// * `output_size` - Optional desired output size [H', W']. /// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w]. /// /// # Returns /// /// A tuple [H', W'] representing the calculated output size. /// /// # Panics /// /// Panics if neither `output_size` nor `scale_factor` is provided, /// or if the scale factor results in dimensions exceeding usize::MAX. fn calculate_output_size( input_dims: [usize; 4], output_size: Option<[usize; 2]>, scale_factor: Option<[f32; 2]>, ) -> [usize; 2] { match (output_size, scale_factor) { (Some(output_size), None) => { // Use provided output_size } (None, Some(scale_factor)) => { // Calculate output size based on scale factor let [_, _, h, w] = input_dims; let new_dim_h = (h as f64) * (scale_factor[0] as f64); if new_dim_h > usize::MAX as f64 { panic!("Scale factor for height is too large"); } let new_dim_w = (w as f64) * (scale_factor[1] as f64); if new_dim_w > usize::MAX as f64 { panic!("Scale factor for width is too large"); } [new_dim_h as usize, new_dim_w as usize] } _ => panic!("Either output_size or scale_factor must be provided"), } } impl ModuleDisplay for Interpolate2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add_debug_attribute("mode", &self.mode) .add("output_size", &format!("{:?}", self.output_size)) .add("scale_factor", &self.scale_factor) .optional() } } #[cfg(test)] mod tests { use burn::tensor::Distribution; use crate::TestBackend; use super::*; #[test] fn test_calculate_output_size() { let input_dims = [1, 1, 4, 4]; let output_size = calculate_output_size(input_dims, Some([2, 2]), None); assert_eq!(output_size, [2, 2]); let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0])); assert_eq!(output_size, [8, 8]); let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5])); assert_eq!(output_size, [2, 2]); let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5])); assert_eq!(output_size, [8, 6]); } #[test] #[should_panic(expected = "Either output_size or scale_factor must be provided")] fn test_missing_params() { calculate_output_size([1, 1, 4, 4], None, None); } #[test] #[should_panic(expected = "Scale factor for height is too large")] fn test_infinite_height() { calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0])); } #[test] #[should_panic(expected = "Scale factor for width is too large")] fn test_infinite_width() { calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0])); } #[test] fn test_module() { let input = Tensor::::random( [2, 3, 4, 4], Distribution::Uniform(0.0, 1.0), &Default::default(), ); // Test with output_size let config = Interpolate2dConfig::new().with_output_size(Some([8, 8])); let interpolate = config.init(); let output = interpolate.forward(input.clone()); assert_eq!(output.dims(), [2, 3, 8, 8]); // Test with scale_factor let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5])); let interpolate = config.init(); let output = interpolate.forward(input.clone()); assert_eq!(output.dims(), [2, 3, 2, 2]); // Test with different interpolation mode let config = Interpolate2dConfig::new() .with_output_size(Some([6, 6])) .with_mode(InterpolateMode::Linear); let interpolate = config.init(); let output = interpolate.forward(input); assert_eq!(output.dims(), [2, 3, 6, 6]); } #[test] fn display() { let config = Interpolate2dConfig::new().with_output_size(Some([20, 20])); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \ scale_factor: None}" ); } } ================================================ FILE: crates/burn-nn/src/modules/interpolate/mod.rs ================================================ mod interpolate1d; mod interpolate2d; pub use interpolate1d::*; pub use interpolate2d::*; use burn_core as burn; use burn::config::Config; use burn::tensor::ops::InterpolateMode as OpsInterpolateMode; /// Algorithm used for downsampling and upsampling /// /// This enum defines different interpolation modes for resampling data. #[derive(Config, Debug)] pub enum InterpolateMode { /// Nearest-neighbor interpolation /// /// This mode selects the value of the nearest sample point for each output pixel. /// It is applicable for both temporal and spatial data. Nearest, /// Linear interpolation /// /// This mode calculates the output value using linear /// interpolation between nearby sample points. /// /// It is applicable for both temporal and spatial data. Linear, /// Cubic interpolation /// /// This mode uses cubic interpolation to calculate the output value /// based on surrounding sample points. /// /// It is applicable for both temporal and spatial data and generally /// provides smoother results than linear interpolation. Cubic, /// Lanczos3 interpolation /// /// This mode uses a 6-tap sinc-based Lanczos filter (a=3) to calculate /// the output value. It generally provides high-quality results, /// especially for downsampling. Lanczos, } impl From for OpsInterpolateMode { fn from(mode: InterpolateMode) -> Self { match mode { InterpolateMode::Nearest => OpsInterpolateMode::Nearest, InterpolateMode::Linear => OpsInterpolateMode::Bilinear, InterpolateMode::Cubic => OpsInterpolateMode::Bicubic, InterpolateMode::Lanczos => OpsInterpolateMode::Lanczos3, } } } ================================================ FILE: crates/burn-nn/src/modules/linear.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Param; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::module::linear; use burn::tensor::{Tensor, backend::Backend}; /// Configuration to create a [`Linear`] layer using the [init function](LinearConfig::init). #[derive(Config, Debug)] pub struct LinearConfig { /// The size of the input features. pub d_input: usize, /// The size of the output features. pub d_output: usize, /// If a bias should be applied during the linear transformation. #[config(default = true)] pub bias: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, /// The layout in which the linear parameters are stored. #[config(default = "LinearLayout::Row")] pub layout: LinearLayout, } #[derive(Config, Debug, Copy)] /// The layout in which the linear parameters are stored. /// /// This can have performance impacts. pub enum LinearLayout { /// Parameters are stored in Row major. Row, /// Parameters are stored in Col major. Col, } /// Applies a linear transformation to the input tensor. /// /// Should be created with [LinearConfig] /// /// `O = IW + b` #[derive(Module, Debug)] #[module(custom_display)] pub struct Linear { /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: /// `U(-k, k)`, where `k = sqrt(1 / d_input)` pub weight: Param>, /// Vector of size `d_output` initialized from a uniform distribution: /// `U(-k, k)`, where `k = sqrt(1 / d_input)` pub bias: Option>>, } impl LinearConfig { /// Initialize a new [`Linear`] module. pub fn init(&self, device: &B::Device) -> Linear { let weight = match self.layout { LinearLayout::Row => { let shape = [self.d_input, self.d_output]; self.initializer .init_with(shape, Some(self.d_input), Some(self.d_output), device) } LinearLayout::Col => { let shape = [self.d_output, self.d_input]; self.initializer .init_with(shape, Some(self.d_output), Some(self.d_input), device) // The param is already transposed when init. We re-transpose to have // [d_output, d_input] while saving. .save_mapper(move |tensor| { B::sync(&tensor.device()).unwrap(); let tensor = tensor.transpose(); B::sync(&tensor.device()).unwrap(); tensor }) // When loading from record we have to transpose. .load_mapper(move |tensor| { B::sync(&tensor.device()).unwrap(); let tensor = tensor.transpose(); B::sync(&tensor.device()).unwrap(); tensor }) // When loading from initialization, we have to transpose. .init_mapper(|tensor| { B::sync(&tensor.device()).unwrap(); let tensor = tensor.transpose(); B::sync(&tensor.device()).unwrap(); tensor }) } }; let bias = if self.bias { Some(self.initializer.init_with( [self.d_output], Some(self.d_input), Some(self.d_output), device, )) } else { None }; Linear { weight, bias } } } impl Linear { /// Applies the forward pass on the input tensor. /// /// # Arguments /// /// - `input` - The input tensor of shape `[..., d_input]`. /// /// # Shapes /// /// - input: `[..., d_input]` /// - output: `[..., d_output]` /// /// # Returns /// /// The transformed tensor of shape `[..., d_output]`. pub fn forward(&self, input: Tensor) -> Tensor { linear( input, self.weight.val(), self.bias.as_ref().map(|b| b.val()), ) } } impl ModuleDisplay for Linear { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, d_output] = self.weight.shape().dims(); content .add("d_input", &d_input) .add("d_output", &d_output) .add("bias", &self.bias.is_some()) .optional() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::module::ParamId; use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder}; use burn::tensor::ElementConversion; use burn::tensor::{Shape, TensorData}; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn initializer_default() { let device = Default::default(); TestBackend::seed(&device, 0); let config = LinearConfig::new(5, 5); let k = (1.0 / config.d_input as f64).sqrt().elem::(); let linear = config.init::(&device); assert_eq!( config.initializer, Initializer::KaimingUniform { gain: 1.0 / 3.0f64.sqrt(), fan_out_only: false } ); linear.weight.to_data().assert_within_range(-k..k); } #[test] fn initializer_zeros() { let device = Default::default(); TestBackend::seed(&device, 0); let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros); let linear = config.init::(&device); assert_eq!(config.initializer, Initializer::Zeros); linear.weight.to_data().assert_approx_eq::( &TensorData::zeros::(linear.weight.shape()), Tolerance::default(), ); } #[test] fn test_linear_forward_no_bias() { let device = Default::default(); TestBackend::seed(&device, 0); let value = 2.; let config = LinearConfig::new(2, 3) .with_initializer(Initializer::Constant { value }) .with_bias(false); let linear = config.init::(&device); let input = Tensor::::ones(Shape::new([1, 2]), &device); let result = linear.forward(input); let expected_result = Tensor::::from_data([[4., 4., 4.]], &device); assert_eq!(result.into_data(), expected_result.into_data()); } #[test] fn test_linear_forward_with_bias() { let device = Default::default(); TestBackend::seed(&device, 0); let device = Default::default(); let value = 2.; let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); let linear = config.init::(&device); let input = Tensor::::ones(Shape::new([1, 2]), &device); let result = linear.forward(input); let expected_result = Tensor::::from_data([[6., 6., 6.]], &device); assert_eq!(result.into_data(), expected_result.into_data()); } #[test] fn test_linear_1d() { let device = Default::default(); TestBackend::seed(&device, 0); let device = Default::default(); let value = 2.; let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); let linear = config.init::(&device); let input_1d = Tensor::::ones(Shape::new([2]), &device); let input_2d = Tensor::::ones(Shape::new([1, 2]), &device); let result_1d = linear.forward(input_1d).unsqueeze::<2>(); let result_2d = linear.forward(input_2d); assert_eq!(result_1d.into_data(), result_2d.into_data()); } #[test] fn display() { let config = LinearConfig::new(3, 5); let linear = config.init::(&Default::default()); assert_eq!( alloc::format!("{linear}"), "Linear {d_input: 3, d_output: 5, bias: true, params: 20}" ); } #[test] fn layout() { let device = Default::default(); let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col); let linear = config.init::(&device); assert_eq!(linear.weight.dims(), [6, 12], "Shape is as configured"); let recorder = BinBytesRecorder::::new(); // We go through serialization to trigger the mappers.. let record = linear.into_record(); let data = recorder.record(record, ()).unwrap(); let record = recorder.load(data.clone(), &device).unwrap(); let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row); let linear_row = config.init::(&device).load_record(record); assert_eq!( linear_row.weight.dims(), [12, 6], "Shape should be transposed" ); let record = recorder.load(data.clone(), &device).unwrap(); let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col); let linear_col = config.init::(&device).load_record(record); assert_eq!( linear_col.weight.dims(), [6, 12], "Shape should be as configured" ); // We go through serialization to trigger the mappers. // // The test will fail if the mapper is not correctly given to the module after loading a // record. let record = linear_col.into_record(); let data = recorder.record(record, ()).unwrap(); let record = recorder.load(data, &device).unwrap(); let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col); let linear_col = config.init::(&device).load_record(record); assert_eq!( linear_col.weight.dims(), [6, 12], "Shape should be as configured" ); } #[test] fn col_row_same_result() { let device = Default::default(); let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col); let linear_col = config_col.init::(&device); let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device); let value = linear_col.forward(signal.clone()); let data_1 = value.into_data(); let weights = linear_col.weight.val().into_data(); let weights = Tensor::from_data(weights, &device); let linear = Linear { weight: Param::initialized(ParamId::new(), weights), bias: linear_col .bias .map(|b| Param::initialized(ParamId::new(), b.val())), }; let value = linear.forward(signal); let data_2 = value.into_data(); data_1.assert_approx_eq::(&data_2, Default::default()); } } ================================================ FILE: crates/burn-nn/src/modules/mod.rs ================================================ /// Attention module pub mod attention; /// Cache module pub mod cache; /// Convolution module pub mod conv; /// Pooling module pub mod pool; /// Transformer module pub mod transformer; /// Interpolate module pub mod interpolate; mod dropout; mod embedding; mod linear; mod noise; mod pos_encoding; mod rnn; mod rope_encoding; mod unfold; pub mod norm; pub use norm::{batch::*, group::*, instance::*, layer::*, rms::*}; pub use dropout::*; pub use embedding::*; pub use linear::*; pub use noise::*; pub use pos_encoding::*; pub use rnn::*; pub use rope_encoding::*; pub use unfold::*; ================================================ FILE: crates/burn-nn/src/modules/noise.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::backend::Backend; use burn::tensor::{Distribution, Tensor}; /// Configuration to create a [GaussianNoise](GaussianNoise) layer using the [init function](GaussianNoiseConfig::init). #[derive(Config, Debug)] pub struct GaussianNoiseConfig { /// Standard deviation of the normal noise distribution. pub std: f64, } /// Add pseudorandom Gaussian noise to an arbitrarily shaped tensor. /// /// This is an effective regularization technique that also contributes to data augmentation. /// Please keep in mind that the value of [std](GaussianNoise::std) should be chosen with care in order to avoid /// distortion. /// /// Should be created with [GaussianNoiseConfig]. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct GaussianNoise { /// Standard deviation of the normal noise distribution. pub std: f64, } impl GaussianNoiseConfig { /// Initialize a new [Gaussian noise](GaussianNoise) module. pub fn init(&self) -> GaussianNoise { if self.std.is_sign_negative() { panic!( "Standard deviation is required to be non-negative, but got {}", self.std ); } GaussianNoise { std: self.std } } } impl GaussianNoise { /// Applies the forward pass on the input tensor. /// /// See [GaussianNoise](GaussianNoise) for more information. /// /// # Shapes /// /// - input: `[..., any]` /// - output: `[..., any]` pub fn forward(&self, input: Tensor) -> Tensor { if B::ad_enabled(&input.device()) && self.std != 0.0 { let noise = Tensor::random( input.shape(), Distribution::Normal(0.0, self.std), &input.device(), ); input + noise } else { input } } } impl ModuleDisplay for GaussianNoise { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("std", &self.std).optional() } } #[cfg(test)] mod tests { use super::*; use burn::tensor::Shape; #[cfg(feature = "std")] use crate::{TestAutodiffBackend, TestBackend}; #[cfg(not(feature = "std"))] use crate::TestBackend; #[cfg(feature = "std")] #[test] fn with_ad_backend_should_mark_input() { let tensor = Tensor::::ones(Shape::new([100, 100]), &Default::default()); let noise = GaussianNoiseConfig::new(0.5).init(); let output = noise.forward(tensor.clone()); assert_ne!(tensor.to_data(), output.to_data()); } #[test] fn without_ad_backend_should_not_change_input() { let tensor = Tensor::::ones(Shape::new([100, 100]), &Default::default()); let noise = GaussianNoiseConfig::new(0.5).init(); let output = noise.forward(tensor.clone()); assert_eq!(tensor.to_data(), output.to_data()); } #[test] #[should_panic(expected = "Standard deviation is required to be non-negative")] fn negative_std_should_panic() { GaussianNoiseConfig { std: -0.5 }.init(); } #[test] fn display() { let config = GaussianNoiseConfig::new(0.5); let layer = config.init(); assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}"); } } ================================================ FILE: crates/burn-nn/src/modules/norm/batch.rs ================================================ use burn_core as burn; use burn::module::Initializer; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::{Tensor, backend::Backend}; use burn::{ config::Config, module::{Module, Param, RunningState}, }; /// [`BatchNorm`] Configuration. /// /// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`]. #[derive(Config, Debug)] pub struct BatchNormConfig { /// The number of features. pub num_features: usize, /// A value required for numerical stability. Default: 1e-5 #[config(default = 1e-5)] pub epsilon: f64, /// Momentum used to update the metrics. Default: 0.1 #[config(default = 0.1)] pub momentum: f64, } /// Applies Batch Normalization over a tensor. /// /// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167). /// /// Assumes input tensor is of shape ``[batch_size, channels, ...]``. /// /// `Y = norm(X) * γ + β` /// /// Where: /// - `X` is the input tensor /// - `Y` is the output tensor /// - `norm` is the normalization function /// - `γ` is the learnable weight /// - `β` is the learnable bias /// /// Should be created using [`BatchNormConfig`]. #[derive(Module, Debug)] #[module(custom_display)] pub struct BatchNorm { /// The learnable weight gamma. pub gamma: Param>, /// The learnable weight beta. pub beta: Param>, /// The running mean. pub running_mean: RunningState>, /// The running variance. pub running_var: RunningState>, /// Momentum used to update the metrics. pub momentum: f64, /// A value required for numerical stability. pub epsilon: f64, } impl BatchNormConfig { /// Initializes a new [batch norm](BatchNorm) module. pub fn init(&self, device: &B::Device) -> BatchNorm { let gamma = Initializer::Ones.init([self.num_features], device); let beta = Initializer::Zeros.init([self.num_features], device); let running_mean = Tensor::zeros([self.num_features], device); let running_var = Tensor::ones([self.num_features], device); BatchNorm { gamma, beta, running_mean: RunningState::new(running_mean), running_var: RunningState::new(running_var), momentum: self.momentum, epsilon: self.epsilon, } } } impl BatchNorm { /// Applies the forward pass on the input tensor. /// /// See [`BatchNorm`] for more information. /// /// # Shapes /// /// - `input`: ``[batch_size, channels, ...]`` /// - `output`: ``[batch_size, channels, ...]`` /// /// # Panics /// /// This function will panic if the input tensor has rank < 2. pub fn forward(&self, input: Tensor) -> Tensor { // Should be move to a compilation error when const generic support that kind of // validation. https://github.com/rust-lang/rust/issues/76560 if D < 2 { panic!( "BatchNorm can only be applied on tensors of rank >= 2 with the following shape \ [batch_size, channels, ...], received {}D tensor", D ); } match B::ad_enabled(&input.device()) { true => self.forward_train(input), false => self.forward_inference(input), } } fn forward_inference(&self, input: Tensor) -> Tensor { let device = input.device(); let channels = input.dims()[1]; let mean = self.running_mean.value().to_device(&device); let var = self.running_var.value().to_device(&device); let mut shape = [1; D]; shape[1] = channels; self.forward_shared(input, mean.reshape(shape), var.reshape(shape)) } fn forward_train(&self, input: Tensor) -> Tensor { let device = input.device(); let dims = input.dims(); let batch_size = dims[0]; let channels = dims[1]; let mut shape_unsqueeze = [1; D]; let mut flatten_size = batch_size; shape_unsqueeze[1] = channels; for dim in dims.iter().take(D).skip(2) { flatten_size *= dim; } let mean = input .clone() .swap_dims(0, 1) .reshape([channels, flatten_size]) .mean_dim(1) .reshape(shape_unsqueeze); let var = input .clone() .sub(mean.clone()) .square() .swap_dims(0, 1) .reshape([channels, flatten_size]) .mean_dim(1) .reshape(shape_unsqueeze); let running_mean = self.running_mean.value_sync().to_device(&device); let running_var = self.running_var.value_sync().to_device(&device); let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add( mean.clone() .detach() .mul_scalar(self.momentum) .reshape([channels]), ); let running_var = running_var.mul_scalar(1.0 - self.momentum).add( var.clone() .detach() .mul_scalar(self.momentum) .reshape([channels]), ); self.running_mean.update(running_mean.detach()); self.running_var.update(running_var.detach()); self.forward_shared(input, mean, var) } fn forward_shared( &self, x: Tensor, mean: Tensor, var: Tensor, ) -> Tensor { let channels = x.dims()[1]; let mut shape = [1; D]; shape[1] = channels; let std = var.add_scalar(self.epsilon).sqrt(); let x = x.sub(mean); let x = x.div(std); let x = x.mul(self.gamma.val().reshape(shape)); x.add(self.beta.val().reshape(shape)) } } impl ModuleDisplay for BatchNorm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [num_features] = self.beta.shape().dims(); content .add("num_features", &num_features) .add("momentum", &self.momentum) .add("epsilon", &self.epsilon) .optional() } } #[cfg(feature = "std")] #[cfg(test)] mod tests_1d { use super::*; use crate::TestAutodiffBackend; use burn::module::AutodiffModule; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn batch_norm_forward_train() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let output = module.forward(input_tensor(&device)); output .to_data() .assert_approx_eq::(&expected_train(), Tolerance::rel_abs(0.1, 0.001)); } #[test] fn batch_norm_forward_inference() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); module.forward(input_tensor(&device)); let module = module.valid(); let output = module.forward(input_tensor(&device)); output .to_data() .assert_approx_eq::(&expected_valid(), Tolerance::default()); } fn expected_valid() -> TensorData { TensorData::from([ [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]], [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]], ]) } fn expected_train() -> TensorData { TensorData::from([ [ [1.1483e+00, 3.7521e-01], [1.6272e-03, 7.5067e-01], [1.6204e+00, -4.5168e-02], ], [ [6.8856e-02, -1.5923e+00], [-1.6318e+00, 8.7949e-01], [-5.3368e-01, -1.0416e+00], ], ]) } fn input_tensor(device: &B::Device) -> Tensor { Tensor::::from_floats( [ [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]], [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]], ], device, ) } #[test] fn batch_norm_forward_train_inference() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); module.forward(input_tensor(&device)); let module = module.valid(); let output = module.forward(input_tensor(&device)); output .to_data() .assert_approx_eq::(&expected_valid(), Tolerance::default()); let module = module.train::(); let output = module.forward(input_tensor(&device)); output .to_data() .assert_approx_eq::(&expected_train(), Tolerance::default()); } } #[cfg(feature = "std")] #[cfg(test)] mod tests_2d { use super::*; use crate::TestAutodiffBackend; use burn::module::AutodiffModule; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn batch_norm_forward_train() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let output = module.forward(input_tensor(&device)); let expected = TensorData::from([ [ [[1.5136, 0.7506], [-1.2216, 0.1477]], [[0.3135, 1.2252], [-0.4150, 0.6130]], [[1.4186, 0.3372], [-1.5183, 1.5262]], ], [ [[0.4483, -1.1914], [-1.2010, 0.7537]], [[-1.6752, 1.3822], [-0.5058, -0.9381]], [[0.0200, -0.3097], [-0.5715, -0.9026]], ], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::rel_abs(0.1, 0.001)); } #[test] fn batch_norm_forward_inference() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); module.forward(input_tensor(&device)); let module = module.valid(); let output = module.forward(input_tensor(&device)); let expected = TensorData::from([ [ [[0.9538, 0.7103], [0.0808, 0.5179]], [[0.6015, 0.8910], [0.3703, 0.6966]], [[0.9171, 0.6912], [0.3037, 0.9395]], ], [ [[0.6138, 0.0904], [0.0874, 0.7113]], [[-0.0297, 0.9408], [0.3415, 0.2042]], [[0.6250, 0.5561], [0.5013, 0.4323]], ], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn batch_norm_running_mean() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let _output = module.forward(input_tensor(&device)); let running_mean = module.running_mean.value_sync(); let expected = TensorData::from([0.0499, 0.0532, 0.0656]); running_mean .reshape([3]) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn batch_norm_running_var() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let _output = module.forward(input_tensor(&device)); let running_var = module.running_var.value_sync(); let expected = TensorData::from([0.9106, 0.9105, 0.9045]); running_var .reshape([3]) .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn batch_norm_running_mean_inner_module() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let _output = module.forward(input_tensor(&device)); let module_valid = module.valid(); let running_mean = module_valid.running_mean.value(); let running_mean_after = module.running_mean.value(); running_mean_after .into_data() .assert_approx_eq::(&running_mean.into_data(), Tolerance::default()); } #[test] fn batch_norm_grads() { let device = Default::default(); let module = BatchNormConfig::new(3).init::(&device); let input = input_tensor(&device).require_grad(); let output = module.forward(input.clone()); let grads = output.backward(); let tolerance = Tolerance::rel_abs(0.1, 0.001); let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]); module .gamma .grad(&grads) .unwrap() .reshape([3]) .into_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([8., 8., 8.]); module .beta .grad(&grads) .unwrap() .reshape([3]) .into_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([ [ [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]], [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]], ], [ [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]], [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]], [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]], ], ]); input .grad(&grads) .unwrap() .into_data() .assert_approx_eq::(&expected, tolerance); } fn input_tensor(device: &B::Device) -> Tensor { Tensor::::from_floats( [ [ [[0.9601, 0.7277], [0.1270, 0.5441]], [[0.6272, 0.9034], [0.4066, 0.7179]], [[0.9378, 0.7230], [0.3544, 0.9591]], ], [ [[0.6356, 0.1362], [0.1333, 0.7287]], [[0.0249, 0.9509], [0.3791, 0.2481]], [[0.6600, 0.5945], [0.5424, 0.4767]], ], ], device, ) } #[test] fn display() { let batch_norm = BatchNormConfig::new(3).init::(&Default::default()); assert_eq!( format!("{batch_norm}"), "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}" ); } } ================================================ FILE: crates/burn-nn/src/modules/norm/group.rs ================================================ use burn::module::Initializer; use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::Param; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init). #[derive(Debug, Config)] pub struct GroupNormConfig { /// The number of groups to separate the channels into pub num_groups: usize, /// The number of channels expected in the input pub num_channels: usize, /// A value required for numerical stability. Default: 1e-5 #[config(default = 1e-5)] pub epsilon: f64, /// A boolean value that when set to `true`, this module has learnable /// per-channel affine parameters initialized to ones (for weights) /// and zeros (for biases). Default: `true` #[config(default = true)] pub affine: bool, } /// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494). /// /// `Y = groupnorm(X) * γ + β` /// /// Where: /// - `X` is the input tensor /// - `Y` is the output tensor /// - `γ` is the learnable weight /// - `β` is the learnable bias /// /// Should be created using [GroupNormConfig](GroupNormConfig). #[derive(Module, Debug)] #[module(custom_display)] pub struct GroupNorm { /// The learnable weight pub gamma: Option>>, /// The learnable bias pub beta: Option>>, /// The number of groups to separate the channels into pub num_groups: usize, /// The number of channels expected in the input pub num_channels: usize, /// A value required for numerical stability pub epsilon: f64, /// A boolean value that when set to `true`, this module has learnable pub affine: bool, } impl ModuleDisplay for GroupNorm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("num_groups", &self.num_groups) .add("num_channels", &self.num_channels) .add("epsilon", &self.epsilon) .add("affine", &self.affine) .optional() } } impl GroupNormConfig { /// Initialize a new [group norm](GroupNorm) module. pub fn init(&self, device: &B::Device) -> GroupNorm { assert_eq!( self.num_channels % self.num_groups, 0, "The number of channels must be divisible by the number of groups" ); let (gamma, beta) = if self.affine { let gamma = Initializer::Ones.init([self.num_channels], device); let beta = Initializer::Zeros.init([self.num_channels], device); (Some(gamma), Some(beta)) } else { (None, None) }; GroupNorm { num_groups: self.num_groups, num_channels: self.num_channels, gamma, beta, epsilon: self.epsilon, affine: self.affine, } } } impl GroupNorm { /// Applies the forward pass on the input tensor. /// /// See [GroupNorm](GroupNorm) for more information. /// /// # Shapes /// /// - input: `[batch_size, num_channels, *]` /// - output: `[batch_size, num_channels, *]` pub fn forward(&self, input: Tensor) -> Tensor { if input.shape()[1] != self.num_channels { panic!( "The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}", self.num_channels, input.shape()[1] ); } let gamma = self.gamma.as_ref().map(|x| x.val()); let beta = self.beta.as_ref().map(|x| x.val()); group_norm( input, gamma, beta, self.num_groups, self.epsilon, self.affine, ) } } /// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494). /// /// `Y = groupnorm(X) * γ + β` /// /// Where: /// - `X` is the input tensor /// - `Y` is the output tensor /// - `γ` is the learnable weight /// - `β` is the learnable bias /// pub(crate) fn group_norm( input: Tensor, gamma: Option>, beta: Option>, num_groups: usize, epsilon: f64, affine: bool, ) -> Tensor { if (beta.is_none() || gamma.is_none()) && affine { panic!("Affine is set to true, but gamma or beta is None"); } let shape = input.shape(); if shape.num_elements() <= 2 { panic!( "input rank for GroupNorm should be at least 3, but got {}", shape.num_elements() ); } let batch_size = shape[0]; let num_channels = shape[1]; let hidden_size = shape[2..].iter().product::() * num_channels / num_groups; let input = input.reshape([batch_size, num_groups, hidden_size]); let mean = input.clone().sum_dim(2) / hidden_size as f64; let input = input.sub(mean); let var = input.clone().square().sum_dim(2) / hidden_size as f64; let input_normalized = input.div(var.add_scalar(epsilon).sqrt()); if affine { let mut affine_shape = [1; D]; affine_shape[1] = num_channels; input_normalized .reshape(shape) .mul(gamma.clone().unwrap().reshape(affine_shape)) .add(beta.clone().unwrap().reshape(affine_shape)) } else { input_normalized.reshape(shape) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use alloc::format; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn group_norm_forward_affine_false() { let device = Default::default(); let module = GroupNormConfig::new(2, 6) .with_affine(false) .init::(&device); assert!(module.gamma.is_none()); assert!(module.beta.is_none()); let input = Tensor::::from_data( TensorData::from([ [ [-0.3034, 0.2726, -0.9659], [-1.1845, -1.3236, 0.0172], [1.9507, 1.2554, -0.8625], [1.0682, 0.3604, 0.3985], [-0.4957, -0.4461, -0.9721], [1.5157, -0.1546, -0.5596], ], [ [-1.6698, -0.4040, -0.7927], [0.3736, -0.0975, -0.1351], [-0.9461, 0.5461, -0.6334], [-1.0919, -0.1158, 0.1213], [-0.9535, 0.1281, 0.4372], [-0.2845, 0.3488, 0.5641], ], ]), &device, ); let output = module.forward(input); let expected = TensorData::from([ [ [-0.1653, 0.3748, -0.7866], [-0.9916, -1.1220, 0.1353], [1.9485, 1.2965, -0.6896], [1.2769, 0.3628, 0.4120], [-0.7427, -0.6786, -1.3578], [1.8547, -0.3022, -0.8252], ], [ [-1.9342, 0.0211, -0.5793], [1.2223, 0.4945, 0.4365], [-0.8163, 1.4887, -0.3333], [-1.7960, -0.0392, 0.3875], [-1.5469, 0.3998, 0.9561], [-0.3428, 0.7970, 1.1845], ], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn group_norm_forward_affine_true() { let device = Default::default(); let module = GroupNormConfig::new(3, 6) .with_affine(true) .init::(&device); let tolerance = Tolerance::permissive(); module .gamma .as_ref() .expect("gamma should not be None") .val() .to_data() .assert_approx_eq::(&TensorData::ones::([6]), tolerance); module .beta .as_ref() .expect("beta should not be None") .val() .to_data() .assert_approx_eq::(&TensorData::zeros::([6]), tolerance); let input = Tensor::::from_data( TensorData::from([ [ [0.3345, 0.4429, 0.6639], [0.5041, 0.4175, 0.8437], [0.6159, 0.3758, 0.4071], [0.5417, 0.5785, 0.7671], [0.3837, 0.9883, 0.0420], [0.4808, 0.8989, 0.6144], ], [ [0.3930, 0.2098, 0.0602], [0.2298, 0.9425, 0.0333], [0.7409, 0.8172, 0.8879], [0.4846, 0.0486, 0.2029], [0.6741, 0.9765, 0.6864], [0.2827, 0.5534, 0.2125], ], ]), &device, ); let output = module.forward(input); let expected = TensorData::from([ [ [-1.1694, -0.5353, 0.7572], [-0.1775, -0.6838, 1.8087], [0.5205, -1.3107, -1.0723], [-0.0459, 0.2351, 1.6734], [-0.5796, 1.3218, -1.6544], [-0.2744, 1.0406, 0.1459], ], [ [0.2665, -0.3320, -0.8205], [-0.2667, 2.0612, -0.9085], [0.6681, 0.9102, 1.1345], [-0.1453, -1.5287, -1.0389], [0.4253, 1.5962, 0.4731], [-1.0903, -0.0419, -1.3623], ], ]); output .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn display() { let config = GroupNormConfig::new(3, 6); let group_norm = config.init::(&Default::default()); assert_eq!( format!("{group_norm}"), "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" ); } } ================================================ FILE: crates/burn-nn/src/modules/norm/instance.rs ================================================ use burn_core as burn; use crate::norm::group_norm; use burn::config::Config; use burn::module::Initializer; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::module::{Module, Param}; use burn::tensor::{Tensor, backend::Backend}; /// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init). #[derive(Debug, Config)] pub struct InstanceNormConfig { /// The number of channels expected in the input pub num_channels: usize, /// A value required for numerical stability. Default: 1e-5 #[config(default = 1e-5)] pub epsilon: f64, /// A boolean value that when set to `true`, this module has learnable /// per-channel affine parameters initialized to ones (for weights) /// and zeros (for biases). Default: `true` #[config(default = true)] pub affine: bool, } /// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022) /// /// Should be created using [InstanceNormConfig](InstanceNormConfig). #[derive(Module, Debug)] #[module(custom_display)] pub struct InstanceNorm { /// The learnable weight pub gamma: Option>>, /// The learnable bias pub beta: Option>>, /// The number of channels expected in the input pub num_channels: usize, /// A value required for numerical stability pub epsilon: f64, /// A boolean value that when set to `true`, this module has learnable pub affine: bool, } impl ModuleDisplay for InstanceNorm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("num_channels", &self.num_channels) .add("epsilon", &self.epsilon) .add("affine", &self.affine) .optional() } } impl InstanceNormConfig { /// Initialize a new [instance norm](InstanceNorm) module. pub fn init(&self, device: &B::Device) -> InstanceNorm { let (gamma, beta) = if self.affine { let gamma = Initializer::Ones.init([self.num_channels], device); let beta = Initializer::Zeros.init([self.num_channels], device); (Some(gamma), Some(beta)) } else { (None, None) }; InstanceNorm { gamma, beta, num_channels: self.num_channels, epsilon: self.epsilon, affine: self.affine, } } } impl InstanceNorm { /// Applies the forward pass on the input tensor. /// /// See also [InstanceNormConfig](InstanceNormConfig) for more information. /// /// # Shapes /// /// - input: `[batch_size, num_channels, *]` /// - output: `[batch_size, num_channels, *]` pub fn forward(&self, input: Tensor) -> Tensor { // Instance norm is equivalent to group norm when the number of groups is equal to the number of channels. let num_groups = self.num_channels; let gamma = self.gamma.as_ref().map(|x| x.val()); let beta = self.beta.as_ref().map(|x| x.val()); group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use alloc::format; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn instance_norm_forward_affine_false() { let device = Default::default(); let module = InstanceNormConfig::new(6) .with_affine(false) .init::(&device); let input = Tensor::::from_data( TensorData::from([ [ [-0.3034, 0.2726, -0.9659], [-1.1845, 1.4078, 0.9774], [0.3963, -1.3738, 1.4125], [1.0682, 0.3604, 0.3985], [-0.4957, -0.4461, -0.9721], [1.5157, -0.1546, -0.5596], ], [ [-1.6698, -0.4040, -0.7927], [0.3736, -0.0975, -0.1351], [-0.9461, 0.5461, -0.6334], [-1.0919, -0.1158, 0.1213], [-0.9535, 0.1281, 0.4372], [-0.2845, 0.3488, 0.5641], ], ]), &device, ); let output = module.forward(input); let expected = TensorData::from([ [ [0.0569, 1.1952, -1.2522], [-1.3971, 0.8883, 0.5088], [0.2183, -1.3192, 1.1009], [1.4126, -0.7649, -0.6477], [0.5999, 0.8091, -1.409], [1.39, -0.4696, -0.9205], ], [ [-1.3492, 1.0417, 0.3075], [1.411, -0.6243, -0.7867], [-0.9363, 1.386, -0.4497], [-1.3899, 0.4692, 0.9208], [-1.3822, 0.4319, 0.9503], [-1.3714, 0.3868, 0.9846], ], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn instance_norm_forward_affine_true() { let device = Default::default(); let module = InstanceNormConfig::new(6) .with_affine(true) .init::(&device); let input = Tensor::::from_data( TensorData::from([ [ [0.3345, 0.4429, 0.6639], [0.5041, 0.4175, 0.8437], [0.6159, 0.3758, 0.4071], [0.5417, 0.5785, 0.7671], [0.3837, 0.9883, 0.0420], [0.4808, 0.8989, 0.6144], ], [ [0.3930, 0.2098, 0.0602], [0.2298, 0.9425, 0.0333], [0.7409, 0.8172, 0.8879], [0.4846, 0.0486, 0.2029], [0.6741, 0.9765, 0.6864], [0.2827, 0.5534, 0.2125], ], ]), &device, ); let output = module.forward(input); let expected = TensorData::from([ [ [-1.06458, -0.2738, 1.33838], [-0.45848, -0.92929, 1.38777], [1.40388, -0.84877, -0.55511], [-0.88515, -0.51245, 1.3976], [-0.22397, 1.32124, -1.09727], [-1.05468, 1.34316, -0.28848], ], [ [1.26372, -0.08229, -1.18144], [-0.44049, 1.38403, -0.94354], [-1.23828, 0.03109, 1.2072], [1.32524, -1.08999, -0.23524], [-0.75061, 1.4132, -0.66259], [-0.45469, 1.38697, -0.93228], ], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = InstanceNormConfig::new(6); let instance_norm = config.init::(&Default::default()); assert_eq!( format!("{instance_norm}"), "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" ); } } ================================================ FILE: crates/burn-nn/src/modules/norm/layer.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Content; use burn::module::DisplaySettings; use burn::module::Initializer; use burn::module::Module; use burn::module::ModuleDisplay; use burn::module::Param; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init). #[derive(Debug, Config)] pub struct LayerNormConfig { /// The size of the input features. pub d_model: usize, /// A value required for numerical stability. Default: 1e-5 #[config(default = 1e-5)] pub epsilon: f64, /// If a bias (beta) should be applied during the normalization. Default: true #[config(default = true)] pub bias: bool, } /// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450). /// /// `Y = norm(X) * γ + β` /// /// Where: /// - `X` is the input tensor /// - `Y` is the output tensor /// - `γ` is the learnable weight (scale) /// - `β` is the learnable bias (optional) /// /// Should be created using [LayerNormConfig](LayerNormConfig). #[derive(Module, Debug)] #[module(custom_display)] pub struct LayerNorm { /// The learnable weight (scale). pub gamma: Param>, /// The learnable bias (optional). pub beta: Option>>, /// A value required for numerical stability. epsilon: f64, } impl LayerNormConfig { /// Initialize a new [layer norm](LayerNorm) module. pub fn init(&self, device: &B::Device) -> LayerNorm { let gamma = Initializer::Ones.init([self.d_model], device); let beta = if self.bias { Some(Initializer::Zeros.init([self.d_model], device)) } else { None }; LayerNorm { gamma, beta, epsilon: self.epsilon, } } } impl LayerNorm { /// Applies the forward pass on the input tensor. /// /// See the [LayerNorm](LayerNorm) documentation for more information. /// /// # Shapes /// /// - input: `[..., any, d_model]` /// - output: `[..., any, d_model]` pub fn forward(&self, input: Tensor) -> Tensor { let (var, mean) = input.clone().var_mean_bias(D - 1); let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt()); let output = input_normalized.mul(self.gamma.val().unsqueeze()); match &self.beta { Some(beta) => output.add(beta.val().unsqueeze()), None => output, } } } impl ModuleDisplay for LayerNorm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_model] = self.gamma.shape().dims(); content .add("d_model", &d_model) .add("epsilon", &self.epsilon) .add("bias", &self.beta.is_some()) .optional() } } #[cfg(test)] mod tests { use super::*; use alloc::format; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[cfg(feature = "std")] use crate::{TestAutodiffBackend, TestBackend}; #[cfg(not(feature = "std"))] use crate::TestBackend; #[test] fn layer_norm_forward() { let device = Default::default(); let module = LayerNormConfig::new(10).init::(&device); let input = Tensor::::from_data( TensorData::from([[ -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, ]]), &device, ); let output = module.forward(input); let expected = TensorData::from([[ -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915, ]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn layer_norm_forward_large_epsilon() { let device = Default::default(); let module = LayerNormConfig::new(10) .with_epsilon(1e-1) .init::(&device); let input = Tensor::::from_data( TensorData::from([[ -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728, ]]), &device, ); let output = module.forward(input); let expected = TensorData::from([[ -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790, ]]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[cfg(feature = "std")] #[test] fn layer_norm_backward() { let device = Default::default(); let module = LayerNormConfig::new(2).init::(&device); let tensor_1 = Tensor::::from_data( TensorData::from([[0.0, 1.0], [3.0, 4.0]]), &device, ) .require_grad(); let tensor_2 = Tensor::::from_data( TensorData::from([[6.0, 7.0], [9.0, 10.0]]), &device, ) .require_grad(); let x = tensor_1.clone().matmul(tensor_2.clone()); let output = module.forward(x); let grads = output.backward(); let tensor_1_grad = tensor_1.grad(&grads).unwrap(); let tensor_2_grad = tensor_2.grad(&grads).unwrap(); let gamma_grad = module.gamma.grad(&grads).unwrap(); let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap(); let expected = TensorData::from([-2.0, 2.0]); gamma_grad .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::from([2.0, 2.0]); beta_grad .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::zeros::(tensor_1_grad.shape()); tensor_1_grad .to_data() .assert_approx_eq::(&expected, Tolerance::default()); let expected = TensorData::zeros::(tensor_2_grad.shape()); tensor_2_grad .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = LayerNormConfig::new(6); let layer_norm = config.init::(&Default::default()); assert_eq!( format!("{layer_norm}"), "LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}" ); } #[test] fn display_no_bias() { let config = LayerNormConfig::new(6).with_bias(false); let layer_norm = config.init::(&Default::default()); assert_eq!( format!("{layer_norm}"), "LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}" ); } } ================================================ FILE: crates/burn-nn/src/modules/norm/mod.rs ================================================ //! # Normalization Layers //! //! Users who wish to provide an abstraction over swappable normalization //! layers can use the [`Normalization`] wrapper, with support for: //! * [`Normalization::Batch`] - [`BatchNorm`] //! * [`Normalization::Group`] - [`GroupNorm`] //! * [`Normalization::Instance`] - [`InstanceNorm`] //! * [`Normalization::Layer`] - [`LayerNorm`] //! * [`Normalization::Rms`] - [`RmsNorm`] //! //! [`NormalizationConfig`] can be used as a generic normalization policy: //! * Construct a config with arbitrary input features (we suggest `0`). //! * Clone and match that config to the target input layer, //! using the [`NormalizationConfig::with_num_features()`] method. pub(crate) mod batch; pub(crate) mod group; pub(crate) mod instance; pub(crate) mod layer; pub(crate) mod rms; mod normalization_wrapper; pub use batch::*; pub use group::*; pub use instance::*; pub use layer::*; pub use normalization_wrapper::*; pub use rms::*; ================================================ FILE: crates/burn-nn/src/modules/norm/normalization_wrapper.rs ================================================ use burn_core as burn; use crate::{ BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig, LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig, }; use burn::prelude::{Config, Module}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// ['Normalization'] Configuration. /// /// The enum is non-exhaustive to prepare for future additions. /// /// Can be used as a generic configuration for normalization layers: /// * Construct a config with arbitrary input features (we suggest `0`). /// * Clone and match that config to the target input layer, /// using the [`NormalizationConfig::with_num_features()`] method. #[derive(Config, Debug)] #[non_exhaustive] pub enum NormalizationConfig { /// ['BatchNorm'] Configuration. Batch(BatchNormConfig), /// ['GroupNorm'] Configuration. Group(GroupNormConfig), /// ['InstanceNorm'] Configuration. Instance(InstanceNormConfig), /// ['LayerNorm'] Configuration. Layer(LayerNormConfig), /// ['RmsNorm'] Configuration. Rms(RmsNormConfig), } impl From for NormalizationConfig { fn from(config: BatchNormConfig) -> Self { Self::Batch(config) } } impl From for NormalizationConfig { fn from(config: GroupNormConfig) -> Self { Self::Group(config) } } impl From for NormalizationConfig { fn from(config: InstanceNormConfig) -> Self { Self::Instance(config) } } impl From for NormalizationConfig { fn from(config: LayerNormConfig) -> Self { Self::Layer(config) } } impl From for NormalizationConfig { fn from(config: RmsNormConfig) -> Self { Self::Rms(config) } } impl NormalizationConfig { /// Initialize a ['Norm'] layer. pub fn init(&self, device: &B::Device) -> Normalization { match self { NormalizationConfig::Batch(config) => config.init(device).into(), NormalizationConfig::Group(config) => config.init(device).into(), NormalizationConfig::Instance(config) => config.init(device).into(), NormalizationConfig::Layer(config) => config.init(device).into(), NormalizationConfig::Rms(config) => config.init(device).into(), } } /// Set the number of features. pub fn with_num_features(self, num_features: usize) -> Self { match self { NormalizationConfig::Batch(config) => BatchNormConfig { num_features, ..config } .into(), NormalizationConfig::Group(config) => GroupNormConfig { num_channels: num_features, ..config } .into(), NormalizationConfig::Instance(config) => InstanceNormConfig { num_channels: num_features, ..config } .into(), NormalizationConfig::Layer(config) => LayerNormConfig { d_model: num_features, ..config } .into(), NormalizationConfig::Rms(config) => RmsNormConfig { d_model: num_features, ..config } .into(), } } /// Get the number of features. pub fn num_features(&self) -> usize { match self { NormalizationConfig::Batch(config) => config.num_features, NormalizationConfig::Group(config) => config.num_channels, NormalizationConfig::Instance(config) => config.num_channels, NormalizationConfig::Layer(config) => config.d_model, NormalizationConfig::Rms(config) => config.d_model, } } } /// Normalization Layer Wrapper /// /// Provides support for built-in ``burn::nn::norm`` norm layers: /// * [`Normalization::Batch`] - [`BatchNorm`] /// * [`Normalization::Group`] - [`GroupNorm`] /// * [`Normalization::Instance`] - [`InstanceNorm`] /// * [`Normalization::Layer`] - [`LayerNorm`] /// * [`Normalization::Rms`] - [`RmsNorm`] /// /// The enum is non-exhaustive, to prepare for future additions. #[derive(Module, Debug)] #[non_exhaustive] pub enum Normalization { /// [`BatchNorm`] layer. Batch(BatchNorm), /// [`GroupNorm`] layer. Group(GroupNorm), /// ['InstanceNorm'] layer. Instance(InstanceNorm), /// [`LayerNorm`] layer. Layer(LayerNorm), /// ['RmsNorm'] layer. Rms(RmsNorm), } impl From> for Normalization { fn from(layer: BatchNorm) -> Self { Self::Batch(layer) } } impl From> for Normalization { fn from(layer: GroupNorm) -> Self { Self::Group(layer) } } impl From> for Normalization { fn from(layer: InstanceNorm) -> Self { Self::Instance(layer) } } impl From> for Normalization { fn from(layer: LayerNorm) -> Self { Self::Layer(layer) } } impl From> for Normalization { fn from(layer: RmsNorm) -> Self { Self::Rms(layer) } } impl Normalization { /// Applies normalization to a tensor. /// /// The normalization contract depends upon the wrapped norm layer; /// but all norm layers assume an input of at least rank 2; /// and produce an output of the same rank and shape. pub fn forward(&self, input: Tensor) -> Tensor { match self { Normalization::Batch(norm) => norm.forward(input), Normalization::Group(norm) => norm.forward(input), Normalization::Instance(norm) => norm.forward(input), Normalization::Layer(norm) => norm.forward(input), Normalization::Rms(norm) => norm.forward(input), } } /// Get the number of features. pub fn num_features(&self) -> usize { match self { Normalization::Batch(norm) => norm.gamma.shape()[0], Normalization::Group(norm) => norm.num_channels, Normalization::Instance(norm) => norm.num_channels, Normalization::Layer(norm) => norm.gamma.shape()[0], Normalization::Rms(norm) => norm.gamma.shape()[0], } } } #[cfg(feature = "std")] #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_match_feature_size() { let config: NormalizationConfig = BatchNormConfig::new(0).into(); assert_eq!(config.num_features(), 0); let config = config.with_num_features(12); assert_eq!(config.num_features(), 12); let config: NormalizationConfig = GroupNormConfig::new(4, 0).into(); assert_eq!(config.num_features(), 0); let config = config.with_num_features(12); assert_eq!(config.num_features(), 12); let config: NormalizationConfig = InstanceNormConfig::new(0).into(); assert_eq!(config.num_features(), 0); let config = config.with_num_features(12); assert_eq!(config.num_features(), 12); let config: NormalizationConfig = LayerNormConfig::new(0).into(); assert_eq!(config.num_features(), 0); let config = config.with_num_features(12); assert_eq!(config.num_features(), 12); let config: NormalizationConfig = RmsNormConfig::new(0).into(); assert_eq!(config.num_features(), 0); let config = config.with_num_features(12); assert_eq!(config.num_features(), 12); } #[test] fn test_batch_norm() { type B = TestAutodiffBackend; let device = Default::default(); let num_features = 12; let input: Tensor = Tensor::ones([2, num_features, 3, 4], &device); let config: NormalizationConfig = BatchNormConfig::new(12).into(); let layer: Normalization = config.init(&device); assert_eq!(layer.num_features(), 12); let expected = match &layer { Normalization::Batch(inner) => inner.forward(input.clone()), _ => panic!("Unexpected layer type"), }; let output = layer.forward(input); output.to_data().assert_eq(&expected.to_data(), true); } #[test] fn test_group_norm() { type B = TestAutodiffBackend; let device = Default::default(); let num_features = 12; let input: Tensor = Tensor::ones([2, num_features, 3, 4], &device); let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into(); let layer: Normalization = config.init(&device); assert_eq!(layer.num_features(), 12); let expected = match &layer { Normalization::Group(inner) => inner.forward(input.clone()), _ => panic!("Unexpected layer type"), }; let output = layer.forward(input); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] fn test_instance_norm() { type B = TestAutodiffBackend; let device = Default::default(); let num_features = 12; let input: Tensor = Tensor::ones([2, num_features, 3, 4], &device); let config: NormalizationConfig = InstanceNormConfig::new(num_features).into(); let layer: Normalization = config.init(&device); assert_eq!(layer.num_features(), 12); let expected = match &layer { Normalization::Instance(inner) => inner.forward(input.clone()), _ => panic!("Unexpected layer type"), }; let output = layer.forward(input); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] fn test_layer_norm() { type B = TestAutodiffBackend; let device = Default::default(); let num_features = 12; let input: Tensor = Tensor::ones([2, 3, 4, num_features], &device); let config: NormalizationConfig = LayerNormConfig::new(num_features).into(); let layer: Normalization = config.init(&device); assert_eq!(layer.num_features(), 12); let expected = match &layer { Normalization::Layer(inner) => inner.forward(input.clone()), _ => panic!("Unexpected layer type"), }; let output = layer.forward(input); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] fn test_rms_norm() { type B = TestAutodiffBackend; let device = Default::default(); let num_features = 12; let input: Tensor = Tensor::ones([2, 3, 4, num_features], &device); let config: NormalizationConfig = RmsNormConfig::new(num_features).into(); let layer: Normalization = config.init(&device); assert_eq!(layer.num_features(), 12); let expected = match &layer { Normalization::Rms(inner) => inner.forward(input.clone()), _ => panic!("Unexpected layer type"), }; let output = layer.forward(input); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-nn/src/modules/norm/rms.rs ================================================ use burn::tensor::DType; use burn_core as burn; use burn::config::Config; use burn::module::Initializer; use burn::module::Module; use burn::module::Param; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init). #[derive(Config, Debug)] pub struct RmsNormConfig { /// The size of the input features. pub d_model: usize, /// A value required for numerical stability. Default: 1e-5 #[config(default = 1e-5)] pub epsilon: f64, } impl RmsNormConfig { /// Initialize a new [RMS Norm](RmsNorm) module. /// /// # Panics /// /// Panics if `epsilon` is not positive. pub fn init(&self, device: &B::Device) -> RmsNorm { assert!(self.epsilon > 0.0, "epsilon must be positive."); let gamma = Initializer::Ones.init([self.d_model], device); RmsNorm { gamma, epsilon: self.epsilon, } } } /// Applies RMS Normalization over an input tensor along the last dimension. /// /// `Y = X / sqrt(mean(X^2) + eps) * gamma` /// /// Where: /// - `X` is the input tensor /// - `Y` is the output tensor /// - `gamma` is the learnable weight /// - `mean` is the mean operation /// - `eps` is a small value to avoid division by zero. /// /// Should be created using the [RmsNormConfig](RmsNormConfig) configuration. #[derive(Module, Debug)] #[module(custom_display)] pub struct RmsNorm { /// The learnable parameter to scale the normalized tensor pub gamma: Param>, /// A value required for numerical stability pub epsilon: f64, } impl RmsNorm { /// Applies the forward pass on the input tensor. /// /// See the [RmsNorm](RmsNorm) documentation for more information. /// /// # Shapes /// /// - input: `[..., any, d_model]` /// - output: `[..., any, d_model]` pub fn forward(&self, x: Tensor) -> Tensor { // Calculate the root-mean-square norm of the input tensor along the last dimension let dtype = x.dtype(); let rms = (x.clone().cast(DType::F32).square().mean_dim(D - 1) + self.epsilon).sqrt(); (x / rms.cast(dtype)) * self.gamma.val().unsqueeze() } } impl ModuleDisplay for RmsNorm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_model] = self.gamma.shape().dims(); content .add("d_model", &d_model) .add("epsilon", &self.epsilon) .optional() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use alloc::format; use burn::tensor::TensorData; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn rms_norm_forward() { let device = Default::default(); let module = RmsNormConfig::new(3) .with_epsilon(1e-5) .init::(&device); let input = Tensor::arange(0..9, &device).float().reshape([3, 3]); let output = module.forward(input); let expected = TensorData::from([ [0.0000, 0.7746, 1.5492], [0.7348, 0.9798, 1.2247], [0.8514, 0.9933, 1.1352], ]); output .to_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn display() { let config = RmsNormConfig::new(6); let layer_norm = config.init::(&Default::default()); assert_eq!( format!("{layer_norm}"), "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}" ); } } ================================================ FILE: crates/burn-nn/src/modules/pool/adaptive_avg_pool1d.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::adaptive_avg_pool1d; /// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init). #[derive(Config, Debug)] pub struct AdaptiveAvgPool1dConfig { /// The size of the output. pub output_size: usize, } /// Applies a 1D adaptive avg pooling over input tensors. /// /// Should be created with [AdaptiveAvgPool1dConfig]. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AdaptiveAvgPool1d { /// The size of the output. pub output_size: usize, } impl ModuleDisplay for AdaptiveAvgPool1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content.add("output_size", &self.output_size).optional() } } impl AdaptiveAvgPool1dConfig { /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. pub fn init(&self) -> AdaptiveAvgPool1d { AdaptiveAvgPool1d { output_size: self.output_size, } } } impl AdaptiveAvgPool1d { /// Applies the forward pass on the input tensor. /// /// See [adaptive_avg_pool1d](burn::tensor::module::adaptive_avg_pool1d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, length]` /// - output: `[batch_size, channels, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { adaptive_avg_pool1d(input, self.output_size) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let config = AdaptiveAvgPool1dConfig::new(3); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "AdaptiveAvgPool1d {output_size: 3}" ); } } ================================================ FILE: crates/burn-nn/src/modules/pool/adaptive_avg_pool2d.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::adaptive_avg_pool2d; /// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init). #[derive(Config, Debug)] pub struct AdaptiveAvgPool2dConfig { /// The size of the output. pub output_size: [usize; 2], } /// Applies a 2D adaptive avg pooling over input tensors. /// /// Should be created with [AdaptiveAvgPool2dConfig]. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AdaptiveAvgPool2d { /// The size of the output. pub output_size: [usize; 2], } impl ModuleDisplay for AdaptiveAvgPool2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let output_size = alloc::format!("{:?}", self.output_size); content.add("output_size", &output_size).optional() } } impl AdaptiveAvgPool2dConfig { /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. pub fn init(&self) -> AdaptiveAvgPool2d { AdaptiveAvgPool2d { output_size: self.output_size, } } } impl AdaptiveAvgPool2d { /// Applies the forward pass on the input tensor. /// /// See [adaptive_avg_pool2d](burn::tensor::module::adaptive_avg_pool2d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, height_in, width_in]` /// - output: `[batch_size, channels, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { adaptive_avg_pool2d(input, self.output_size) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let config = AdaptiveAvgPool2dConfig::new([3, 3]); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "AdaptiveAvgPool2d {output_size: [3, 3]}" ); } } ================================================ FILE: crates/burn-nn/src/modules/pool/avg_pool1d.rs ================================================ use burn_core as burn; use crate::PaddingConfig1d; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::PadMode; use burn::tensor::module::avg_pool1d; /// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init). #[derive(Config, Debug)] pub struct AvgPool1dConfig { /// The size of the kernel. pub kernel_size: usize, /// The stride. #[config(default = "kernel_size")] pub stride: usize, /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If the padding is counted in the denominator when computing the average. #[config(default = "true")] pub count_include_pad: bool, /// If true, use ceiling instead of floor for output size calculation. #[config(default = "false")] pub ceil_mode: bool, } /// Applies a 1D avg pooling over input tensors. /// /// Should be created with [AvgPool1dConfig](AvgPool1dConfig). /// /// # Remarks /// /// The zero-padding values will be included in the calculation /// of the average. This means that the zeros are counted as /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool1d { /// The stride. pub stride: usize, /// The size of the kernel. pub kernel_size: usize, /// The padding configuration. pub padding: PaddingConfig1d, /// If the padding is counted in the denominator when computing the average. pub count_include_pad: bool, /// If true, use ceiling instead of floor for output size calculation. pub ceil_mode: bool, } impl ModuleDisplay for AvgPool1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("kernel_size", &self.kernel_size) .add("stride", &self.stride) .add_debug_attribute("padding", &self.padding) .add("count_include_pad", &self.count_include_pad) .add("ceil_mode", &self.ceil_mode) .optional() } } impl AvgPool1dConfig { /// Initialize a new [avg pool 1d](AvgPool1d) module. pub fn init(&self) -> AvgPool1d { AvgPool1d { stride: self.stride, kernel_size: self.kernel_size, padding: self.padding.clone(), count_include_pad: self.count_include_pad, ceil_mode: self.ceil_mode, } } } impl AvgPool1d { /// Applies the forward pass on the input tensor. /// /// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, length_in]` /// - output: `[batch_size, channels, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels, length] = input.dims(); // Calculate padding as pair - handles Same, Valid, and Explicit uniformly let (left, right) = self.padding .calculate_padding_1d_pair(length, self.kernel_size, self.stride); // TODO: Move asymmetric padding to functional level via PoolOptions // See: https://github.com/tracel-ai/burn/issues/4362 // Handle asymmetric padding by applying explicit pad operation first if left != right { // Burn's pad takes (left, right, top, bottom) for the last two dimensions // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0 let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0)); // Use zero padding for the pool operation since we already padded avg_pool1d( padded, self.kernel_size, self.stride, 0, self.count_include_pad, self.ceil_mode, ) } else { // Symmetric padding avg_pool1d( input, self.kernel_size, self.stride, left, self.count_include_pad, self.ceil_mode, ) } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use rstest::rstest; #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = AvgPool1dConfig::new(2) .with_stride(1) .with_padding(PaddingConfig1d::Same); let pool = config.init(); // Input: [batch=1, channels=2, length=5] let input = Tensor::::ones([1, 2, 5], &device); let output = pool.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 2, 5]); } #[test] fn display() { let config = AvgPool1dConfig::new(3); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}" ); } #[rstest] #[case(1)] #[case(2)] fn default_strides_match_kernel_size(#[case] kernel_size: usize) { let config = AvgPool1dConfig::new(kernel_size); assert_eq!( config.stride, kernel_size, "Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor", config.stride, config.kernel_size ); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create avg pool with asymmetric padding: left=1, right=2 let config = AvgPool1dConfig::new(3) .with_stride(1) .with_padding(PaddingConfig1d::Explicit(1, 2)); let pool = config.init(); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = pool.forward(input); // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 // Output length = (7 - 3) / 1 + 1 = 5 assert_eq!(output.dims(), [1, 2, 5]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create avg pool with symmetric explicit padding: left=2, right=2 let config = AvgPool1dConfig::new(3) .with_stride(1) .with_padding(PaddingConfig1d::Explicit(2, 2)); let pool = config.init(); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = pool.forward(input); // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 // Output length = (8 - 3) / 1 + 1 = 6 assert_eq!(output.dims(), [1, 2, 6]); } } ================================================ FILE: crates/burn-nn/src/modules/pool/avg_pool2d.rs ================================================ use burn_core as burn; use crate::PaddingConfig2d; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::PadMode; use burn::tensor::module::avg_pool2d; /// Configuration to create a [2D avg pooling](AvgPool2d) layer using the [init function](AvgPool2dConfig::init). #[derive(Config, Debug)] pub struct AvgPool2dConfig { /// The size of the kernel. pub kernel_size: [usize; 2], /// The strides. #[config(default = "kernel_size")] pub strides: [usize; 2], /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If the padding is counted in the denominator when computing the average. #[config(default = "true")] pub count_include_pad: bool, /// If true, use ceiling instead of floor for output size calculation. #[config(default = "false")] pub ceil_mode: bool, } /// Applies a 2D avg pooling over input tensors. /// /// Should be created with [AvgPool2dConfig](AvgPool2dConfig). /// /// # Remarks /// /// The zero-padding values will be included in the calculation /// of the average. This means that the zeros are counted as /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool2d { /// Stride of the pooling. pub stride: [usize; 2], /// Size of the kernel. pub kernel_size: [usize; 2], /// Padding configuration. pub padding: PaddingConfig2d, /// If the padding is counted in the denominator when computing the average. pub count_include_pad: bool, /// If true, use ceiling instead of floor for output size calculation. pub ceil_mode: bool, } impl ModuleDisplay for AvgPool2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) .add("stride", &alloc::format!("{:?}", &self.stride)) .add_debug_attribute("padding", &self.padding) .add("count_include_pad", &self.count_include_pad) .add("ceil_mode", &self.ceil_mode) .optional() } } impl AvgPool2dConfig { /// Initialize a new [avg pool 2d](AvgPool2d) module. pub fn init(&self) -> AvgPool2d { AvgPool2d { stride: self.strides, kernel_size: self.kernel_size, padding: self.padding.clone(), count_include_pad: self.count_include_pad, ceil_mode: self.ceil_mode, } } } impl AvgPool2d { /// Applies the forward pass on the input tensor. /// /// See [avg_pool2d](burn::tensor::module::avg_pool2d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, height_in, width_in]` /// - output: `[batch_size, channels, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels_in, height_in, width_in] = input.dims(); // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs( height_in, width_in, &self.kernel_size, &self.stride, ); // TODO: Move asymmetric padding to functional level via PoolOptions // See: https://github.com/tracel-ai/burn/issues/4362 // Handle asymmetric padding by applying explicit pad operation first if top != bottom || left != right { // Burn's pad takes (left, right, top, bottom) for the last two dimensions let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0)); // Use zero padding for the pool operation since we already padded avg_pool2d( padded, self.kernel_size, self.stride, [0, 0], self.count_include_pad, self.ceil_mode, ) } else { // Symmetric padding avg_pool2d( input, self.kernel_size, self.stride, [top, left], self.count_include_pad, self.ceil_mode, ) } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use rstest::rstest; #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = AvgPool2dConfig::new([2, 2]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Same); let pool = config.init(); // Input: [batch=1, channels=2, height=5, width=5] let input = Tensor::::ones([1, 2, 5, 5], &device); let output = pool.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 2, 5, 5]); } #[test] fn display() { let config = AvgPool2dConfig::new([3, 3]); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "AvgPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}" ); } #[rstest] #[case([2, 2])] #[case([1, 2])] fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) { let config = AvgPool2dConfig::new(kernel_size); assert_eq!( config.strides, kernel_size, "Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool2dConfig::new constructor", config.strides, config.kernel_size ); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create avg pool with asymmetric padding: top=1, left=2, bottom=3, right=4 let config = AvgPool2dConfig::new([3, 3]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)); let pool = config.init(); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = pool.forward(input); // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 assert_eq!(output.dims(), [1, 2, 6, 9]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create avg pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2 let config = AvgPool2dConfig::new([3, 3]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)); let pool = config.init(); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = pool.forward(input); // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 assert_eq!(output.dims(), [1, 2, 6, 7]); } } ================================================ FILE: crates/burn-nn/src/modules/pool/max_pool1d.rs ================================================ use burn_core as burn; use crate::PaddingConfig1d; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::PadMode; use burn::tensor::module::max_pool1d; /// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init). #[derive(Config, Debug)] pub struct MaxPool1dConfig { /// The size of the kernel. pub kernel_size: usize, /// The stride. #[config(default = "kernel_size")] pub stride: usize, /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// The dilation. #[config(default = "1")] pub dilation: usize, /// If true, use ceiling instead of floor for output size calculation. #[config(default = "false")] pub ceil_mode: bool, } /// Applies a 1D max pooling over input tensors. /// /// Should be created with [MaxPool1dConfig](MaxPool1dConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct MaxPool1d { /// The stride. pub stride: usize, /// The size of the kernel. pub kernel_size: usize, /// The padding configuration. pub padding: PaddingConfig1d, /// The dilation. pub dilation: usize, /// If true, use ceiling instead of floor for output size calculation. pub ceil_mode: bool, } impl ModuleDisplay for MaxPool1d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("kernel_size", &self.kernel_size) .add("stride", &self.stride) .add_debug_attribute("padding", &self.padding) .add("dilation", &self.dilation) .add("ceil_mode", &self.ceil_mode) .optional() } } impl MaxPool1dConfig { /// Initialize a new [max pool 1d](MaxPool1d) module. pub fn init(&self) -> MaxPool1d { MaxPool1d { stride: self.stride, kernel_size: self.kernel_size, padding: self.padding.clone(), dilation: self.dilation, ceil_mode: self.ceil_mode, } } } impl MaxPool1d { /// Applies the forward pass on the input tensor. /// /// See [max_pool1d](burn::tensor::module::max_pool1d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, length_in]` /// - output: `[batch_size, channels, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels, length] = input.dims(); // Calculate padding as pair - handles Same, Valid, and Explicit uniformly let (left, right) = self.padding .calculate_padding_1d_pair(length, self.kernel_size, self.stride); // TODO: Move asymmetric padding to functional level via PoolOptions // See: https://github.com/tracel-ai/burn/issues/4362 // Handle asymmetric padding by applying explicit pad operation first if left != right { // For 1D (NCL format), pad the length dimension with (left, right) // and no padding for channel dimension (top=0, bottom=0) // Use -inf for max pooling so padded values don't affect the max let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY)); // Use zero padding for the pool operation since we already padded max_pool1d( padded, self.kernel_size, self.stride, 0, self.dilation, self.ceil_mode, ) } else { // Symmetric padding max_pool1d( input, self.kernel_size, self.stride, left, self.dilation, self.ceil_mode, ) } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use rstest::rstest; #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = MaxPool1dConfig::new(2) .with_stride(1) .with_padding(PaddingConfig1d::Same); let pool = config.init(); // Input: [batch=1, channels=2, length=5] let input = Tensor::::ones([1, 2, 5], &device); let output = pool.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 2, 5]); } #[test] fn display() { let config = MaxPool1dConfig::new(3); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}" ); } #[rstest] #[case(1)] #[case(2)] fn default_strides_match_kernel_size(#[case] kernel_size: usize) { let config = MaxPool1dConfig::new(kernel_size); assert_eq!( config.stride, kernel_size, "Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor", config.stride, config.kernel_size ); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create max pool with asymmetric padding: left=1, right=2 let config = MaxPool1dConfig::new(3) .with_stride(1) .with_padding(PaddingConfig1d::Explicit(1, 2)); let pool = config.init(); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = pool.forward(input); // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 // Output length = (7 - 3) / 1 + 1 = 5 assert_eq!(output.dims(), [1, 2, 5]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create max pool with symmetric explicit padding: left=2, right=2 let config = MaxPool1dConfig::new(3) .with_stride(1) .with_padding(PaddingConfig1d::Explicit(2, 2)); let pool = config.init(); // Input: [batch=1, channels=2, length=4] let input = Tensor::::ones([1, 2, 4], &device); let output = pool.forward(input); // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 // Output length = (8 - 3) / 1 + 1 = 6 assert_eq!(output.dims(), [1, 2, 6]); } } ================================================ FILE: crates/burn-nn/src/modules/pool/max_pool2d.rs ================================================ use burn_core as burn; use crate::PaddingConfig2d; use burn::config::Config; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::ops::PadMode; use burn::tensor::module::max_pool2d; /// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init). #[derive(Debug, Config)] pub struct MaxPool2dConfig { /// The size of the kernel. pub kernel_size: [usize; 2], /// The strides. #[config(default = "kernel_size")] pub strides: [usize; 2], /// The padding configuration. /// /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes /// will automatically use asymmetric padding to preserve input dimensions. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// The dilation. #[config(default = "[1, 1]")] pub dilation: [usize; 2], /// If true, use ceiling instead of floor for output size calculation. #[config(default = "false")] pub ceil_mode: bool, } /// Applies a 2D max pooling over input tensors. /// /// Should be created with [MaxPool2dConfig](MaxPool2dConfig). #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct MaxPool2d { /// The strides. pub stride: [usize; 2], /// The size of the kernel. pub kernel_size: [usize; 2], /// The padding configuration. pub padding: PaddingConfig2d, /// The dilation. pub dilation: [usize; 2], /// If true, use ceiling instead of floor for output size calculation. pub ceil_mode: bool, } impl ModuleDisplay for MaxPool2d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) .add("stride", &alloc::format!("{:?}", &self.stride)) .add_debug_attribute("padding", &self.padding) .add("dilation", &alloc::format!("{:?}", &self.dilation)) .add("ceil_mode", &self.ceil_mode) .optional() } } impl MaxPool2dConfig { /// Initialize a new [max pool 2d](MaxPool2d) module. pub fn init(&self) -> MaxPool2d { MaxPool2d { stride: self.strides, kernel_size: self.kernel_size, padding: self.padding.clone(), dilation: self.dilation, ceil_mode: self.ceil_mode, } } } impl MaxPool2d { /// Applies the forward pass on the input tensor. /// /// See [max_pool2d](burn::tensor::module::max_pool2d) for more information. /// /// # Shapes /// /// - input: `[batch_size, channels, height_in, width_in]` /// - output: `[batch_size, channels, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { let [_batch_size, _channels_in, height_in, width_in] = input.dims(); // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs( height_in, width_in, &self.kernel_size, &self.stride, ); // TODO: Move asymmetric padding to functional level via PoolOptions // See: https://github.com/tracel-ai/burn/issues/4362 // Handle asymmetric padding by applying explicit pad operation first if top != bottom || left != right { // Burn's pad takes (left, right, top, bottom) for the last two dimensions // Use -inf for max pooling so padded values don't affect the max let padded = input.pad( (left, right, top, bottom), PadMode::Constant(f32::NEG_INFINITY), ); // Use zero padding for the pool operation since we already padded max_pool2d( padded, self.kernel_size, self.stride, [0, 0], self.dilation, self.ceil_mode, ) } else { // Symmetric padding max_pool2d( input, self.kernel_size, self.stride, [top, left], self.dilation, self.ceil_mode, ) } } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use rstest::rstest; #[test] fn same_with_even_kernel_uses_asymmetric_padding() { let device = Default::default(); let config = MaxPool2dConfig::new([2, 2]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Same); let pool = config.init(); // Input: [batch=1, channels=2, height=5, width=5] let input = Tensor::::ones([1, 2, 5, 5], &device); let output = pool.forward(input); // Same padding should preserve spatial dimensions assert_eq!(output.dims(), [1, 2, 5, 5]); } #[test] fn display() { let config = MaxPool2dConfig::new([3, 3]); let layer = config.init(); assert_eq!( alloc::format!("{layer}"), "MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}" ); } #[rstest] #[case([2, 2])] #[case([1, 2])] fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) { let config = MaxPool2dConfig::new(kernel_size); assert_eq!( config.strides, kernel_size, "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor", config.strides, config.kernel_size ); } #[test] fn asymmetric_padding_forward() { let device = Default::default(); // Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4 let config = MaxPool2dConfig::new([3, 3]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)); let pool = config.init(); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = pool.forward(input); // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 assert_eq!(output.dims(), [1, 2, 6, 9]); } #[test] fn symmetric_explicit_padding_forward() { let device = Default::default(); // Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2 let config = MaxPool2dConfig::new([3, 3]) .with_strides([1, 1]) .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)); let pool = config.init(); // Input: [batch=1, channels=2, height=4, width=5] let input = Tensor::::ones([1, 2, 4, 5], &device); let output = pool.forward(input); // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 assert_eq!(output.dims(), [1, 2, 6, 7]); } } ================================================ FILE: crates/burn-nn/src/modules/pool/mod.rs ================================================ mod adaptive_avg_pool1d; mod adaptive_avg_pool2d; mod avg_pool1d; mod avg_pool2d; mod max_pool1d; mod max_pool2d; pub use adaptive_avg_pool1d::*; pub use adaptive_avg_pool2d::*; pub use avg_pool1d::*; pub use avg_pool2d::*; pub use max_pool1d::*; pub use max_pool2d::*; ================================================ FILE: crates/burn-nn/src/modules/pos_encoding.rs ================================================ use burn_core as burn; use alloc::vec::Vec; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::TensorData; use burn::tensor::backend::Backend; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init). #[derive(Config, Debug)] pub struct PositionalEncodingConfig { /// Maximum sequence size to use. #[config(default = "5_000")] pub max_sequence_size: usize, /// The size of each vector. pub d_model: usize, /// Max time scale to use. #[config(default = "10_000")] pub max_timescale: usize, } /// Positional encoding layer for transformer models. /// /// This layer adds positional information to the input embeddings, allowing the transformer model /// to take into account the order of the sequence. The positional encoding is added to the input /// embeddings by computing a set of sinusoidal functions with different frequencies and phases. /// /// Sinusoids are used for positional embedding introduced in /// [Attention is all you need](https://arxiv.org/abs/1706.03762). /// /// The reference implementation can be found here: /// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT /// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) /// /// Should be created using [PositionalEncodingConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct PositionalEncoding { /// The sinusoids used to add positional information to the input embeddings. pub sinusoids: Tensor, /// The maximum sequence size to use. pub max_sequence_size: usize, /// Max time scale to use. pub max_timescale: usize, } impl ModuleDisplay for PositionalEncoding { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [_, _, d_model] = self.sinusoids.shape().dims(); content .add("d_model", &d_model) .add("max_sequence_size", &self.max_sequence_size) .add("max_timescale", &self.max_timescale) .optional() } } impl PositionalEncodingConfig { /// Initialize a new [PositionalEncoding](PositionalEncoding) module. pub fn init(&self, device: &B::Device) -> PositionalEncoding { let sinusoids = generate_sinusoids::( self.max_sequence_size, self.d_model, self.max_timescale, device, ) .unsqueeze::<3>(); PositionalEncoding { sinusoids, max_sequence_size: self.max_sequence_size, max_timescale: self.max_timescale, } } } impl PositionalEncoding { /// Applies the forward pass on the input tensor by adding the sinusoids to the input. /// /// # Shapes /// /// * input: [batch_size, seq_length, d_model] /// * output: [batch_size, seq_length, d_model] /// /// /// # Panics /// /// * Panics if the input sequence length is greater than the maximum sequence size. /// * Panics if the input d_model is not equal to the d_model of the sinusoids. pub fn forward(&self, input: Tensor) -> Tensor { let [_, seq_length, d_model_input] = input.dims(); let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims(); assert!( max_sequence_size >= seq_length, "max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})" ); assert!( d_model_input == d_model, "d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})" ); let slices = [0..batch_size, 0..seq_length, 0..d_model]; input.add(self.sinusoids.clone().slice(slices)) } } /// Returns sinusoids for positional embedding introduced in /// [Attention is all you need](https://arxiv.org/abs/1706.03762). /// /// The reference implementation can be found here: /// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT /// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) /// /// # Arguments /// /// * `length` - The length of the sequence. /// * `d_model` - The size of each vector. /// * `max_timescale` - The maximum time scale to use. /// /// # Returns /// /// A tensor of shape [length, d_model] containing the sinusoids. pub fn generate_sinusoids( length: usize, d_model: usize, max_timescale: usize, device: &B::Device, ) -> Tensor { assert!(d_model.is_multiple_of(2), "d_model must be even"); assert!( max_timescale >= length, "max_timescale must be greater than length" ); // Calculate the increment for the logarithmic timescale let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32; // Create a vector to hold the sinusoids let mut scaled_time_sin_cos = Vec::with_capacity(length); // Loop over each position in the sequence for i in 0..length { // Create a vector to hold the sinusoids for this position let mut row = Vec::with_capacity(d_model / 2); // Loop over each dimension of the sinusoids for k in (0..d_model).step_by(2) { // Calculate the division term for this dimension let div_term = (k as f32 * log_timescale_increment).exp(); // Calculate the sine and cosine values for this dimension and position row.push((div_term * i as f32).sin()); row.push((div_term * i as f32).cos()); } // Add the sinusoids for this position to the vector scaled_time_sin_cos.push(row); } // Convert the sinusoids to a tensor and return it let data = TensorData::new( scaled_time_sin_cos.into_iter().flatten().collect(), [length, d_model], ); Tensor::::from_data(data, device) } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_module() { let d_model = 6; let length = 3; // expected to broadcast let batch_size = 2; let device = Default::default(); let pe = PositionalEncodingConfig::new(d_model).init::(&device); // Use a tensor of zeros as input for easy verification of the output // The output should be the sinusoids broadcasted to the input shape let tensor = Tensor::zeros([batch_size, length, d_model], &device); let output = pe.forward(tensor); assert_eq!(&*output.shape(), [batch_size, length, d_model]); let expected = Tensor::::from_floats( [ [ [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], ], [ [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], ], ], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] fn test_generate_sinusoids() { let device = Default::default(); let sinusoids = generate_sinusoids::(12, 6, 10_000, &device); // The values are taken from the pytorch reference implementation let expected = Tensor::::from_floats( [ [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000], [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000], [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999], [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998], [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996], [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994], [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992], [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989], [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985], [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981], [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977], [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972], ], &device, ); sinusoids .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] #[should_panic] fn d_model_input_should_match() { let d_model = 8; let device = Default::default(); let pe = PositionalEncodingConfig::new(d_model).init::(&device); let input = Tensor::zeros([1, 5, 10], &device); let _output = pe.forward(input); } #[test] #[should_panic] fn input_length_should_be_less_than_max_len() { let d_model = 8; let device = Default::default(); let pe = PositionalEncodingConfig::new(d_model).init::(&device); let input = Tensor::zeros([1, 6_000, d_model], &device); let _output = pe.forward(input); } #[test] fn display() { let config = PositionalEncodingConfig::new(4); let pe = config.init::(&Default::default()); assert_eq!( alloc::format!("{pe}"), "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}" ); } } ================================================ FILE: crates/burn-nn/src/modules/rnn/basic.rs ================================================ use burn_core as burn; use crate::GateController; use crate::activation::{Activation, ActivationConfig}; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// A RnnState is used to store hidden state in RNN. pub struct RnnState { /// The hidden state. pub hidden: Tensor, } impl RnnState { /// Initialize a new [RNN State](RnnState). pub fn new(hidden: Tensor) -> Self { Self { hidden } } } /// Configuration to create a [Rnn](Rnn) module using the [init function](RnnConfig::init). #[derive(Config, Debug)] pub struct RnnConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the Rnn transformation. pub bias: bool, /// Rnn initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`. /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`. #[config(default = true)] pub batch_first: bool, /// If true, process the sequence in reverse order. /// This is useful for implementing reverse-direction RNNs (e.g., ONNX reverse direction). #[config(default = false)] pub reverse: bool, /// Optional hidden state clip threshold. If provided, hidden state values are clipped /// to the range `[-clip, +clip]` after each timestep. This can help prevent /// exploding values during inference. pub clip: Option, /// Activation function applied to the hidden state before computing hidden output. /// Default is Tanh, which is standard for Rnn. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, } /// The Rnn module. This implementation is for a unidirectional, stateless, Rnn. /// Should be created with [RnnConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Rnn { /// gate controller for Rnn (has single gate). pub gate: GateController, /// The hidden state of the Rnn. pub d_hidden: usize, /// If true, input is `[batch_size, seq_length, input_size]`. /// If false, input is `[seq_length, batch_size, input_size]`. pub batch_first: bool, /// If true, process the sequence in reverse order. pub reverse: bool, /// Optional hidden state clip threshold. pub clip: Option, /// Activation function for hidden output. pub hidden_activation: Activation, } impl ModuleDisplay for Rnn { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self.gate.input_transform.weight.shape().dims(); let bias = self.gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .optional() } } impl RnnConfig { /// Initialize a new [Rnn](Rnn) module. pub fn init(&self, device: &B::Device) -> Rnn { let d_output = self.d_hidden; let new_gate = || { GateController::new( self.d_input, d_output, self.bias, self.initializer.clone(), device, ) }; Rnn { gate: new_gate(), d_hidden: self.d_hidden, batch_first: self.batch_first, reverse: self.reverse, clip: self.clip, hidden_activation: self.hidden_activation.init(device), } } } impl Rnn { /// Applies the forward pass on the input tensor. This RNN implementation /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape: /// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default) /// - `[sequence_length, batch_size, input_size]` if `batch_first` is false /// - state: An optional `RnnState` representing the initial hidden state. /// The state tensor has shape `[batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: /// - output: A tensor represents the output features of Rnn. Shape: /// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true /// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false /// - state: A `RnnState` represents the final hidden state. The hidden state tensor has the shape /// `[batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> (Tensor, RnnState) { // Convert to batch-first layout internally if needed let batched_input = if self.batch_first { batched_input } else { batched_input.swap_dims(0, 1) }; let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.dims(); // Process sequence in forward or reverse order based on config let (output, state) = if self.reverse { self.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), state, batch_size, seq_length, &device, ) } else { self.forward_iter( batched_input.iter_dim(1).zip(0..seq_length), state, batch_size, seq_length, &device, ) }; // Convert output back to seq-first layout if needed let output = if self.batch_first { output } else { output.swap_dims(0, 1) }; (output, state) } fn forward_iter, usize)>>( &self, input_timestep_iter: I, state: Option>, batch_size: usize, seq_length: usize, device: &B::Device, ) -> (Tensor, RnnState) { let mut batched_hidden_state = Tensor::empty([batch_size, seq_length, self.d_hidden], device); let mut hidden_state = match state { Some(state) => state.hidden, None => Tensor::zeros([batch_size, self.d_hidden], device), }; for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze_dim(1); // Compute gate output: h_t = activation(W_i @ x_t + W_h @ h_{t-1} + b) let biased_gate_sum = self .gate .gate_product(input_t.clone(), hidden_state.clone()); let output_values = self.hidden_activation.forward(biased_gate_sum); // Update hidden state hidden_state = output_values; // Apply hidden state clipping if configured if let Some(clip) = self.clip { hidden_state = hidden_state.clamp(-clip, clip); } let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1); // store the hidden state for this timestep batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], unsqueezed_hidden_state.clone(), ); } (batched_hidden_state, RnnState::new(hidden_state)) } } /// Configuration to create a [BiRnn](BiRnn) module using the [init function](BiRnnConfig::init). #[derive(Config, Debug)] pub struct BiRnnConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the BiRnn transformation. pub bias: bool, /// BiRnn initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`. /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`. #[config(default = true)] pub batch_first: bool, /// Optional hidden state clip threshold. pub clip: Option, /// Activation function applied to the hidden state before computing hidden output. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, } /// The BiRnn module. This implementation is for Bidirectional RNN. /// Should be created with [BiRnnConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct BiRnn { /// RNN for the forward direction. pub forward: Rnn, /// RNN for the reverse direction. pub reverse: Rnn, /// The size of the hidden state. pub d_hidden: usize, /// If true, input is `[batch_size, seq_length, input_size]`. /// If false, input is `[seq_length, batch_size, input_size]`. pub batch_first: bool, } impl ModuleDisplay for BiRnn { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims(); let bias = self.forward.gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .optional() } } impl BiRnnConfig { /// Initialize a new [Bidirectional RNN](BiRnn) module. pub fn init(&self, device: &B::Device) -> BiRnn { // Internal RNNs always use batch_first=true; BiRnn handles layout conversion let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias) .with_initializer(self.initializer.clone()) .with_batch_first(true) .with_clip(self.clip) .with_hidden_activation(self.hidden_activation.clone()); BiRnn { forward: base_config.clone().init(device), reverse: base_config.init(device), d_hidden: self.d_hidden, batch_first: self.batch_first, } } } impl BiRnn { /// Applies the forward pass on the input tensor. This Bidirectional RNN implementation /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape: /// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default) /// - `[sequence_length, batch_size, input_size]` if `batch_first` is false /// - state: An optional `RnnState` representing the hidden state. /// Each state tensor has shape `[2, batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: /// - output: A tensor represents the output features of RNN. Shape: /// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true /// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false /// - state: A `RnnState` represents the final forward and reverse states. /// The `state.hidden` have the shape `[2, batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> (Tensor, RnnState) { // Convert to batch-first layout internally if needed let batched_input = if self.batch_first { batched_input } else { batched_input.swap_dims(0, 1) }; let device = batched_input.clone().device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); let [init_state_forward, init_state_reverse] = match state { Some(state) => { let hidden_state_forward = state .hidden .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); let hidden_state_reverse = state .hidden .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); [ Some(RnnState::new(hidden_state_forward)), Some(RnnState::new(hidden_state_reverse)), ] } None => [None, None], }; // forward direction let (batched_hidden_state_forward, final_state_forward) = self .forward .forward(batched_input.clone(), init_state_forward); // reverse direction let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), init_state_reverse, batch_size, seq_length, &device, ); let output = Tensor::cat( [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), 2, ); // Convert output back to seq-first layout if needed let output = if self.batch_first { output } else { output.swap_dims(0, 1) }; let state = RnnState::new(Tensor::stack( [final_state_forward.hidden, final_state_reverse.hidden].to_vec(), 0, )); (output, state) } } #[cfg(test)] mod tests { use super::*; use crate::{LinearRecord, TestBackend}; use burn::module::Param; use burn::tensor::{Device, Distribution, TensorData}; use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; #[cfg(feature = "std")] use crate::TestAutodiffBackend; fn create_single_feature_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &Device, ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; let record_2 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; GateController::create_with_weights( d_input, d_output, bias, initializer, record_1, record_2, ) } #[test] fn test_with_uniform_initializer() { let device = Default::default(); TestBackend::seed(&device, 0); let config = RnnConfig::new(5, 5, false) .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); let rnn = config.init::(&Default::default()); let gate_to_data = |gate: GateController| gate.input_transform.weight.val().to_data(); gate_to_data(rnn.gate).assert_within_range::(0.elem()..1.elem()); } /// Test forward pass with simple input vector. /// /// Simple RNN: h_t = tanh(W_input @ x_t + W_hidden @ h_{t-1} + b) /// With input=0.1, weight_input=0.5, bias=0.0, h_0=0.0, weight_hidden=0.5 /// h_t = tanh(0.5*0.1 + 0.5*0) = tanh(0.05) = 0.04995 #[test] fn test_forward_single_input_single_feature() { let device = Default::default(); TestBackend::seed(&device, 0); let config = RnnConfig::new(1, 1, false); let device = Default::default(); let mut rnn = config.init::(&device); rnn.gate = create_single_feature_gate_controller( 0.5, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); // single timestep with single feature let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); let (output, state) = rnn.forward(input, None); let tolerance = Tolerance::default(); let expected = TensorData::from([[0.04995]]); state .hidden .to_data() .assert_approx_eq::(&expected, tolerance); output .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0) .to_data() .assert_approx_eq::(&state.hidden.to_data(), tolerance); } #[test] fn test_batched_forward_pass_batch_of_one() { let device = Default::default(); let rnn = RnnConfig::new(64, 1024, true).init(&device); let batched_input = Tensor::::random([1, 2, 64], Distribution::Default, &device); let (output, state) = rnn.forward(batched_input, None); assert_eq!(output.dims(), [1, 2, 1024]); assert_eq!(state.hidden.dims(), [1, 1024]); } #[test] #[cfg(feature = "std")] fn test_batched_backward_pass() { use burn::tensor::Shape; let device = Default::default(); let rnn = RnnConfig::new(64, 32, true).init(&device); let shape: Shape = [8, 10, 64].into(); let batched_input = Tensor::::random(shape, Distribution::Default, &device); let (output, _) = rnn.forward(batched_input.clone(), None); let fake_loss = output; let grads = fake_loss.backward(); let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap(); // Asserts that the gradients exist and are non-zero assert_ne!( some_gradient .any() .into_data() .iter::() .next() .unwrap(), 0.0 ); } #[test] fn test_bidirectional() { let device = Default::default(); TestBackend::seed(&device, 0); let config = BiRnnConfig::new(2, 3, true); let mut rnn = config.init(&device); fn create_gate_controller( input_weights: [[f32; D1]; D2], input_biases: [f32; D1], hidden_weights: [[f32; D1]; D1], hidden_biases: [f32; D1], device: &Device, ) -> GateController { let d_input = input_weights[0].len(); let d_output = input_weights.len(); let input_record = LinearRecord { weight: Param::from_data(TensorData::from(input_weights), device), bias: Some(Param::from_data(TensorData::from(input_biases), device)), }; let hidden_record = LinearRecord { weight: Param::from_data(TensorData::from(hidden_weights), device), bias: Some(Param::from_data(TensorData::from(hidden_biases), device)), }; GateController::create_with_weights( d_input, d_output, true, Initializer::XavierUniform { gain: 1.0 }, input_record, hidden_record, ) } // [batch_size=1, seq_length=4, input_size=2] let input = Tensor::::from_data( TensorData::from([[ [0.949, -0.861], [0.892, 0.927], [-0.173, -0.301], [-0.081, 0.992], ]]), &device, ); // [2, batch_size=1, hidden_size=3] let h0 = Tensor::::from_data( TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), &device, ); rnn.forward.gate = create_gate_controller( // input_weights: [input_size=2, hidden_size=3] [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]], // input_biases: [hidden_size=3] [-0.196, 0.354, 0.209], // hidden_weights: [hidden_size=3, hidden_size=3] [ [-0.320, 0.232, -0.165], [0.093, -0.572, -0.315], [-0.467, 0.325, 0.046], ], // hidden_biases: [hidden_size=3] [0.181, -0.190, -0.245], &device, ); rnn.reverse.gate = create_gate_controller( [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]], [0.540, -0.164, 0.033], [ [0.159, 0.180, -0.037], [-0.443, 0.485, -0.488], [0.098, -0.085, -0.140], ], [-0.510, 0.105, 0.114], &device, ); // [batch_size=1, sequence_length=4, hidden_size * 2 = 6] // The expected output values were computed from PyTorch let expected_output_with_init_state = TensorData::from([[ [0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602], [0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766], [-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725], [0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387], ]]); let expected_output_without_init_state = TensorData::from([[ [0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607], [0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002], [-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727], [0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282], ]]); //`[2, batch_size=1, hidden_size=3]` let expected_hn_with_init_state = TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]); let expected_hn_without_init_state = TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]); let (output_with_init_state, state_with_init_state) = rnn.forward(input.clone(), Some(RnnState::new(h0))); let (output_without_init_state, state_without_init_state) = rnn.forward(input, None); let tolerance = Tolerance::permissive(); output_with_init_state .to_data() .assert_approx_eq::(&expected_output_with_init_state, tolerance); output_without_init_state .to_data() .assert_approx_eq::(&expected_output_without_init_state, tolerance); state_with_init_state .hidden .to_data() .assert_approx_eq::(&expected_hn_with_init_state, tolerance); state_without_init_state .hidden .to_data() .assert_approx_eq::(&expected_hn_without_init_state, tolerance); } #[test] fn display_rnn() { let config = RnnConfig::new(2, 3, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}" ); } #[test] fn display_birnn() { let config = BiRnnConfig::new(2, 3, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}" ); } #[test] fn test_rnn_clipping() { let device = Default::default(); // Create Rnn with clipping enabled let clip_value = 0.3; let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value)); let rnn = config.init::(&device); let input = Tensor::::random([2, 5, 4], Distribution::Default, &device); let (_, state) = rnn.forward(input, None); // Verify output values are within the clip range let hidden_state: Vec = state.hidden.to_data().to_vec().unwrap(); for val in hidden_state { assert!( val >= -clip_value as f32 && val <= clip_value as f32, "Value {} is outside clip range [-{}, {}]", val, clip_value, clip_value ); } } #[test] fn test_forward_reverse_sequence() { let device = Default::default(); TestBackend::seed(&device, 0); // Create RNN with reverse=true to process sequence in reverse order let config = RnnConfig::new(1, 1, false).with_reverse(true); let mut rnn = config.init::(&device); rnn.gate = create_single_feature_gate_controller( 0.5, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); // Create input with 3 timesteps: [0.1, 0.2, 0.3] // Shape: [batch_size=1, seq_length=3, input_features=1] let input = Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); let (output, state) = rnn.forward(input, None); // With reverse=true and weight=0.5, sequence is processed in reverse: // t=2 (last): h = tanh(0.5*0.3 + 0.5*0) = tanh(0.15) ≈ 0.1488850 // t=1 (mid): h = tanh(0.5*0.2 + 0.5*0.1488850) ≈ 0.17269433 // t=0 (first): h = tanh(0.5*0.1 + 0.5*0.17269433) ≈ 0.135508 let expected_final_hidden = TensorData::from([[0.135508]]); let tolerance = Tolerance::default(); state .hidden .to_data() .assert_approx_eq::(&expected_final_hidden, tolerance); // Verify output tensor has correct shape and matches state at final timestep assert_eq!(output.dims(), [1, 3, 1]); } } ================================================ FILE: crates/burn-nn/src/modules/rnn/gate_controller.rs ================================================ use burn_core as burn; use crate::{Linear, LinearConfig, LinearLayout}; use burn::module::{Initializer, Module}; use burn::tensor::{Tensor, backend::Backend}; /// A GateController represents a gate in an LSTM cell. An /// LSTM cell generally contains three gates: an input gate, /// forget gate, and output gate. Additionally, cell gate /// is just used to compute the cell state. /// /// An Lstm gate is modeled as two linear transformations. /// The results of these transformations are used to calculate /// the gate's output. #[derive(Module, Debug)] pub struct GateController { /// Represents the affine transformation applied to input vector pub input_transform: Linear, /// Represents the affine transformation applied to the hidden state pub hidden_transform: Linear, } impl GateController { /// Initialize a new [gate_controller](GateController) module. pub fn new( d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &B::Device, ) -> Self { Self { input_transform: LinearConfig { d_input, d_output, bias, initializer: initializer.clone(), layout: LinearLayout::Row, } .init(device), hidden_transform: LinearConfig { d_input: d_output, d_output, bias, initializer, layout: LinearLayout::Row, } .init(device), } } /// Helper function for performing weighted matrix product for a gate and adds /// bias, if any. /// /// Mathematically, performs `Wx*X + Wh*H + b`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { self.input_transform.forward(input) + self.hidden_transform.forward(hidden) } /// Used to initialize a gate controller with known weight layers, /// allowing for predictable behavior. Used only for testing in /// lstm. #[cfg(test)] pub fn create_with_weights( d_input: usize, d_output: usize, bias: bool, initializer: Initializer, input_record: crate::LinearRecord, hidden_record: crate::LinearRecord, ) -> Self { let l1 = LinearConfig { d_input, d_output, bias, initializer: initializer.clone(), layout: LinearLayout::Row, } .init(&input_record.weight.device()) .load_record(input_record); let l2 = LinearConfig { d_input, d_output, bias, initializer, layout: LinearLayout::Row, } .init(&hidden_record.weight.device()) .load_record(hidden_record); Self { input_transform: l1, hidden_transform: l2, } } } ================================================ FILE: crates/burn-nn/src/modules/rnn/gru.rs ================================================ use burn_core as burn; use super::gate_controller::GateController; use crate::activation::{Activation, ActivationConfig}; use burn::config::Config; use burn::module::Initializer; use burn::module::Module; use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init). #[derive(Config, Debug)] pub struct GruConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the Gru transformation. pub bias: bool, /// If reset gate should be applied after weight multiplication. /// /// This configuration option controls how the reset gate is applied to the hidden state. /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU). /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication. /// /// The differing implementations can give slightly different numerical results and have different efficiencies. For more /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs). /// /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`). #[config(default = "true")] pub reset_after: bool, /// Gru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// Activation function for the update and reset gates. /// Default is Sigmoid, which is standard for GRU gates. #[config(default = "ActivationConfig::Sigmoid")] pub gate_activation: ActivationConfig, /// Activation function for the new/candidate gate. /// Default is Tanh, which is standard for GRU. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, /// Optional hidden state clip threshold. If provided, hidden state values are clipped /// to the range `[-clip, +clip]` after each timestep. This can help prevent /// exploding values during inference. pub clip: Option, } /// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru. /// /// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078). /// /// Should be created with [GruConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Gru { /// The update gate controller. pub update_gate: GateController, /// The reset gate controller. pub reset_gate: GateController, /// The new gate controller. pub new_gate: GateController, /// The size of the hidden state. pub d_hidden: usize, /// If reset gate should be applied after weight multiplication. pub reset_after: bool, /// Activation function for gates (update, reset). pub gate_activation: Activation, /// Activation function for new/candidate gate. pub hidden_activation: Activation, /// Optional hidden state clip threshold. pub clip: Option, } impl ModuleDisplay for Gru { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self.update_gate.input_transform.weight.shape().dims(); let bias = self.update_gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .add("reset_after", &self.reset_after) .optional() } } impl GruConfig { /// Initialize a new [gru](Gru) module. pub fn init(&self, device: &B::Device) -> Gru { let d_output = self.d_hidden; let update_gate = GateController::new( self.d_input, d_output, self.bias, self.initializer.clone(), device, ); let reset_gate = GateController::new( self.d_input, d_output, self.bias, self.initializer.clone(), device, ); let new_gate = GateController::new( self.d_input, d_output, self.bias, self.initializer.clone(), device, ); Gru { update_gate, reset_gate, new_gate, d_hidden: self.d_hidden, reset_after: self.reset_after, gate_activation: self.gate_activation.init(device), hidden_activation: self.hidden_activation.init(device), clip: self.clip, } } } impl Gru { /// Applies the forward pass on the input tensor. This GRU implementation /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`. /// /// # Parameters /// - batched_input: `[batch_size, sequence_length, input_size]`. /// - state: An optional tensor representing an initial cell state with dimensions /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used. /// /// # Returns /// - output: `[batch_size, sequence_length, hidden_size]` pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> Tensor { let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); self.forward_iter( batched_input.iter_dim(1).zip(0..seq_length), state, batch_size, seq_length, &device, ) .0 } /// Forward pass variant that accepts an iterator over timesteps. /// Used by BiGru to process sequences in either direction. /// /// # Parameters /// - input_timestep_iter: Iterator yielding (input_tensor, timestep_index) pairs. /// The timestep_index determines where in the output tensor to store results. /// - state: Optional initial hidden state with shape `[batch_size, hidden_size]`. /// - batch_size: Batch size of the input. /// - seq_length: Sequence length of the input. /// - device: Device to create tensors on. /// /// # Returns /// - output: `[batch_size, sequence_length, hidden_size]` /// - final_hidden: Final hidden state `[batch_size, hidden_size]` pub(crate) fn forward_iter, usize)>>( &self, input_timestep_iter: I, state: Option>, batch_size: usize, seq_length: usize, device: &B::Device, ) -> (Tensor, Tensor) { let mut batched_hidden_state = Tensor::empty([batch_size, seq_length, self.d_hidden], device); let mut hidden_t = match state { Some(state) => state, None => Tensor::zeros([batch_size, self.d_hidden], device), }; for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze_dim(1); // u(pdate)g(ate) tensors let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, None, &self.update_gate); let update_values = self.gate_activation.forward(biased_ug_input_sum); // r(eset)g(ate) tensors let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, None, &self.reset_gate); let reset_values = self.gate_activation.forward(biased_rg_input_sum); // n(ew)g(ate) tensor let biased_ng_input_sum = if self.reset_after { self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate) } else { let reset_t = hidden_t.clone().mul(reset_values); self.gate_product(&input_t, &reset_t, None, &self.new_gate) }; let candidate_state = self.hidden_activation.forward(biased_ng_input_sum); // calculate linear interpolation between previous hidden state and candidate state: // h_t = (1 - z_t) * g_t + z_t * h_{t-1} let one_minus_z = update_values.clone().neg().add_scalar(1.0); hidden_t = candidate_state.mul(one_minus_z) + update_values.mul(hidden_t); // Apply hidden state clipping if configured if let Some(clip) = self.clip { hidden_t = hidden_t.clamp(-clip, clip); } let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1); batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], unsqueezed_hidden_state, ); } (batched_hidden_state, hidden_t) } /// Helper function for performing weighted matrix product for a gate and adds /// bias, if any, and optionally applies reset to hidden state. /// /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms /// r = reset state fn gate_product( &self, input: &Tensor, hidden: &Tensor, reset: Option<&Tensor>, gate: &GateController, ) -> Tensor { let input_product = input.clone().matmul(gate.input_transform.weight.val()); let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val()); let input_part = match &gate.input_transform.bias { Some(bias) => input_product + bias.val().unsqueeze(), None => input_product, }; let hidden_part = match &gate.hidden_transform.bias { Some(bias) => hidden_product + bias.val().unsqueeze(), None => hidden_product, }; match reset { Some(r) => input_part + r.clone().mul(hidden_part), None => input_part + hidden_part, } } } /// Configuration to create a [BiGru](BiGru) module using the [init function](BiGruConfig::init). #[derive(Config, Debug)] pub struct BiGruConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the BiGru transformation. pub bias: bool, /// If reset gate should be applied after weight multiplication. #[config(default = "true")] pub reset_after: bool, /// BiGru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`. /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`. #[config(default = true)] pub batch_first: bool, /// Activation function for the update and reset gates. #[config(default = "ActivationConfig::Sigmoid")] pub gate_activation: ActivationConfig, /// Activation function for the new/candidate gate. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, /// Optional hidden state clip threshold. pub clip: Option, } /// The BiGru module. This implementation is for Bidirectional GRU. /// /// Based on the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078). /// /// Should be created with [BiGruConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct BiGru { /// GRU for the forward direction. pub forward: Gru, /// GRU for the reverse direction. pub reverse: Gru, /// The size of the hidden state. pub d_hidden: usize, /// If true, input is `[batch_size, seq_length, input_size]`. /// If false, input is `[seq_length, batch_size, input_size]`. pub batch_first: bool, } impl ModuleDisplay for BiGru { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self .forward .update_gate .input_transform .weight .shape() .dims(); let bias = self.forward.update_gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .optional() } } impl BiGruConfig { /// Initialize a new [Bidirectional GRU](BiGru) module. pub fn init(&self, device: &B::Device) -> BiGru { // Internal GRUs always use batch_first=true; BiGru handles layout conversion let base_config = GruConfig::new(self.d_input, self.d_hidden, self.bias) .with_initializer(self.initializer.clone()) .with_reset_after(self.reset_after) .with_gate_activation(self.gate_activation.clone()) .with_hidden_activation(self.hidden_activation.clone()) .with_clip(self.clip); BiGru { forward: base_config.clone().init(device), reverse: base_config.init(device), d_hidden: self.d_hidden, batch_first: self.batch_first, } } } impl BiGru { /// Applies the forward pass on the input tensor. This Bidirectional GRU implementation /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape: /// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default) /// - `[sequence_length, batch_size, input_size]` if `batch_first` is false /// - state: An optional tensor representing the initial hidden state with shape /// `[2, batch_size, hidden_size]`. If no initial state is provided, it is initialized to zeros. /// /// ## Returns: /// - output: A tensor representing the output features. Shape: /// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true /// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false /// - state: The final forward and reverse hidden states stacked along dimension 0 /// with shape `[2, batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> (Tensor, Tensor) { // Convert to batch-first layout internally if needed let batched_input = if self.batch_first { batched_input } else { batched_input.swap_dims(0, 1) }; let device = batched_input.clone().device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); let [init_state_forward, init_state_reverse] = match state { Some(state) => { let hidden_state_forward = state .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); let hidden_state_reverse = state .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); [Some(hidden_state_forward), Some(hidden_state_reverse)] } None => [None, None], }; // forward direction let (batched_hidden_state_forward, final_state_forward) = self.forward.forward_iter( batched_input.clone().iter_dim(1).zip(0..seq_length), init_state_forward, batch_size, seq_length, &device, ); // reverse direction let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), init_state_reverse, batch_size, seq_length, &device, ); let output = Tensor::cat( [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), 2, ); // Convert output back to seq-first layout if needed let output = if self.batch_first { output } else { output.swap_dims(0, 1) }; let state = Tensor::stack([final_state_forward, final_state_reverse].to_vec(), 0); (output, state) } } #[cfg(test)] mod tests { use super::*; use crate::{LinearRecord, TestBackend}; use burn::module::Param; use burn::tensor::{Distribution, TensorData}; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; fn init_gru(reset_after: bool, device: &B::Device) -> Gru { fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &B::Device, ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; let record_2 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; GateController::create_with_weights( d_input, d_output, bias, initializer, record_1, record_2, ) } let config = GruConfig::new(1, 1, false).with_reset_after(reset_after); let mut gru = config.init::(device); gru.update_gate = create_gate_controller( 0.5, 0.0, 1, 1, false, Initializer::XavierNormal { gain: 1.0 }, device, ); gru.reset_gate = create_gate_controller( 0.6, 0.0, 1, 1, false, Initializer::XavierNormal { gain: 1.0 }, device, ); gru.new_gate = create_gate_controller( 0.7, 0.0, 1, 1, false, Initializer::XavierNormal { gain: 1.0 }, device, ); gru } /// Test forward pass with simple input vector. /// /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 /// /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 #[test] fn tests_forward_single_input_single_feature() { let device = Default::default(); TestBackend::seed(&device, 0); let mut gru = init_gru::(false, &device); let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); let expected = TensorData::from([[0.034]]); // Reset gate applied to hidden state before the matrix multiplication let state = gru.forward(input.clone(), None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0); let tolerance = Tolerance::default(); output .to_data() .assert_approx_eq::(&expected, tolerance); // Reset gate applied to hidden state after the matrix multiplication gru.reset_after = true; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0); output .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn tests_forward_seq_len_3() { let device = Default::default(); TestBackend::seed(&device, 0); let mut gru = init_gru::(true, &device); let input = Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]); let result = gru.forward(input.clone(), None); let output = result .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0); let tolerance = Tolerance::default(); output .to_data() .assert_approx_eq::(&expected, tolerance); // Reset gate applied to hidden state before the matrix multiplication gru.reset_after = false; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0); output .to_data() .assert_approx_eq::(&expected, tolerance); } #[test] fn test_batched_forward_pass() { let device = Default::default(); let gru = GruConfig::new(64, 1024, true).init::(&device); let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); let hidden_state = gru.forward(batched_input, None); assert_eq!(&*hidden_state.shape(), [8, 10, 1024]); } #[test] fn display() { let config = GruConfig::new(2, 8, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}" ); } #[test] fn test_bigru_batched_forward_pass() { let device = Default::default(); let bigru = BiGruConfig::new(64, 1024, true).init::(&device); let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); let (output, state) = bigru.forward(batched_input, None); // Output should have hidden_size * 2 features (forward + reverse concatenated) assert_eq!(&*output.shape(), [8, 10, 2048]); // State should have shape [2, batch_size, hidden_size] assert_eq!(&*state.shape(), [2, 8, 1024]); } #[test] fn test_bigru_with_initial_state() { let device = Default::default(); let bigru = BiGruConfig::new(32, 64, true).init::(&device); let batched_input = Tensor::::random([4, 5, 32], Distribution::Default, &device); let initial_state = Tensor::::random([2, 4, 64], Distribution::Default, &device); let (output, state) = bigru.forward(batched_input, Some(initial_state)); assert_eq!(&*output.shape(), [4, 5, 128]); assert_eq!(&*state.shape(), [2, 4, 64]); } #[test] fn test_bigru_seq_first() { let device = Default::default(); let bigru = BiGruConfig::new(32, 64, true) .with_batch_first(false) .init::(&device); // Input shape: [seq_length, batch_size, input_size] when batch_first=false let batched_input = Tensor::::random([5, 4, 32], Distribution::Default, &device); let (output, state) = bigru.forward(batched_input, None); // Output shape: [seq_length, batch_size, hidden_size * 2] assert_eq!(&*output.shape(), [5, 4, 128]); assert_eq!(&*state.shape(), [2, 4, 64]); } /// Test BiGru against PyTorch reference implementation. /// Expected values computed with PyTorch nn.GRU(bidirectional=True). #[test] fn test_bigru_against_pytorch() { use burn::tensor::Device; let device = Default::default(); TestBackend::seed(&device, 0); let config = BiGruConfig::new(2, 3, true); let mut bigru = config.init::(&device); fn create_gate_controller( input_weights: [[f32; D1]; D2], input_biases: [f32; D1], hidden_weights: [[f32; D1]; D1], hidden_biases: [f32; D1], device: &Device, ) -> GateController { let d_input = input_weights[0].len(); let d_output = input_weights.len(); let input_record = LinearRecord { weight: Param::from_data(TensorData::from(input_weights), device), bias: Some(Param::from_data(TensorData::from(input_biases), device)), }; let hidden_record = LinearRecord { weight: Param::from_data(TensorData::from(hidden_weights), device), bias: Some(Param::from_data(TensorData::from(hidden_biases), device)), }; GateController::create_with_weights( d_input, d_output, true, Initializer::XavierUniform { gain: 1.0 }, input_record, hidden_record, ) } let input = Tensor::::from_data( TensorData::from([[ [0.949, -0.861], [0.892, 0.927], [-0.173, -0.301], [-0.081, 0.992], ]]), &device, ); let h0 = Tensor::::from_data( TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), &device, ); // Forward GRU gates (weights from PyTorch with seed 42, transposed for burn) bigru.forward.update_gate = create_gate_controller( [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]], [0.2932, -0.3519, -0.5715], [ [-0.3471, 0.5214, 0.0961], [0.0545, -0.4904, -0.1875], [-0.5702, 0.4457, 0.3568], ], [-0.0100, 0.4518, -0.4102], &device, ); bigru.forward.reset_gate = create_gate_controller( [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]], [-0.2524, 0.3333, 0.1033], [ [-0.2695, -0.0677, -0.4557], [0.1472, -0.2345, -0.2662], [-0.2660, 0.3830, -0.1630], ], [0.1663, 0.2391, 0.1826], &device, ); bigru.forward.new_gate = create_gate_controller( [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]], [-0.2231, -0.4428, 0.4737], [ [0.0900, -0.1821, 0.2430], [0.4665, 0.1551, 0.5155], [0.0631, -0.1566, 0.3337], ], [0.0364, -0.3941, 0.1780], &device, ); // Reverse GRU gates bigru.reverse.update_gate = create_gate_controller( [[-0.3444, 0.1924, -0.4765], [0.5193, 0.5556, -0.5727]], [0.1090, 0.1779, -0.5385], [ [0.1221, 0.3925, 0.5287], [-0.1472, -0.4187, -0.1948], [0.3441, -0.3082, -0.2047], ], [0.0016, -0.2148, -0.0400], &device, ); bigru.reverse.reset_gate = create_gate_controller( [[-0.1988, -0.1203, -0.3422], [0.1769, 0.4788, -0.3443]], [-0.5053, -0.3676, 0.5771], [ [-0.3936, 0.3504, -0.4486], [0.3063, -0.1370, -0.2914], [-0.2334, 0.3303, 0.1760], ], [-0.5080, -0.2488, -0.3456], &device, ); bigru.reverse.new_gate = create_gate_controller( [[-0.4517, 0.2339, 0.4797], [-0.3884, 0.2067, -0.2982]], [-0.3792, -0.1922, 0.0903], [ [-0.5586, -0.0762, -0.3944], [-0.3306, -0.4191, -0.4898], [0.1442, 0.0135, -0.3179], ], [-0.3912, -0.3963, -0.3368], &device, ); // Expected values from PyTorch let expected_output_with_init = TensorData::from([[ [0.24537, 0.14018, 0.19449, -0.49777, -0.15647, 0.48392], [0.27468, -0.14514, 0.56205, -0.60381, -0.04986, 0.15683], [-0.04062, -0.33486, 0.52330, -0.42244, -0.12644, -0.12034], [-0.11743, -0.53873, 0.54429, -0.64943, 0.30127, -0.41943], ]]); let expected_hn_with_init = TensorData::from([ [[-0.11743, -0.53873, 0.54429]], [[-0.49777, -0.15647, 0.48392]], ]); let expected_output_without_init = TensorData::from([[ [0.07452, -0.08247, 0.46677, -0.46770, -0.18086, 0.47519], [0.15843, -0.27144, 0.65781, -0.50286, -0.12806, 0.14884], [-0.10704, -0.41573, 0.53954, -0.24794, -0.24003, -0.10294], [-0.16505, -0.57952, 0.53565, -0.23598, -0.07137, -0.28937], ]]); let expected_hn_without_init = TensorData::from([ [[-0.16505, -0.57952, 0.53565]], [[-0.46770, -0.18086, 0.47519]], ]); let (output_with_init, hn_with_init) = bigru.forward(input.clone(), Some(h0)); let (output_without_init, hn_without_init) = bigru.forward(input, None); let tolerance = Tolerance::permissive(); output_with_init .to_data() .assert_approx_eq::(&expected_output_with_init, tolerance); output_without_init .to_data() .assert_approx_eq::(&expected_output_without_init, tolerance); hn_with_init .to_data() .assert_approx_eq::(&expected_hn_with_init, tolerance); hn_without_init .to_data() .assert_approx_eq::(&expected_hn_without_init, tolerance); } #[test] fn bigru_display() { let config = BiGruConfig::new(2, 8, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "BiGru {d_input: 2, d_hidden: 8, bias: true, params: 576}" ); } #[test] fn test_gru_custom_activations() { let device = Default::default(); // Create GRU with custom activations (ReLU instead of Sigmoid/Tanh) let config = GruConfig::new(4, 8, true) .with_gate_activation(ActivationConfig::Relu) .with_hidden_activation(ActivationConfig::Relu); let gru = config.init::(&device); let input = Tensor::::random([2, 3, 4], Distribution::Default, &device); // Should run without panicking and produce valid output let output = gru.forward(input, None); assert_eq!(&*output.shape(), [2, 3, 8]); } #[test] fn test_bigru_custom_activations() { let device = Default::default(); // Create BiGRU with custom activations let config = BiGruConfig::new(4, 8, true) .with_gate_activation(ActivationConfig::Relu) .with_hidden_activation(ActivationConfig::Relu); let bigru = config.init::(&device); let input = Tensor::::random([2, 3, 4], Distribution::Default, &device); let (output, state) = bigru.forward(input, None); assert_eq!(&*output.shape(), [2, 3, 16]); // hidden_size * 2 assert_eq!(&*state.shape(), [2, 2, 8]); } #[test] fn test_gru_clipping() { let device = Default::default(); // Create GRU with clipping enabled let clip_value = 0.5; let config = GruConfig::new(4, 8, true).with_clip(Some(clip_value)); let gru = config.init::(&device); let input = Tensor::::random([2, 5, 4], Distribution::Default, &device); let output = gru.forward(input, None); // Verify output values are within the clip range let output_data: Vec = output.to_data().to_vec().unwrap(); for val in output_data { assert!( val >= -clip_value as f32 && val <= clip_value as f32, "Value {} is outside clip range [-{}, {}]", val, clip_value, clip_value ); } } #[test] fn test_bigru_clipping() { let device = Default::default(); // Create BiGRU with clipping enabled let clip_value = 0.3; let config = BiGruConfig::new(4, 8, true).with_clip(Some(clip_value)); let bigru = config.init::(&device); let input = Tensor::::random([2, 5, 4], Distribution::Default, &device); let (output, state) = bigru.forward(input, None); // Verify output values are within the clip range let output_data: Vec = output.to_data().to_vec().unwrap(); for val in output_data { assert!( val >= -clip_value as f32 && val <= clip_value as f32, "Output value {} is outside clip range [-{}, {}]", val, clip_value, clip_value ); } // Verify state values are within the clip range let state_data: Vec = state.to_data().to_vec().unwrap(); for val in state_data { assert!( val >= -clip_value as f32 && val <= clip_value as f32, "State value {} is outside clip range [-{}, {}]", val, clip_value, clip_value ); } } /// Test Gru against PyTorch reference implementation. /// Expected values computed with PyTorch nn.GRU (seed=42 for weights, seed=123 for input). #[test] fn test_gru_against_pytorch() { use burn::tensor::Device; let device = Default::default(); TestBackend::seed(&device, 0); let config = GruConfig::new(2, 3, true); let mut gru = config.init::(&device); fn create_gate_controller( input_weights: [[f32; D1]; D2], input_biases: [f32; D1], hidden_weights: [[f32; D1]; D1], hidden_biases: [f32; D1], device: &Device, ) -> GateController { let d_input = input_weights[0].len(); let d_output = input_weights.len(); let input_record = LinearRecord { weight: Param::from_data(TensorData::from(input_weights), device), bias: Some(Param::from_data(TensorData::from(input_biases), device)), }; let hidden_record = LinearRecord { weight: Param::from_data(TensorData::from(hidden_weights), device), bias: Some(Param::from_data(TensorData::from(hidden_biases), device)), }; GateController::create_with_weights( d_input, d_output, true, Initializer::XavierUniform { gain: 1.0 }, input_record, hidden_record, ) } // Input: [batch=1, seq=4, input=2] let input = Tensor::::from_data( TensorData::from([[ [-0.11147, 0.12036], [-0.36963, -0.24042], [-1.19692, 0.20927], [-0.97236, -0.75505], ]]), &device, ); // Initial hidden state: [batch=1, hidden=3] let h0 = Tensor::::from_data( TensorData::from([[0.3239, -0.10852, 0.21033]]), &device, ); // Update gate (z) - weights from PyTorch, transposed for Burn's Row layout gru.update_gate = create_gate_controller( [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]], [0.2932, -0.3519, -0.5715], [ [-0.3471, 0.5214, 0.0961], [0.0545, -0.4904, -0.1875], [-0.5702, 0.4457, 0.3568], ], [-0.0100, 0.4518, -0.4102], &device, ); // Reset gate (r) gru.reset_gate = create_gate_controller( [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]], [-0.2524, 0.3333, 0.1033], [ [-0.2695, -0.0677, -0.4557], [0.1472, -0.2345, -0.2662], [-0.2660, 0.3830, -0.1630], ], [0.1663, 0.2391, 0.1826], &device, ); // New gate (n) gru.new_gate = create_gate_controller( [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]], [-0.2231, -0.4428, 0.4737], [ [0.0900, -0.1821, 0.2430], [0.4665, 0.1551, 0.5155], [0.0631, -0.1566, 0.3337], ], [0.0364, -0.3941, 0.1780], &device, ); // Expected values from PyTorch let expected_output_with_h0 = TensorData::from([[ [0.05665, -0.34932, 0.43267], [-0.1737, -0.49246, 0.38099], [-0.35401, -0.68099, 0.05061], [-0.47854, -0.70427, -0.13648], ]]); let expected_output_no_h0 = TensorData::from([[ [-0.0985, -0.31661, 0.36126], [-0.24563, -0.47784, 0.34609], [-0.39497, -0.67659, 0.03083], [-0.50146, -0.70066, -0.14894], ]]); let output_with_h0 = gru.forward(input.clone(), Some(h0)); let output_no_h0 = gru.forward(input, None); let tolerance = Tolerance::permissive(); output_with_h0 .to_data() .assert_approx_eq::(&expected_output_with_h0, tolerance); output_no_h0 .to_data() .assert_approx_eq::(&expected_output_no_h0, tolerance); } } ================================================ FILE: crates/burn-nn/src/modules/rnn/lstm.rs ================================================ use burn_core as burn; use crate::GateController; use crate::activation::{Activation, ActivationConfig}; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// A LstmState is used to store cell state and hidden state in LSTM. pub struct LstmState { /// The cell state. pub cell: Tensor, /// The hidden state. pub hidden: Tensor, } impl LstmState { /// Initialize a new [LSTM State](LstmState). pub fn new(cell: Tensor, hidden: Tensor) -> Self { Self { cell, hidden } } } /// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init). #[derive(Config, Debug)] pub struct LstmConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the Lstm transformation. pub bias: bool, /// Lstm initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`. /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`. #[config(default = true)] pub batch_first: bool, /// If true, process the sequence in reverse order. /// This is useful for implementing reverse-direction LSTMs (e.g., ONNX reverse direction). #[config(default = false)] pub reverse: bool, /// Optional cell state clip threshold. If provided, cell state values are clipped /// to the range `[-clip, +clip]` after each timestep. This can help prevent /// exploding values during inference. pub clip: Option, /// If true, couples the input and forget gates: `f_t = 1 - i_t`. /// This reduces the number of parameters and is based on GRU-style simplification. #[config(default = false)] pub input_forget: bool, /// Activation function for the input, forget, and output gates. /// Default is Sigmoid, which is standard for LSTM gates. #[config(default = "ActivationConfig::Sigmoid")] pub gate_activation: ActivationConfig, /// Activation function for the cell gate (candidate cell state). /// Default is Tanh, which is standard for LSTM. #[config(default = "ActivationConfig::Tanh")] pub cell_activation: ActivationConfig, /// Activation function applied to the cell state before computing hidden output. /// Default is Tanh, which is standard for LSTM. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, } /// The Lstm module. This implementation is for a unidirectional, stateless, Lstm. /// /// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244). /// /// Should be created with [LstmConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct Lstm { /// The input gate regulates which information to update and store in the cell state at each time step. pub input_gate: GateController, /// The forget gate is used to control which information to discard or keep in the memory cell at each time step. /// Note: When `input_forget` is true, this gate is not used (forget = 1 - input). pub forget_gate: GateController, /// The output gate determines which information from the cell state to output at each time step. pub output_gate: GateController, /// The cell gate is used to compute the cell state that stores and carries information through time. pub cell_gate: GateController, /// The hidden state of the LSTM. pub d_hidden: usize, /// If true, input is `[batch_size, seq_length, input_size]`. /// If false, input is `[seq_length, batch_size, input_size]`. pub batch_first: bool, /// If true, process the sequence in reverse order. pub reverse: bool, /// Optional cell state clip threshold. pub clip: Option, /// If true, couples input and forget gates: f_t = 1 - i_t. pub input_forget: bool, /// Activation function for gates (input, forget, output). pub gate_activation: Activation, /// Activation function for cell gate (candidate cell state). pub cell_activation: Activation, /// Activation function for hidden output. pub hidden_activation: Activation, } impl ModuleDisplay for Lstm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self.input_gate.input_transform.weight.shape().dims(); let bias = self.input_gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .optional() } } impl LstmConfig { /// Initialize a new [lstm](Lstm) module. pub fn init(&self, device: &B::Device) -> Lstm { let d_output = self.d_hidden; let new_gate = || { GateController::new( self.d_input, d_output, self.bias, self.initializer.clone(), device, ) }; Lstm { input_gate: new_gate(), forget_gate: new_gate(), output_gate: new_gate(), cell_gate: new_gate(), d_hidden: self.d_hidden, batch_first: self.batch_first, reverse: self.reverse, clip: self.clip, input_forget: self.input_forget, gate_activation: self.gate_activation.init(device), cell_activation: self.cell_activation.init(device), hidden_activation: self.hidden_activation.init(device), } } } impl Lstm { /// Applies the forward pass on the input tensor. This LSTM implementation /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape: /// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default) /// - `[sequence_length, batch_size, input_size]` if `batch_first` is false /// - state: An optional `LstmState` representing the initial cell state and hidden state. /// Each state tensor has shape `[batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: /// - output: A tensor represents the output features of LSTM. Shape: /// - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true /// - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape /// `[batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> (Tensor, LstmState) { // Convert to batch-first layout internally if needed let batched_input = if self.batch_first { batched_input } else { batched_input.swap_dims(0, 1) }; let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.dims(); // Process sequence in forward or reverse order based on config let (output, state) = if self.reverse { self.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), state, batch_size, seq_length, &device, ) } else { self.forward_iter( batched_input.iter_dim(1).zip(0..seq_length), state, batch_size, seq_length, &device, ) }; // Convert output back to seq-first layout if needed let output = if self.batch_first { output } else { output.swap_dims(0, 1) }; (output, state) } fn forward_iter, usize)>>( &self, input_timestep_iter: I, state: Option>, batch_size: usize, seq_length: usize, device: &B::Device, ) -> (Tensor, LstmState) { let mut batched_hidden_state = Tensor::empty([batch_size, seq_length, self.d_hidden], device); let (mut cell_state, mut hidden_state) = match state { Some(state) => (state.cell, state.hidden), None => ( Tensor::zeros([batch_size, self.d_hidden], device), Tensor::zeros([batch_size, self.d_hidden], device), ), }; for (input_t, t) in input_timestep_iter { let input_t = input_t.squeeze_dim(1); // i(nput)g(ate) tensors let biased_ig_input_sum = self .input_gate .gate_product(input_t.clone(), hidden_state.clone()); let input_values = self.gate_activation.forward(biased_ig_input_sum); // f(orget)g(ate) tensors - either computed or coupled to input gate let forget_values = if self.input_forget { // Coupled mode: f_t = 1 - i_t input_values.clone().neg().add_scalar(1.0) } else { let biased_fg_input_sum = self .forget_gate .gate_product(input_t.clone(), hidden_state.clone()); self.gate_activation.forward(biased_fg_input_sum) }; // o(output)g(ate) tensors let biased_og_input_sum = self .output_gate .gate_product(input_t.clone(), hidden_state.clone()); let output_values = self.gate_activation.forward(biased_og_input_sum); // c(ell)g(ate) tensors let biased_cg_input_sum = self .cell_gate .gate_product(input_t.clone(), hidden_state.clone()); let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum); cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values; // Apply cell state clipping if configured if let Some(clip) = self.clip { cell_state = cell_state.clamp(-clip, clip); } hidden_state = output_values * self.hidden_activation.forward(cell_state.clone()); let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1); // store the hidden state for this timestep batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], unsqueezed_hidden_state.clone(), ); } ( batched_hidden_state, LstmState::new(cell_state, hidden_state), ) } } /// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init). #[derive(Config, Debug)] pub struct BiLstmConfig { /// The size of the input features. pub d_input: usize, /// The size of the hidden state. pub d_hidden: usize, /// If a bias should be applied during the BiLstm transformation. pub bias: bool, /// BiLstm initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`. /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`. #[config(default = true)] pub batch_first: bool, /// Optional cell state clip threshold. pub clip: Option, /// If true, couples the input and forget gates. #[config(default = false)] pub input_forget: bool, /// Activation function for the input, forget, and output gates. #[config(default = "ActivationConfig::Sigmoid")] pub gate_activation: ActivationConfig, /// Activation function for the cell gate (candidate cell state). #[config(default = "ActivationConfig::Tanh")] pub cell_activation: ActivationConfig, /// Activation function applied to the cell state before computing hidden output. #[config(default = "ActivationConfig::Tanh")] pub hidden_activation: ActivationConfig, } /// The BiLstm module. This implementation is for Bidirectional LSTM. /// /// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf). /// /// Should be created with [BiLstmConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct BiLstm { /// LSTM for the forward direction. pub forward: Lstm, /// LSTM for the reverse direction. pub reverse: Lstm, /// The size of the hidden state. pub d_hidden: usize, /// If true, input is `[batch_size, seq_length, input_size]`. /// If false, input is `[seq_length, batch_size, input_size]`. pub batch_first: bool, } impl ModuleDisplay for BiLstm { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_input, _] = self .forward .input_gate .input_transform .weight .shape() .dims(); let bias = self.forward.input_gate.input_transform.bias.is_some(); content .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) .optional() } } impl BiLstmConfig { /// Initialize a new [Bidirectional LSTM](BiLstm) module. pub fn init(&self, device: &B::Device) -> BiLstm { // Internal LSTMs always use batch_first=true; BiLstm handles layout conversion let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias) .with_initializer(self.initializer.clone()) .with_batch_first(true) .with_clip(self.clip) .with_input_forget(self.input_forget) .with_gate_activation(self.gate_activation.clone()) .with_cell_activation(self.cell_activation.clone()) .with_hidden_activation(self.hidden_activation.clone()); BiLstm { forward: base_config.clone().init(device), reverse: base_config.init(device), d_hidden: self.d_hidden, batch_first: self.batch_first, } } } impl BiLstm { /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation /// returns the state for each element in a sequence (i.e., across seq_length) and a final state. /// /// ## Parameters: /// - batched_input: The input tensor of shape: /// - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default) /// - `[sequence_length, batch_size, input_size]` if `batch_first` is false /// - state: An optional `LstmState` representing the initial cell state and hidden state. /// Each state tensor has shape `[2, batch_size, hidden_size]`. /// If no initial state is provided, these tensors are initialized to zeros. /// /// ## Returns: /// - output: A tensor represents the output features of LSTM. Shape: /// - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true /// - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and /// `state.hidden` have the shape `[2, batch_size, hidden_size]`. pub fn forward( &self, batched_input: Tensor, state: Option>, ) -> (Tensor, LstmState) { // Convert to batch-first layout internally if needed let batched_input = if self.batch_first { batched_input } else { batched_input.swap_dims(0, 1) }; let device = batched_input.clone().device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); let [init_state_forward, init_state_reverse] = match state { Some(state) => { let cell_state_forward = state .cell .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); let hidden_state_forward = state .hidden .clone() .slice([0..1, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); let cell_state_reverse = state .cell .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); let hidden_state_reverse = state .hidden .slice([1..2, 0..batch_size, 0..self.d_hidden]) .squeeze_dim(0); [ Some(LstmState::new(cell_state_forward, hidden_state_forward)), Some(LstmState::new(cell_state_reverse, hidden_state_reverse)), ] } None => [None, None], }; // forward direction let (batched_hidden_state_forward, final_state_forward) = self .forward .forward(batched_input.clone(), init_state_forward); // reverse direction let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter( batched_input.iter_dim(1).rev().zip((0..seq_length).rev()), init_state_reverse, batch_size, seq_length, &device, ); let output = Tensor::cat( [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(), 2, ); // Convert output back to seq-first layout if needed let output = if self.batch_first { output } else { output.swap_dims(0, 1) }; let state = LstmState::new( Tensor::stack( [final_state_forward.cell, final_state_reverse.cell].to_vec(), 0, ), Tensor::stack( [final_state_forward.hidden, final_state_reverse.hidden].to_vec(), 0, ), ); (output, state) } } #[cfg(test)] mod tests { use super::*; use crate::{LinearRecord, TestBackend}; use burn::module::Param; use burn::tensor::{Device, Distribution, TensorData}; use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem}; type FT = FloatElem; #[cfg(feature = "std")] use crate::TestAutodiffBackend; #[test] fn test_with_uniform_initializer() { let device = Default::default(); TestBackend::seed(&device, 0); let config = LstmConfig::new(5, 5, false) .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 }); let lstm = config.init::(&Default::default()); let gate_to_data = |gate: GateController| gate.input_transform.weight.val().to_data(); gate_to_data(lstm.input_gate).assert_within_range::(0.elem()..1.elem()); gate_to_data(lstm.forget_gate).assert_within_range::(0.elem()..1.elem()); gate_to_data(lstm.output_gate).assert_within_range::(0.elem()..1.elem()); gate_to_data(lstm.cell_gate).assert_within_range::(0.elem()..1.elem()); } /// Test forward pass with simple input vector. /// /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928 /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725 /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723 /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937 /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243 /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648 #[test] fn test_forward_single_input_single_feature() { let device = Default::default(); TestBackend::seed(&device, 0); let config = LstmConfig::new(1, 1, false); let device = Default::default(); let mut lstm = config.init::(&device); fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &Device, ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; let record_2 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), }; GateController::create_with_weights( d_input, d_output, bias, initializer, record_1, record_2, ) } lstm.input_gate = create_gate_controller( 0.5, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); lstm.forget_gate = create_gate_controller( 0.7, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); lstm.cell_gate = create_gate_controller( 0.9, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); lstm.output_gate = create_gate_controller( 1.1, 0.0, 1, 1, false, Initializer::XavierUniform { gain: 1.0 }, &device, ); // single timestep with single feature let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); let (output, state) = lstm.forward(input, None); let expected = TensorData::from([[0.046]]); let tolerance = Tolerance::default(); state .cell .to_data() .assert_approx_eq::(&expected, tolerance); let expected = TensorData::from([[0.0242]]); state .hidden .to_data() .assert_approx_eq::(&expected, tolerance); output .select(0, Tensor::arange(0..1, &device)) .squeeze_dim::<2>(0) .to_data() .assert_approx_eq::(&state.hidden.to_data(), tolerance); } #[test] fn test_batched_forward_pass() { let device = Default::default(); let lstm = LstmConfig::new(64, 1024, true).init(&device); let batched_input = Tensor::::random([8, 10, 64], Distribution::Default, &device); let (output, state) = lstm.forward(batched_input, None); assert_eq!(output.dims(), [8, 10, 1024]); assert_eq!(state.cell.dims(), [8, 1024]); assert_eq!(state.hidden.dims(), [8, 1024]); } #[test] fn test_batched_forward_pass_batch_of_one() { let device = Default::default(); let lstm = LstmConfig::new(64, 1024, true).init(&device); let batched_input = Tensor::::random([1, 2, 64], Distribution::Default, &device); let (output, state) = lstm.forward(batched_input, None); assert_eq!(output.dims(), [1, 2, 1024]); assert_eq!(state.cell.dims(), [1, 1024]); assert_eq!(state.hidden.dims(), [1, 1024]); } #[test] #[cfg(feature = "std")] fn test_batched_backward_pass() { use burn::tensor::Shape; let device = Default::default(); let lstm = LstmConfig::new(64, 32, true).init(&device); let shape: Shape = [8, 10, 64].into(); let batched_input = Tensor::::random(shape, Distribution::Default, &device); let (output, _) = lstm.forward(batched_input.clone(), None); let fake_loss = output; let grads = fake_loss.backward(); let some_gradient = lstm .output_gate .hidden_transform .weight .grad(&grads) .unwrap(); // Asserts that the gradients exist and are non-zero assert_ne!( some_gradient .any() .into_data() .iter::() .next() .unwrap(), 0.0 ); } #[test] fn test_bidirectional() { let device = Default::default(); TestBackend::seed(&device, 0); let config = BiLstmConfig::new(2, 3, true); let device = Default::default(); let mut lstm = config.init(&device); fn create_gate_controller( input_weights: [[f32; D1]; D2], input_biases: [f32; D1], hidden_weights: [[f32; D1]; D1], hidden_biases: [f32; D1], device: &Device, ) -> GateController { let d_input = input_weights[0].len(); let d_output = input_weights.len(); let input_record = LinearRecord { weight: Param::from_data(TensorData::from(input_weights), device), bias: Some(Param::from_data(TensorData::from(input_biases), device)), }; let hidden_record = LinearRecord { weight: Param::from_data(TensorData::from(hidden_weights), device), bias: Some(Param::from_data(TensorData::from(hidden_biases), device)), }; GateController::create_with_weights( d_input, d_output, true, Initializer::XavierUniform { gain: 1.0 }, input_record, hidden_record, ) } let input = Tensor::::from_data( TensorData::from([[ [0.949, -0.861], [0.892, 0.927], [-0.173, -0.301], [-0.081, 0.992], ]]), &device, ); let h0 = Tensor::::from_data( TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]), &device, ); let c0 = Tensor::::from_data( TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]), &device, ); lstm.forward.input_gate = create_gate_controller( [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]], [-0.196, 0.354, 0.209], [ [-0.320, 0.232, -0.165], [0.093, -0.572, -0.315], [-0.467, 0.325, 0.046], ], [0.181, -0.190, -0.245], &device, ); lstm.forward.forget_gate = create_gate_controller( [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]], [0.315, -0.413, -0.041], [ [0.453, 0.063, 0.561], [0.211, 0.149, 0.213], [-0.499, -0.158, 0.068], ], [-0.431, -0.535, 0.125], &device, ); lstm.forward.cell_gate = create_gate_controller( [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]], [-0.358, 0.282, -0.078], [ [-0.358, 0.109, 0.139], [-0.345, 0.091, -0.368], [-0.508, 0.221, -0.507], ], [0.502, -0.509, -0.247], &device, ); lstm.forward.output_gate = create_gate_controller( [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]], [-0.227, -0.274, 0.039], [ [-0.383, 0.449, 0.222], [-0.357, -0.093, 0.449], [-0.106, 0.236, 0.360], ], [-0.361, -0.209, -0.454], &device, ); lstm.reverse.input_gate = create_gate_controller( [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]], [0.540, -0.164, 0.033], [ [0.159, 0.180, -0.037], [-0.443, 0.485, -0.488], [0.098, -0.085, -0.140], ], [-0.510, 0.105, 0.114], &device, ); lstm.reverse.forget_gate = create_gate_controller( [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]], [0.141, 0.004, 0.055], [ [-0.005, -0.277, -0.515], [-0.011, -0.101, -0.365], [0.426, 0.379, 0.337], ], [-0.382, 0.331, -0.176], &device, ); lstm.reverse.cell_gate = create_gate_controller( [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]], [-0.206, -0.546, 0.462], [ [0.449, -0.240, 0.071], [-0.045, 0.131, 0.124], [0.138, -0.201, 0.191], ], [-0.030, 0.211, -0.352], &device, ); lstm.reverse.output_gate = create_gate_controller( [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]], [-0.387, -0.250, 0.066], [ [-0.030, 0.268, 0.299], [-0.019, -0.280, -0.314], [0.466, -0.365, -0.248], ], [-0.398, -0.199, -0.566], &device, ); let expected_output_with_init_state = TensorData::from([[ [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798], [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742], [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012], [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872], ]]); let expected_output_without_init_state = TensorData::from([[ [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863], [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142], [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846], [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550], ]]); let expected_hn_with_init_state = TensorData::from([ [[-0.03420, 0.07774, -0.09774]], [[-0.15635, -0.03366, -0.05798]], ]); let expected_cn_with_init_state = TensorData::from([ [[-0.13593, 0.17125, -0.22395]], [[-0.45425, -0.11206, -0.12908]], ]); let expected_hn_without_init_state = TensorData::from([ [[-0.04026, 0.07178, -0.10189]], [[-0.15969, -0.05322, -0.08863]], ]); let expected_cn_without_init_state = TensorData::from([ [[-0.15839, 0.15923, -0.23569]], [[-0.47407, -0.17493, -0.19643]], ]); let (output_with_init_state, state_with_init_state) = lstm.forward(input.clone(), Some(LstmState::new(c0, h0))); let (output_without_init_state, state_without_init_state) = lstm.forward(input, None); let tolerance = Tolerance::permissive(); output_with_init_state .to_data() .assert_approx_eq::(&expected_output_with_init_state, tolerance); output_without_init_state .to_data() .assert_approx_eq::(&expected_output_without_init_state, tolerance); state_with_init_state .hidden .to_data() .assert_approx_eq::(&expected_hn_with_init_state, tolerance); state_with_init_state .cell .to_data() .assert_approx_eq::(&expected_cn_with_init_state, tolerance); state_without_init_state .hidden .to_data() .assert_approx_eq::(&expected_hn_without_init_state, tolerance); state_without_init_state .cell .to_data() .assert_approx_eq::(&expected_cn_without_init_state, tolerance); } #[test] fn display_lstm() { let config = LstmConfig::new(2, 3, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}" ); } #[test] fn display_bilstm() { let config = BiLstmConfig::new(2, 3, true); let layer = config.init::(&Default::default()); assert_eq!( alloc::format!("{layer}"), "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}" ); } } ================================================ FILE: crates/burn-nn/src/modules/rnn/mod.rs ================================================ mod gate_controller; /// Basic RNN. pub mod basic; /// Gated Recurrent Unit module. pub mod gru; /// Long Short-Term Memory module. pub mod lstm; pub use basic::*; pub use gate_controller::*; pub use gru::*; pub use lstm::*; ================================================ FILE: crates/burn-nn/src/modules/rope_encoding.rs ================================================ use burn_core as burn; use alloc::vec; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Int; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use core::ops::Range; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init). #[derive(Config, Debug)] pub struct RotaryEncodingConfig { /// Maximum sequence length of input pub max_sequence_length: usize, /// Size of the input embedding or hidden dimension pub d_model: usize, /// Scaling factor for frequency computation. Defaults to 10000.0 #[config(default = "10000.0")] pub theta: f32, } impl RotaryEncodingConfig { /// Initialize a new [RotaryEncoding](RotaryEncoding) module. /// /// # Panics /// /// Panics if the size of input embedding dimension is not even. /// Panics if the theta parameter is not positive. pub fn init(&self, device: &B::Device) -> RotaryEncoding { self.initialize(|x| x, device) } /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function. /// This is useful to apply different RoPE extensions. /// /// # Panics /// /// Panics if the size of input embedding dimension is not even. /// Panics if the theta parameter is not positive. pub fn init_with_frequency_scaling( &self, scaling: impl Fn(Tensor) -> Tensor, device: &B::Device, ) -> RotaryEncoding { self.initialize(scaling, device) } /// Initialize a new [RotaryEncoding](RotaryEncoding) module. /// /// # Panics /// /// Panics if the size of input embedding dimension is not even. /// Panics if the theta parameter is not positive. fn initialize( &self, scaling: impl Fn(Tensor) -> Tensor, device: &B::Device, ) -> RotaryEncoding { assert_eq!( self.d_model % 2, 0, "The input embedding dimension must be even" ); assert!( self.theta > 0.0, "Theta parameter must be positive (default: 10000)." ); // Calculate the rotation frequencies for positional embeddings based on the formula // `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]` let exponent = Tensor::::arange_step(0..self.d_model as i64, 2, device) .float() .div_scalar(self.d_model as f32); // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))` // This is done since burn doesn't support exponentiation of scalar to tensor let theta = exponent.mul_scalar(self.theta.ln()).exp().recip(); let theta = scaling(theta); let freq_complex = RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone()); RotaryEncoding { freq_complex, theta, start_offset: 0, } } } /// A module that applies rotary positional encoding to a tensor. /// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes /// absolute positional information with rotation matrix and naturally incorporates /// explicit relative position dependency in self-attention formulation. /// /// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) /// /// Should be created using [RotaryEncodingConfig]. #[derive(Module, Debug)] #[module(custom_display)] pub struct RotaryEncoding { /// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components // Essentially a cache of pre-computed RoPE values. pub freq_complex: Tensor, /// Frequency vector used to compute/apply the complex rotations. pub theta: Tensor, start_offset: usize, } impl ModuleDisplay for RotaryEncoding { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims(); content .add("d_model", &d_model) .add("max_sequence_length", &max_sequence_length) .optional() } } #[allow(clippy::single_range_in_vec_init)] impl RotaryEncoding { /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model) /// /// # Arguments: /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors /// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim) /// respectively. /// /// # Returns: /// Output tensor with the same shape as input tensor after applying rotary encoding. /// /// # Panics /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension. pub fn forward(&self, x: Tensor) -> Tensor { self.apply(x, 0) } /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model) /// /// # Arguments: /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors /// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim) /// respectively. /// * `start` - Sequence start position index. /// /// # Returns: /// Output tensor with the same shape as input tensor after applying rotary encoding. /// /// # Panics /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension. pub fn apply(&self, x: Tensor, start: usize) -> Tensor { assert!( D >= 2, "Input tensor must have at least 2 dimensions for sequence length and hidden dimension" ); let device = x.device(); let input_shape = x.shape(); // Extract the sequence length and embedding dimension, other dimensions are kept generic // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads) let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]); let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model); // Create a dummy tensor with signed ones based on the 2D rotation matrix // [[cos, -sin], [sin, cos]] let sign_tensor = Tensor::::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device); // Rotate input using the frequency tensor. Slice the frequencies till input sequence length let out: Tensor = x .reshape([dummy_dim_size, seq_len, d_model / 2, 2]) .matmul(sign_tensor.unsqueeze()) .reshape([dummy_dim_size, seq_len, d_model, 2]) * self .freq_complex .clone() .slice([start..start + seq_len]) .unsqueeze(); // Sum the real and imaginary components to get output tensor and reshape to original shape out.sum_dim(-1).reshape(input_shape) } /// Shifts the pre-computed rotary frequency to cover a new range of positions. /// /// This method updates the internal frequency tensor `freq_complex` to store /// the rotary positional encodings for a new window of positions starting at `start`. pub fn shift(&mut self, start: usize) { let max_seq_len = self.freq_complex.dims()[0]; assert!( start > self.start_offset, "Shift start position must be monotonically increasing" ); let current_end = self.start_offset + max_seq_len; if start >= current_end { // Overwrite the whole buffer let new_freqs = Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone()); self.freq_complex .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs)); } else { // Shift the tail let num_keep = current_end - start; let start_rel = start - self.start_offset; let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]); self.freq_complex .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs)); // Compute the rest and assign let new_freqs = Self::compute_rotary_frequencies( current_end..start + max_seq_len, self.theta.clone(), ); self.freq_complex .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs)); } self.start_offset = start; } /// Computes the positional rotation frequencies (cosine and sine values) used in RoPE. /// /// # Arguments /// - `range`: Range of position indices `[start, end)`. /// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies. /// /// # Returns /// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency. fn compute_rotary_frequencies(range: Range, theta: Tensor) -> Tensor { let d_model = theta.dims()[0] * 2; let num_positions = range.end - range.start; // Generate frequency values for positional embeddings let frequencies: Tensor = Tensor::::arange(range.start as i64..range.end as i64, &theta.device()) .float() .unsqueeze() .transpose() .repeat_dim(1, d_model / 2) * theta.unsqueeze(); // Convert frequency values to complex numbers (polar form) let p_cos = frequencies.clone().cos(); let p_sin = frequencies.sin(); Tensor::cat(vec![p_cos, p_sin], 1) .reshape([num_positions, 2, d_model / 2]) .transpose() .unsqueeze_dim::<4>(2) .repeat_dim(2, 2) .reshape([num_positions, d_model, 2]) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_rotary_encoding_forward() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); let input = Tensor::::from_floats( [ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], ], &device, ); // Input = [Batch size, Num of heads, Seq_len, d_model] let input = input.unsqueeze::<4>(); let output = rotary_encoding.forward(input); let expected_output = Tensor::::from_floats( [ [ [1.0000, 2.0000, 3.0000, 4.0000], [-2.3473, 7.4492, 6.9197, 8.0696], ], [ [9.0000, 10.0000, 11.0000, 12.0000], [-4.7567, 18.5034, 14.8393, 16.1492], ], ], &device, ); output .squeeze_dim::<3>(0) .to_data() .assert_approx_eq::(&expected_output.to_data(), Tolerance::default()); } #[test] fn test_rotary_encoding_3d() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); let input = Tensor::::from_floats( [ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], ], &device, ); // Input = [Batch size, Num of heads, Seq_len, d_model] // let input = input.unsqueeze::<4>(); let output = rotary_encoding.forward(input); let expected_output = Tensor::::from_floats( [ [ [1.0000, 2.0000, 3.0000, 4.0000], [-2.3473, 7.4492, 6.9197, 8.0696], ], [ [9.0000, 10.0000, 11.0000, 12.0000], [-4.7567, 18.5034, 14.8393, 16.1492], ], ], &device, ); output .to_data() .assert_approx_eq::(&expected_output.to_data(), Tolerance::default()); } #[test] fn test_zero_input_rotary_encoding_forward() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well let input = Tensor::::zeros([1, 2, 2, 4], &device); let output = rotary_encoding.forward(input); let expected_output = Tensor::::from_floats( [ [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], ], [ [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], ], ], &device, ); output .squeeze_dim::<3>(0) .to_data() .assert_approx_eq::(&expected_output.to_data(), Tolerance::default()); } #[test] #[should_panic] fn test_valid_input_hidden_dim() { // Hidden dimension must be even to be able to split into real and imaginary components // for rotation let d_model = 15; let device = Default::default(); let pe = RotaryEncodingConfig::new(10, d_model).init::(&device); let input = Tensor::::zeros([1, 5, d_model], &device); let _output = pe.forward(input); } #[test] fn test_rotary_encoding_frequencies() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::(&device); let expected_freqs = Tensor::::from_floats( [ [ [1.0000, 0.0000], [1.0000, 0.0000], [1.0000, 0.0000], [1.0000, 0.0000], ], [ [5.4030e-01, 8.4147e-01], [9.9500e-01, 9.9833e-02], [9.9995e-01, 9.9998e-03], [9.9999e-01, 9.9999e-04], ], ], &device, ) .unsqueeze_dim::<4>(2) .repeat_dim(2, 2) .reshape([2, 8, 2]); rotary_encoding .freq_complex .to_data() .assert_approx_eq::(&expected_freqs.to_data(), Tolerance::default()); } fn apply_freq_scaling_by_parts(freqs: Tensor) -> Tensor { // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45 let scale_factor = 8.; let low_freq_factor = 1.; let high_freq_factor = 4.; let old_context_len = 8192.; let low_freq_wavelen = old_context_len / low_freq_factor; let high_freq_wavelen = old_context_len / high_freq_factor; let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI); // if wavelen >= high_freq_wavelen let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen); let smooth = wavelen .clone() .recip() .mul_scalar(old_context_len) .sub_scalar(low_freq_factor) .div_scalar(high_freq_factor - low_freq_factor); // (1 - smooth) * freq / scale_factor + smooth * freq let new_freqs = smooth .clone() .neg() .add_scalar(1.) .mul(freqs.clone().div_scalar(scale_factor)) .add(smooth.clone().mul(freqs.clone())); let new_freqs = freqs.clone().mask_where(cond, new_freqs); // if wavelen > low_freq_wavelen let cond = wavelen.clone().greater_elem(low_freq_wavelen); let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor)); // if wavelen < high_freq_wavelen let cond = wavelen.lower_elem(high_freq_wavelen); new_freqs.mask_where(cond, freqs) } #[test] fn test_rotary_encoding_with_frequency_scaling() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(2, 8) .init_with_frequency_scaling::(apply_freq_scaling_by_parts, &device); let expected_freqs = Tensor::::from_floats( [ [ [1.0000, 0.0000], [1.0000, 0.0000], [1.0000, 0.0000], [1.0000, 0.0000], ], [ [5.4030e-01, 8.4148e-01], [9.9500e-01, 9.9833e-02], [9.9995e-01, 9.9998e-03], [1.0000, 2.1361e-04], ], ], &device, ) .unsqueeze_dim::<4>(2) .repeat_dim(2, 2) .reshape([2, 8, 2]); rotary_encoding .freq_complex .to_data() .assert_approx_eq::(&expected_freqs.to_data(), Tolerance::default()); } #[test] fn test_rotary_encoding_shift_full() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); // Input = [Batch size, Num of heads, Seq_len, d_model] let input = Tensor::::from_floats( [ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], ], &device, ) .unsqueeze::<4>(); // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same // initial position let expected_output = rotary_encoding.apply(input.clone(), 6); let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::(&device); rotary_encoding.shift(6); // start > 4 will perform a full re-compute let output = rotary_encoding.apply(input, 0); output .into_data() .assert_approx_eq::(&expected_output.into_data(), Tolerance::default()); } #[test] fn test_rotary_encoding_shift() { let device = Default::default(); let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); // Input = [Batch size, Num of heads, Seq_len, d_model] let input = Tensor::::from_floats( [ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], ], &device, ) .unsqueeze::<4>(); // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same // initial position let expected_output = rotary_encoding.apply(input.clone(), 2); let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::(&device); rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest let output = rotary_encoding.apply(input, 0); output .into_data() .assert_approx_eq::(&expected_output.into_data(), Tolerance::default()); } #[test] fn test_rotary_encoding_shift_multiple() { let device = Default::default(); let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::(&device); rotary_encoding.shift(2); rotary_encoding.shift(5); } #[test] #[should_panic = "Shift start position must be monotonically increasing"] fn test_rotary_encoding_shift_should_increase() { let device = Default::default(); let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::(&device); rotary_encoding.shift(6); rotary_encoding.shift(4); // should be monotonically increasing } #[test] fn display() { let config = RotaryEncodingConfig::new(10, 4); let pe = config.init::(&Default::default()); assert_eq!( alloc::format!("{pe}"), "RotaryEncoding {d_model: 4, max_sequence_length: 10}" ); } } ================================================ FILE: crates/burn-nn/src/modules/transformer/decoder.rs ================================================ use burn_core as burn; use alloc::vec::Vec; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::{Bool, Tensor, backend::Backend}; use crate::activation::ActivationConfig; use crate::cache::TensorCache; use crate::{ Dropout, DropoutConfig, LayerNorm, LayerNormConfig, attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, }; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; /// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init). #[derive(Config, Debug)] pub struct TransformerDecoderConfig { /// The size of the model. pub d_model: usize, /// The size of the position-wise feed-forward network. pub d_ff: usize, /// The number of attention heads. pub n_heads: usize, /// The number of layers. pub n_layers: usize, /// The dropout rate. Default: 0.1 #[config(default = 0.1)] pub dropout: f64, /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, /// Use "quiet softmax" instead of regular softmax. /// /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. /// /// Reference: #[config(default = false)] pub quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, /// The activation function used in the position-wise feed-forward network. Default: Gelu #[config(default = "ActivationConfig::Gelu")] pub activation: ActivationConfig, /// The epsilon value for layer normalization. Default: 1e-5 #[config(default = 1e-5)] pub layer_norm_eps: f64, } /// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). /// /// # Params /// /// - layers: transformer decoder layers with `d_model` input and output features. /// /// Should be created using [TransformerDecoderConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct TransformerDecoder { /// Transformer decoder layers. pub layers: Vec>, /// The size of the model. pub d_model: usize, /// The size of the position-wise feed-forward network. pub d_ff: usize, /// The number of attention heads. pub n_heads: usize, /// The number of layers. pub n_layers: usize, /// The dropout rate. Default: 0.1 pub dropout: f64, /// Layer norm will be applied first instead of after the other modules. pub norm_first: bool, /// Use "quiet softmax" instead of regular softmax. pub quiet_softmax: bool, } impl ModuleDisplay for TransformerDecoder { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("d_model", &self.d_model) .add("d_ff", &self.d_ff) .add("n_heads", &self.n_heads) .add("n_layers", &self.n_layers) .add("dropout", &self.dropout) .add("norm_first", &self.norm_first) .add("quiet_softmax", &self.quiet_softmax) .optional() } } impl TransformerDecoderConfig { /// Initialize a new [Transformer Decoder](TransformerDecoder) module. pub fn init(&self, device: &B::Device) -> TransformerDecoder { let layers = (0..self.n_layers) .map(|_| TransformerDecoderLayer::new(self, device)) .collect::>(); TransformerDecoder { layers, d_model: self.d_model, d_ff: self.d_ff, n_heads: self.n_heads, n_layers: self.n_layers, dropout: self.dropout, norm_first: self.norm_first, quiet_softmax: self.quiet_softmax, } } } /// [Transformer Decoder](TransformerDecoder) forward pass input argument. #[derive(Debug)] pub struct TransformerDecoderInput { target: Tensor, target_mask_pad: Option>, target_mask_attn: Option>, memory: Tensor, memory_mask_pad: Option>, memory_mask_attn: Option>, } impl TransformerDecoderInput { /// Create a [transformer decoder](TransformerDecoder) input argument. pub fn new(target: Tensor, memory: Tensor) -> Self { Self { target, target_mask_pad: None, target_mask_attn: None, memory, memory_mask_pad: None, memory_mask_attn: None, } } /// Register the memory padding mask. pub fn memory_mask_pad(mut self, mask_pad: Tensor) -> Self { self.memory_mask_pad = Some(mask_pad); self } /// Register the memory attention mask. pub fn memory_mask_attn(mut self, mask_attn: Tensor) -> Self { self.memory_mask_attn = Some(mask_attn); self } /// Register the target padding mask. pub fn target_mask_pad(mut self, mask_pad: Tensor) -> Self { self.target_mask_pad = Some(mask_pad); self } /// Register the target attention mask. pub fn target_mask_attn(mut self, mask_attn: Tensor) -> Self { self.target_mask_attn = Some(mask_attn); self } } /// [Transformer Decoder](TransformerDecoder) layer module. #[derive(Module, Debug)] pub struct TransformerDecoderLayer { /// Cross-attention module. pub cross_attn: MultiHeadAttention, /// Self-attention module. pub self_attn: MultiHeadAttention, /// Position-wise feed-forward module. pub pwff: PositionWiseFeedForward, /// First layer norm. pub norm_1: LayerNorm, /// Second layer norm. pub norm_2: LayerNorm, /// Third layer norm. pub norm_3: LayerNorm, /// Dropout. pub dropout: Dropout, /// Whether to apply norm first. pub norm_first: bool, } /// Autoregressive cache for a single [Transformer Decoder Layer](TransformerDecoderLayer). pub struct TransformerDecoderLayerAutoregressiveCache { /// Cross-attention cache. pub cross_attn: MhaCache, /// Self-attention cache. pub self_attn: MhaCache, /// Position-wise feed-forward cache. pub pwff: TensorCache, /// First layer norm cache. pub norm_1: TensorCache, /// Second layer norm cache. pub norm_2: TensorCache, /// Third layer norm cache. pub norm_3: TensorCache, } impl TransformerDecoderLayerAutoregressiveCache { /// Create an empty cache. pub fn empty() -> Self { Self { cross_attn: MhaCache::autoregressive_cross_attention(), self_attn: MhaCache::autoregressive(), pwff: TensorCache::empty(), norm_1: TensorCache::empty(), norm_2: TensorCache::empty(), norm_3: TensorCache::empty(), } } } /// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerDecoderAutoregressiveCache { layers: Vec>, } impl TransformerDecoderAutoregressiveCache { fn empty(num_layers: usize) -> Self { Self { layers: (0..num_layers) .map(|_| TransformerDecoderLayerAutoregressiveCache::empty()) .collect(), } } } impl TransformerDecoderLayer { /// Create a new [TransformerDecoderLayer](TransformerDecoderLayer). pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self { let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) .with_quiet_softmax(config.quiet_softmax) .init(device); let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) .with_quiet_softmax(config.quiet_softmax) .init(device); let norm_1 = LayerNormConfig::new(config.d_model) .with_epsilon(config.layer_norm_eps) .init(device); let norm_2 = LayerNormConfig::new(config.d_model) .with_epsilon(config.layer_norm_eps) .init(device); let norm_3 = LayerNormConfig::new(config.d_model) .with_epsilon(config.layer_norm_eps) .init(device); let dropout = DropoutConfig::new(config.dropout).init(); let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) .with_activation(config.activation.clone()) .init(device); Self { cross_attn, self_attn, norm_1, norm_2, norm_3, pwff, dropout, norm_first: config.norm_first, } } /// Applies the TransformerDecoder forward pass to the input tensor. pub fn forward(&self, mut input: TransformerDecoderInput) -> TransformerDecoderInput { // Self attention residual path. let x = input.target; let mut residual_path = x.clone(); // Normalize. if self.norm_first { residual_path = self.norm_3.forward(residual_path); } // Self attention. let mut self_attn_input = MhaInput::self_attn(residual_path); if let Some(mask_pad) = &input.target_mask_pad { self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); } if let Some(mask_attn) = &input.target_mask_attn { self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); } let residual_path = self.self_attn.forward(self_attn_input).context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Cross attention residual path. // Normalize. let residual_path = if self.norm_first { self.norm_1.forward(x.clone()) } else { x = self.norm_1.forward(x); x.clone() }; // Cross attention. let mut cross_attn_input = MhaInput::new(residual_path, input.memory.clone(), input.memory.clone()); if let Some(mask_pad) = &input.memory_mask_pad { cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); } if let Some(mask_attn) = &input.memory_mask_attn { cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); } let residual_path = self.cross_attn.forward(cross_attn_input).context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Feed forward residual path. // Normalize. let residual_path = if self.norm_first { self.norm_2.forward(x.clone()) } else { x = self.norm_2.forward(x); x.clone() }; let residual_path = self.pwff.forward(residual_path); let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Main path. // Normalize. if !self.norm_first { x = self.norm_3.forward(x) } input.target = x; input } /// Applies the forward pass using an autoregressive cache. pub fn forward_autoregressive_inference( &self, mut input: TransformerDecoderInput, cache: &mut TransformerDecoderLayerAutoregressiveCache, ) -> TransformerDecoderInput { // Self attention residual path. let x = input.target; let mut residual_path = x.clone(); // Normalize. if self.norm_first { residual_path = cache .norm_3 .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x)); } // Self attention. let mut self_attn_input = MhaInput::self_attn(residual_path); if let Some(mask_pad) = &input.target_mask_pad { self_attn_input = self_attn_input.mask_pad(mask_pad.clone()); } if let Some(mask_attn) = &input.target_mask_attn { self_attn_input = self_attn_input.mask_attn(mask_attn.clone()); } let residual_path = self .self_attn .forward_cache(self_attn_input, &mut cache.self_attn) .context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Cross attention residual path. // Normalize. let residual_path = if self.norm_first { cache .norm_1 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x)) } else { x = cache .norm_1 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x)); x.clone() }; // Cross attention. let mut cross_attn_input = MhaInput::new(residual_path, input.memory.clone(), input.memory.clone()); if let Some(mask_pad) = &input.memory_mask_pad { cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone()); } if let Some(mask_attn) = &input.memory_mask_attn { cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone()); } let residual_path = self .cross_attn .forward_cache(cross_attn_input, &mut cache.cross_attn) .context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Feed forward residual path. // Normalize. let residual_path = if self.norm_first { cache .norm_2 .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x)) } else { x = cache .norm_2 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x)); x.clone() }; let residual_path = cache .pwff .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x)); let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Main path. // Normalize. if !self.norm_first { x = cache .norm_3 .forward_autoregressive(x, 1, |x| self.norm_3.forward(x)) } input.target = x; input } } impl TransformerDecoder { /// Applies the forward pass. pub fn forward(&self, mut input: TransformerDecoderInput) -> Tensor { for layer in self.layers.iter() { input = layer.forward(input); } input.target } /// Applies the forward pass on the input using autoregressive cache. pub fn forward_autoregressive_inference( &self, mut input: TransformerDecoderInput, cache: &mut TransformerDecoderAutoregressiveCache, ) -> Tensor { for i in 0..self.layers.len() { let layer = self.layers.get(i).unwrap(); let cache = cache.layers.get_mut(i).unwrap(); input = layer.forward_autoregressive_inference(input, cache); } input.target } /// Create an empty autoregressive cache. pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache { TransformerDecoderAutoregressiveCache::empty(self.layers.len()) } } #[cfg(test)] mod tests { use burn::tensor::Device; use super::*; use crate::{TestBackend, attention::generate_autoregressive_mask}; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_autoregressive_norm_last() { let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; let device = Default::default(); TestBackend::seed(&device, 0); test_autoregressive( TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers) .with_norm_first(false), ) } #[test] fn test_autoregressive_norm_first() { let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; let device = Default::default(); TestBackend::seed(&device, 0); test_autoregressive( TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), ) } fn test_autoregressive(config: TransformerDecoderConfig) { let device: Device = Default::default(); let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; let transformer = config.init::(&device); let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) .float() .reshape([batch_size, seq_length, d_model]); let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) .float() .reshape([batch_size, seq_length, d_model]); let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); let input = TransformerDecoderInput::new(target.clone(), memory.clone()) .target_mask_attn(mask_attn); // Normal forward using masking. let output_1 = transformer.forward(input); // Forward using the autoregressive cache. let mut output_2 = Vec::new(); let mut cache = transformer.new_autoregressive_cache(); for i in 1..seq_length + 1 { let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]); let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device()); let input = TransformerDecoderInput::new(target.clone(), memory.clone()) .target_mask_attn(mask_attn); let next_tok = transformer // Greedy sampling .forward_autoregressive_inference(input, &mut cache) .slice([0..batch_size, i - 1..i, 0..d_model]); output_2.push(next_tok); } let output_2 = Tensor::cat(output_2, 1); // Should produce the same tokens. let tolerance = Tolerance::rel_abs(5e-3, 1e-4); output_1 .into_data() .assert_approx_eq::(&output_2.into_data(), tolerance); } #[test] fn display() { let config = TransformerDecoderConfig::new(2, 4, 2, 3); let transformer = config.init::(&Default::default()); assert_eq!( alloc::format!("{transformer}"), "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \ dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}" ); } } ================================================ FILE: crates/burn-nn/src/modules/transformer/encoder.rs ================================================ use burn_core as burn; use alloc::vec::Vec; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ Dropout, DropoutConfig, LayerNorm, LayerNormConfig, activation::ActivationConfig, attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, cache::TensorCache, }; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::{Bool, Tensor, backend::Backend}; /// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init). #[derive(Config, Debug)] pub struct TransformerEncoderConfig { /// The size of the model. pub d_model: usize, /// The size of the position-wise feed-forward network. pub d_ff: usize, /// The number of attention heads. pub n_heads: usize, /// The number of layers. pub n_layers: usize, /// The dropout rate. Default: 0.1 #[config(default = 0.1)] pub dropout: f64, /// Layer norm will be applied first instead of after the other modules. #[config(default = false)] pub norm_first: bool, /// Use "quiet softmax" instead of regular softmax. /// /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head). /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression. /// /// Reference: #[config(default = false)] pub quiet_softmax: bool, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, /// The activation function used in the position-wise feed-forward network. Default: Gelu #[config(default = "ActivationConfig::Gelu")] pub activation: ActivationConfig, /// The epsilon value for layer normalization. Default: 1e-5 #[config(default = 1e-5)] pub layer_norm_eps: f64, } /// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). /// /// # Params /// /// - layers: transformer encoder layers with `d_model` input and output features. /// /// Should be created using [TransformerEncoderConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct TransformerEncoder { /// The transformer encoder layers. pub layers: Vec>, /// The size of the model. pub d_model: usize, /// The size of the position-wise feed-forward network. pub d_ff: usize, /// The number of attention heads. pub n_heads: usize, /// The number of layers. pub n_layers: usize, /// The dropout rate. Default: 0.1 pub dropout: f64, /// Layer norm will be applied first instead of after the other modules. pub norm_first: bool, /// Use "quiet softmax" instead of regular softmax. pub quiet_softmax: bool, } impl ModuleDisplay for TransformerEncoder { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("d_model", &self.d_model) .add("d_ff", &self.d_ff) .add("n_heads", &self.n_heads) .add("n_layers", &self.n_layers) .add("dropout", &self.dropout) .add("norm_first", &self.norm_first) .add("quiet_softmax", &self.quiet_softmax) .optional() } } /// [Transformer Encoder](TransformerEncoder) forward pass input argument. #[derive(Debug)] pub struct TransformerEncoderInput { tensor: Tensor, mask_pad: Option>, mask_attn: Option>, } impl TransformerEncoderInput { /// Create a [transformer encoder](TransformerEncoder) input argument. pub fn new(tensor: Tensor) -> Self { Self { tensor, mask_pad: None, mask_attn: None, } } /// Register the padding mask. pub fn mask_pad(mut self, mask_pad: Tensor) -> Self { self.mask_pad = Some(mask_pad); self } /// Register the attention mask. pub fn mask_attn(mut self, mask_attn: Tensor) -> Self { self.mask_attn = Some(mask_attn); self } } impl TransformerEncoderConfig { /// Initialize a new [transformer encoder](TransformerEncoder) module. pub fn init(&self, device: &B::Device) -> TransformerEncoder { let layers = (0..self.n_layers) .map(|_| TransformerEncoderLayer::new(self, device)) .collect::>(); TransformerEncoder { layers, d_model: self.d_model, d_ff: self.d_ff, n_heads: self.n_heads, n_layers: self.n_layers, dropout: self.dropout, norm_first: self.norm_first, quiet_softmax: self.quiet_softmax, } } } impl TransformerEncoder { /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - tensor: `[batch_size, seq_length, d_model]` /// - output: `[batch_size, seq_length, d_model]` pub fn forward(&self, input: TransformerEncoderInput) -> Tensor { let mut x = input.tensor; for layer in self.layers.iter() { x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone()); } x } /// Applies the forward pass on the input tensor using autoregressive cache. /// /// # Shapes /// /// - tensor: `[batch_size, seq_length, d_model]` /// - output: `[batch_size, seq_length, d_model]` pub fn forward_autoregressive_inference( &self, input: TransformerEncoderInput, cache: &mut TransformerEncoderAutoregressiveCache, ) -> Tensor { let mut x = input.tensor; for i in 0..self.layers.len() { let layer = self.layers.get(i).unwrap(); let cache = cache.layers.get_mut(i).unwrap(); x = layer.forward_autoregressive_inference( x, input.mask_pad.clone(), input.mask_attn.clone(), cache, ); } x } /// Create an empty autoregressive cache. pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { TransformerEncoderAutoregressiveCache::empty(self.layers.len()) } } /// Transformer encoder layer module. #[derive(Module, Debug)] pub struct TransformerEncoderLayer { /// Multi-head self-attention sub-layer. pub mha: MultiHeadAttention, /// Position-wise feed-forward sub-layer. pub pwff: PositionWiseFeedForward, /// Layer normalization applied around the feed-forward sub-layer. pub norm_1: LayerNorm, /// Layer normalization applied around the attention sub-layer. pub norm_2: LayerNorm, /// Dropout module applied to residual connections. pub dropout: Dropout, /// If `true`, apply layer normalization before sub-layers (pre-norm), /// otherwise apply it after (post-norm). pub norm_first: bool, } impl TransformerEncoderLayer { /// Create a new transformer encoder layer from the given configuration. pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self { let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) .with_quiet_softmax(config.quiet_softmax) .init(device); let norm_1 = LayerNormConfig::new(config.d_model) .with_epsilon(config.layer_norm_eps) .init(device); let norm_2 = LayerNormConfig::new(config.d_model) .with_epsilon(config.layer_norm_eps) .init(device); let dropout = DropoutConfig::new(config.dropout).init(); let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff) .with_initializer(config.initializer.clone()) .with_dropout(config.dropout) .with_activation(config.activation.clone()) .init(device); Self { mha, norm_1, norm_2, pwff, dropout, norm_first: config.norm_first, } } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[batch_size, seq_length, d_model]` /// - output: `[batch_size, seq_length, d_model]` pub fn forward( &self, input: Tensor, mask_pad: Option>, mask_attn: Option>, ) -> Tensor { // Multi-head attention residual path. let x = input; let mut residual_path = x.clone(); // Normalize. if self.norm_first { residual_path = self.norm_2.forward(residual_path) } // Multi-head attention. let mut input_mhs = MhaInput::self_attn(residual_path); if let Some(mask_pad) = mask_pad { input_mhs = input_mhs.mask_pad(mask_pad); } if let Some(mask_attn) = mask_attn { input_mhs = input_mhs.mask_attn(mask_attn); } let residual_path = self.mha.forward(input_mhs).context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Feed forward residual path. // Normalize. let residual_path = if self.norm_first { self.norm_1.forward(x.clone()) } else { x = self.norm_1.forward(x); x.clone() }; // Feed forward. let residual_path = self.pwff.forward(residual_path); let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Main path. // Normalize. if !self.norm_first { x = self.norm_2.forward(x) } x } /// Applies the forward pass using an autoregressive cache. pub fn forward_autoregressive_inference( &self, input: Tensor, mask_pad: Option>, mask_attn: Option>, cache: &mut TransformerEncoderLayerAutoregressiveCache, ) -> Tensor { // Multi-head attention residual path. let x = input; let mut residual_path = x.clone(); // Normalize. if self.norm_first { residual_path = cache .norm_2 .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x)) } // Multi-head attention. let mut input_mhs = MhaInput::self_attn(residual_path); if let Some(mask_pad) = mask_pad { input_mhs = input_mhs.mask_pad(mask_pad); } if let Some(mask_attn) = mask_attn { input_mhs = input_mhs.mask_attn(mask_attn); } let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context; let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Feed forward residual path. // Normalize. let residual_path = if self.norm_first { cache .norm_1 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x)) } else { x = cache .norm_1 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x)); x.clone() }; // Feed forward. let residual_path = cache .pwff .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x)); let residual_path = self.dropout.forward(residual_path); let mut x = x + residual_path; // Main path. // Normalize. if !self.norm_first { x = cache .norm_2 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x)) } x } } /// Autoregressive cache for a single [Transformer Encoder Layer](TransformerEncoderLayer). pub struct TransformerEncoderLayerAutoregressiveCache { /// Multi-head attention cache. pub mha: MhaCache, /// Position-wise feed-forward cache. pub pwff: TensorCache, /// First layer norm cache. pub norm_1: TensorCache, /// Second layer norm cache. pub norm_2: TensorCache, } impl TransformerEncoderLayerAutoregressiveCache { /// Create an empty cache. pub fn empty() -> Self { Self { mha: MhaCache::autoregressive(), pwff: TensorCache::empty(), norm_1: TensorCache::empty(), norm_2: TensorCache::empty(), } } } /// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer. /// /// To be used during inference when decoding tokens. pub struct TransformerEncoderAutoregressiveCache { layers: Vec>, } impl TransformerEncoderAutoregressiveCache { fn empty(num_layers: usize) -> Self { Self { layers: (0..num_layers) .map(|_| TransformerEncoderLayerAutoregressiveCache::empty()) .collect(), } } } #[cfg(test)] mod tests { use super::*; use crate::{TestBackend, attention::generate_autoregressive_mask}; use burn::tensor::Distribution; use burn::tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn test_autoregressive_norm_last() { let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; test_autoregressive( TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers) .with_norm_first(false), ) } #[test] fn test_autoregressive_norm_first() { let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3]; test_autoregressive( TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true), ) } fn test_autoregressive(config: TransformerEncoderConfig) { let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; let device = Default::default(); let transformer = config.init(&device); let tensor = Tensor::::random( [batch_size, seq_length, d_model], Distribution::Default, &device, ); let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device()); let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); let output_1 = transformer.forward(input); let mut output_2 = Vec::new(); let mut cache = transformer.new_autoregressive_cache(); for i in 1..seq_length + 1 { let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]); let input = TransformerEncoderInput::new(tensor.clone()); let next_tok = transformer .forward_autoregressive_inference(input, &mut cache) .slice([0..batch_size, i - 1..i, 0..d_model]); output_2.push(next_tok); } let output_2 = Tensor::cat(output_2, 1); output_1 .into_data() .assert_approx_eq::(&output_2.into_data(), Tolerance::permissive()); } #[test] fn display() { let config = TransformerEncoderConfig::new(2, 4, 2, 3); let transformer = config.init::(&Default::default()); assert_eq!( alloc::format!("{transformer}"), "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \ n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}" ); } } ================================================ FILE: crates/burn-nn/src/modules/transformer/mod.rs ================================================ mod decoder; mod encoder; mod pwff; pub use decoder::*; pub use encoder::*; pub use pwff::*; ================================================ FILE: crates/burn-nn/src/modules/transformer/pwff.rs ================================================ use burn_core as burn; use crate::activation::{Activation, ActivationConfig}; use crate::{Dropout, DropoutConfig, Linear, LinearConfig}; use burn::config::Config; use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay}; use burn::tensor::{Tensor, backend::Backend}; /// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init). #[derive(Config, Debug)] pub struct PositionWiseFeedForwardConfig { /// The size of the input and output features. pub d_model: usize, /// The size of the hidden inner features. pub d_ff: usize, /// The dropout rate. Default: 0.1 #[config(default = 0.1)] pub dropout: f64, /// The type of function used to initialize neural network parameters #[config( default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}" )] pub initializer: Initializer, /// The activation function used between the two linear layers. Default: Gelu #[config(default = "ActivationConfig::Gelu")] pub activation: ActivationConfig, } /// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7). /// /// # Params /// /// - linear inner: Linear layer with `d_model` input features and `d_ff` output features. /// - linear outer: Linear layer with `d_ff` input features and `d_model` output features. /// /// `FFN(x) = max(0, xW1 + b1)W2 + b2` /// /// Should be created using [PositionWiseFeedForwardConfig] #[derive(Module, Debug)] #[module(custom_display)] pub struct PositionWiseFeedForward { /// Linear layer with `d_model` input features and `d_ff` output features. pub linear_inner: Linear, /// Linear layer with `d_ff` input features and `d_model` output features. pub linear_outer: Linear, /// Dropout layer. pub dropout: Dropout, /// Activation function. pub activation: Activation, } impl ModuleDisplay for PositionWiseFeedForward { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let [d_model, dff] = self.linear_inner.weight.shape().dims(); content .add("d_model", &d_model) .add("d_ff", &dff) .add("prob", &self.dropout.prob) .optional() } } impl PositionWiseFeedForwardConfig { /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. pub fn init(&self, device: &B::Device) -> PositionWiseFeedForward { PositionWiseFeedForward { linear_inner: LinearConfig::new(self.d_model, self.d_ff) .with_initializer(self.initializer.clone()) .init(device), linear_outer: LinearConfig::new(self.d_ff, self.d_model) .with_initializer(self.initializer.clone()) .init(device), dropout: DropoutConfig::new(self.dropout).init(), activation: self.activation.init(device), } } } impl PositionWiseFeedForward { /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - tensor: `[batch_size, seq_length, d_model]` /// - output: `[batch_size, seq_length, d_model]` pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input); let x = self.activation.forward(x); let x = self.dropout.forward(x); self.linear_outer.forward(x) } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn display() { let config = PositionWiseFeedForwardConfig::new(2, 4); let pwff = config.init::(&Default::default()); assert_eq!( alloc::format!("{pwff}"), "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}" ); } } ================================================ FILE: crates/burn-nn/src/modules/unfold.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::unfold4d; use burn::tensor::ops::UnfoldOptions; /// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init). #[derive(Config, Debug)] pub struct Unfold4dConfig { /// The size of the kernel. pub kernel_size: [usize; 2], /// The stride of the convolution. #[config(default = "[1, 1]")] pub stride: [usize; 2], /// Spacing between kernel elements. #[config(default = "[1, 1]")] pub dilation: [usize; 2], /// The padding configuration. #[config(default = "[0, 0]")] pub padding: [usize; 2], } /// Four-dimensional unfolding. /// /// Should be created with [Unfold4dConfig]. #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Unfold4d { /// The size of the kernel. pub kernel_size: [usize; 2], /// The stride of the convolution. pub stride: [usize; 2], /// Spacing between kernel elements. pub dilation: [usize; 2], /// The padding configuration. pub padding: [usize; 2], } impl ModuleDisplay for Unfold4d { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) .add("stride", &alloc::format!("{:?}", &self.stride)) .add("dilation", &alloc::format!("{:?}", &self.dilation)) .add("padding", &alloc::format!("{:?}", &self.padding)) .optional() } } impl Unfold4dConfig { /// Initializes a new [Unfold4d] module. pub fn init(&self) -> Unfold4d { Unfold4d { kernel_size: self.kernel_size, stride: self.stride, dilation: self.dilation, padding: self.padding, } } } impl Unfold4d { /// Applies the forward pass on the input tensor. /// /// See [unfold4d](burn::tensor::module::unfold4d) for more information. /// /// # Shapes /// /// input: `[batch_size, channels_in, height, width]` /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]` pub fn forward(&self, input: Tensor) -> Tensor { unfold4d( input, self.kernel_size, UnfoldOptions::new(self.stride, self.padding, self.dilation), ) } } #[cfg(test)] mod tests { use super::*; #[test] fn display() { let config = Unfold4dConfig::new([3, 3]); let unfold = config.init(); assert_eq!( alloc::format!("{unfold}"), "Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}" ); } } ================================================ FILE: crates/burn-nn/src/padding.rs ================================================ use burn_core as burn; use burn::config::Config; /// Calculate asymmetric padding for "same" convolution. /// Returns (start_padding, end_padding) where start is applied first (top/left). /// For odd total padding, the extra pad goes to the end (bottom/right) following ONNX convention. fn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) { let size_out = size_in.div_ceil(stride); // ceil division for same padding let total_padding = if size_out > 0 { let needed = (size_out - 1) * stride + kernel_size; needed.saturating_sub(size_in) } else { 0 }; let pad_start = total_padding / 2; let pad_end = total_padding - pad_start; (pad_start, pad_end) } /// Padding configuration for 1D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig1d { /// Dynamically calculates padding to ensure output size matches input size. Same, /// No padding applied. Valid, /// Applies explicit padding values. /// Format: (left, right) /// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`). Explicit(usize, usize), } impl PaddingConfig1d { /// Calculate padding as (left, right) pair for 1D operations. /// For `Same` padding, this computes the actual asymmetric padding if needed. pub(crate) fn calculate_padding_1d_pair( &self, length: usize, kernel_size: usize, stride: usize, ) -> (usize, usize) { match self { Self::Valid => (0, 0), Self::Same => calculate_same_padding(kernel_size, stride, length), Self::Explicit(left, right) => (*left, *right), } } } /// Padding configuration for 2D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig2d { /// Dynamically calculates padding to preserve input dimensions in output. Same, /// No padding applied. Valid, /// Applies explicit padding values. /// Format: (top, left, bottom, right) /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`). Explicit(usize, usize, usize, usize), } impl PaddingConfig2d { /// Calculate padding as ((top, bottom), (left, right)) pairs for 2D operations. /// For `Same` padding, this computes the actual asymmetric padding if needed. pub(crate) fn calculate_padding_2d_pairs( &self, height: usize, width: usize, kernel_size: &[usize; 2], stride: &[usize; 2], ) -> ((usize, usize), (usize, usize)) { match self { Self::Valid => ((0, 0), (0, 0)), Self::Same => { let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height); let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width); ((top, bottom), (left, right)) } Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)), } } /// Calculate symmetric padding for 2D operations. /// Returns padding values [height, width] (same for both sides). /// Panics if asymmetric padding is detected. pub(crate) fn calculate_padding_2d( &self, height: usize, width: usize, kernel_size: &[usize; 2], stride: &[usize; 2], ) -> [usize; 2] { let ((top, bottom), (left, right)) = self.calculate_padding_2d_pairs(height, width, kernel_size, stride); if top != bottom || left != right { panic!("Asymmetric padding should be handled via calculate_padding_2d_pairs()") } [top, left] } } /// Padding configuration for 3D operators. #[derive(Config, Debug, PartialEq)] pub enum PaddingConfig3d { /// Dynamically calculates padding to preserve input dimensions in output. Same, /// No padding applied. Valid, /// Applies explicit symmetric padding values. /// Format: (depth, height, width) — same padding on both sides of each dimension. Explicit(usize, usize, usize), } impl PaddingConfig3d { /// Calculate symmetric padding for 3D operations. /// Returns padding values [depth, height, width] (same for both sides). pub(crate) fn calculate_padding_3d( &self, depth: usize, height: usize, width: usize, kernel_size: &[usize; 3], stride: &[usize; 3], ) -> [usize; 3] { match self { Self::Valid => [0, 0, 0], Self::Same => { let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth); let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height); let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width); if front != back || top != bottom || left != right { panic!( "Asymmetric 3D 'Same' padding is not supported. \ Use odd kernel sizes for symmetric padding." ) } [front, top, left] } Self::Explicit(depth, height, width) => [*depth, *height, *width], } } } #[cfg(test)] mod tests { use super::*; // ==================== PaddingConfig1d Tests ==================== #[test] fn test_padding_config_1d_calculate_pair_valid() { let padding = PaddingConfig1d::Valid; assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0)); } #[test] fn test_padding_config_1d_calculate_pair_explicit() { let padding = PaddingConfig1d::Explicit(1, 2); assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2)); } #[test] fn test_padding_config_1d_calculate_pair_same() { let padding = PaddingConfig1d::Same; // kernel=3, stride=1, length=10: total=2, start=1, end=1 assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1)); } // ==================== PaddingConfig2d Tests ==================== #[test] fn test_padding_config_2d_calculate_pairs_valid() { let padding = PaddingConfig2d::Valid; assert_eq!( padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]), ((0, 0), (0, 0)) ); } #[test] fn test_padding_config_2d_calculate_pairs_explicit() { let padding = PaddingConfig2d::Explicit(1, 2, 3, 4); assert_eq!( padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]), ((1, 3), (2, 4)) ); } #[test] fn test_padding_config_2d_calculate_symmetric_valid() { let padding = PaddingConfig2d::Valid; assert_eq!( padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]), [0, 0] ); } #[test] fn test_padding_config_2d_calculate_symmetric_explicit() { let padding = PaddingConfig2d::Explicit(2, 3, 2, 3); assert_eq!( padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]), [2, 3] ); } #[test] #[should_panic( expected = "Asymmetric padding should be handled via calculate_padding_2d_pairs" )] fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() { let padding = PaddingConfig2d::Explicit(1, 2, 3, 4); let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]); } // ==================== PaddingConfig3d Tests ==================== #[test] fn test_padding_config_3d_calculate_valid() { let padding = PaddingConfig3d::Valid; assert_eq!( padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]), [0, 0, 0] ); } #[test] fn test_padding_config_3d_calculate_explicit() { let padding = PaddingConfig3d::Explicit(1, 2, 3); assert_eq!( padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]), [1, 2, 3] ); } #[test] fn test_padding_config_3d_calculate_same_odd_kernel() { let padding = PaddingConfig3d::Same; // kernel=3, stride=1: total=2, symmetric (1,1) per dim assert_eq!( padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]), [1, 1, 1] ); } } ================================================ FILE: crates/burn-nn/tests/quantize.rs ================================================ use burn_core as burn; use burn::module::{Module, Quantizer}; use burn::tensor::{ Device, Distribution, Tensor, Tolerance, ops::{FloatElem, QuantizedTensor}, quantization::{ Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantScheme, QuantValue, }, }; use burn_nn::{ Linear, LinearConfig, transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, }; #[cfg(all( test, not(feature = "test-wgpu"), not(feature = "test-cuda"), not(feature = "test-rocm") ))] pub type B = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-wgpu"))] /// Backend for test cases pub type B = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] /// Backend for test cases pub type B = burn_cuda::Cuda; #[cfg(all(test, feature = "test-rocm"))] /// Backend for test cases pub type B = burn_rocm::Rocm; fn should_quantize_module, const D: usize, F: Fn(&M) -> Tensor>( module: M, scheme: QuantScheme, func: F, tolerance: Tolerance>, ) { let result = func(&module); let calibration = Calibration::MinMax; let mut quantizer = Quantizer { calibration, scheme, }; let q_module = module.quantize_weights(&mut quantizer); let q_result = func(&q_module); result .into_data() .assert_approx_eq::(&q_result.into_data(), tolerance); } #[test] fn should_quantize_transformer() { let device: Device = Default::default(); let transformer: TransformerEncoder = TransformerEncoderConfig::new(128, 256, 2, 2).init(&device); let signal = Tensor::random([2, 32, 128], Distribution::Default, &device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([32])) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.forward(TransformerEncoderInput::new(signal.clone())), Tolerance::rel_abs(1e-2, 2e-2), // slightly higher abs tolerance (permissive: 1e-2) ); } #[test] fn should_quantize_linear_128_256() { let device: Device = Default::default(); let transformer: Linear = LinearConfig::new(128, 256).with_bias(false).init(&device); let signal = Tensor::::random([1, 128], Distribution::Default, &device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::Tensor) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.forward(signal.clone()), Tolerance::permissive(), ); } #[test] fn should_quantize_linear() { let device: Device = Default::default(); let transformer: Linear = LinearConfig::new(32, 32).with_bias(false).init(&device); let signal = Tensor::::random([1, 32], Distribution::Default, &device); // Default scheme should select supported QuantStore default // TODO: set native if dtype is supported by the test backend let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::Tensor) // .with_store(QuantStore::Native) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.forward(signal.clone()), Tolerance::permissive(), ); } #[test] fn should_quantize_linear_weights() { let device: Device = Default::default(); let transformer: Linear = LinearConfig::new(32, 32).with_bias(false).init(&device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::Tensor) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.weight.val().dequantize(), Tolerance::permissive(), ); } #[test] fn should_quantize_linear_blocks() { let device: Device = Default::default(); let transformer: Linear = LinearConfig::new(32, 32).with_bias(false).init(&device); let signal = Tensor::::random([1, 32], Distribution::Default, &device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([16])) // .with_store(QuantStore::Native) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.forward(signal.clone()), Tolerance::permissive(), ); } #[test] fn should_quantize_linear_weights_blocks() { let device: Device = Default::default(); let transformer: Linear = LinearConfig::new(32, 32).with_bias(false).init(&device); let scheme = as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([16])) // .with_store(QuantStore::Native) .with_param(QuantParam::F32); should_quantize_module( transformer, scheme, |tr| tr.weight.val().dequantize(), Tolerance::permissive(), ); } ================================================ FILE: crates/burn-no-std-tests/Cargo.toml ================================================ [package] authors = [ "nathanielsimard ", "Dilshod Tadjibaev (@antimora)", ] edition.workspace = true license.workspace = true name = "burn-no-std-tests" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-no-std-tests" version.workspace = true [lints] workspace = true [features] default = [] tracing = [ "burn/tracing", "burn-ndarray/tracing", "burn-store/tracing", ] [dependencies] # ** Please make sure all dependencies support no_std ** burn = { path = "../burn", version = "=0.21.0-pre.2", default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", default-features = false } burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", default-features = false, features = ["safetensors", "burnpack"]} ================================================ FILE: crates/burn-no-std-tests/README.md ================================================ The `burn-no-std-tests` contains integration tests aimed to check `no_std` compatibility of `burn`, `burn-core`, `burn-tensor` and `burn-ndarray` packages. Currently there is only a minimal test that checks if mnist model can be built with `no_std`. More tests should be added to check completeness. The continuous integration (CI) should build with additional targets: * `wasm32-unknown-unknown` - WebAssembly * `thumbv7m-none-eabi` - ARM Cortex-M3 * `thumbv6m-none-eabi` - ARM Cortex-M0+ Shell commands to build and test the package: ```sh # install the new targets if not installed previously rustup target add thumbv6m-none-eabi rustup target add thumbv7m-none-eabi rustup target add wasm32-unknown-unknown # build for various targets cargo build # regular build cargo build --target thumbv7m-none-eabi cargo build --target wasm32-unknown-unknown RUSTFLAGS="--cfg portable_atomic_unsafe_assume_single_core" cargo build --target thumbv6m-none-eabi # test cargo test ``` ================================================ FILE: crates/burn-no-std-tests/src/burnpack.rs ================================================ // Test Burnpack storage in no-std environment use burn::{ module::Module, nn, tensor::{Tensor, backend::Backend}, }; use burn_store::{BurnpackStore, ModuleSnapshot, PathFilter}; /// Simple model for testing Burnpack storage #[derive(Module, Debug)] pub struct TestModel { linear1: nn::Linear, linear2: nn::Linear, batch_norm: nn::BatchNorm, } impl TestModel { pub fn new(device: &B::Device) -> Self { Self { linear1: nn::LinearConfig::new(10, 20).init(device), linear2: nn::LinearConfig::new(20, 10).init(device), batch_norm: nn::BatchNormConfig::new(10).init(device), } } pub fn forward(&self, x: Tensor) -> Tensor { let x = self.linear1.forward(x); let x = self.linear2.forward(x); // Apply batch norm (expand to 3D, apply, then squeeze back) let x: Tensor = x.unsqueeze_dim(2); let x = self.batch_norm.forward(x); x.squeeze_dim(2) } } /// Test basic Burnpack save and load in no-std pub fn test_burnpack_basic(device: &B::Device) { // Create a model let model = TestModel::::new(device); // Save to bytes (no file I/O in no-std) let mut save_store = BurnpackStore::from_bytes(None); model .save_into(&mut save_store) .expect("Failed to save model"); // Get the serialized bytes let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load from bytes let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); let result = loaded_model .load_from(&mut load_store) .expect("Failed to load model"); // Verify all tensors were loaded assert!(result.is_success(), "Should have no errors"); assert!(!result.applied.is_empty(), "Should have loaded tensors"); // Test that the model still works let input = Tensor::::ones([2, 10], device); let _output = loaded_model.forward(input); } /// Test Burnpack with filtering in no-std pub fn test_burnpack_filtering(device: &B::Device) { let model = TestModel::::new(device); // Save only linear1 weights let filter = PathFilter::new() .with_full_path("linear1.weight") .with_full_path("linear1.bias"); let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); model .save_into(&mut save_store) .expect("Failed to save filtered model"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load with partial loading allowed let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); let mut partial_model = TestModel::::new(device); let result = partial_model .load_from(&mut load_store) .expect("Failed to load partial model"); // Verify that only linear1 was loaded assert_eq!(result.applied.len(), 2, "Should have loaded 2 tensors"); assert!(!result.missing.is_empty(), "Should have missing tensors"); } /// Test Burnpack with metadata in no-std pub fn test_burnpack_metadata(device: &B::Device) { let model = TestModel::::new(device); // Save with metadata let mut save_store = BurnpackStore::from_bytes(None) .metadata("version", "1.0.0") .metadata("environment", "no-std") .metadata("model_type", "test"); model .save_into(&mut save_store) .expect("Failed to save model with metadata"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify it works let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); let result = loaded_model .load_from(&mut load_store) .expect("Failed to load model with metadata"); assert!(result.is_success(), "Should load successfully"); } // Note: Key remapping test is omitted as KeyRemapper requires std feature // Note: Regex filtering test is omitted as with_regex requires std feature /// Test Burnpack with match_all in no-std pub fn test_burnpack_match_all(device: &B::Device) { let model = TestModel::::new(device); // Save with match_all (should save everything) let mut save_store = BurnpackStore::from_bytes(None).match_all(); model .save_into(&mut save_store) .expect("Failed to save model"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load everything let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); let result = loaded_model .load_from(&mut load_store) .expect("Failed to load model"); assert!(result.is_success(), "Should load successfully"); // linear1 (weight, bias) + linear2 (weight, bias) + batch_norm (4 params) assert_eq!(result.applied.len(), 8, "Should load all 8 tensors"); assert!(result.missing.is_empty(), "Should have no missing tensors"); assert!(result.unused.is_empty(), "Should have no unused tensors"); } /// Run all Burnpack no-std tests pub fn run_all_tests(device: &B::Device) { test_burnpack_basic::(device); test_burnpack_filtering::(device); test_burnpack_metadata::(device); // test_burnpack_remapping requires KeyRemapper which needs std // test_burnpack_regex_filter requires with_regex which needs std test_burnpack_match_all::(device); } ================================================ FILE: crates/burn-no-std-tests/src/conv.rs ================================================ // Originally copied from the burn/examples/mnist package use burn::{ config::Config, module::Module, nn, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct ConvBlock { conv: nn::conv::Conv2d, pool: nn::pool::MaxPool2d, activation: nn::Gelu, } #[derive(Config, Debug)] pub struct ConvBlockConfig { channels: [usize; 2], #[config(default = "[3, 3]")] kernel_size: [usize; 2], } impl ConvBlock { pub fn new(config: &ConvBlockConfig, device: &B::Device) -> Self { let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size) .with_padding(nn::PaddingConfig2d::Same) .init(device); let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size) .with_strides([1, 1]) .with_padding(nn::PaddingConfig2d::Same) .init(); let activation = nn::Gelu::new(); Self { conv, pool, activation, } } pub fn forward(&self, input: Tensor) -> Tensor { let x = self.conv.forward(input.clone()); let x = self.pool.forward(x); let x = self.activation.forward(x); (x + input) / 2.0 } } ================================================ FILE: crates/burn-no-std-tests/src/lib.rs ================================================ #![no_std] pub mod burnpack; pub mod conv; pub mod mlp; pub mod model; pub mod safetensors; extern crate alloc; ================================================ FILE: crates/burn-no-std-tests/src/mlp.rs ================================================ // Originally copied from the burn/examples/mnist package use alloc::vec::Vec; use burn::{ config::Config, module::Module, nn, tensor::{Tensor, backend::Backend}, }; /// Configuration to create a [Multilayer Perceptron](Mlp) layer. #[derive(Config, Debug)] pub struct MlpConfig { /// The number of layers. #[config(default = 3)] pub num_layers: usize, /// The dropout rate. #[config(default = 0.5)] pub dropout: f64, /// The size of each layer. #[config(default = 256)] pub d_model: usize, } /// Multilayer Perceptron module. #[derive(Module, Debug)] pub struct Mlp { linears: Vec>, dropout: nn::Dropout, activation: nn::Relu, } impl Mlp { /// Create the module from the given configuration. pub fn new(config: &MlpConfig, device: &B::Device) -> Self { let mut linears = Vec::with_capacity(config.num_layers); for _ in 0..config.num_layers { linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init(device)); } Self { linears, dropout: nn::DropoutConfig::new(0.3).init(), activation: nn::Relu::new(), } } /// Applies the forward pass on the input tensor. /// /// # Shapes /// /// - input: `[batch_size, d_model]` /// - output: `[batch_size, d_model]` pub fn forward(&self, input: Tensor) -> Tensor { let mut x = input; for linear in self.linears.iter() { x = linear.forward(x); x = self.dropout.forward(x); x = self.activation.forward(x); } x } } ================================================ FILE: crates/burn-no-std-tests/src/model.rs ================================================ // Originally copied from the burn/examples/mnist package use crate::{ conv::{ConvBlock, ConvBlockConfig}, mlp::{Mlp, MlpConfig}, }; use burn::{ config::Config, module::Module, nn, tensor::{Tensor, backend::Backend}, }; #[derive(Config, Debug)] pub struct MnistConfig { #[config(default = 42)] pub seed: u64, pub mlp: MlpConfig, #[config(default = 784)] pub input_size: usize, #[config(default = 10)] pub output_size: usize, } #[derive(Module, Debug)] pub struct Model { mlp: Mlp, conv: ConvBlock, input: nn::Linear, output: nn::Linear, num_classes: usize, } impl Model { pub fn new(config: &MnistConfig, device: &B::Device) -> Self { let mlp = Mlp::new(&config.mlp, device); let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(device); let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(device); let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1]), device); Self { mlp, conv, output, input, num_classes: config.output_size, } } pub fn forward(&self, input: Tensor) -> Tensor { let [batch_size, height, width] = input.dims(); let x = input.reshape([batch_size, 1, height, width]).detach(); let x = self.conv.forward(x); let x = x.reshape([batch_size, height * width]); let x = self.input.forward(x); let x = self.mlp.forward(x); self.output.forward(x) } } ================================================ FILE: crates/burn-no-std-tests/src/safetensors.rs ================================================ // Test SafeTensors storage in no-std environment use burn::{ module::Module, nn, tensor::{Tensor, backend::Backend}, }; use burn_store::{ModuleSnapshot, SafetensorsStore}; /// Simple model for testing SafeTensors storage #[derive(Module, Debug)] pub struct TestModel { linear1: nn::Linear, linear2: nn::Linear, } impl TestModel { pub fn new(device: &B::Device) -> Self { Self { linear1: nn::LinearConfig::new(10, 20).init(device), linear2: nn::LinearConfig::new(20, 10).init(device), } } pub fn forward(&self, x: Tensor) -> Tensor { let x = self.linear1.forward(x); self.linear2.forward(x) } } /// Test basic SafeTensors save and load in no-std pub fn test_safetensors_basic(device: &B::Device) { // Create a model let model = TestModel::::new(device); // Save to bytes (no file I/O in no-std) let mut save_store = SafetensorsStore::from_bytes(None); model .save_into(&mut save_store) .expect("Failed to save model"); // Get the serialized bytes let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load from bytes let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); loaded_model .load_from(&mut load_store) .expect("Failed to load model"); // Test that the model still works let input = Tensor::::ones([2, 10], device); let _output = loaded_model.forward(input); } /// Test SafeTensors with filtering in no-std pub fn test_safetensors_filtering(device: &B::Device) { let model = TestModel::::new(device); // Save only linear1 weights let mut save_store = SafetensorsStore::from_bytes(None) .with_full_path("linear1.weight") .with_full_path("linear1.bias"); model .save_into(&mut save_store) .expect("Failed to save filtered model"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load with partial loading allowed let mut load_store = SafetensorsStore::from_bytes(Some(bytes)).allow_partial(true); let mut partial_model = TestModel::::new(device); let result = partial_model .load_from(&mut load_store) .expect("Failed to load partial model"); // Verify that only linear1 was loaded assert_eq!(result.applied.len(), 2, "Should have loaded 2 tensors"); assert!(!result.missing.is_empty(), "Should have missing tensors"); } /// Test SafeTensors with metadata in no-std pub fn test_safetensors_metadata(device: &B::Device) { let model = TestModel::::new(device); // Save with metadata let mut save_store = SafetensorsStore::from_bytes(None) .metadata("version", "1.0.0") .metadata("environment", "no-std"); model .save_into(&mut save_store) .expect("Failed to save model with metadata"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify it works let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = TestModel::::new(device); loaded_model .load_from(&mut load_store) .expect("Failed to load model with metadata"); } /// Run all SafeTensors no-std tests pub fn run_all_tests(device: &B::Device) { test_safetensors_basic::(device); test_safetensors_filtering::(device); test_safetensors_metadata::(device); } ================================================ FILE: crates/burn-no-std-tests/tests/burnpack_tests.rs ================================================ extern crate alloc; #[test] fn test_burnpack_no_std() { use burn_ndarray::NdArray; use burn_no_std_tests::burnpack; type Backend = NdArray; let device = Default::default(); // Run all Burnpack tests burnpack::run_all_tests::(&device); } ================================================ FILE: crates/burn-no-std-tests/tests/safetensors_tests.rs ================================================ extern crate alloc; #[test] fn test_safetensors_no_std() { use burn_ndarray::NdArray; use burn_no_std_tests::safetensors; type Backend = NdArray; let device = Default::default(); // Run all SafeTensors tests safetensors::run_all_tests::(&device); } ================================================ FILE: crates/burn-no-std-tests/tests/test_integration.rs ================================================ #![no_std] // Must keep it for testing use burn_no_std_tests::mlp::*; use burn_no_std_tests::model::*; use burn::tensor::{Distribution, Tensor, backend::Backend}; use burn_ndarray::NdArray; #[test] fn test_mnist_model_with_random_input() { type Backend = NdArray; // Model configurations let device = Default::default(); let mlp_config = MlpConfig::new(); let mnist_config = MnistConfig::new(mlp_config); let mnist_model: Model = Model::new(&mnist_config, &device); // Pass a fixed seed for random, otherwise a build generated random seed is used Backend::seed(&device, mnist_config.seed); // Some random input let input_shape = [1, 28, 28]; let input = Tensor::::random(input_shape, Distribution::Default, &device); // Run through the model let output = mnist_model.forward(input); assert_eq!(&*output.shape(), [1, 10]); assert!(output.to_data().iter::().all(|x| x <= 1.0)); } ================================================ FILE: crates/burn-optim/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Optimizer building blocks for the Burn deep learning framework" documentation = "https://docs.rs/burn-optim" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-optim" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-optim" version.workspace = true [lints] workspace = true [features] default = [ "std", "burn-core/default", ] doc = [ "std", # Doc features "burn-core/doc", ] std = [ "burn-core/std", "num-traits/std", "serde/std", "log", ] tracing = [ "burn-collective?/tracing", "burn-core/tracing", "burn-cuda?/tracing", "burn-fusion?/tracing", "burn-remote?/tracing", "burn-rocm?/tracing", "burn-router?/tracing", "burn-tch?/tracing", "burn-wgpu?/tracing", ] collective = ["burn-collective"] test-cuda = [ "burn-cuda/default", ] # To use cuda during testing, default uses ndarray. test-rocm = [ "burn-rocm/default", ] # To use hip during testing, default uses ndarray. test-tch = [ "burn-tch/default", ] # To use tch during testing, default uses ndarray. test-wgpu = [ "burn-wgpu/default", ] # To use wgpu during testing, default uses ndarray. test-vulkan = [ "test-wgpu", "burn-wgpu/vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. test-metal = [ "test-wgpu", "burn-wgpu/metal", ] # To use wgpu-spirv during testing, default uses ndarray. # Memory checks are disabled by default test-memory-checks = ["burn-fusion/memory-checks"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true, default-features = false } num-traits = { workspace = true } derive-new = { workspace = true } log = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } # The same implementation of HashMap in std but with no_std support (only alloc crate is needed) hashbrown = { workspace = true, features = ["serde"] } # no_std compatible # FOR TESTING burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } [dev-dependencies] burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2" } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } rstest = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-optim/README.md ================================================ # Burn Optimizers Core building blocks for Burn optimizers. ================================================ FILE: crates/burn-optim/src/grad_clipping/base.rs ================================================ use burn_core as burn; use burn::tensor::backend::Backend; use burn::{config::Config, tensor::Tensor}; /// Gradient Clipping provides a way to mitigate exploding gradients #[derive(Config, Debug)] pub enum GradientClippingConfig { /// Clip the gradient by value. Value(f32), /// Clip the gradient by norm. Norm(f32), } impl GradientClippingConfig { /// Initialize the gradient clipping. /// /// # Returns /// /// The gradient clipping. pub fn init(&self) -> GradientClipping { match self { GradientClippingConfig::Value(val) => GradientClipping::Value(*val), GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val), } } } /// Gradient Clipping provides a way to mitigate exploding gradients /// by clipping every component of the gradient by value or by norm during /// backpropagation. #[derive(Clone)] pub enum GradientClipping { /// Clip the gradient by value. Value(f32), /// Clip the gradient by norm. Norm(f32), } impl GradientClipping { /// Clip the gradient. /// /// # Arguments /// /// * `grad` - The gradient to clip. /// /// # Returns /// /// The clipped gradient. pub fn clip_gradient(&self, grad: Tensor) -> Tensor { match self { GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm), } } fn clip_by_value( &self, grad: Tensor, threshold: f32, ) -> Tensor { let greater_mask = grad.clone().greater_elem(threshold); let lower_mask = grad.clone().lower_elem(-threshold); let clipped_grad = grad.mask_fill(greater_mask, threshold); clipped_grad.mask_fill(lower_mask, -threshold) } fn clip_by_norm( &self, grad: Tensor, threshold: f32, ) -> Tensor { let norm = Self::l2_norm(grad.clone()); let clip_coef = threshold / norm.add_scalar(1e-6); // avoid div by zero let clip_coef_clamped = clip_coef.clamp_max(1.0); grad.mul(clip_coef_clamped.unsqueeze()) } fn l2_norm(tensor: Tensor) -> Tensor { let squared = tensor.square(); let sum = squared.sum(); sum.sqrt() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; use burn::tensor::Tensor; #[test] fn test_clip_by_value() { let gradient: Tensor = Tensor::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &Default::default(), ); let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient); let clipped_gradient_data = clipped_gradient.into_data(); for value in clipped_gradient_data.iter::() { assert!(value <= 0.5); } } #[test] fn test_clip_by_norm() { let gradient: Tensor = Tensor::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &Default::default(), ); let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient); let clipped_gradient_data = clipped_gradient.into_data(); for value in clipped_gradient_data.iter::() { assert!(value <= 0.88); } } #[test] fn test_clip_by_norm_no_clipping() { let gradient: Tensor = Tensor::from_floats( [[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]], &Default::default(), ); let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone()); clipped_gradient .into_data() .assert_eq(&gradient.into_data(), true); } } ================================================ FILE: crates/burn-optim/src/grad_clipping/mod.rs ================================================ mod base; pub use base::*; ================================================ FILE: crates/burn-optim/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![recursion_limit = "256"] //! Burn optimizers. #[macro_use] extern crate derive_new; extern crate alloc; /// Optimizer module. pub mod optim; pub use optim::*; /// Gradient clipping module. pub mod grad_clipping; /// Learning rate scheduler module. #[cfg(feature = "std")] pub mod lr_scheduler; /// Type alias for the learning rate. /// /// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it /// can be used for constant learning rate. pub type LearningRate = f64; // We could potentially change the type. /// Backend for test cases #[cfg(all( test, not(feature = "test-tch"), not(feature = "test-wgpu"), not(feature = "test-cuda"), not(feature = "test-rocm") ))] pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] /// Backend for test cases pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] /// Backend for test cases pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] /// Backend for test cases pub type TestBackend = burn_cuda::Cuda; #[cfg(all(test, feature = "test-rocm"))] /// Backend for test cases pub type TestBackend = burn_rocm::Rocm; /// Backend for autodiff test cases #[cfg(test)] pub type TestAutodiffBackend = burn_autodiff::Autodiff; #[cfg(all(test, feature = "test-memory-checks"))] mod tests { burn_fusion::memory_checks!(); } ================================================ FILE: crates/burn-optim/src/lr_scheduler/base.rs ================================================ pub(super) use alloc::string::String; use burn_core as burn; use burn::record::Record; use burn::tensor::backend::Backend; use crate::LearningRate; /// Learning rate scheduler defines how the learning rate will evolve during training. pub trait LrScheduler: Clone + Send + Sync { /// Scheduler associative type to be used when saving and loading the state. type Record: Record; /// Perform the scheduler step, potentially updating its state, and returning the effective /// learning rate. fn step(&mut self) -> LearningRate; /// Get the current state of the scheduler as a [record](Record). fn to_record(&self) -> Self::Record; /// Load the state of the scheduler as a [record](Record). fn load_record(self, record: Self::Record) -> Self; } #[cfg(test)] pub(super) mod test_utils { use super::*; use crate::TestBackend; // A small tolerance for learning rate comparisons. Depending on how learning rates are // computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is // used here. const LOOSE_EPSILON: LearningRate = 1e-10; pub fn check_lr_sequence(mut scheduler: S, expected_lrs: I) where I: IntoIterator, S: LrScheduler, { expected_lrs .into_iter() .enumerate() .for_each(|(i, expected)| { let lr = scheduler.step(); assert!( (lr - expected).abs() < LOOSE_EPSILON, "Scheduled learning rate {lr} is not approximately equal to the expected value \ {expected} at step {i}", ); }); } // save_at_step is the number of steps to run the scheduler before saving and loading back its // state. pub fn check_save_load(mut scheduler: S, save_at_step: usize) where S: Clone + LrScheduler, { let mut truth = scheduler.clone(); // Consume some steps before saving and loading back (0..save_at_step).for_each(|_| { truth.step(); scheduler.step(); }); let rec = scheduler.to_record::(); scheduler = scheduler.load_record::(rec); // Validate that the scheduler resumes from where it left off. compare_steps(&mut scheduler, &mut truth, save_at_step); } // Check if two schedulers produce the same learning rate sequences over the specified number of // steps. pub fn compare_steps(a: &mut S, b: &mut S, num_steps: usize) { (0..num_steps).for_each(|i| { let lr_a = a.step(); let lr_b = b.step(); assert!( (lr_a - lr_b).abs() < LOOSE_EPSILON, "The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \ sequences are not approximately equal", ); }); } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/composed.rs ================================================ use burn_core as burn; use super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig}; use super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig}; use super::linear::{LinearLrScheduler, LinearLrSchedulerConfig}; use super::noam::{NoamLrScheduler, NoamLrSchedulerConfig}; use super::{LrScheduler, String}; use crate::LearningRate; use burn::config::Config; use burn::record::Record; use burn::tensor::backend::Backend; /// Compose multiple [learning rate schedulers](LrScheduler) together. #[derive(Config, Debug)] pub struct ComposedLrSchedulerConfig { #[config(default = "Vec::new()")] schedulers: Vec, #[config(default = "SchedulerReduction::Prod")] reduction: SchedulerReduction, } /// Compose multiple [learning rate schedulers](LrScheduler) together. #[derive(Clone)] pub struct ComposedLrScheduler { schedulers: Vec, reduction: SchedulerReduction, } /// Defines how the learning rates generated by the schedulers are combined. #[derive(Config, Debug, Copy)] pub enum SchedulerReduction { /// All learning rates are averaged. Avg, /// All learning rates are summed. Sum, /// All learning rates are multiplied. Prod, } impl ComposedLrSchedulerConfig { /// Initialize the learning rate scheduler. pub fn init(&self) -> Result { let mut schedulers = Vec::with_capacity(self.schedulers.len()); for config in self.schedulers.iter() { let config = match config { LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?), LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?), LrSchedulerConfig::Exponential(config) => { LrSchedulerItem::Exponential(config.init()?) } LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?), }; schedulers.push(config); } Ok(ComposedLrScheduler { schedulers, reduction: self.reduction, }) } /// Appends a [linear scheduler](LinearLrScheduler). pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self { self.schedulers.push(LrSchedulerConfig::Linear(config)); self } /// Appends a [cosine scheduler](ComposedLrSchedulerConfig). pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self { self.schedulers.push(LrSchedulerConfig::Cosine(config)); self } /// Appends an [exponential scheduler](ExponentialLrScheduler). pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self { self.schedulers.push(LrSchedulerConfig::Exponential(config)); self } /// Appends a [noam scheduler](NoamLrScheduler). pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self { self.schedulers.push(LrSchedulerConfig::Noam(config)); self } } #[derive(Config, Debug)] enum LrSchedulerConfig { Linear(LinearLrSchedulerConfig), Cosine(CosineAnnealingLrSchedulerConfig), Exponential(ExponentialLrSchedulerConfig), Noam(NoamLrSchedulerConfig), } #[derive(Clone)] enum LrSchedulerItem { Linear(LinearLrScheduler), Cosine(CosineAnnealingLrScheduler), Exponential(ExponentialLrScheduler), Noam(NoamLrScheduler), } #[derive(Record)] /// Record item for the [composed learning rate scheduler](ComposedLrScheduler). pub enum LrSchedulerRecord { /// The linear variant. Linear(::Record), /// The cosine variant. Cosine(::Record), /// The exponential variant. Exponential(::Record), /// The noam variant. Noam(::Record), } #[derive(Record)] /// Records for the [composed learning rate scheduler](ComposedLrScheduler). pub struct ComposedLrSchedulerRecord { schedulers: Vec>, } impl LrScheduler for ComposedLrScheduler { type Record = ComposedLrSchedulerRecord; fn step(&mut self) -> LearningRate { let mut step = match self.reduction { SchedulerReduction::Avg => 0.0, SchedulerReduction::Sum => 0.0, SchedulerReduction::Prod => 1.0, }; let num_scheduler = self.schedulers.len() as f64; for lr in self.schedulers.iter_mut().map(|s| match s { LrSchedulerItem::Linear(item) => item.step(), LrSchedulerItem::Cosine(item) => item.step(), LrSchedulerItem::Exponential(item) => item.step(), LrSchedulerItem::Noam(item) => item.step(), }) { step = match self.reduction { SchedulerReduction::Avg => step + (lr / num_scheduler), SchedulerReduction::Sum => step + lr, SchedulerReduction::Prod => step * lr, } } step } fn to_record(&self) -> Self::Record { ComposedLrSchedulerRecord:: { schedulers: self .schedulers .iter() .map(|s| match s { LrSchedulerItem::Linear(item) => { LrSchedulerRecord::Linear(item.to_record::()) } LrSchedulerItem::Cosine(item) => { LrSchedulerRecord::Cosine(item.to_record::()) } LrSchedulerItem::Exponential(item) => { LrSchedulerRecord::Exponential(item.to_record::()) } LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::()), }) .collect(), } } fn load_record(mut self, record: Self::Record) -> Self { self.schedulers = self .schedulers .into_iter() .zip(record.schedulers) .map(|scheduler| match scheduler { (LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => { LrSchedulerItem::Linear(item.load_record::(record)) } (LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => { LrSchedulerItem::Cosine(item.load_record::(record)) } (LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => { LrSchedulerItem::Exponential(item.load_record::(record)) } (LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => { LrSchedulerItem::Noam(item.load_record::(record)) } _ => panic!("Invalid state"), }) .collect(); self } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/constant.rs ================================================ use burn_core as burn; use burn::tensor::backend::Backend; use super::LrScheduler; use crate::LearningRate; /// Constant learning rate implementing [learning rate scheduler](LrScheduler). /// /// # Notes /// /// You can also use [learning rate](LearningRate) which the same effect. #[derive(new, Clone, Debug)] pub struct ConstantLr { lr: LearningRate, } impl From for ConstantLr { fn from(lr: LearningRate) -> Self { Self { lr } } } impl LrScheduler for ConstantLr { type Record = (); fn step(&mut self) -> LearningRate { self.lr } fn to_record(&self) -> Self::Record {} fn load_record(self, _record: Self::Record) -> Self { self } } impl LrScheduler for LearningRate { type Record = (); fn step(&mut self) -> LearningRate { *self } fn to_record(&self) -> Self::Record {} fn load_record(self, _record: Self::Record) -> Self { self } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/cosine.rs ================================================ use burn_core as burn; use super::{LrScheduler, String}; use crate::LearningRate; use burn::config::Config; use burn::tensor::backend::Backend; /// The configuration for creating a [Cosine Annealing learning rate scheduler with warm /// restarts](CosineAnnealingLrScheduler). /// /// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by /// following a cosine function. After `num_iters` iterations, the learning rate is reset to /// `initial_lr`. #[derive(Config, Debug)] pub struct CosineAnnealingLrSchedulerConfig { // The initial learning rate. initial_lr: LearningRate, // The final learning rate. #[config(default = 0.0)] min_lr: LearningRate, // The number of iterations between two restarts. The two restart iterations themselves are not // included. num_iters: usize, } impl CosineAnnealingLrSchedulerConfig { /// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler). /// /// # Errors /// /// An error will be returned if any of the following conditions is true: /// /// * `initial_lr` is out of range (0.0, 1.0] /// * `min_lr` is out of range [0.0, `initial_lr`] /// * `num_iters` is 0 pub fn init(&self) -> Result { if self.initial_lr <= 0. || self.initial_lr > 1. { return Err("Initial learning rate must be greater than 0 and at most 1".into()); } if self.min_lr < 0.0 || self.min_lr > self.initial_lr { return Err( "Minimum learning rate must be at least 0 and at most equal to the initial \ learning rate" .into(), ); } if self.num_iters == 0 { return Err("Number of iterations must be at least 1".into()); } Ok(CosineAnnealingLrScheduler { min_lr: self.min_lr, max_lr: self.initial_lr, num_iters: self.num_iters, current_iter: usize::MAX, }) } } /// A Cosine Annealing learning rate scheduler. /// /// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm /// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more /// information. #[derive(Clone, Copy, Debug)] pub struct CosineAnnealingLrScheduler { min_lr: LearningRate, max_lr: LearningRate, num_iters: usize, current_iter: usize, } impl LrScheduler for CosineAnnealingLrScheduler { type Record = usize; fn step(&mut self) -> LearningRate { // Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the // first call. We could've used i64 with an initial value -1, but keeping it in usize saves // us from some type casting here. self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1); self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI) .cos()) } fn to_record(&self) -> Self::Record { self.current_iter } fn load_record(mut self, record: Self::Record) -> Self { self.current_iter = record; self } } #[cfg(test)] mod tests { use super::super::test_utils; use super::*; #[test] fn config_initial_lr_too_low() { let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_initial_lr_too_high() { let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_min_lr_too_low() { let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10) .with_min_lr(-0.1) .init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Minimum learning rate must be at least 0 and at most equal to the initial learning \ rate", "Error messages should match", ); } #[test] fn config_min_lr_too_high() { let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10) .with_min_lr(0.6) .init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Minimum learning rate must be at least 0 and at most equal to the initial learning \ rate", "Error messages should match", ); } #[test] fn config_num_iters_too_low() { let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Number of iterations must be at least 1", "Error messages should match", ); } #[test] fn test_lr_change() { const INITIAL_LR: LearningRate = 0.5; const MIN_LR: LearningRate = 0.1; let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2) .with_min_lr(MIN_LR) .init() .unwrap(); let expected_lrs = [ INITIAL_LR, // cos(0) (INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2) MIN_LR, // cos(PI) INITIAL_LR, // restart ]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_save_and_load() { const NUM_ITERS: usize = 9; let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS) .init() .unwrap(); test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2); } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/exponential.rs ================================================ use burn_core as burn; use super::{LrScheduler, String}; use crate::LearningRate; use burn::config::Config; use burn::tensor::backend::Backend; /// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler). /// /// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by /// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning /// rate is given by `initial_lr * gamma^i`. #[derive(Config, Debug)] pub struct ExponentialLrSchedulerConfig { // The initial learning rate. initial_lr: LearningRate, // The constant that the learning rate is multiplied by on each iteration. gamma: f64, } impl ExponentialLrSchedulerConfig { /// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler). /// /// # Errors /// /// An error will be returned if any of the following conditions is true: /// /// * `initial_lr` is out of range (0.0, 1.0] /// * `gamma` is out of range (0.0, 1.0] pub fn init(&self) -> Result { if self.initial_lr <= 0. || self.initial_lr > 1. { return Err("Initial learning rate must be greater than 0 and at most 1".into()); } if self.gamma <= 0. || self.gamma > 1. { return Err("Gamma must be greater than 0 and at most 1".into()); } Ok(ExponentialLrScheduler { // Such an initial value eliminates the need for special-case handling of the first // learning rate. previous_lr: self.initial_lr / self.gamma, gamma: self.gamma, }) } } /// A exponential learning rate scheduler. /// /// See [ExponentialLrSchedulerConfig] for more information. #[derive(Clone, Copy, Debug)] pub struct ExponentialLrScheduler { // The previous iteration's learning rate. previous_lr: LearningRate, // The constant that the learning rate is multiplied by on each iteration. gamma: f64, } impl LrScheduler for ExponentialLrScheduler { type Record = LearningRate; fn step(&mut self) -> LearningRate { self.previous_lr *= self.gamma; self.previous_lr } fn to_record(&self) -> Self::Record { self.previous_lr } fn load_record(mut self, record: Self::Record) -> Self { self.previous_lr = record; self } } #[cfg(test)] mod tests { use super::super::test_utils; use super::*; #[test] fn config_initial_lr_too_low() { let r = ExponentialLrSchedulerConfig::new(0., 0.5).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_initial_lr_too_high() { let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_gamma_too_low() { let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Gamma must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_gamma_too_high() { let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Gamma must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn test_lr_change() { let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap(); let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_save_and_load() { let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3) .init() .unwrap(); test_utils::check_save_load(scheduler, 7); } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/linear.rs ================================================ use burn_core as burn; use super::{LrScheduler, String}; use crate::LearningRate; use burn::config::Config; use burn::tensor::backend::Backend; /// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler). /// /// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a /// constant amount on each iteration until reaching a final learning rate `final_lr`. The /// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to /// `final_lr`. #[derive(Config, Debug)] pub struct LinearLrSchedulerConfig { // The initial learning rate. initial_lr: LearningRate, // The final learning rate. final_lr: LearningRate, // The number of iterations before reaching the final learning rate. num_iters: usize, } impl LinearLrSchedulerConfig { /// Initializes a [linear learning rate scheduler](LinearLrScheduler). /// /// # Errors /// /// An error will be returned if any of the following conditions is true: /// /// * `initial_lr` is out of range (0.0, 1.0] /// * `final_lr` is out of range [0.0, 1.0] /// * `num_iters` is 0 pub fn init(&self) -> Result { if self.initial_lr <= 0. || self.initial_lr > 1. { return Err("Initial learning rate must be greater than 0 and at most 1".into()); } if self.final_lr < 0. || self.final_lr > 1. { return Err("Final learning rate must be at least 0 and at most 1".into()); } if self.num_iters == 0 { return Err("Number of iterations must be at least 1".into()); } Ok(LinearLrScheduler { final_lr: self.final_lr, step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64, remaining_iters: self.num_iters + 1, }) } } /// A linear learning rate scheduler. /// /// See [LinearLrSchedulerConfig] for more information. #[derive(Clone, Copy, Debug)] pub struct LinearLrScheduler { // The final learning rate after the linear changing process stops. final_lr: LearningRate, // The amount that the learning rate changes by on each iteration. step_size: f64, // The number of iterations left before reaching the final learning rate. remaining_iters: usize, } impl LrScheduler for LinearLrScheduler { type Record = usize; fn step(&mut self) -> LearningRate { self.remaining_iters -= (self.remaining_iters != 0) as usize; self.final_lr - self.step_size * self.remaining_iters as f64 } fn to_record(&self) -> Self::Record { self.remaining_iters } fn load_record(mut self, record: Self::Record) -> Self { self.remaining_iters = record; self } } #[cfg(test)] mod tests { use super::super::test_utils; use super::*; #[test] fn config_initial_lr_too_low() { let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_initial_lr_too_high() { let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Initial learning rate must be greater than 0 and at most 1", "Error messages should match", ); } #[test] fn config_final_lr_too_low() { let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Final learning rate must be at least 0 and at most 1", "Error messages should match", ); } #[test] fn config_final_lr_too_high() { let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Final learning rate must be at least 0 and at most 1", "Error messages should match", ); } #[test] fn config_num_iters_too_low() { let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init(); assert!(r.is_err(), "Should return an error"); assert_eq!( r.unwrap_err(), "Number of iterations must be at least 1", "Error messages should match", ); } #[test] fn test_lr_decreasing() { let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap(); let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_lr_increasing() { let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap(); let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_lr_unchanging() { let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap(); let expected_lrs = [0.3, 0.3, 0.3, 0.3]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_save_and_load() { const NUM_ITERS: usize = 6; let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS) .init() .unwrap(); test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2); } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/mod.rs ================================================ /// Constant learning rate scheduler pub mod constant; /// Composed learning rate scheduler pub mod composed; /// Linear learning rate scheduler pub mod linear; /// Noam learning rate scheduler pub mod noam; /// Exponential learning rate scheduler pub mod exponential; /// Cosine learning rate scheduler pub mod cosine; /// Step learning rate scheduler pub mod step; mod base; pub use base::*; ================================================ FILE: crates/burn-optim/src/lr_scheduler/noam.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::tensor::backend::Backend; use super::{LrScheduler, String}; use crate::LearningRate; /// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler. #[derive(Config, Debug)] pub struct NoamLrSchedulerConfig { /// The overall scale factor for the learning rate decay. factor: f64, /// The number of steps before the exponential decay stats. #[config(default = 4000)] warmup_steps: usize, /// The size of the model. #[config(default = 512)] model_size: usize, } /// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). #[derive(Clone, Debug)] pub struct NoamLrScheduler { warmup_steps: f64, embedding_size: f64, factor: f64, step: f64, } impl NoamLrSchedulerConfig { /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler. /// /// # Errors /// /// An error will be returned if any of the following conditions is true: /// /// * `warmup_steps` is 0 /// * `model_size` is 0 pub fn init(&self) -> Result { if self.warmup_steps == 0 { return Err( "Number of steps before exponential decay starts must be greater than 0".into(), ); } if self.model_size == 0 { return Err("Model size must be greater than 0".into()); } Ok(NoamLrScheduler { warmup_steps: self.warmup_steps as f64, embedding_size: self.model_size as f64, factor: self.factor, step: 0.0, }) } } impl LrScheduler for NoamLrScheduler { type Record = usize; fn step(&mut self) -> LearningRate { self.step += 1.0; let arg1 = self.step.powf(-0.5); let arg2 = self.step * self.warmup_steps.powf(-1.5); self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2) } fn to_record(&self) -> Self::Record { self.step as usize } fn load_record(mut self, record: Self::Record) -> Self { self.step = record as f64; self } } #[cfg(test)] mod tests { use super::*; #[test] fn test_config_warmup_steps_invalid() { let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init(); assert!(r.is_err(), "Should return an error"); } #[test] fn test_config_warmup_steps_valid() { let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init(); assert!(r.is_ok(), "Should return a success value"); } #[test] fn test_config_model_size_invalid() { let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init(); assert!(r.is_err(), "Should return an error"); } #[test] fn test_config_model_size_valid() { let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init(); assert!(r.is_ok(), "Should return a success value"); } #[test] fn test_function_increase_and_decrease() { let warmup_steps = 100; let mut scheduler = NoamLrSchedulerConfig::new(10.0) .with_warmup_steps(warmup_steps) .init() .unwrap(); let mut lr_current = 0.0; for _ in 0..warmup_steps { let lr = scheduler.step(); assert!( lr > lr_current, "Learning rate should increase before the warmup_steps is reached." ); lr_current = lr; } for _ in 0..warmup_steps { let lr = scheduler.step(); assert!( lr < lr_current, "Learning rate should decrease after the warmup_steps is reached." ); lr_current = lr; } } } ================================================ FILE: crates/burn-optim/src/lr_scheduler/step.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::tensor::backend::Backend; use super::{LrScheduler, String}; use crate::LearningRate; /// The configuration for create a [step learning rate scheduler](StepLrScheduler). /// /// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until /// the same value has been given for `step_size` times. Then it multiplies the learning rate by /// `gamma` before repeating the process. /// /// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but /// a warning log will be output for such a value in case of mistyping. /// /// ## Notes /// /// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than /// `i32::MAX + 1` times. #[derive(Config, Debug)] pub struct StepLrSchedulerConfig { // The learning rate at the initial step. initial_lr: LearningRate, // The number of iterations over which the learning rate remains unchanged before the next // update. step_size: usize, /// The factor by which the learning rate is multiplied with each update. Default: 0.1. #[config(default = 0.1)] gamma: f64, } impl StepLrSchedulerConfig { /// Initializes a [step learning rate scheduler](StepLrScheduler). /// /// # Errors /// /// An error will be returned if `step_size` is 0. pub fn init(&self) -> Result { if self.step_size == 0 { return Err("Step size must be greater than 0".into()); } // Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful // in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518). if self.initial_lr <= 0.0 { log::warn!( "Initial learning rate value of {} is not a positive number. Ignore this warning \ if it is intended.", self.initial_lr ); } if self.gamma <= 0.0 || self.gamma >= 1.0 { log::warn!( "Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \ intended.", self.gamma ); } Ok(StepLrScheduler { init_lr: self.initial_lr, step_size: self.step_size, gamma: self.gamma, iter_idx: -1, }) } } /// Step learning rate scheduler. #[derive(Clone, Debug)] pub struct StepLrScheduler { init_lr: LearningRate, step_size: usize, gamma: f64, // The index of the current iteration. // `i32` is used for avoiding truncating the exponent when taking powers of `gamma`. iter_idx: i32, } impl LrScheduler for StepLrScheduler { type Record = i32; fn step(&mut self) -> LearningRate { self.iter_idx = self .iter_idx .checked_add(1) .expect("`.step()` should be called no more than `i32::MAX + 1` times"); // Type casting below causes no truncation, as all the values fall within the ranges. self.init_lr * self .gamma .powi((self.iter_idx as usize / self.step_size) as i32) } fn to_record(&self) -> Self::Record { self.iter_idx } fn load_record(mut self, record: Self::Record) -> Self { self.iter_idx = record; self } } #[cfg(test)] mod tests { use super::super::test_utils; use super::*; use crate::TestBackend; // Warning logs for initial LR and gamma are not tested because there seems no straightforward // way to do it. // // Creating a mock logger that collects logs into `String` for later examination seems a possible // solution, but unit tests run in the same process in parallel, where the single logger would // be shared by multiple tests, so logs from different tests would be mixed up with no easy way // to separate them. // Using "--test-threads=1" could prevent mixup, but whether the ability to test logging is // worth the slowdown would be a question. Also, using a primitive provided by `std` to // synchronize the logger across tests is not an option since we need to support `no-std`. // Maybe the mocking approach can be reconsidered after we are given an option to run tests in // separate processes like what the issue below is proposing: // https://github.com/rust-lang/rust/issues/47506 // // As a side note, a helper crate exists for the exact purpose: // https://crates.io/crates/testing_logger // but the crate has been unmaintained and using it would introduce another dependency. #[test] fn test_config_step_size_zero() { let r = StepLrSchedulerConfig::new(1.0, 0).init(); assert!(r.is_err(), "Should return an error"); } #[test] fn test_config_step_size_nonzero() { let r = StepLrSchedulerConfig::new(1.0, 1).init(); assert!(r.is_ok(), "Should return a success value"); } #[test] fn test_config_default_gamma() { const INIT_LR: LearningRate = 0.4; const STEP_SIZE: usize = 2; let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE) .init() .unwrap(); let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE) .with_gamma(0.1) .init() .unwrap(); test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE); } #[test] fn test_lr_decreasing() { let scheduler = StepLrSchedulerConfig::new(0.5, 3) .with_gamma(0.1) .init() .unwrap(); let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_lr_increasing() { let scheduler = StepLrSchedulerConfig::new(0.1, 2) .with_gamma(2.0) .init() .unwrap(); let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_lr_unchanging() { let scheduler = StepLrSchedulerConfig::new(3.1, 1) .with_gamma(1.0) .init() .unwrap(); let expected_lrs = [3.1, 3.1, 3.1]; test_utils::check_lr_sequence(scheduler, expected_lrs); } #[test] fn test_save_and_load() { const STEP_SIZE: usize = 10; let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE) .with_gamma(0.03) .init() .unwrap(); test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2); } // It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that // depends on private fields is used to implement the test. #[test] fn test_number_of_calls_within_limit() { // Create a scheduler that has already run `i32::MAX` steps let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap(); scheduler = scheduler.load_record::(i32::MAX - 1); scheduler.step(); } #[test] #[should_panic = "i32::MAX"] fn test_number_of_calls_over_limit() { // Create a scheduler that has already run `i32::MAX` steps let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap(); scheduler = scheduler.load_record::(i32::MAX - 1); scheduler.step(); scheduler.step(); } } ================================================ FILE: crates/burn-optim/src/optim/adagrad.rs ================================================ use burn_core as burn; use burn::{module::AutodiffModule, record::Record}; use burn::config::Config; use burn::tensor::{Tensor, backend::AutodiffBackend}; use burn::tensor::{backend::Backend, ops::Device}; use super::{ SimpleOptimizer, adaptor::OptimizerAdaptor, decay::{WeightDecay, WeightDecayConfig}, }; use crate::{LearningRate, grad_clipping::GradientClippingConfig}; /// AdaGrad configuration. #[derive(Config, Debug)] pub struct AdaGradConfig { #[config(default = 0.)] lr_decay: f64, #[config(default = 1e-5)] epsilon: f32, /// [Weight decay](WeightDecayConfig) config. weight_decay: Option, /// [Gradient Clipping](GradientClippingConfig) config. grad_clipping: Option, } /// AdaGrad optimizer #[derive(Clone)] pub struct AdaGrad { lr_decay: LrDecay, weight_decay: Option, } /// AdaGrad state. #[derive(Record, Clone, new)] pub struct AdaGradState { lr_decay: LrDecayState, } impl SimpleOptimizer for AdaGrad { type State = AdaGradState; fn step( &self, lr: LearningRate, tensor: Tensor, mut grad: Tensor, state: Option>, ) -> (Tensor, Option>) { let mut state_lr_decay = None; if let Some(state) = state { state_lr_decay = Some(state.lr_decay); } if let Some(weight_decay) = &self.weight_decay { grad = weight_decay.transform(grad, tensor.clone()); } let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay); let state = AdaGradState::new(state_lr_decay); (tensor - grad, Some(state)) } fn to_device(mut state: Self::State, device: &Device) -> Self::State { state.lr_decay = state.lr_decay.to_device(device); state } } impl AdaGradConfig { /// Initialize AdaGrad optimizer. /// /// # Returns /// /// Returns an optimizer that can be used to optimize a module. pub fn init>( &self, ) -> OptimizerAdaptor { let optim = AdaGrad { lr_decay: LrDecay { lr_decay: self.lr_decay, epsilon: self.epsilon, }, weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), }; let mut optim = OptimizerAdaptor::from(optim); if let Some(config) = &self.grad_clipping { optim = optim.with_grad_clipping(config.init()); } optim } } /// Learning rate decay state (also includes sum state). #[derive(Record, new, Clone)] pub struct LrDecayState { time: usize, sum: Tensor, } #[derive(Clone)] struct LrDecay { lr_decay: f64, epsilon: f32, } impl LrDecay { pub fn transform( &self, grad: Tensor, lr: LearningRate, lr_decay_state: Option>, ) -> (Tensor, LrDecayState) { let state = if let Some(mut state) = lr_decay_state { state.sum = state.sum.add(grad.clone().square()); state.time += 1; state } else { LrDecayState::new(1, grad.clone().square()) }; let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay); let grad = grad .div(state.sum.clone().sqrt().add_scalar(self.epsilon)) .mul_scalar(new_lr); (grad, state) } } impl LrDecayState { /// Move state to device. /// /// # Arguments /// /// * `device` - Device to move state to. /// /// # Returns /// /// Returns state moved to device. pub fn to_device(mut self, device: &B::Device) -> Self { self.sum = self.sum.to_device(device); self } } #[cfg(test)] mod tests { use burn::tensor::Tolerance; use burn::tensor::ops::FloatElem; use super::*; use crate::TestAutodiffBackend; use crate::{GradientsParams, Optimizer}; use burn::module::{Module, Param}; use burn::tensor::{Distribution, Tensor, TensorData}; use burn_nn::{Linear, LinearConfig, LinearRecord}; const LEARNING_RATE: LearningRate = 0.01; #[test] fn test_adagrad_optimizer_save_load_state() { let device = Default::default(); let linear = LinearConfig::new(6, 6).init(&device); let x = Tensor::::random([2, 6], Distribution::Default, &device); let mut optimizer = create_adagrad(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let _linear = optimizer.step(LEARNING_RATE, linear, grads); #[cfg(feature = "std")] { use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; BinFileRecorder::::default() .record( optimizer.to_record(), std::env::temp_dir().as_path().join("test_optim_adagrad"), ) .unwrap(); } #[cfg(not(feature = "std"))] { use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder}; let result = BinBytesRecorder::::default() .record(optimizer.to_record(), ()) .unwrap(); assert!(!result.is_empty()); } let state_optim_before = optimizer.to_record(); let state_optim_before_copy = optimizer.to_record(); let optimizer = create_adagrad(); let optimizer = optimizer.load_record(state_optim_before_copy); let state_optim_after = optimizer.to_record(); assert_eq!(state_optim_before.len(), state_optim_after.len()); } #[test] fn test_adagrad_optimizer_with_numbers() { let device = Default::default(); let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &device, ) .require_grad(); let mut optimizer = AdaGradConfig::new() .with_epsilon(1e-8) .with_lr_decay(0.5) .init(); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); let weights_expected = TensorData::from([ [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711], [ 0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756, ], [ -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538, ], [ -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964, ], [ 0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504, ], [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895], ]); let bias_expected = TensorData::from([ -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714, ]); let (weight_updated, bias_updated) = ( state_updated.weight.val().into_data(), state_updated.bias.unwrap().val().into_data(), ); type FT = FloatElem; let tolerance = Tolerance::absolute(1e-6); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: Some(Param::from_data(bias, &device)), }; LinearConfig::new(6, 6).init(&device).load_record(record) } fn create_adagrad() -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdaGradConfig::new(); AdaGrad { lr_decay: LrDecay { lr_decay: config.lr_decay, epsilon: config.epsilon, }, weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), } .into() } } ================================================ FILE: crates/burn-optim/src/optim/adam.rs ================================================ use burn_core as burn; use burn::{module::AutodiffModule, record::Record}; use burn::config::Config; use burn::tensor::{Tensor, backend::AutodiffBackend}; use burn::tensor::{backend::Backend, ops::Device}; use super::{ SimpleOptimizer, adaptor::OptimizerAdaptor, decay::{WeightDecay, WeightDecayConfig}, }; use crate::{LearningRate, grad_clipping::GradientClippingConfig}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Adam configuration. #[derive(Config, Debug)] pub struct AdamConfig { /// Parameter for Adam. #[config(default = 0.9)] beta_1: f32, /// Parameter for Adam. #[config(default = 0.999)] beta_2: f32, /// A value required for numerical stability. #[config(default = 1e-5)] epsilon: f32, /// Whether to use AMSGrad algorithm #[config(default = false)] amsgrad: bool, /// [Weight decay](WeightDecayConfig) config. weight_decay: Option, /// [Gradient Clipping](GradientClippingConfig) config. grad_clipping: Option, } /// Adam optimizer. /// /// See: /// - [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf). /// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ) #[derive(Clone)] pub struct Adam { momentum: AdaptiveMomentum, weight_decay: Option, } /// Adam state. #[derive(Record, Clone, new)] pub struct AdamState { /// The current adaptive momentum. pub momentum: AdaptiveMomentumState, } impl SimpleOptimizer for Adam { type State = AdamState; fn step( &self, lr: LearningRate, tensor: Tensor, mut grad: Tensor, state: Option>, ) -> (Tensor, Option>) { let mut state_momentum = None; if let Some(state) = state { state_momentum = Some(state.momentum); } if let Some(weight_decay) = &self.weight_decay { grad = weight_decay.transform(grad, tensor.clone()); } let (grad, state_momentum) = self.momentum.transform(grad, state_momentum); let state = AdamState::new(state_momentum); let delta = grad.mul_scalar(lr); (tensor - delta, Some(state)) } fn to_device(mut state: Self::State, device: &Device) -> Self::State { state.momentum = state.momentum.to_device(device); state } } impl AdamConfig { /// Initialize Adam optimizer. /// /// # Returns /// /// Returns an optimizer that can be used to optimize a module. pub fn init>(&self) -> OptimizerAdaptor { let optim = Adam { momentum: AdaptiveMomentum { beta_1: self.beta_1, beta_2: self.beta_2, epsilon: self.epsilon, amsgrad: self.amsgrad, }, weight_decay: self.weight_decay.as_ref().map(WeightDecay::new), }; let mut optim = OptimizerAdaptor::from(optim); if let Some(config) = &self.grad_clipping { optim = optim.with_grad_clipping(config.init()); } optim } } /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumState { /// The number of iterations aggregated. pub time: usize, /// The first order momentum. pub moment_1: Tensor, /// The second order momentum. pub moment_2: Tensor, /// Max of second order momentum (for AMSGrad) #[new(default)] pub max_moment_2: Option>, } #[derive(Clone)] struct AdaptiveMomentum { beta_1: f32, beta_2: f32, epsilon: f32, amsgrad: bool, } impl AdaptiveMomentum { pub fn transform( &self, grad: Tensor, momentum_state: Option>, ) -> (Tensor, AdaptiveMomentumState) { let state = if let Some(mut state) = momentum_state { let factor = 1.0 - self.beta_1; state.moment_1 = state .moment_1 .mul_scalar(self.beta_1) .add(grad.clone().mul_scalar(factor)); let factor = 1.0 - self.beta_2; state.moment_2 = state .moment_2 .mul_scalar(self.beta_2) .add(grad.square().mul_scalar(factor)); if self.amsgrad { let max_v = state .max_moment_2 .take() .unwrap_or_else(|| state.moment_2.clone()); let new_max = max_v.max_pair(state.moment_2.clone()); state.max_moment_2 = Some(new_max); } state.time += 1; state } else { let factor = 1.0 - self.beta_1; let moment_1 = grad.clone().mul_scalar(factor); let factor = 1.0 - self.beta_2; let moment_2 = grad.square().mul_scalar(factor); let max_moment_2 = self.amsgrad.then(|| moment_2.clone()); AdaptiveMomentumState { time: 1, moment_1, moment_2, max_moment_2, } }; let time = state.time as i32; let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt(); let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time)); let v_to_use = if self.amsgrad { state.max_moment_2.as_ref().unwrap_or(&state.moment_2) } else { &state.moment_2 }; let grad = state.moment_1.clone().mul_scalar(combined_factor).div( v_to_use .clone() .sqrt() .add_scalar(self.epsilon * bias_correction2_sqrt), ); (grad, state) } } impl AdaptiveMomentumState { /// Move state to device. /// /// # Arguments /// /// * `device` - Device to move state to. /// /// # Returns /// /// Returns state moved to device. pub fn to_device(mut self, device: &B::Device) -> Self { self.moment_1 = self.moment_1.to_device(device); self.moment_2 = self.moment_2.to_device(device); self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device)); self } } #[cfg(test)] mod tests { use burn::tensor::Tolerance; use burn::tensor::ops::FloatElem; use super::*; use crate::TestAutodiffBackend; use crate::{GradientsParams, Optimizer}; use burn::module::{Module, Param}; use burn::tensor::{Distribution, Tensor, TensorData}; use burn_nn::{Linear, LinearConfig, LinearRecord}; const LEARNING_RATE: LearningRate = 0.01; #[test] fn test_adam_optimizer_save_load_state() { let device = Default::default(); let linear = LinearConfig::new(6, 6).init(&device); let x = Tensor::::random([2, 6], Distribution::Default, &device); let mut optimizer = create_adam(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let _linear = optimizer.step(LEARNING_RATE, linear, grads); #[cfg(feature = "std")] { use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; BinFileRecorder::::default() .record( optimizer.to_record(), std::env::temp_dir().as_path().join("test_optim_adam"), ) .unwrap(); } #[cfg(not(feature = "std"))] { use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder}; let result = BinBytesRecorder::::default() .record(optimizer.to_record(), ()) .unwrap(); assert!(!result.is_empty()); } let state_optim_before = optimizer.to_record(); let state_optim_before_copy = optimizer.to_record(); let optimizer = create_adam(); let optimizer = optimizer.load_record(state_optim_before_copy); let state_optim_after = optimizer.to_record(); assert_eq!(state_optim_before.len(), state_optim_after.len()); } #[test] fn test_adam_optimizer_with_amsgrad_50_steps() { let device = Default::default(); let mut linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let mut optimizer = AdamConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_amsgrad(true) .with_weight_decay(Some(WeightDecayConfig::new(0.5))) .init(); for i in 1..=50 { let x = Tensor::::ones([2, 6], &device) .mul_scalar(i as f32 * 0.1) .require_grad(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); linear = optimizer.step(LEARNING_RATE, linear, grads); } let state_updated = linear.into_record(); let weight_updated = state_updated.weight.to_data(); let bias_updated = state_updated.bias.unwrap().to_data(); let weights_expected = TensorData::from([ [ -0.9125810265541077, -0.45855265855789185, -0.1915993094444275, -0.2759990692138672, -0.5099529027938843, -0.5287043452262878, ], [ -0.5181325674057007, -0.6139854788780212, -0.9574727416038513, -0.34102925658226013, -0.400514155626297, -0.8847861886024475, ], [ -0.614483118057251, -0.5611032247543335, -0.8887064456939697, -0.34762972593307495, -0.8708556890487671, -0.2830044627189636, ], [ -0.8904699683189392, -0.8151527643203735, -0.9621278643608093, -0.8905676603317261, -0.671261191368103, -0.4333854615688324, ], [ -0.26599061489105225, -0.8119961023330688, -0.22424538433551788, -0.7672406435012817, -0.2163349837064743, -0.6258266568183899, ], [ -0.611397922039032, -0.6075160503387451, -0.4701341986656189, -0.4039117991924286, -0.5663845539093018, -0.21262989938259125, ], ]); let bias_expected = TensorData::from([ -0.8817203044891357, -0.4038999378681183, -0.5889149308204651, -0.37475723028182983, -0.3557940721511841, -0.47914788126945496, ]); type FT = FloatElem; let tolerance = Tolerance::absolute(1e-5); weight_updated.assert_approx_eq::(&weights_expected, tolerance); bias_updated.assert_approx_eq::(&bias_expected, tolerance); } #[test] fn test_adam_optimizer_with_numbers() { let device = Default::default(); let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &device, ) .require_grad(); let mut optimizer = AdamConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_weight_decay(Some(WeightDecayConfig::new(0.5))) .init(); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); let weights_expected = TensorData::from([ [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154], [ 0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133, ], [ -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047, ], [ -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651, ], [ 0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343, ], [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346], ]); let bias_expected = TensorData::from([ -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999, ]); let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); type FT = FloatElem; let tolerance = Tolerance::absolute(1e-2); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } #[test] fn test_adam_optimizer_no_nan() { let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let x = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &Default::default(), ) .require_grad(); let mut optimizer = AdamConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_weight_decay(Some(WeightDecayConfig::new(0.5))) .init(); let grads = linear.forward(x.clone()).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); assert!(!state_updated.weight.to_data().as_slice::().unwrap()[0].is_nan()); } fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: Some(Param::from_data(bias, &device)), }; LinearConfig::new(6, 6).init(&device).load_record(record) } fn create_adam() -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdamConfig::new(); Adam { momentum: AdaptiveMomentum { beta_1: config.beta_1, beta_2: config.beta_2, epsilon: config.epsilon, amsgrad: config.amsgrad, }, weight_decay: config.weight_decay.as_ref().map(WeightDecay::new), } .into() } } ================================================ FILE: crates/burn-optim/src/optim/adamw.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::tensor::{Tensor, backend::AutodiffBackend}; use burn::tensor::{backend::Backend, ops::Device}; use burn::{module::AutodiffModule, record::Record}; use super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor}; use crate::{LearningRate, grad_clipping::GradientClippingConfig}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// [`AdamW`] Configuration. #[derive(Config, Debug)] pub struct AdamWConfig { /// Parameter for AdamW. #[config(default = 0.9)] beta_1: f32, /// Parameter for AdamW. #[config(default = 0.999)] beta_2: f32, /// A value required for numerical stability. #[config(default = 1e-5)] epsilon: f32, /// Weight decay config. #[config(default = 1e-4)] weight_decay: f32, /// Cautious weight decay config. /// /// See: #[config(default = false)] cautious_weight_decay: bool, /// Whether to use AMSGrad algorithm #[config(default = false)] amsgrad: bool, /// [Gradient Clipping](GradientClippingConfig) config. grad_clipping: Option, } /// AdamW optimizer. /// /// See: /// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101). /// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402) /// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ) /// /// Configured by [`AdamWConfig`]. #[derive(Clone)] pub struct AdamW { momentum: AdaptiveMomentumW, weight_decay: f32, cautious_weight_decay: bool, } /// AdamW state. #[derive(Record, Clone, new)] pub struct AdamWState { /// Th current adaptive momentum state. pub momentum: AdaptiveMomentumState, } impl SimpleOptimizer for AdamW { type State = AdamWState; /// A single optimization step for any tensor that represents the parameters of a model. fn step( &self, // Learning rate. lr: LearningRate, // Any tensor that represents the parameters of a model. tensor: Tensor, // Gradient of the loss w.r.t. the parameters. grad: Tensor, // State of the optimizer. state: Option>, ) -> (Tensor, Option>) { let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum)); let decay_rate = lr * (self.weight_decay as f64); let decayed_tensor = if decay_rate == 0.0 { tensor.clone() } else if self.cautious_weight_decay { // Cautious weight decay. // See: https://arxiv.org/abs/2510.12402 let tensor_pos = tensor.clone().greater_equal_elem(0.0); let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0); let differ = tensor_pos.not_equal(grad_pos); // Zero out the decay where the decay is counter to the update direction. tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0) } else { tensor.clone().mul_scalar(1.0 - decay_rate) }; let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr); let state = AdamWState { momentum: momentum_state, }; (tensor_updated, Some(state)) } fn to_device(mut state: Self::State, device: &Device) -> Self::State { state.momentum = state.momentum.to_device(device); state } } impl AdamWConfig { /// Initialize AdamW optimizer. /// /// # Returns /// /// Returns an optimizer that can be used to optimize a module. pub fn init>(&self) -> OptimizerAdaptor { let optim = AdamW { momentum: AdaptiveMomentumW { beta_1: self.beta_1, beta_2: self.beta_2, epsilon: self.epsilon, amsgrad: self.amsgrad, }, weight_decay: self.weight_decay, cautious_weight_decay: self.cautious_weight_decay, }; let mut optim = OptimizerAdaptor::from(optim); if let Some(config) = &self.grad_clipping { optim = optim.with_grad_clipping(config.init()); } optim } } #[derive(Clone)] struct AdaptiveMomentumW { beta_1: f32, beta_2: f32, epsilon: f32, amsgrad: bool, } impl AdaptiveMomentumW { pub fn transform( &self, grad: Tensor, state: Option>, ) -> (Tensor, AdaptiveMomentumState) { let factor_1 = 1.0 - self.beta_1; let factor_2 = 1.0 - self.beta_2; let state = if let Some(mut state) = state { // Update first moment estimate. state.moment_1 = state .moment_1 .mul_scalar(self.beta_1) .add(grad.clone().mul_scalar(factor_1)); // Update second moment estimate. state.moment_2 = state .moment_2 .mul_scalar(self.beta_2) .add(grad.square().mul_scalar(factor_2)); if self.amsgrad { let max_v = state .max_moment_2 .take() .unwrap_or_else(|| state.moment_2.clone()); state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone())); } // Update time. state.time += 1; state } else { // Initialize first moment estimate. let moment_1 = grad.clone().mul_scalar(factor_1); // Initialize second moment estimate. let moment_2 = grad.square().mul_scalar(factor_2); let max_moment_2 = self.amsgrad.then(|| moment_2.clone()); AdaptiveMomentumState { time: 1, moment_1, moment_2, max_moment_2, } }; let time: i32 = state.time as i32; // Compute bias-corrected first and second moment estimates. let moment_1_corrected = state .moment_1 .clone() .div_scalar(1f32 - self.beta_1.powi(time)); let v_to_use = if self.amsgrad { state.max_moment_2.as_ref().unwrap_or(&state.moment_2) } else { &state.moment_2 }; let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time)); let update_delta = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon)); (update_delta, state) } } #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use crate::{GradientsParams, Optimizer}; use burn::module::{Module, Param}; use burn::tensor::{Distribution, Tensor, TensorData}; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_nn::{Linear, LinearConfig, LinearRecord}; type FT = FloatElem; const LEARNING_RATE: LearningRate = 0.01; #[test] fn test_adamw_optimizer_save_load_state() { let device = Default::default(); let linear = LinearConfig::new(6, 6).init(&device); let x = Tensor::::random([2, 6], Distribution::Default, &device); let mut optimizer = create_adamw(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let _linear = optimizer.step(LEARNING_RATE, linear, grads); #[cfg(feature = "std")] { use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; BinFileRecorder::::default() .record( optimizer.to_record(), std::env::temp_dir().as_path().join("test_optim_adamw"), ) .unwrap(); } #[cfg(not(feature = "std"))] { use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder}; let result = BinBytesRecorder::::default() .record(optimizer.to_record(), ()) .unwrap(); assert!(!result.is_empty()); } let state_optim_before = optimizer.to_record(); let state_optim_before_copy = optimizer.to_record(); let optimizer = create_adamw(); let optimizer = optimizer.load_record(state_optim_before_copy); let state_optim_after = optimizer.to_record(); assert_eq!(state_optim_before.len(), state_optim_after.len()); } #[test] fn test_adamw_optimizer_with_amsgrad_50_steps() { let device = Default::default(); let mut linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let mut optimizer = AdamWConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_amsgrad(true) .with_weight_decay(0.5) .init(); for i in 1..=50 { let x = Tensor::::ones([2, 6], &device) .mul_scalar(i as f32 * 0.1) .require_grad(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); linear = optimizer.step(LEARNING_RATE, linear, grads); } let state_updated = linear.into_record(); let weight_updated = state_updated.weight.to_data(); let bias_updated = state_updated.bias.unwrap().to_data(); let weights_expected = TensorData::from([ [ -0.7822558283805847, -0.42578864097595215, -0.21805696189403534, -0.28366872668266296, -0.46587175130844116, -0.4805040955543518, ], [ -0.4722539782524109, -0.5471276640892029, -0.8181359767913818, -0.33425918221473694, -0.3805687427520752, -0.7601516842842102, ], [ -0.5475167632102966, -0.5057991743087769, -0.763265073299408, -0.3393959403038025, -0.7490996718406677, -0.28911691904067993, ], [ -0.7646660208702087, -0.7050473093986511, -0.8218720555305481, -0.7647438049316406, -0.5919585227966309, -0.40617525577545166, ], [ -0.27588561177253723, -0.7025567889213562, -0.24343004822731018, -0.6672990918159485, -0.23728127777576447, -0.556389570236206, ], [ -0.5451040267944336, -0.5420684814453125, -0.4348171353340149, -0.3832150399684906, -0.5099242925643921, -0.23440153896808624, ], ]); let bias_expected = TensorData::from([ -0.7473056316375732, -0.3745720386505127, -0.5188710689544678, -0.35184532403945923, -0.33705732226371765, -0.4332566559314728, ]); type FT = FloatElem; let tolerance = Tolerance::absolute(1e-5); weight_updated.assert_approx_eq::(&weights_expected, tolerance); bias_updated.assert_approx_eq::(&bias_expected, tolerance); } #[test] fn test_adamw_optimizer_with_numbers() { let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let device = Default::default(); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &device, ) .require_grad(); let mut optimizer = AdamWConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_weight_decay(0.5) .init(); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); let weights_expected = TensorData::from([ [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], [ 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, ], [ -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, ], [ -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, ], [ 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, ], [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580], ]); let bias_expected = TensorData::from([ -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, ]); let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); let tolerance = Tolerance::absolute(1e-2); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } #[test] fn test_adamw_optimizer_with_numbers_cautious() { let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let device = Default::default(); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085], ], &device, ) .require_grad(); let mut optimizer = AdamWConfig::new() .with_cautious_weight_decay(true) .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_weight_decay(0.5) .init(); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); let weights_expected = TensorData::from([ [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534], [ 0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182, ], [ -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981, ], [ -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081, ], [ 0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993, ], [ -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332, ], ]); let bias_expected = TensorData::from([ -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080, ]); let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); let tolerance = Tolerance::absolute(1e-2); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } #[test] fn test_adam_optimizer_no_nan() { let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let x = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &Default::default(), ) .require_grad(); let mut optimizer = AdamWConfig::new() .with_epsilon(1e-8) .with_beta_1(0.9) .with_beta_2(0.999) .with_weight_decay(0.5) .init(); let grads = linear.forward(x.clone()).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); assert!(!state_updated.weight.to_data().as_slice::().unwrap()[0].is_nan()); } fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: Some(Param::from_data(bias, &device)), }; LinearConfig::new(6, 6).init(&device).load_record(record) } fn create_adamw() -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdamWConfig::new(); AdamW { momentum: AdaptiveMomentumW { beta_1: config.beta_1, beta_2: config.beta_2, epsilon: config.epsilon, amsgrad: config.amsgrad, }, weight_decay: config.weight_decay, cautious_weight_decay: false, } .into() } } ================================================ FILE: crates/burn-optim/src/optim/base.rs ================================================ use burn_core::{self as burn, Tensor}; use burn_core::module::ParamId; use burn_core::prelude::{Backend, DeviceOps}; use burn_core::tensor::Device; use burn_core::tensor::backend::DeviceId; use super::GradientsParams; use crate::LearningRate; use alloc::vec::Vec; use burn::module::AutodiffModule; use burn::record::Record; use burn::tensor::backend::AutodiffBackend; #[derive(Default)] /// Exposes multiple gradients for each parameter. pub struct MultiGradientsParams { /// Each [GradientsParams] has its associated [DeviceId]. pub grads: Vec<(GradientsParams, DeviceId)>, } impl MultiGradientsParams { /// Removes the gradients for the given [parameter id](ParamId). /// /// Potentially accumulates the gradients from multiple sources using a device associated with /// a parameter id. The same parameter will be accumulated using the same device during /// all training. pub fn remove( &mut self, id: ParamId, ) -> Option<(Tensor, Device)> { let (mut tensor, device, index) = self.select(id)?; for (i, (grads, _)) in self.grads.iter_mut().enumerate() { if i == index { continue; } if let Some(grad) = grads.remove::(id) { tensor = tensor + grad.to_device(&device); } } Some((tensor, device)) } fn select( &mut self, id: ParamId, ) -> Option<(Tensor, Device, usize)> { let id_val = id.val() as usize; for i in 0..self.grads.len() { let selected_device_index = (id_val + i) % self.grads.len(); if let Some(acc) = self.grads[selected_device_index].0.remove::(id) { let device_id = self.grads[selected_device_index].1; let device = ::from_id(device_id); return Some((acc.to_device(&device), device, selected_device_index)); } } None } } /// General trait to optimize [module](AutodiffModule). pub trait Optimizer: Send + Clone where M: AutodiffModule, B: AutodiffBackend, { /// Optimizer associative type to be used when saving and loading the state. type Record: Record; /// Perform the optimizer step using the given learning rate and gradients. /// The updated module is returned. fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M; /// Perform the optimizer step using the given learning rate and gradients. /// The updated module is returned. fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M; /// Get the current state of the optimizer as a [record](Record). fn to_record(&self) -> Self::Record; /// Load the state of the optimizer as a [record](Record). fn load_record(self, record: Self::Record) -> Self; } ================================================ FILE: crates/burn-optim/src/optim/decay.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::record::Record; use burn::tensor::Tensor; use burn::tensor::backend::Backend; /// Configuration to create [weight decay](WeightDecay). #[derive(Config, Debug)] pub struct WeightDecayConfig { /// L2 penalty. pub penalty: f32, } /// State of [weight decay](WeightDecay). #[derive(Record, Clone, new)] pub struct WeightDecayState { pub(crate) grad_last_step: Tensor, } /// Weight decay implementation that transforms gradients. #[derive(Clone)] pub struct WeightDecay { penalty: f32, } impl WeightDecay { /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). pub fn new(config: &WeightDecayConfig) -> Self { Self { penalty: config.penalty, } } /// Transforms a gradient. /// /// # Arguments /// /// * `grad` - Gradient to transform. /// * `tensor` - Tensor param of the last iteration. /// /// # Returns /// /// * `grad` - Transformed gradient. pub fn transform( &self, grad: Tensor, tensor: Tensor, ) -> Tensor { tensor.mul_scalar(self.penalty).add(grad) } } impl WeightDecayState { /// Moves the state to a device. /// /// # Arguments /// /// * `device` - Device to move the state to. /// /// # Returns /// /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.grad_last_step = self.grad_last_step.to_device(device); self } } ================================================ FILE: crates/burn-optim/src/optim/grad_accum.rs ================================================ use burn_core as burn; use core::marker::PhantomData; use burn::module::{AutodiffModule, ModuleVisitor, Param}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use super::GradientsParams; /// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object. pub struct GradientsAccumulator { grads: GradientsParams, phantom: PhantomData, } impl Default for GradientsAccumulator { fn default() -> Self { Self::new() } } impl GradientsAccumulator { /// Create a new gradients accumulator. pub fn new() -> Self { Self { grads: GradientsParams::new(), phantom: PhantomData, } } } impl GradientsAccumulator { /// Accumulate the given gradients for each parameter in the given module. pub fn accumulate(&mut self, module: &M, grads: GradientsParams) where M: AutodiffModule, { let mut visitor = ModuleGradsAccumulator::::new(&mut self.grads, grads); module.visit(&mut visitor); } /// Return the accumulated gradients and reset the accumulator state. pub fn grads(&mut self) -> GradientsParams { let mut grads = GradientsParams::new(); core::mem::swap(&mut self.grads, &mut grads); grads } } #[derive(new)] struct ModuleGradsAccumulator<'a, M> { grads: &'a mut GradientsParams, grads_new: GradientsParams, phantom: PhantomData, } impl> ModuleVisitor for ModuleGradsAccumulator<'_, M> { fn visit_float(&mut self, param: &Param>) { let grad_updated = match self.grads_new.remove::(param.id) { Some(new) => match self.grads.remove::(param.id) { Some(grad) => grad.add(new), None => new, }, None => match self.grads.remove::(param.id) { Some(grad) => grad, None => return, }, }; self.grads .register::(param.id, grad_updated); } } #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use burn::tensor::{Distribution, backend::Backend}; use burn_nn::{Linear, LinearConfig}; #[test] fn test_accumulate_gradients_one_step() { let device = Default::default(); let mut accumulator = GradientsAccumulator::new(); let layer = layer::(&device); let loss = layer.forward(random_tensor::(&device)); let grads = GradientsParams::from_grads(loss.backward(), &layer); accumulator.accumulate(&layer, grads); let grads = accumulator.grads(); assert!(!grads.is_empty()) } #[test] fn test_accumulate_gradients_two_steps() { let device = Default::default(); let mut accumulator = GradientsAccumulator::new(); let layer = layer::(&device); let loss_1 = layer.forward(random_tensor(&device)); let loss_2 = layer.forward(random_tensor(&device)); let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer); let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer); accumulator.accumulate(&layer, grads_1); accumulator.accumulate(&layer, grads_2); let grads = accumulator.grads(); assert_eq!(grads.len(), 2) } fn layer(device: &B::Device) -> Linear { LinearConfig::new(20, 20).init(device) } fn random_tensor(device: &B::Device) -> Tensor { Tensor::::random([2, 20], Distribution::Default, device) } } ================================================ FILE: crates/burn-optim/src/optim/grads.rs ================================================ use burn_core as burn; #[cfg(feature = "collective")] use burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce}; use burn::{ Tensor, tensor::{ backend::{AutodiffBackend, Backend}, container::TensorContainer, }, }; use burn::module::{AutodiffModule, ParamId}; use super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter}; /// Data type that contains gradients for parameters. #[derive(Default, Debug)] pub struct GradientsParams { container: TensorContainer, } impl GradientsParams { /// Creates a new [GradientsParams](GradientsParams). pub fn new() -> Self { Self::default() } /// Extract each tensor gradients for the given [module](AutodiffModule). /// /// Note: This consumes the gradients. See ['from_module'] to extract gradients only for /// a specific module. pub fn from_grads>( grads: B::Gradients, module: &M, ) -> Self { let mut grads = grads; Self::from_module(&mut grads, module) } /// Extract each tensor gradients for the given [module](AutodiffModule). pub fn from_module>( grads: &mut B::Gradients, module: &M, ) -> Self { let mut grads_params = GradientsParams::new(); let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params, None); module.visit(&mut visitor); grads_params } /// Extract tensor gradients for the given [module](AutodiffModule) and given parameters. pub fn from_params>( grads: &mut B::Gradients, module: &M, params: &[ParamId], ) -> Self { let mut grads_params = GradientsParams::new(); let mut visitor = GradientsParamsConverter::::new(grads, &mut grads_params, Some(params.to_vec())); module.visit(&mut visitor); grads_params } /// Get the gradients for the given [parameter id](ParamId). /// /// # Notes /// /// You should use [remove](GradientsParams::remove) if you want to get the gradients /// only one time. pub fn get(&self, id: ParamId) -> Option> where B: Backend, { self.container.get(&id).map(Tensor::from_primitive) } /// Remove the gradients for the given [parameter id](ParamId). pub fn remove(&mut self, id: ParamId) -> Option> where B: Backend, { self.container.remove(&id).map(Tensor::from_primitive) } /// Register a gradients tensor for the given [parameter id](ParamId). /// /// # Notes /// /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced. pub fn register(&mut self, id: ParamId, value: Tensor) where B: Backend, { self.container.register(id, value.into_primitive()) } /// The number of gradients tensors registered. pub fn len(&self) -> usize { self.container.len() } /// If any tensor is contained. pub fn is_empty(&self) -> bool { self.len() == 0 } /// Change the device of each tensor gradients registered for the given [module](AutodiffModule). pub fn to_device>( mut self, device: &B::Device, module: &M, ) -> Self { let mut visitor = GradientsParamsChangeDevice::::new(device, &mut self); module.visit(&mut visitor); self } /// Syncs the gradient params with the other peers in the collective. #[cfg(feature = "collective")] pub fn all_reduce( mut self, peer_id: PeerId, op: ReduceOperation, ) -> Result { let mut ids = self .container .ids() .into_iter() .copied() .collect::>(); // This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes! ids.sort(); for id in ids { let Some(grad) = self.container.remove::(&id) else { todo!() }; let grad = match grad { burn::tensor::TensorPrimitive::Float(grad) => { let grad = all_reduce::(peer_id, grad, op)?; burn::tensor::TensorPrimitive::Float(grad) } burn::tensor::TensorPrimitive::QFloat(_grad) => { unimplemented!("quantized all-reduce unimplemented") } }; self.container.register::(id, grad); } Ok(self) } } #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use burn::module::{Module, list_param_ids}; use burn::tensor::{Distribution, backend::Backend}; use burn_nn::{Linear, LinearConfig}; #[test] fn test_convert_grads() { let device = Default::default(); let layer_1 = layer::(&device); let mut layer_2 = layer_1.clone(); layer_2 = layer_2.fork(&device); let loss_1 = layer_1.forward(random_tensor(&device)); let loss_2 = layer_2.forward(random_tensor(&device)); let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2); let param_ids_1 = list_param_ids(&layer_1); let param_ids_2 = list_param_ids(&layer_2); assert_eq!(param_ids_1, param_ids_2); assert_eq!(grads_1.len(), param_ids_1.len()); assert_eq!(grads_2.len(), param_ids_2.len()); } fn layer(device: &B::Device) -> Linear { LinearConfig::new(20, 20).init(device) } fn random_tensor(device: &B::Device) -> Tensor { Tensor::::random([2, 20], Distribution::Default, device) } } ================================================ FILE: crates/burn-optim/src/optim/lbfgs.rs ================================================ #![allow(clippy::excessive_precision)] use burn_core as burn; use super::GradientsParams; use crate::LearningRate; use burn::config::Config; use burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param}; use burn::prelude::ToElement; use burn::record::Record; use burn::tensor::backend::Backend; use burn::tensor::{Tensor, backend::AutodiffBackend}; use serde::{Deserialize, Serialize}; use alloc::vec; use alloc::vec::Vec; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Cubic Interpolate /// /// Uses two points (x1, f1), (x2, f2) and their first derivatives g1,g2 to construct /// a cubic interpolant and return its minimum within the given bounds. fn cubic_interpolate( x1: f64, f1: f64, g1: f64, x2: f64, f2: f64, g2: f64, bounds: Option<(f64, f64)>, ) -> f64 { // Compute bounds of interpolation area let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) }); // Code for most common case: cubic interpolation of 2 points // with function and derivative values for both // Solution in this case (where x2 is the farthest point) // d1 = g1 + g2 - 3*(f1 - f2) / (x1-x2); // d2 = sqrt(d1^2 - g1 * g2); // min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); // t_new = min(max(min_pos,min_bound), max_bound); let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2); let d2_square = d1 * d1 - g1 * g2; if d2_square >= 0.0 { let d2 = d2_square.sqrt(); let min_pos = if x1 <= x2 { x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2)) } else { x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2)) }; min_pos.max(min_bound).min(max_bound) } else { (min_bound + max_bound) / 2.0 } } /// Auxiliary Struct For Strong_Wolfe struct LineSearchSample { // step size t: f64, // loss f: f64, // gradient g: Tensor, // directional derivative gtd: f64, } #[allow(clippy::too_many_arguments)] fn strong_wolfe( // obj_func(x,step size,direction) -> (loss,grad) obj_func: &mut F, x: &Tensor, // initial step size mut t: f64, d: &Tensor, f: f64, g: Tensor, gtd: f64, c1: f64, c2: f64, tolerance_change: f64, max_ls: usize, ) -> (f64, Tensor, f64, usize) where F: FnMut(&Tensor, f64, &Tensor) -> (f64, Tensor), { let d_norm = d.clone().abs().max().into_scalar().to_f64(); // evaluate objective and gradient using initial step let (mut f_new, mut g_new) = obj_func(x, t, d); let mut ls_func_evals = 1; let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64(); // bracket an interval [t_prev,t] containing a point satisfying the Wolfe criteria let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd); let mut done = false; let mut ls_iter = 0; // the interval [low,high] using for Zoom phase let mut bracket: Option<[LineSearchSample; 2]> = None; // point which satisfy the wolfe condition let mut wolfe_bracket: Option> = None; while ls_iter < max_ls { // Checking Conditions. // Checking the Armijo Condition and function value increasing condition. // Armijo: f(x+t*d) <= f(x) + c_1 t gtd if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) { bracket = Some([ LineSearchSample { t: t_prev, f: f_prev, g: g_prev, gtd: gtd_prev, }, LineSearchSample { t, f: f_new, g: g_new.clone(), gtd: gtd_new, }, ]); break; } // Checking Strong Wolfe Condition // |gtd_new| <= -c_2 gtd if gtd_new.abs() <= -c2 * gtd { wolfe_bracket = Some(LineSearchSample { t, f: f_new, g: g_new.clone(), gtd: gtd_new, }); done = true; break; } // gtd_new >=0 , there must be a local minimum in the interval. if gtd_new >= 0.0 { bracket = Some([ LineSearchSample { t: t_prev, f: f_prev, g: g_prev, gtd: gtd_prev, }, LineSearchSample { t, f: f_new, g: g_new.clone(), gtd: gtd_new, }, ]); break; } // interpolate let min_step = t + 0.01 * (t - t_prev); let max_step = t * 10.0; let t_next = cubic_interpolate( t_prev, f_prev, gtd_prev, t, f_new, gtd_new, Some((min_step, max_step)), ); t_prev = t; f_prev = f_new; g_prev = g_new; gtd_prev = gtd_new; // next step t = t_next; (f_new, g_new) = obj_func(x, t, d); ls_func_evals += 1; gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64(); ls_iter += 1; } if let Some(sample) = wolfe_bracket { return (sample.f, sample.g, sample.t, ls_func_evals); } let mut bracket = bracket.unwrap_or_else(|| { [ LineSearchSample { t: 0.0, f, g: g.clone(), gtd, }, LineSearchSample { t, f: f_new, g: g_new.clone(), gtd: gtd_new, }, ] }); // zoom phase let mut insuf_progress = false; // find high and low points in bracket let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f { (0, 1) } else { (1, 0) }; while !done && ls_iter < max_ls { let diff = (bracket[1].t - bracket[0].t).abs(); // line-search bracket is so small if diff * d_norm < tolerance_change { break; } // compute new trial value t = cubic_interpolate( bracket[0].t, bracket[0].f, bracket[0].gtd, bracket[1].t, bracket[1].f, bracket[1].gtd, None, ); let b_min = bracket[0].t.min(bracket[1].t); let b_max = bracket[0].t.max(bracket[1].t); let eps = 0.1 * (b_max - b_min); if (b_max - t).min(t - b_min) < eps { // interpolation close to boundary if insuf_progress || t >= b_max || t <= b_min { t = if (t - b_max).abs() < (t - b_min).abs() { b_max - eps } else { b_min + eps }; insuf_progress = false; } else { insuf_progress = true; } } else { insuf_progress = false; } // Evaluate new point (f_new, g_new) = obj_func(x, t, d); ls_func_evals += 1; gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64(); ls_iter += 1; let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f; if !armijo_holds { bracket[high_idx] = LineSearchSample { t, f: f_new, g: g_new, gtd: gtd_new, }; } else { if gtd_new.abs() <= -c2 * gtd { return (f_new, g_new, t, ls_func_evals); } if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 { bracket[high_idx] = LineSearchSample { t: bracket[low_idx].t, f: bracket[low_idx].f, g: bracket[low_idx].g.clone(), gtd: bracket[low_idx].gtd, }; } bracket[low_idx] = LineSearchSample { t, f: f_new, g: g_new, gtd: gtd_new, }; } if bracket[0].f <= bracket[1].f { low_idx = 0; high_idx = 1; } else { low_idx = 1; high_idx = 0; } } // return stuff ( bracket[low_idx].f, bracket[low_idx].g.clone(), bracket[low_idx].t, ls_func_evals, ) } /// Strategy for the line search optimization phase #[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum LineSearchFn { /// No line search performed #[default] None, /// strong wolfe conditions /// /// See: StrongWolfe, } /// LBFGS Configuration. #[derive(Config, Debug)] pub struct LBFGSConfig { /// Maximal number of iterations per optimization step (default: 20) #[config(default = 20)] pub max_iter: usize, /// Update history size (default: 100). #[config(default = 100)] pub history_size: usize, /// Termination tolerance on first order optimality (default: 1e-7). #[config(default = 1e-7)] pub tolerance_grad: f64, /// Termination tolerance on function value/parameter changes (default: 1e-9). #[config(default = 1e-9)] pub tolerance_change: f64, /// Maximal number of function evaluations per optimization step (default: max_iter * 1.25). #[config(default = "None")] pub max_eval: Option, /// Either ‘strong_wolfe’ or None (default: None). #[config(default = "LineSearchFn::None")] pub line_search_fn: LineSearchFn, } impl LBFGSConfig { /// Initialize AdamW optimizer /// /// # Returns /// /// Returns an optimizer that can be used to optimize a module pub fn init(&self) -> LBFGS { // by default max_eval = max_iter * 5/4 let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4); LBFGS { config: LBFGSConfig { max_iter: self.max_iter, history_size: self.history_size, tolerance_grad: self.tolerance_grad, tolerance_change: self.tolerance_change, max_eval: Some(max_eval), line_search_fn: self.line_search_fn, }, state: Default::default(), } } } /// Collects gradients in module visit order. struct FlattenGradsVisitorInner<'a, B: AutodiffBackend> { grads: &'a GradientsParams, tensors: &'a mut Vec>, } impl ModuleVisitor for FlattenGradsVisitorInner<'_, B> { fn visit_float(&mut self, param: &Param>) { if let Some(g) = self.grads.get::(param.id) { let numel = g.shape().num_elements(); self.tensors.push(g.reshape([numel])); } } } /// Flatten params to inner backend 1D tensor. fn flatten_params_inner>( module: &M, ) -> Tensor { let mut tensors = Vec::new(); let mut visitor = FlattenParamsVisitorInner:: { tensors: &mut tensors, }; module.visit(&mut visitor); if tensors.is_empty() { return Tensor::empty([0], &module.devices()[0]); } Tensor::cat(tensors, 0) } struct FlattenParamsVisitorInner<'a, B: AutodiffBackend> { tensors: &'a mut Vec>, } impl ModuleVisitor for FlattenParamsVisitorInner<'_, B> { fn visit_float(&mut self, param: &Param>) { let t = param.val().inner(); let numel = t.shape().num_elements(); self.tensors.push(t.reshape([numel])); } } /// Flatten gradients for a module. fn flatten_grads_inner>( module: &M, grads: &GradientsParams, ) -> Tensor { let mut tensors = Vec::new(); let mut visitor = FlattenGradsVisitorInner { grads, tensors: &mut tensors, }; module.visit(&mut visitor); if tensors.is_empty() { return Tensor::empty([0], &module.devices()[0]); } Tensor::cat(tensors, 0) } /// Mapper that assigns each float param from a flat inner-backend 1D tensor. struct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> { flat: &'a Tensor, offset: &'a mut usize, } impl ParamsFromFlatMapperInner<'_, B> { fn take_slice(&mut self, numel: usize) -> Tensor { let start = *self.offset; *self.offset += numel; self.flat.clone().slice(start..*self.offset) } } impl ModuleMapper for ParamsFromFlatMapperInner<'_, B> { fn map_float(&mut self, param: Param>) -> Param> { let (id, tensor, mapper) = param.consume(); let numel = tensor.shape().num_elements(); let slice_1d = self.take_slice(numel); let new_inner = slice_1d.reshape(tensor.shape()); let new_tensor = Tensor::from_inner(new_inner).require_grad(); Param::from_mapped_value(id, new_tensor, mapper) } } /// Overwrite module parameters from a flat inner-backend 1D tensor fn set_params_from_flat_inner>( module: M, flat: Tensor, ) -> M { let mut offset = 0; let mut mapper = ParamsFromFlatMapperInner { flat: &flat, offset: &mut offset, }; module.map(&mut mapper) } /// L-BFGS optimizer state #[derive(Clone, Record)] pub struct LBFGSState { /// Historical displacement vectors pub history_s: Vec>, /// Historical gradient difference vectors pub history_y: Vec>, /// Search direction pub d: Option>, /// Step size from the previous iteration pub t: Option, /// Flattened gradient from the previous iteration pub prev_flat_grad: Option>, /// Loss value from the previous iteration pub prev_loss: Option, /// Global iteration count pub g_iter: usize, } impl LBFGSState { /// Moves all historical tensors to the target device. pub fn to_device(self, device: &B::Device) -> Self { Self { history_s: self .history_s .into_iter() .map(|t| t.to_device(device)) .collect(), history_y: self .history_y .into_iter() .map(|t| t.to_device(device)) .collect(), d: self.d.map(|t| t.to_device(device)), t: self.t, prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)), prev_loss: self.prev_loss, g_iter: self.g_iter, } } } impl Default for LBFGSState { fn default() -> Self { Self { history_s: Vec::new(), history_y: Vec::new(), d: None, t: Some(1.0), prev_flat_grad: None, prev_loss: None, g_iter: 0, } } } /// L-BFGS optimizer. /// /// Ported from [pytorch](https://github.com/pytorch/pytorch/torch/optim/lbfgs.py). Heavily inspired by [miniFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html) /// /// See also: /// - [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) /// /// # Note /// This optimizer is memory intensive #[derive(Clone)] pub struct LBFGS { config: LBFGSConfig, state: LBFGSState, } impl LBFGS { /// A single optimization step for any tensor that represents the parameters of a model. pub fn step(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64) where M: AutodiffModule + Clone, F: FnMut(M) -> (f64, GradientsParams), { // evaluate initial f(x) and df/dx let (mut loss, grads) = closure(module.clone()); let mut current_evals = 1; let mut flat_grad = flatten_grads_inner::(&module, &grads); let mut x_flat = flatten_params_inner::(&module); let opt_cond = flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad; // optimal condition if opt_cond { return (module, loss); } // tensors cached in state let mut d = self .state .d .take() .unwrap_or_else(|| flat_grad.clone().neg()); let mut t = self.state.t.unwrap_or(lr); let mut prev_flat_grad = self.state.prev_flat_grad.take(); let mut n_iter = 0; // optimize for a max of max_iter iterations while n_iter < self.config.max_iter { // keep track of nb of iterations n_iter += 1; self.state.g_iter += 1; // compute gradient descent direction if self.state.g_iter == 1 { d = flat_grad.clone().neg(); self.state.history_s.clear(); self.state.history_y.clear(); } else { // do lbfgs update (update memory) if let Some(pg) = prev_flat_grad.as_ref() { let y = flat_grad.clone().sub(pg.clone()); let s = d.clone().mul_scalar(t); let ys = y.clone().dot(s.clone()).into_scalar().to_f64(); if ys > 1e-10 { // updating memory if self.state.history_s.len() >= self.config.history_size { // shift history by one (limited-memory) self.state.history_s.remove(0); self.state.history_y.remove(0); } self.state.history_s.push(s); self.state.history_y.push(y); } } // compute the approximate (L-BFGS) inverse Hessian // multiplied by the gradient let num_old = self.state.history_s.len(); let mut q = flat_grad.clone().neg(); let mut alphas: Vec> = vec![Tensor::zeros([1], &flat_grad.device()); num_old]; if num_old > 0 { // multiply by initial Hessian // r/d is the final direction for i in (0..num_old).rev() { let s = &self.state.history_s[i]; let y = &self.state.history_y[i]; let rho = y.clone().dot(s.clone()).powf_scalar(-1.0); let alpha = rho.clone().mul(s.clone().dot(q.clone())); alphas[i] = alpha.clone(); q = q.sub(y.clone().mul(alpha)); } let last_s = &self.state.history_s[num_old - 1]; let last_y = &self.state.history_y[num_old - 1]; let ys = last_y.clone().dot(last_s.clone()); let yy = last_y.clone().dot(last_y.clone()); let h_diag = ys.div(yy); let mut r = q.mul(h_diag); for ((s, y), alpha) in self .state .history_s .iter() .zip(self.state.history_y.iter()) .zip(alphas.into_iter()) .take(num_old) { let rho = y.clone().dot(s.clone()).powf_scalar(-1.0); let beta = rho.mul(y.clone().dot(r.clone())); r = r.add(s.clone().mul(alpha.sub(beta))); } d = r; } else { d = q; } } prev_flat_grad = Some(flat_grad.clone()); let prev_loss_iter = loss; // compute step len if self.state.g_iter == 1 { let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64(); t = (1.0f64 / grad_l1).min(1.0) * lr; } else { t = lr; } // directional derivative let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64(); if gtd > -self.config.tolerance_change { break; } let ls_func_evals; if let LineSearchFn::StrongWolfe = self.config.line_search_fn { // perform line search, using user function let mut obj_func = |current_x: &Tensor, step: f64, dir: &Tensor| { let update = dir.clone().mul_scalar(step); let new_x = current_x.clone().add(update); let tmp_module = set_params_from_flat_inner::(module.clone(), new_x); let (l, g) = closure(tmp_module); (l, flatten_grads_inner::(&module, &g)) }; let (ls_f, ls_g, ls_t, evals) = strong_wolfe( &mut obj_func, &x_flat, t, &d, loss, flat_grad.clone(), gtd, 1e-4, 0.9, self.config.tolerance_change, self.config.max_eval.unwrap() - current_evals, ); loss = ls_f; flat_grad = ls_g; t = ls_t; ls_func_evals = evals; x_flat = x_flat.add(d.clone().mul_scalar(t)); module = set_params_from_flat_inner::(module, x_flat.clone()); } else { // no line search, simply move with fixed-step let step_vec = d.clone().mul_scalar(t); x_flat = x_flat.add(step_vec); module = set_params_from_flat_inner::(module, x_flat.clone()); // re-evaluate function only if not in last iteration // the reason we do this: in a stochastic setting, // no use to re-evaluate that function here let (new_loss, new_grads) = closure(module.clone()); loss = new_loss; flat_grad = flatten_grads_inner::(&module, &new_grads); ls_func_evals = 1; } // update func eval current_evals += ls_func_evals; // check conditions if current_evals >= self.config.max_eval.unwrap() { break; } if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad { break; } if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64() <= self.config.tolerance_change { break; } if (loss - prev_loss_iter).abs() < self.config.tolerance_change { break; } } self.state.d = Some(d); self.state.t = Some(t); self.state.prev_flat_grad = prev_flat_grad; self.state.prev_loss = Some(loss); (module, loss) } /// Moves the optimizer state to the specified device. pub fn to_device(self, device: &B::Device) -> Self { Self { config: self.config, // History tensors reside in InnerBackend, so we convert the device accordingly state: self.state.to_device(device), } } } #[cfg(test)] mod tests { use super::*; use crate::GradientsParams; use crate::TestAutodiffBackend; use burn::module::{Module, Param}; use burn::tensor::{Tensor, TensorData}; use burn_nn::{Linear, LinearConfig, LinearRecord}; fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: Some(Param::from_data(bias, &device)), }; LinearConfig::new(6, 6).init(&device).load_record(record) } #[test] fn test_cubic_interpolate() { let tolerance = 1e-8; // basic let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0); let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None); assert!( (result - 0.00000).abs() < tolerance, "Basic: Result {} should be close to 0.0", result ); // bound let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0); let bounds = Some((0.6, 1.0)); let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds); assert!( (result - 0.6000000000).abs() < tolerance, "Bound: Result {} should be clamped to 0.6", result ); // d2_square < 0,should return mid value let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0); let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0))); assert!( (result - 0.5000000).abs() < tolerance, "Fallback: Result {} should be midpoint 0.5", result ); // asymmetric let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0); let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None); assert!( (result - 0.4606553370833684).abs() < tolerance, "Asymmetric: Result {} should be 0.4606553370833684", result ); // not good value let (x1, f1, g1, x2, f2, g2) = ( 1.231232145, -0.12567458754, 9.1231243007, 8.239105015, -100.9012398021, 123201321.0293982, ); let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None); let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4))); assert!( (result_1 - 5.9031480234724434).abs() < tolerance, "not good value 1: Result {} should be 5.9031480234724434", result ); assert!( (result_2 - 4.4000000000000004).abs() < tolerance, "not good value 2: Result {} should be 4.4000000000000004", result ); } #[test] fn test_strong_wolfe_direct_comparison() { let device = Default::default(); let tol = 1e-8; { let x = Tensor::::from_floats([2.1321912957_f64], &device); let d = Tensor::::from_floats([0.91312321_f64], &device); let t_initial = 1.213132_f64; fn func( x_base: &Tensor, t_val: f64, d_vec: &Tensor, ) -> (f64, Tensor) { let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val)); let x2 = curr_x.clone().mul(curr_x.clone()); let x3 = x2.clone().mul(curr_x.clone()); let x4 = x2.clone().mul(x2.clone()); // f(x) = x^4 - 2*x^2 + x let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone(); let f_val = f_elements.sum().into_scalar().to_f64(); // g(x) = 4*x^3 - 4*x + 1 let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0) + Tensor::ones_like(&curr_x); (f_val, g) } let (f_init, g_init) = func(&x, 0.0, &d); let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64(); println!("Initial State: f={},gtd = {}", f_init, gtd_init); assert!((f_init - 13.7080059052).abs() < tol); assert!((gtd_init - 28.5305728912).abs() < tol); let mut obj_func = |xb: &Tensor, tv: f64, dv: &Tensor| func(xb, tv, dv); let (f_final, _g_final, t_final, evals) = strong_wolfe( &mut obj_func, &x, t_initial, &d, f_init, g_init, gtd_init, 1e-4, // c1 0.9, // c2 1e-9, // tolerance_change 10, // max_ls ); let g_f = _g_final.into_scalar().to_f64(); println!( "f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}", f_final, g_f, t_final, evals ); assert!((f_final - 13.708005905151367).abs() < tol); assert!((g_f - 31.2450428009).abs() < tol); assert!((t_final.to_f64() - 0.0).abs() < tol); assert!((evals == 11)); } } #[test] fn test_lbfgs_strong_wolfe_comparison() { let device = Default::default(); let tol = 1e-5; let x_data = Tensor::::from_data([[1.0], [2.0], [3.0]], &device); let y_true = Tensor::::from_data([[3.0], [5.0], [7.0]], &device); let weight = TensorData::from([[0.5f64]]); let bias = TensorData::from([0.1f64]); let module = given_linear_layer(weight, bias); let mut optimizer: LBFGS = LBFGSConfig::new() .with_line_search_fn(LineSearchFn::StrongWolfe) .init(); let mut closure = |mod_in: Linear| { let output = mod_in.forward(x_data.clone()); let loss = burn_nn::loss::MseLoss::new().forward( output, y_true.clone(), burn_nn::loss::Reduction::Sum, ); let grads = loss.backward(); let grads_params = GradientsParams::from_grads(grads, &mod_in); (loss.into_scalar().to_f64(), grads_params) }; let initial_loss = closure(module.clone()).0; assert!((initial_loss - 50.1300048828).abs() < tol); let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure); assert!((final_loss - 0.0234732367).abs() < tol); let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64(); let optimized_bias: f64 = updated_module .bias .as_ref() .unwrap() .val() .into_scalar() .to_f64(); assert!((optimized_data - 2.0570652485).abs() < tol); assert!((optimized_bias - 0.8106800914).abs() < tol); } #[test] fn test_lbfgs_no_strong_wolfe_comparison() { let device = Default::default(); let tol = 1e-5; let x_data = Tensor::::from_data([[1.0], [2.0], [3.0]], &device); let y_true = Tensor::::from_data([[3.0], [5.0], [7.0]], &device); let weight = TensorData::from([[0.5f64]]); let bias = TensorData::from([0.1f64]); let module = given_linear_layer(weight, bias); let mut optimizer: LBFGS = LBFGSConfig::new() .with_line_search_fn(LineSearchFn::None) .init(); let mut closure = |mod_in: Linear| { let output = mod_in.forward(x_data.clone()); let loss = burn_nn::loss::MseLoss::new().forward( output, y_true.clone(), burn_nn::loss::Reduction::Sum, ); let grads = loss.backward(); let grads_params = GradientsParams::from_grads(grads, &mod_in); (loss.into_scalar().to_f64(), grads_params) }; let initial_loss = closure(module.clone()).0; assert!((initial_loss - 50.1300048828).abs() < tol); let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure); assert!((final_loss - 48.2181930542).abs() < tol); let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64(); let optimized_bias: f64 = updated_module .bias .as_ref() .unwrap() .val() .into_scalar() .to_f64(); assert!((optimized_data - 0.5302446192).abs() < tol); assert!((optimized_bias - 0.1142520783).abs() < tol); } } ================================================ FILE: crates/burn-optim/src/optim/mod.rs ================================================ /// Weight decay module for optimizers. pub mod decay; /// Momentum module for optimizers. pub mod momentum; mod adagrad; mod adam; mod adamw; mod base; mod grad_accum; mod grads; mod lbfgs; mod muon; mod rmsprop; mod sgd; mod simple; mod visitor; pub use adagrad::*; pub use adam::*; pub use adamw::*; pub use base::*; pub use grad_accum::*; pub use grads::*; pub use lbfgs::*; pub use muon::*; pub use rmsprop::*; pub use sgd::*; pub use simple::*; ================================================ FILE: crates/burn-optim/src/optim/momentum.rs ================================================ use burn_core as burn; use burn::config::Config; use burn::record::Record; use burn::tensor::backend::Backend; use burn::tensor::{ElementConversion, Tensor}; /// Configuration to create [momentum](Momentum). #[derive(Config, Debug)] pub struct MomentumConfig { /// Momentum factor #[config(default = 0.9)] pub momentum: f64, /// Dampening factor. #[config(default = 0.1)] pub dampening: f64, /// Enables Nesterov momentum, see [On the importance of initialization and /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). #[config(default = false)] pub nesterov: bool, } /// State of [momentum](Momentum). #[derive(Record, Clone, new)] pub struct MomentumState { velocity: Tensor, } /// Momentum implementation that transforms gradients. #[derive(Clone)] pub struct Momentum { momentum: B::FloatElem, dampening: f64, nesterov: bool, } impl Momentum { /// Creates a new [momentum](Momentum) from a [config](MomentumConfig). pub fn new(config: &MomentumConfig) -> Self { Self { momentum: config.momentum.elem(), dampening: config.dampening, nesterov: config.nesterov, } } /// Transforms a gradient. /// /// # Arguments /// /// * `grad` - Gradient to transform. /// * `state` - State of the optimizer. /// /// # Returns /// /// * `grad` - Transformed gradient. /// * `state` - State of the optimizer. pub fn transform( &self, grad: Tensor, state: Option>, ) -> (Tensor, MomentumState) { let velocity = if let Some(state) = state { grad.clone() .mul_scalar(1.0 - self.dampening) .add(state.velocity.mul_scalar(self.momentum)) } else { grad.clone() }; let grad = match self.nesterov { true => velocity.clone().mul_scalar(self.momentum).add(grad), false => velocity.clone(), }; (grad, MomentumState::new(velocity)) } } impl MomentumState { /// Moves the state to a device. /// /// # Arguments /// /// * `device` - Device to move the state to. /// /// # Returns /// /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.velocity = self.velocity.to_device(device); self } } ================================================ FILE: crates/burn-optim/src/optim/muon.rs ================================================ use burn_core as burn; use burn::{module::AutodiffModule, record::Record}; use burn::config::Config; use burn::tensor::{Tensor, backend::AutodiffBackend}; use burn::tensor::{backend::Backend, ops::Device}; use serde::{Deserialize, Serialize}; use super::{ SimpleOptimizer, adaptor::OptimizerAdaptor, decay::WeightDecayConfig, momentum::{Momentum, MomentumConfig, MomentumState}, }; use crate::LearningRate; #[cfg(not(feature = "std"))] #[allow(unused_imports)] use num_traits::Float as _; /// Learning rate adjustment method for Muon optimizer. /// /// Muon adjusts the learning rate based on parameter shape to maintain consistent /// RMS across rectangular matrices. /// /// # References /// /// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/) /// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982) #[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum AdjustLrFn { /// Keller Jordan's original method: `lr * sqrt(max(1, A/B))` /// /// This scales the learning rate based on the aspect ratio of the weight matrix, /// ensuring that tall matrices (more rows than columns) get proportionally larger /// learning rates. /// /// # Example /// /// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414` #[default] Original, /// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))` /// /// This method is designed to match AdamW's RMS, allowing Muon to directly reuse /// learning rates and weight decay values tuned for AdamW without retuning. /// /// # Example /// /// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4` MatchRmsAdamW, } impl AdjustLrFn { /// Calculate the learning rate adjustment ratio for a given parameter shape. /// /// # Arguments /// /// * `shape` - Parameter shape (uses first two dimensions) /// /// # Returns /// /// Adjustment ratio to multiply with the base learning rate fn adjustment_ratio(&self, shape: &[usize]) -> f64 { if shape.len() < 2 { return 1.0; } let a = shape[0] as f64; let b = shape[1] as f64; match self { Self::Original => { // sqrt(max(1, A/B)) let ratio = a / b; ratio.max(1.0).sqrt() } Self::MatchRmsAdamW => { // 0.2 * sqrt(max(A, B)) 0.2 * a.max(b).sqrt() } } } } /// Muon configuration. /// /// Muon is an optimizer specifically designed for 2D parameters of neural network /// hidden layers (weight matrices). Other parameters such as biases and embeddings /// should be optimized using a standard method such as AdamW. /// /// # Learning Rate Adjustment /// /// Muon adjusts the learning rate based on parameter shape to maintain consistent /// RMS across rectangular matrices. Two methods are available: /// /// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions. /// This is Keller Jordan's method and is the default. /// /// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method /// designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters. /// /// # Example /// /// ```ignore /// use burn_optim::{MuonConfig, AdjustLrFn}; /// /// // Using default (Original) method /// let optimizer = MuonConfig::new().init(); /// /// // Using MatchRmsAdamW for AdamW-compatible hyperparameters /// let optimizer = MuonConfig::new() /// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW) /// .init(); /// ``` /// /// # References /// /// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/) /// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982) /// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py) /// - [Original Implementation](https://github.com/KellerJordan/Muon) #[derive(Config, Debug)] pub struct MuonConfig { /// [Weight decay](WeightDecayConfig) config. weight_decay: Option, /// [Momentum](MomentumConfig) config. /// /// Muon always uses momentum. Default configuration: /// - momentum: 0.95 /// - dampening: 0.0 /// - nesterov: true #[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")] momentum: MomentumConfig, /// Newton-Schulz iteration coefficients (a, b, c). /// /// These coefficients are selected to maximize the slope at zero for the /// quintic iteration. Default values are from Keller Jordan's implementation. #[config(default = "(3.4445, -4.775, 2.0315)")] ns_coefficients: (f32, f32, f32), /// Epsilon for numerical stability. #[config(default = 1e-7)] epsilon: f32, /// Number of Newton-Schulz iteration steps. #[config(default = 5)] ns_steps: usize, /// Learning rate adjustment method. /// /// Controls how the learning rate is adjusted based on parameter shape. /// See [`AdjustLrFn`] for available methods. #[config(default = "AdjustLrFn::Original")] adjust_lr_fn: AdjustLrFn, } impl MuonConfig { /// Initialize Muon optimizer. /// /// # Returns /// /// Returns an optimizer adaptor that can be used to optimize a module. /// /// # Example /// /// ```ignore /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig}; /// /// // Basic configuration with default (Original) LR adjustment /// let optimizer = MuonConfig::new() /// .with_weight_decay(Some(WeightDecayConfig::new(0.01))) /// .init(); /// /// // With AdamW-compatible settings using MatchRmsAdamW /// let optimizer = MuonConfig::new() /// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW) /// .with_weight_decay(Some(WeightDecayConfig::new(0.1))) /// .init(); /// /// // Custom momentum and NS settings /// let optimizer = MuonConfig::new() /// .with_momentum(MomentumConfig { /// momentum: 0.9, /// dampening: 0.1, /// nesterov: false, /// }) /// .with_ns_steps(7) /// .init(); /// ``` pub fn init>( &self, ) -> OptimizerAdaptor, M, B> { let momentum = Momentum::new(&self.momentum); let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty); let optim = Muon { momentum, ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps), weight_decay_penalty, epsilon: self.epsilon, adjust_lr_fn: self.adjust_lr_fn, }; OptimizerAdaptor::from(optim) } } /// Parameters for Newton-Schulz orthogonalization. #[derive(Clone, Copy)] struct NewtonSchulzParams { a: f32, b: f32, c: f32, steps: usize, } impl NewtonSchulzParams { fn new(coefficients: (f32, f32, f32), steps: usize) -> Self { Self { a: coefficients.0, b: coefficients.1, c: coefficients.2, steps, } } } /// Muon optimizer. /// /// Muon internally runs standard SGD-momentum, and then performs an orthogonalization /// post-processing step, in which each 2D parameter's update is replaced with the /// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz /// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. /// /// # Important Notes /// /// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW /// or SGD for biases, embeddings, and layer norms. /// /// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based /// on parameter shape. See [`AdjustLrFn`] for details. /// /// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay /// AFTER orthogonalization but uses the original (unadjusted) learning rate for it. #[derive(Clone)] pub struct Muon { momentum: Momentum, ns_params: NewtonSchulzParams, weight_decay_penalty: Option, epsilon: f32, adjust_lr_fn: AdjustLrFn, } impl Muon { /// Adjust learning rate based on parameter shape. /// /// # Arguments /// /// * `lr` - Base learning rate /// * `shape` - Parameter shape (uses first two dimensions) /// /// # Returns /// /// Adjusted learning rate /// /// ```ignore /// // For a [1024, 512] weight matrix with lr=0.01: /// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414 /// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064 /// ``` fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate { lr * self.adjust_lr_fn.adjustment_ratio(shape) } /// Perform Newton-Schulz orthogonalization on a gradient tensor. /// /// This computes the zeroth power (orthogonalization) of the input matrix G /// using a quintic Newton-Schulz iteration. /// /// # Algorithm /// /// 1. Transpose if tall matrix (A > B) /// 2. Normalize: X = X / ||X|| /// 3. For k steps: /// - A = X @ X^T /// - B = b*A + c*A^2 /// - X = a*X + B@X /// 4. Transpose back if needed /// /// # References /// /// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py fn zeropower_via_newtonschulz(&self, g: Tensor) -> Tensor { let shape = g.shape(); let dim_m2 = shape[D - 2]; let dim_m1 = shape[D - 1]; // Step 1: Transpose if tall matrix (more rows than columns) let (mut x, needs_transpose) = if dim_m2 > dim_m1 { (g.swap_dims(D - 2, D - 1), true) } else { (g, false) }; // Step 2: Normalize by Frobenius norm // X = X / (||X|| + epsilon) let norm = x .clone() .powf_scalar(2.0) .sum() .sqrt() .clamp_min(self.epsilon) .unsqueeze(); x = x.div(norm); // Step 3: Newton-Schulz iteration // This is the quintic iteration with coefficients (a, b, c) let NewtonSchulzParams { a, b, c, steps } = self.ns_params; for _ in 0..steps { // A = X @ X^T let x_t = x.clone().swap_dims(D - 2, D - 1); let a_matrix = x.clone().matmul(x_t); // B = b*A + c*A@A let a_squared = a_matrix.clone().matmul(a_matrix.clone()); let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c)); // X = a*X + B@X x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone())); } // Step 4: Restore transpose if it was a tall matrix if needs_transpose { x = x.swap_dims(D - 2, D - 1); } x } } /// Muon state. #[derive(Record, Clone, new)] pub struct MuonState { /// Current momentum state pub momentum: MomentumState, } impl SimpleOptimizer for Muon { type State = MuonState; /// Perform a single Muon optimization step. /// /// # Algorithm /// /// 1. Apply momentum to gradient /// 2. Orthogonalize update via Newton-Schulz /// 3. Adjust learning rate based on parameter shape /// 4. Apply weight decay (using original lr) /// 5. Update parameter (using adjusted lr) /// /// # Notes /// /// Unlike typical optimizers, the weight decay and parameter update use /// different learning rates: /// - Weight decay uses the original `lr` /// - Parameter update uses the shape-adjusted `lr` /// /// # Panics /// This function will panic if the input tensors are not 2D. fn step( &self, lr: LearningRate, tensor: Tensor, grad: Tensor, state: Option>, ) -> (Tensor, Option>) { assert!( D == 2, "Newton-Schulz iteration requires 2D tensors, got {}D", D ); // Step 1: Apply momentum let state_momentum = state.map(|s| s.momentum); let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum); // Step 2: Orthogonalize via Newton-Schulz let update = self.zeropower_via_newtonschulz(grad); // Step 3: Adjust learning rate based on parameter shape let adjusted_lr = self.adjust_lr(lr, &tensor.shape()); // Step 4: Apply weight decay (using ORIGINAL lr, not adjusted) // Muon applies weight decay AFTER orthogonalization let tensor = if let Some(penalty) = self.weight_decay_penalty { let decay_factor = 1.0 - lr * penalty as f64; tensor.mul_scalar(decay_factor) } else { tensor }; // Step 5: Update parameter (using ADJUSTED lr) let delta = update.mul_scalar(adjusted_lr); let new_state = MuonState::new(new_momentum_state); (tensor - delta, Some(new_state)) } fn to_device(mut state: Self::State, device: &Device) -> Self::State { state.momentum = state.momentum.to_device(device); state } } #[cfg(test)] mod tests { use super::*; use crate::TestAutodiffBackend; use crate::{GradientsParams, Optimizer}; use burn::module::{Module, Param}; use burn::tensor::{Distribution, Tensor, TensorData}; use burn_nn::{Linear, LinearConfig, LinearRecord}; type TestBackend = burn_ndarray::NdArray; const TOLERANCE: f64 = 1e-8; fn given_linear_layer_no_bias(weight: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: None, //No bias for Muon optimizer }; LinearConfig::new(4, 4) .with_bias(false) .init(&device) .load_record(record) } #[test] fn test_adjust_lr_fn_original() { let method = AdjustLrFn::Original; // Square matrix [512, 512] -> sqrt(1) = 1.0 let ratio = method.adjustment_ratio(&[512, 512]); assert!((ratio - 1.0).abs() < TOLERANCE); // Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414 let ratio = method.adjustment_ratio(&[1024, 512]); let expected = (2.0f64).sqrt(); assert!((ratio - expected).abs() < TOLERANCE); // Wide matrix [512, 1024] -> max(1, 0.5) = 1.0 let ratio = method.adjustment_ratio(&[512, 1024]); assert!((ratio - 1.0).abs() < TOLERANCE); } #[test] fn test_adjust_lr_fn_match_rms_adamw() { let method = AdjustLrFn::MatchRmsAdamW; // [1024, 512] -> 0.2 * sqrt(1024) = 6.4 let ratio = method.adjustment_ratio(&[1024, 512]); let expected = 0.2 * 1024.0f64.sqrt(); assert!((ratio - expected).abs() < TOLERANCE); // [512, 512] -> 0.2 * sqrt(512) ≈ 4.525 let ratio = method.adjustment_ratio(&[512, 512]); let expected = 0.2 * 512.0f64.sqrt(); assert!((ratio - expected).abs() < TOLERANCE); } #[test] #[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")] fn test_1d_tensor_panics() { let device = Default::default(); let config = MuonConfig::new(); let optim: Muon = Muon { momentum: Momentum::new(&config.momentum), ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps), weight_decay_penalty: None, epsilon: config.epsilon, adjust_lr_fn: config.adjust_lr_fn, }; let tensor_1d = Tensor::::zeros([512], &device); let grad_1d = Tensor::::ones([512], &device); let _ = optim.step(0.01, tensor_1d, grad_1d, None); } #[test] fn test_muon_optimizer_save_load_state() { let device = Default::default(); // Use Linear layer WITHOUT bias for Muon optimizer let linear = LinearConfig::new(6, 6) .with_bias(false) // No bias - only 2D weight matrix .init::(&device); let x = Tensor::::random([2, 6], Distribution::Default, &device); let mut optimizer = MuonConfig::new().init::>(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let _linear = optimizer.step(0.01, linear, grads); let state_before = optimizer.to_record(); let state_before_copy = optimizer.to_record(); let optimizer_new = MuonConfig::new().init::>(); let optimizer_loaded = optimizer_new.load_record(state_before_copy); let state_after = optimizer_loaded.to_record(); assert_eq!(state_before.len(), state_after.len()); } #[test] fn test_muon_with_weight_decay() { let device = Default::default(); // Create Linear layer WITHOUT bias for Muon let linear = given_linear_layer_no_bias(TensorData::from([ [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], ])); let x = Tensor::::from_floats( [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]], &device, ) .require_grad(); let mut optimizer = MuonConfig::new() .with_weight_decay(Some(WeightDecayConfig::new(0.01))) .init::>(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(0.01, linear, grads); let state = linear.into_record(); let weight = state.weight.to_data(); for val in weight.as_slice::().unwrap() { assert!( *val < 1.0, "Weight should be reduced by weight decay, got {}", val ); } } #[test] fn test_newton_schulz_orthogonalization() { let device = Default::default(); let matrix = Tensor::::from_floats([[1.0, 0.5], [0.5, 1.0]], &device); let config = MuonConfig::new(); let muon: Muon = Muon { momentum: Momentum::new(&config.momentum), ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps), weight_decay_penalty: None, epsilon: config.epsilon, adjust_lr_fn: config.adjust_lr_fn, }; let orthogonalized = muon.zeropower_via_newtonschulz(matrix); let o_t = orthogonalized.clone().transpose(); let product = orthogonalized.matmul(o_t); let data = product.into_data(); let values = data.as_slice::().unwrap(); assert!( (values[0] - 1.0).abs() < 0.1, "Product[0,0] should be ~1.0, got {}", values[0] ); assert!( (values[3] - 1.0).abs() < 0.1, "Product[1,1] should be ~1.0, got {}", values[3] ); } #[test] fn test_tall_matrix_transpose() { // Test that tall matrices (A > B) are transposed during Newton-Schulz iteration // and then transposed back let device = Default::default(); // Create a tall matrix: [8, 4] (more rows than columns) let tall_matrix = Tensor::::from_floats( [ [1.0, 0.5, 0.3, 0.2], [0.5, 1.0, 0.4, 0.1], [0.3, 0.4, 1.0, 0.5], [0.2, 0.1, 0.5, 1.0], [0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.2, 0.4, 0.1, 0.3], [0.3, 0.1, 0.4, 0.2], ], &device, ); let config = MuonConfig::new(); let muon: Muon = Muon { momentum: Momentum::new(&config.momentum), ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps), weight_decay_penalty: None, epsilon: config.epsilon, adjust_lr_fn: config.adjust_lr_fn, }; // Perform Newton-Schulz orthogonalization let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone()); // Verify shape is preserved (should be transposed internally but returned in original shape) let original_shape = tall_matrix.shape(); let result_shape = orthogonalized.shape(); assert_eq!( original_shape.dims::<2>(), result_shape.dims::<2>(), "Shape should be preserved: [8, 4]" ); // Verify output is different from input (orthogonalization happened) let original_data = tall_matrix.into_data(); let result_data = orthogonalized.into_data(); assert_ne!( original_data.as_slice::().unwrap(), result_data.as_slice::().unwrap(), "Orthogonalized matrix should differ from input" ); // For comparison, test a wide matrix [4, 8] should NOT be transposed let wide_matrix = Tensor::::from_floats( [ [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3], [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1], [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4], [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2], ], &device, ); let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone()); // Verify wide matrix shape is also preserved let wide_original_shape = wide_matrix.shape(); let wide_result_shape = orthogonalized_wide.shape(); assert_eq!( wide_original_shape.dims::<2>(), wide_result_shape.dims::<2>(), "Wide matrix shape should be preserved: [4, 8]" ); } #[test] fn test_zero_gradient() { // Test that Muon handles zero gradients gracefully let device = Default::default(); let tensor = Tensor::::from_floats( [ [1.0, 0.5, 0.3, 0.2], [0.5, 1.0, 0.4, 0.1], [0.3, 0.4, 1.0, 0.5], [0.2, 0.1, 0.5, 1.0], ], &device, ); // Zero gradient - all zeros let zero_grad = Tensor::::zeros([4, 4], &device); let config = MuonConfig::new(); let muon: Muon = Muon { momentum: Momentum::new(&config.momentum), ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps), weight_decay_penalty: None, epsilon: config.epsilon, adjust_lr_fn: config.adjust_lr_fn, }; // Should not panic or produce NaN let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None); // Verify state was created assert!(state.is_some()); // With zero gradient and no weight decay, tensor should remain unchanged let original_data = tensor.into_data(); let updated_data = updated_tensor.clone().into_data(); let original_vals = original_data.as_slice::().unwrap(); let updated_vals = updated_data.as_slice::().unwrap(); for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) { assert!( (orig - upd).abs() < 1e-6, "With zero gradient, tensor should remain unchanged (or very close)" ); } // Verify no NaN values for val in updated_vals { assert!( !val.is_nan(), "Result should not contain NaN values with zero gradient" ); } // Test with weight decay - should still work let muon_with_decay: Muon = Muon { momentum: Momentum::new(&config.momentum), ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps), weight_decay_penalty: Some(0.01), epsilon: config.epsilon, adjust_lr_fn: config.adjust_lr_fn, }; let tensor2 = Tensor::::from_floats( [ [1.0, 0.5, 0.3, 0.2], [0.5, 1.0, 0.4, 0.1], [0.3, 0.4, 1.0, 0.5], [0.2, 0.1, 0.5, 1.0], ], &device, ); let zero_grad2 = Tensor::::zeros([4, 4], &device); let (updated_tensor_decay, _) = muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None); // With zero gradient but with weight decay, tensor should be slightly reduced let updated_decay_data = updated_tensor_decay.into_data(); let updated_decay_vals = updated_decay_data.as_slice::().unwrap(); for val in updated_decay_vals { assert!( !val.is_nan(), "Result should not contain NaN with zero gradient and weight decay" ); } // With weight decay, values should be slightly smaller than original let original_vals2 = tensor2.into_data().as_slice::().unwrap().to_vec(); for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) { if orig.abs() > 1e-6 { // Non-zero values should be reduced by weight decay assert!( upd.abs() < orig.abs(), "Weight decay should reduce magnitude: original={}, updated={}", orig, upd ); } } } } ================================================ FILE: crates/burn-optim/src/optim/rmsprop.rs ================================================ use burn_core as burn; use burn::{module::AutodiffModule, record::Record}; use super::{ SimpleOptimizer, adaptor::OptimizerAdaptor, decay::{WeightDecay, WeightDecayConfig}, }; use crate::{LearningRate, grad_clipping::GradientClippingConfig}; use burn::config::Config; use burn::tensor::backend::Backend; use burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device}; /// Configuration to create the [RmsProp](RmsProp) optimizer. #[derive(Config, Debug)] pub struct RmsPropConfig { /// Smoothing constant. #[config(default = 0.99)] alpha: f32, /// momentum for RmsProp. #[config(default = 0.9)] momentum: f32, /// A value required for numerical stability. #[config(default = 1e-5)] epsilon: f32, /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance #[config(default = false)] centered: bool, /// [Weight decay](WeightDecayConfig) config. weight_decay: Option, /// [Gradient Clipping](GradientClippingConfig) config. grad_clipping: Option, } impl RmsPropConfig { /// Initialize RmsProp optimizer. /// /// # Returns /// /// Returns an optimizer that can be used to optimize a module. pub fn init>( &self, ) -> OptimizerAdaptor { let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); let mut optim = OptimizerAdaptor::from(RmsProp { alpha: self.alpha, centered: self.centered, weight_decay, momentum: RmsPropMomentum { momentum: self.momentum, epsilon: self.epsilon, }, }); if let Some(config) = &self.grad_clipping { optim = optim.with_grad_clipping(config.init()); } optim } } /// Optimizer that implements stochastic gradient descent with momentum. /// The optimizer can be configured with [RmsPropConfig](RmsPropConfig). #[derive(Clone)] pub struct RmsProp { alpha: f32, // epsilon: f32, centered: bool, // momentum: Option>, momentum: RmsPropMomentum, weight_decay: Option, } impl SimpleOptimizer for RmsProp { type State = RmsPropState; fn step( &self, lr: LearningRate, tensor: Tensor, mut grad: Tensor, state: Option>, ) -> (Tensor, Option>) { // fetch state for params let mut state_square_avg = None; let mut state_centered = None; let mut state_momentum = None; if let Some(state) = state { state_square_avg = Some(state.square_avg); state_centered = Some(state.centered); state_momentum = state.momentum; } // weight_decay transform if let Some(weight_decay) = &self.weight_decay { grad = weight_decay.transform(grad, tensor.clone()); } // square_avg transform let (grad, state_square_avg) = SquareAvgState::transform(self.alpha, grad, state_square_avg); // centered transform let (grad, state_square_avg, state_centered) = CenteredState::transform( self.alpha, self.centered, grad, state_square_avg, state_centered, ); // momentum transform let (grad, state_centered, state_momentum) = self.momentum .transform(grad, state_centered, state_momentum); // transition state let state = RmsPropState::new(state_square_avg, state_centered, state_momentum); // tensor param transform let delta = grad.mul_scalar(lr); (tensor - delta, Some(state)) } fn to_device(mut state: Self::State, device: &Device) -> Self::State { state.square_avg = state.square_avg.to_device(device); state.centered = state.centered.to_device(device); state.momentum = state.momentum.map(|momentum| momentum.to_device(device)); state } } /// State of [RmsProp](RmsProp) #[derive(Record, Clone, new)] pub struct RmsPropState { /// Current squared average state. pub square_avg: SquareAvgState, /// Current centered state pub centered: CenteredState, /// Current gradient momentum, if any. pub momentum: Option>, } /// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct SquareAvgState { /// Current squared average. pub square_avg: Tensor, } impl SquareAvgState { /// transform [SquareAvgState] to the next step fn transform(alpha: f32, grad: Tensor, state: Option) -> (Tensor, Self) { match state { Some(state) => { let square_avg = state .square_avg .mul_scalar(alpha) .add(grad.clone().square().mul_scalar(1. - alpha)); (grad, Self { square_avg }) } _ => { let square_avg = grad.clone().square().mul_scalar(1. - alpha); (grad, Self { square_avg }) } } } /// Moves the state to a device. /// /// # Arguments /// /// * `device` - Device to move the state to. /// /// # Returns /// /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.square_avg = self.square_avg.to_device(device); self } } /// [CenteredState](CenteredState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct CenteredState { /// The averaged gradient to calculate the centered gradient, if available. pub grad_avg: Option>, /// The current average value. pub avg: Tensor, } impl CenteredState { /// transform [CenteredState] to the next step fn transform( alpha: f32, centered: bool, grad: Tensor, square_avg_state: SquareAvgState, centered_state: Option, ) -> (Tensor, SquareAvgState, Self) { if centered { let grad_avg_constant = grad.clone().mul_scalar(1. - alpha); let grad_avg = match centered_state { Some(state) => state .grad_avg .map_or(grad_avg_constant.clone(), move |grad_avg| { grad_avg.mul_scalar(alpha).add(grad_avg_constant) }), _ => grad_avg_constant, }; let avg = square_avg_state .square_avg .clone() .sub(grad_avg.clone().square()); ( grad, square_avg_state, Self { grad_avg: Some(grad_avg), avg, }, ) } else { ( grad, square_avg_state.clone(), Self { grad_avg: None, avg: square_avg_state.square_avg, }, ) } } /// Moves the state to a device. /// /// # Arguments /// /// * `device` - Device to move the state to. /// /// # Returns /// /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device)); self.avg = self.avg.to_device(device); self } } /// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer. /// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation) #[derive(Clone)] pub struct RmsPropMomentum { momentum: f32, epsilon: f32, } impl RmsPropMomentum { /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step fn transform( &self, grad: Tensor, centered_state: CenteredState, momentum_state: Option>, ) -> ( Tensor, CenteredState, Option>, ) { let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon)); if self.momentum > 0. { let buf = match momentum_state { Some(state) => state.buf.mul_scalar(self.momentum).add(grad), _ => grad, }; ( buf.clone(), centered_state, Some(RmsPropMomentumState { buf }), ) } else { (grad, centered_state, None) } } } /// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct RmsPropMomentumState { buf: Tensor, } impl RmsPropMomentumState { /// Moves the state to a device. /// /// # Arguments /// /// * `device` - Device to move the state to. /// /// # Returns /// /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.buf = self.buf.to_device(device); self } } #[cfg(test)] mod tests { use burn::tensor::ops::FloatElem; use burn::tensor::{Shape, Tolerance}; use super::*; use crate::TestAutodiffBackend; use crate::optim::{GradientsParams, Optimizer}; use burn::module::{Module, Param}; use burn::tensor::{Distribution, Tensor, TensorData}; use burn_nn::{Linear, LinearConfig, LinearRecord}; type FT = FloatElem; const LEARNING_RATE: LearningRate = 0.01; #[test] fn test_rmsprop_optimizer_save_load_state() { let device = Default::default(); let linear = LinearConfig::new(6, 6).init(&device); let x = Tensor::::random([2, 6], Distribution::Default, &device); let mut optimizer = create_rmsprop(); let grads = linear.forward(x).backward(); let grads = GradientsParams::from_grads(grads, &linear); let _linear = optimizer.step(LEARNING_RATE, linear, grads); #[cfg(feature = "std")] { use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; BinFileRecorder::::default() .record( optimizer.to_record(), std::env::temp_dir().as_path().join("test_optim_rmsprop"), ) .unwrap(); } #[cfg(not(feature = "std"))] { use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder}; let result = BinBytesRecorder::::default() .record(optimizer.to_record(), ()) .unwrap(); assert!(!result.is_empty()); } let state_optim_before = optimizer.to_record(); let state_optim_before_copy = optimizer.to_record(); let optimizer = create_rmsprop(); let optimizer = optimizer.load_record(state_optim_before_copy); let state_optim_after = optimizer.to_record(); assert_eq!(state_optim_before.len(), state_optim_after.len()); } /// used for test differences and debug #[test] fn test_rmsprop_optimizer_with_numbers_basic() { let linear = given_linear_layer( TensorData::from([ [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], ]), TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), ); let device = Default::default(); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &device, ) .require_grad(); let mut optimizer = RmsPropConfig::new() .with_alpha(0.99) .with_epsilon(1e-8) .with_weight_decay(WeightDecayConfig::new(0.05).into()) .with_momentum(0.9) .with_centered(false) .init(); // println!("linear is {:?}", linear); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); // println!("linear is {:?}", linear); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); // println!("linear is {:?}", linear); let state_updated = linear.into_record(); let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); // println!("\nweight_updated\n{:?}", weight_updated); // println!("\nbias_updated\n{:?}", bias_updated); let weights_expected = TensorData::from([ [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937], [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809], [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881], [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366], [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005], [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710], ]); let bias_expected = TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]); let tolerance = Tolerance::absolute(1e-6); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } #[test] fn test_rmsprop_optimizer_with_numbers() { let linear = given_linear_layer( TensorData::from([ [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671], [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922], [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130], [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626], [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304], [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833], ]), TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]), ); let device = Default::default(); let x_1 = Tensor::::from_floats( [ [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310], [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883], ], &device, ) .require_grad(); let x_2 = Tensor::::from_floats( [ [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528], [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085], ], &device, ) .require_grad(); let mut optimizer = RmsPropConfig::new() .with_alpha(0.99) .with_epsilon(1e-8) .with_weight_decay(WeightDecayConfig::new(0.05).into()) .with_momentum(0.9) .with_centered(false) .init(); let grads = linear.forward(x_1).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let grads = linear.forward(x_2).backward(); let grads = GradientsParams::from_grads(grads, &linear); let linear = optimizer.step(LEARNING_RATE, linear, grads); let state_updated = linear.into_record(); let weights_expected = TensorData::from([ [ -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779, ], [ -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207, ], [ -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967, ], [ -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997, ], [ 0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912, ], [ -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126, ], ]); let bias_expected = TensorData::from([ -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800, ]); let (weight_updated, bias_updated) = ( state_updated.weight.to_data(), state_updated.bias.unwrap().to_data(), ); // println!("\nweight_updated\n{:?}", weight_updated); // println!("\nbias_updated\n{:?}", bias_updated); let tolerance = Tolerance::absolute(1e-6); bias_updated.assert_approx_eq::(&bias_expected, tolerance); weight_updated.assert_approx_eq::(&weights_expected, tolerance); } fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear { let device = Default::default(); let record = LinearRecord { weight: Param::from_data(weight, &device), bias: Some(Param::from_data(bias, &device)), }; LinearConfig::new(6, 6).init(&device).load_record(record) } #[allow(dead_code)] fn create_random_tensor() -> Tensor { Tensor::::random( Shape::new([2, 20]), Distribution::Default, &Default::default(), ) } fn create_rmsprop() -> OptimizerAdaptor, TestAutodiffBackend> { RmsPropConfig { alpha: 0.99, epsilon: 1e-9, centered: false, weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), momentum: 0.9, grad_clipping: None, } .init() } } ================================================ FILE: crates/burn-optim/src/optim/sgd.rs ================================================ use burn_core as burn; use super::SimpleOptimizer; use super::adaptor::OptimizerAdaptor; use super::decay::{WeightDecay, WeightDecayConfig}; use super::momentum::{Momentum, MomentumConfig, MomentumState}; use crate::LearningRate; use crate::grad_clipping::GradientClippingConfig; use burn::config::Config; use burn::module::AutodiffModule; use burn::record::Record; use burn::tensor::Tensor; use burn::tensor::backend::{AutodiffBackend, Backend}; /// Configuration to create the [Sgd](Sgd) optimizer. #[derive(Config, Debug)] pub struct SgdConfig { /// [Weight decay](WeightDecayConfig) config. weight_decay: Option, /// [Momentum](MomentumConfig) config. momentum: Option, /// [Gradient Clipping](GradientClippingConfig) config. gradient_clipping: Option, } /// Optimizer that implements stochastic gradient descent with momentum. /// /// The optimizer can be configured with [SgdConfig](SgdConfig). #[derive(Clone)] pub struct Sgd { momentum: Option>, weight_decay: Option, } /// State of [Sgd](Sgd). #[derive(Record, Clone, new)] pub struct SgdState { /// The current state of the momentum (if any). pub momentum: Option>, } impl SgdConfig { /// Creates a new [SgdConfig](SgdConfig) with default values. pub fn init>( &self, ) -> OptimizerAdaptor, M, B> { let momentum = self.momentum.as_ref().map(Momentum::new); let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); let mut optim = OptimizerAdaptor::from(Sgd { momentum, weight_decay, }); if let Some(config) = &self.gradient_clipping { optim = optim.with_grad_clipping(config.init()); } optim } } impl SimpleOptimizer for Sgd { type State = SgdState; fn step( &self, lr: LearningRate, tensor: Tensor, mut grad: Tensor, state: Option>, ) -> (Tensor, Option>) { let mut state_momentum = None; if let Some(state) = state { state_momentum = state.momentum; } if let Some(weight_decay) = &self.weight_decay { grad = weight_decay.transform(grad, tensor.clone()); } if let Some(momentum) = &self.momentum { let (grad_out, state) = momentum.transform(grad, state_momentum); state_momentum = Some(state); grad = grad_out; } let state = SgdState::new(state_momentum); let delta = grad.mul_scalar(lr); (tensor - delta, Some(state)) } fn to_device(mut state: Self::State, device: &B::Device) -> Self::State { state.momentum = state.momentum.map(|state| state.to_device(device)); state } } #[cfg(test)] mod tests { use super::*; use crate::{ TestAutodiffBackend, TestBackend, grad_clipping::GradientClipping, optim::{GradientsParams, Optimizer}, }; use burn::tensor::{Distribution, Shape}; use burn_nn::{Linear, LinearConfig}; const LEARNING_RATE: LearningRate = 0.02; #[test] fn with_updated_params_should_have_state() { let device = Default::default(); let layer = layer::(&device); let mut optim = sgd_with_all(); let loss = layer.forward(random_tensor::(&device)); let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &layer); let _layer = optim.step(LEARNING_RATE, layer, grads); let record = optim.to_record(); assert!(!record.is_empty()); } #[test] fn without_updated_params_should_not_have_state() { let optim = sgd_with_all(); let record = optim.to_record(); assert!(record.is_empty()); } #[test] fn can_attach_gradient_clipping() { let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5)); assert!(optim.has_gradient_clipping()); } #[test] fn should_load_state() { let device = Default::default(); let layer = layer::(&device); let mut optim = sgd_with_all(); let loss = layer.forward(random_tensor(&device)); let grads = loss.backward(); let grads = GradientsParams::from_grads(grads, &layer); let _layer = optim.step(LEARNING_RATE, layer, grads); let record = optim.to_record(); let optim_new = sgd_with_all(); let record_new = optim_new.to_record(); let optim_new = optim_new.load_record(record.clone()); let state_restored = optim_new.to_record(); assert_ne!(record.len(), record_new.len()); assert_eq!(record.len(), state_restored.len()); } fn random_tensor(device: &B::Device) -> Tensor { Tensor::::random(Shape::new([2, 20]), Distribution::Default, device) } fn layer(device: &B::Device) -> Linear { LinearConfig::new(20, 20).init(device) } fn sgd_with_all() -> OptimizerAdaptor, Linear, TestAutodiffBackend> { SgdConfig { weight_decay: Some(WeightDecayConfig { penalty: 0.05 }), momentum: Some(MomentumConfig { momentum: 0.9, dampening: 0.1, nesterov: true, }), gradient_clipping: None, } .init() } } ================================================ FILE: crates/burn-optim/src/optim/simple/adaptor.rs ================================================ use burn_core::{self as burn, prelude::Backend, tensor::Device}; use super::{SimpleOptimizer, record::AdaptorRecord}; use crate::{ LearningRate, MultiGradientsParams, grad_clipping::GradientClipping, optim::{GradientsParams, Optimizer}, }; use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use core::marker::PhantomData; use hashbrown::HashMap; /// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into /// an [optimizer](Optimizer). #[derive(Clone)] pub struct OptimizerAdaptor where O: SimpleOptimizer, M: AutodiffModule, B: AutodiffBackend, { optim: O, records: HashMap>, module: PhantomData, grad_clipping: Option, } impl From for OptimizerAdaptor where B: AutodiffBackend, M: AutodiffModule, O: SimpleOptimizer, { fn from(optim: O) -> Self { Self { optim, records: HashMap::new(), module: PhantomData, grad_clipping: None, } } } impl OptimizerAdaptor where O: SimpleOptimizer, M: AutodiffModule, B: AutodiffBackend, { /// Sets the gradient clipping. /// /// # Arguments /// /// * `gradient_clipping` - The gradient clipping. /// /// # Returns /// /// The optimizer. pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { self.grad_clipping = Some(gradient_clipping); self } #[cfg(test)] pub(crate) fn has_gradient_clipping(&self) -> bool { self.grad_clipping.is_some() } } impl Optimizer for OptimizerAdaptor where B: AutodiffBackend, M: AutodiffModule, O: SimpleOptimizer, { type Record = HashMap>; fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M { let mut grads = GradAdaptor::Single(grads); let mut mapper = SimpleOptimizerMapper::::new( &self.optim, &mut self.records, &mut grads, lr, self.grad_clipping.as_ref(), ); module.map(&mut mapper) } fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M { let mut grads = GradAdaptor::Multi(grads); let mut mapper = SimpleOptimizerMapper::::new( &self.optim, &mut self.records, &mut grads, lr, self.grad_clipping.as_ref(), ); module.map(&mut mapper) } fn to_record(&self) -> Self::Record { self.records.clone() } fn load_record(mut self, record: Self::Record) -> Self { self.records = record; self } } enum GradAdaptor { Single(GradientsParams), Multi(MultiGradientsParams), } impl GradAdaptor { fn remove( &mut self, id: ParamId, ) -> Option<(Tensor, Device)> { match self { GradAdaptor::Single(grads) => grads.remove(id).map(|t| { let device = t.device(); (t, device) }), GradAdaptor::Multi(grads) => grads.remove(id), } } } #[derive(new)] struct SimpleOptimizerMapper<'a, M, B, O> where M: AutodiffModule, B: AutodiffBackend, O: SimpleOptimizer, { optimizer: &'a O, records: &'a mut HashMap>, grads: &'a mut GradAdaptor, lr: LearningRate, phantom: PhantomData, grad_clipping: Option<&'a GradientClipping>, } impl ModuleMapper for SimpleOptimizerMapper<'_, M, B, O> where M: AutodiffModule, B: AutodiffBackend, O: SimpleOptimizer, { fn map_float(&mut self, param: Param>) -> Param> { let (id, tensor, mapper) = param.consume(); let grad = self.grads.remove(id); let tensor = if let Some((grad, device)) = grad { let is_require_grad = tensor.is_require_grad(); let (key, record) = self.records.remove_entry(&id).unzip(); let tensor = if tensor.device() != device { tensor.to_device(&device) } else { tensor }; debug_assert_eq!( grad.device(), device, "The gradient is on the provided device" ); let clipped_grad = if let Some(g_clipping) = self.grad_clipping { g_clipping.clip_gradient(grad) } else { grad }; debug_assert_eq!( tensor.device(), device, "Tensor and gradients are on the same device." ); let (tensor, state) = self.optimizer.step( self.lr, tensor.inner(), clipped_grad, record.map(|record| O::to_device(record.into_state(), &device)), ); if let Some(state) = state { self.records .insert(key.unwrap_or(id), AdaptorRecord::from_state(state)); } let mut tensor = Tensor::from_inner(tensor); if is_require_grad { tensor = tensor.require_grad(); } tensor } else { tensor }; Param::from_mapped_value(id, tensor, mapper) } } ================================================ FILE: crates/burn-optim/src/optim/simple/base.rs ================================================ use burn_core as burn; use crate::LearningRate; use burn::record::Record; use burn::tensor::{Tensor, backend::Backend}; /// Simple optimizer is an opinionated trait to simplify the process of implementing an /// optimizer. /// /// Implementations don't have to handle missing gradients, loading and exporting records, navigate the /// module parameter structure, handle tracked and untracked tensors, and the likes. pub trait SimpleOptimizer: Send + Sync + Clone where B: Backend, { /// The state of the optimizer. It also implements [record](Record), so that it can be saved. type State: Record + Clone + 'static; /// The optimizer step is performed for one tensor at a time with its gradient and state. /// /// Note that the state is passed as parameter, so implementations don't have to handle /// the saving and loading of recorded states. fn step( &self, lr: LearningRate, tensor: Tensor, grad: Tensor, state: Option>, ) -> (Tensor, Option>); /// Change the device of the state. /// /// This function will be called accordingly to have the state on the same device as the /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called. fn to_device(state: Self::State, device: &B::Device) -> Self::State; } ================================================ FILE: crates/burn-optim/src/optim/simple/mod.rs ================================================ mod base; pub use base::*; /// Adaptor module for optimizers. pub mod adaptor; /// Record module for optimizers. pub mod record; ================================================ FILE: crates/burn-optim/src/optim/simple/record/base.rs ================================================ use burn_core as burn; use super::{AdaptorRecordItemV1, AdaptorRecordV1}; use crate::optim::SimpleOptimizer; use burn::record::{PrecisionSettings, Record}; use burn::tensor::backend::AutodiffBackend; use serde::{Deserialize, Serialize}; /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record. /// /// Records are versioned for backward compatibility, so old records can be loaded. pub enum AdaptorRecord where O: SimpleOptimizer, B: AutodiffBackend, { /// Version 1. V1(AdaptorRecordV1), } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize, Clone)] #[serde(bound = "")] pub enum AdaptorRecordItem< O: SimpleOptimizer, B: AutodiffBackend, S: PrecisionSettings, > { /// Version 1. V1(AdaptorRecordItemV1), } impl Record for AdaptorRecord where O: SimpleOptimizer, B: AutodiffBackend, { type Item = AdaptorRecordItem; fn into_item(self) -> Self::Item { match self { AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()), } } fn from_item(item: Self::Item, device: &B::Device) -> Self { match item { AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)), } } } impl Clone for AdaptorRecord where O: SimpleOptimizer, B: AutodiffBackend, { fn clone(&self) -> Self { match self { AdaptorRecord::V1(record) => Self::V1(record.clone()), } } } impl AdaptorRecord where O: SimpleOptimizer, B: AutodiffBackend, { /// Converts the record into the optimizer state. /// /// # Returns /// /// The optimizer state. pub fn into_state(self) -> O::State { match self { AdaptorRecord::V1(record) => record.into_state(), } } /// Converts the optimizer state into the record. /// /// # Arguments /// /// * `state`: The optimizer state. /// /// # Returns /// /// The record. pub fn from_state(state: O::State) -> Self { Self::V1(AdaptorRecordV1::from_state(state)) } } ================================================ FILE: crates/burn-optim/src/optim/simple/record/mod.rs ================================================ mod base; mod v1; pub use base::*; pub use v1::*; ================================================ FILE: crates/burn-optim/src/optim/simple/record/v1.rs ================================================ use burn_core as burn; use crate::optim::SimpleOptimizer; use burn::record::{PrecisionSettings, Record}; use burn::tensor::backend::Backend; use core::any::Any; use serde::{Deserialize, Serialize}; #[cfg(not(feature = "std"))] use alloc::boxed::Box; /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. pub enum AdaptorRecordV1, B: Backend> { /// Rank 0. Rank0(O::State<0>), /// Rank 1. Rank1(O::State<1>), /// Rank 2. Rank2(O::State<2>), /// Rank 3. Rank3(O::State<3>), /// Rank 4. Rank4(O::State<4>), /// Rank 5. Rank5(O::State<5>), /// Rank 6. Rank6(O::State<6>), /// Rank 7. Rank7(O::State<7>), /// Rank 8. Rank8(O::State<8>), } impl, B: Backend> Clone for AdaptorRecordV1 { fn clone(&self) -> Self { match self { AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()), AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()), AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()), AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()), AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()), AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()), AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()), AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()), AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()), } } } /// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize, Clone)] #[serde(bound = "")] pub enum AdaptorRecordItemV1, B: Backend, S: PrecisionSettings> { /// Rank 0. Rank0( as Record>::Item), /// Rank 1. Rank1( as Record>::Item), /// Rank 2. Rank2( as Record>::Item), /// Rank 3. Rank3( as Record>::Item), /// Rank 4. Rank4( as Record>::Item), /// Rank 5. Rank5( as Record>::Item), /// Rank 6. Rank6( as Record>::Item), /// Rank 7. Rank7( as Record>::Item), /// Rank 8. Rank8( as Record>::Item), } impl AdaptorRecordV1 where O: SimpleOptimizer, B: Backend, { /// Convert the record into the state. /// /// # Returns /// /// The state. /// /// # Panics /// /// Panics if the state dimension is not supported. pub fn into_state(self) -> O::State { let boxed_state: Box = match self { AdaptorRecordV1::Rank0(s) => Box::new(s), AdaptorRecordV1::Rank1(s) => Box::new(s), AdaptorRecordV1::Rank2(s) => Box::new(s), AdaptorRecordV1::Rank3(s) => Box::new(s), AdaptorRecordV1::Rank4(s) => Box::new(s), AdaptorRecordV1::Rank5(s) => Box::new(s), AdaptorRecordV1::Rank6(s) => Box::new(s), AdaptorRecordV1::Rank7(s) => Box::new(s), AdaptorRecordV1::Rank8(s) => Box::new(s), }; let state = boxed_state .downcast::>() .expect("Unsupported state dimension, dimension up to 8 are supported."); *state } /// Convert the state into the record. /// /// # Arguments /// /// * `state`: The state. /// /// # Returns /// /// The record. pub fn from_state(state: O::State) -> Self { let state: Box = Box::new(state); match D { 0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()), 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()), 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()), 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()), 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()), 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()), 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()), 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()), 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()), _ => panic!("Unsupported state dimension, dimension up to 8 are supported."), } } } impl Record for AdaptorRecordV1 where O: SimpleOptimizer, B: Backend, { type Item = AdaptorRecordItemV1; fn into_item(self) -> Self::Item { match self { AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()), AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()), AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()), AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()), AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()), AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()), AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()), AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()), AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()), } } fn from_item(item: Self::Item, device: &B::Device) -> Self { match item { AdaptorRecordItemV1::Rank0(item) => { AdaptorRecordV1::Rank0( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank1(item) => { AdaptorRecordV1::Rank1( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank2(item) => { AdaptorRecordV1::Rank2( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank3(item) => { AdaptorRecordV1::Rank3( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank4(item) => { AdaptorRecordV1::Rank4( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank5(item) => { AdaptorRecordV1::Rank5( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank6(item) => { AdaptorRecordV1::Rank6( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank7(item) => { AdaptorRecordV1::Rank7( as Record>::from_item(item, device)) } AdaptorRecordItemV1::Rank8(item) => { AdaptorRecordV1::Rank8( as Record>::from_item(item, device)) } } } } ================================================ FILE: crates/burn-optim/src/optim/visitor.rs ================================================ use burn_core as burn; use super::GradientsParams; use burn::module::{AutodiffModule, ModuleVisitor, Param, ParamId}; use burn::tensor::{Tensor, backend::AutodiffBackend}; use core::marker::PhantomData; #[cfg(not(feature = "std"))] use alloc::vec::Vec; #[derive(new)] pub struct GradientsParamsConverter<'a, M: AutodiffModule, B: AutodiffBackend> { grads: &'a mut B::Gradients, grads_params: &'a mut GradientsParams, phatom: PhantomData, filter: Option>, } #[derive(new)] pub struct GradientsParamsChangeDevice<'a, M: AutodiffModule, B: AutodiffBackend> { device: &'a B::Device, grads: &'a mut GradientsParams, phatom: PhantomData, } impl ModuleVisitor for GradientsParamsConverter<'_, M, B> where B: AutodiffBackend, M: AutodiffModule, { fn visit_float(&mut self, param: &Param>) { if let Some(filter) = self.filter.as_ref() && !filter.contains(¶m.id) { return; } let Some(grad) = param.val().grad_remove(self.grads) else { return; }; self.grads_params .register::(param.id, grad); } } impl ModuleVisitor for GradientsParamsChangeDevice<'_, M, B> where B: AutodiffBackend, M: AutodiffModule, { fn visit_float(&mut self, param: &Param>) { let Some(grad) = self.grads.remove::(param.id) else { return; }; self.grads .register::(param.id, grad.to_device(self.device)); } } ================================================ FILE: crates/burn-remote/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Backend router decorator over the network." edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-remote" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote" documentation = "https://docs.rs/burn-router-remote" version.workspace = true [lints] workspace = true [features] default = ["client", "server"] doc = [] tracing = [ "burn-communication/tracing", "burn-ir/tracing", "burn-router/tracing", "burn-std/tracing", "burn-backend/tracing", ] client = ["tokio-tungstenite", "async-channel", "tokio/sync"] server = [ "tokio-tungstenite", "async-channel", "tokio/sync", "axum", "tracing-core/default", "tracing-subscriber/default", ] [dependencies] burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = true } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = true } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = true } burn-communication = { path = "../burn-communication", version = "=0.21.0-pre.2", features = [ "data-service", "websocket", ] } bytes = { workspace = true } # Basic dependencies derive-new = { workspace = true } log = { workspace = true } # Shared dependencies tokio = { workspace = true, features = ["rt-multi-thread"] } serde = { workspace = true, features = ["derive"] } serde_bytes = { workspace = true } rmp-serde = { workspace = true } futures-util = { workspace = true } # Client dependencies async-channel = { workspace = true, optional = true } tokio-tungstenite = { workspace = true, optional = true } # Server dependencies axum = { workspace = true, features = ["ws"], optional = true } tracing-core = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } tokio-util = { workspace = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-remote/README.md ================================================ ================================================ FILE: crates/burn-remote/src/client/base.rs ================================================ pub use super::RemoteDevice; use super::worker::{ClientRequest, ClientWorker}; use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent}; use async_channel::{RecvError, SendError, Sender}; use burn_communication::ProtocolClient; use burn_ir::TensorId; use burn_std::id::StreamId; use std::{ future::Future, sync::{Arc, atomic::AtomicU64}, }; #[derive(Clone)] pub struct RemoteClient { pub(crate) device: RemoteDevice, pub(crate) sender: Arc, pub(crate) runtime: Arc, } impl RemoteClient { pub fn init(device: RemoteDevice) -> Self { ClientWorker::::start(device) } pub(crate) fn new( device: RemoteDevice, sender: Sender, runtime: Arc, session_id: SessionId, ) -> Self { Self { device, runtime, sender: Arc::new(RemoteSender { sender, position_counter: AtomicU64::new(0), tensor_id_counter: AtomicU64::new(0), session_id, }), } } } pub(crate) struct RemoteSender { sender: Sender, position_counter: AtomicU64, tensor_id_counter: AtomicU64, session_id: SessionId, } #[allow(unused)] #[derive(Debug)] pub enum RemoteSendError { SendError(SendError), RecvError(RecvError), } impl RemoteSender { /// Generate a new unique (for this [`RemoteSender`] [`TensorId`]. pub(crate) fn new_tensor_id(&self) -> TensorId { TensorId::new( self.tensor_id_counter .fetch_add(1, std::sync::atomic::Ordering::Relaxed), ) } /// Give the next operation sequence number. fn next_position(&self) -> u64 { self.position_counter .fetch_add(1, std::sync::atomic::Ordering::Relaxed) } pub(crate) fn send(&self, task: ComputeTask) { self.sender .send_blocking(ClientRequest::WithoutCallback(Task::Compute( task, ConnectionId::new(self.next_position(), StreamId::current()), ))) .unwrap(); } pub(crate) fn send_async( &self, task: ComputeTask, ) -> impl Future> + Send + use<> { let stream_id = StreamId::current(); let position = self.next_position(); let sender = self.sender.clone(); async move { let (tx, rx) = async_channel::bounded(1); if let Err(e) = sender .send(ClientRequest::WithSyncCallback( Task::Compute(task, ConnectionId::new(position, stream_id)), tx, )) .await { return Err(RemoteSendError::SendError(e)); } match rx.recv().await { Ok(response) => Ok(response), Err(e) => Err(RemoteSendError::RecvError(e)), } } } pub(crate) fn close(&mut self) { let sender = self.sender.clone(); let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id)); sender.send_blocking(close_task).unwrap(); } } impl Drop for RemoteSender { fn drop(&mut self) { self.close(); } } ================================================ FILE: crates/burn-remote/src/client/channel.rs ================================================ use std::marker::PhantomData; use burn_backend::Shape; use burn_communication::ProtocolClient; use burn_ir::TensorIr; use burn_router::{RouterTensor, RunnerChannel, get_client}; use super::{ RemoteClient, runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle}, }; /// A local channel with direct connection to the backend runner clients. pub struct RemoteChannel { _p: PhantomData, } impl RunnerChannel for RemoteChannel { type Device = RemoteDevice; type Bridge = RemoteBridge; type Client = RemoteClient; type FloatElem = f32; type IntElem = i32; type BoolElem = u32; fn name(device: &Self::Device) -> String { format!("remote-{device:?}") } fn init_client(device: &Self::Device) -> Self::Client { RemoteClient::init::(device.clone()) } fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle { RemoteTensorHandle { client: client.clone(), tensor: tensor.clone(), _p: PhantomData, } } fn register_tensor( _client: &Self::Client, _handle: RemoteTensorHandle, _shape: Shape, _dtype: burn_backend::DType, ) -> RouterTensor { // This function is normally only used to move a tensor from a device to another. // // In other words, to change the client. panic!("Can't register manually a tensor on a remote channel."); } fn change_client_backend( tensor: RouterTensor, target_device: &Self::Device, // target device ) -> RouterTensor { // Get tensor handle from current client let original_client = tensor.client.clone(); let desc = tensor.into_ir(); let handle = Self::get_tensor_handle(&desc, &original_client); let handle = handle.change_backend(target_device); let id = handle.tensor.id; let target_client = get_client::(target_device); let router_tensor: RouterTensor = RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client); router_tensor } } impl Clone for RemoteChannel { fn clone(&self) -> Self { RemoteChannel { _p: PhantomData } } } ================================================ FILE: crates/burn-remote/src/client/mod.rs ================================================ mod base; mod channel; mod runner; mod worker; pub use base::*; pub use channel::*; pub use runner::RemoteDevice; ================================================ FILE: crates/burn-remote/src/client/runner.rs ================================================ use super::{RemoteChannel, RemoteClient}; use crate::shared::{ComputeTask, TaskResponseContent, TensorRemote}; use burn_backend::{DeviceId, DeviceOps, ExecutionError, TensorData}; use burn_communication::{Address, ProtocolClient, data_service::TensorTransferId}; use burn_ir::TensorIr; use burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client}; use burn_std::{backtrace::BackTrace, future::DynFut}; use std::sync::OnceLock; use std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex}; // TODO: we should work with the parsed structure of Address, not the string. static ADDRESS_REGISTRY: OnceLock>> = OnceLock::new(); fn get_address_registry() -> &'static Mutex> { ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) } /// Map a string network address to a (local runtime) global unique u32. /// /// Globally stable over the lifetime of the process, shared between threads, /// If the address has never been seen, a new id will be created. /// If the address has been seen, the previous id will be returned. pub fn address_to_id>(address: S) -> u32 { let registry = get_address_registry(); let mut registry = registry.lock().unwrap(); let next_id = registry.len() as u32; *registry .entry(address.as_ref().to_string()) .or_insert_with(|| next_id) } /// Look up an address by id. /// /// Returns the same address given ids by [`address_to_id`]. pub fn id_to_address(id: u32) -> Option { let registry = get_address_registry(); let registry = registry.lock().unwrap(); for entry in registry.iter() { if entry.1 == &id { return Some(entry.0.clone()); } } None } // It is very important to block on any request made with the sender, since ordering is crucial // when registering operation or creating tensors. // // The overhead is minimal, since we only wait for the task to be sent to the async // channel, but not sent to the server and even less processed by the server. impl RunnerClient for RemoteClient { type Device = RemoteDevice; fn register_op(&self, op: burn_ir::OperationIr) { self.sender .send(ComputeTask::RegisterOperation(Box::new(op))); } fn read_tensor_async( &self, tensor: burn_ir::TensorIr, ) -> DynFut> { // Important for ordering to call the creation of the future sync. let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor)); Box::pin(async move { match fut.await { Ok(response) => match response { TaskResponseContent::ReadTensor(res) => res, _ => panic!("Invalid message type"), }, Err(e) => Err(ExecutionError::Generic { reason: format!("Failed to read tensor: {:?}", e), backtrace: BackTrace::capture(), }), } }) } fn register_tensor_data(&self, data: TensorData) -> RouterTensor { let id = self.sender.new_tensor_id(); let shape = data.shape.clone(); let dtype = data.dtype; self.sender.send(ComputeTask::RegisterTensor(id, data)); RouterTensor::new(id, shape, dtype, self.clone()) } fn device(&self) -> Self::Device { self.device.clone() } fn sync(&self) -> Result<(), ExecutionError> { // Important for ordering to call the creation of the future sync. let fut = self.sender.send_async(ComputeTask::SyncBackend); match self.runtime.block_on(fut) { Ok(response) => match response { TaskResponseContent::SyncBackend(res) => res, _ => panic!("Invalid message type"), }, Err(e) => Err(ExecutionError::Generic { reason: format!("Failed to sync: {:?}", e), backtrace: BackTrace::capture(), }), } } fn seed(&self, seed: u64) { self.sender.send(ComputeTask::Seed(seed)); } fn create_empty_handle(&self) -> burn_ir::TensorId { self.sender.new_tensor_id() } fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet { let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype)); match self.runtime.block_on(fut) { Ok(_response) => panic!("Invalid message type"), Err(e) => panic!("Failed to check dtype support: {:?}", e), } } } #[derive(Clone, PartialEq, Eq, Debug)] /// The device contains the connection information of the server. pub struct RemoteDevice { pub(crate) address: Address, /// The id of the device in the local registry, see [`address_to_id`]. pub(crate) id: u32, } impl RemoteDevice { /// Create a device from an url. pub fn new(address: &str) -> Self { let id = address_to_id(address); Self { address: Address::from_str(address).unwrap(), id, } } } impl Default for RemoteDevice { fn default() -> Self { let address = match std::env::var("BURN_REMOTE_ADDRESS") { Ok(address) => address, Err(_) => String::from("ws://127.0.0.1:3000"), }; Self::new(&address) } } impl burn_std::device::Device for RemoteDevice { fn from_id(device_id: DeviceId) -> Self { if device_id.type_id != 0 { panic!("Invalid device id: {device_id} (expected type 0)"); } let address = id_to_address(device_id.index_id) .unwrap_or_else(|| panic!("Invalid device id: {device_id}")); Self::new(&address) } fn to_id(&self) -> DeviceId { DeviceId { type_id: 0, index_id: self.id, } } fn device_count(_type_id: u16) -> usize { 1 } } impl DeviceOps for RemoteDevice {} pub struct RemoteBridge { _p: PhantomData, } pub struct RemoteTensorHandle { pub(crate) client: RemoteClient, pub(crate) tensor: TensorIr, pub(crate) _p: PhantomData, } static TRANSFER_COUNTER: Mutex> = Mutex::new(None); fn get_next_transfer_id() -> TensorTransferId { let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap(); if transfer_counter.is_none() { *transfer_counter = Some(0.into()); transfer_counter.unwrap() } else { let mut transfer_counter = transfer_counter.unwrap(); transfer_counter.next(); transfer_counter } } impl RemoteTensorHandle { /// Changes the backend of the tensor via a dWebSocket. /// We ask the original server to expose the tensor, then ask the target server to fetch /// the tensor. The target server will open a new network connection to the original server /// to download the data. /// This way the client never sees the tensor's data, and we avoid a bottleneck. pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self { let transfer_id = get_next_transfer_id(); self.client.sender.send(ComputeTask::ExposeTensorRemote { tensor: self.tensor.clone(), count: 1, transfer_id, }); let target_client = get_client::>(target_device); let new_id = target_client.sender.new_tensor_id(); let remote_tensor = TensorRemote { transfer_id, address: self.client.device.address.clone(), }; target_client .sender .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id)); self.tensor.id = new_id; self.client = target_client; self } } impl MultiBackendBridge for RemoteBridge { type TensorHandle = RemoteTensorHandle; type Device = RemoteDevice; fn change_backend_float( tensor: Self::TensorHandle, _shape: burn_backend::Shape, target_device: &Self::Device, ) -> Self::TensorHandle { tensor.change_backend(target_device) } fn change_backend_int( tensor: Self::TensorHandle, _shape: burn_backend::Shape, target_device: &Self::Device, ) -> Self::TensorHandle { tensor.change_backend(target_device) } fn change_backend_bool( tensor: Self::TensorHandle, _shape: burn_backend::Shape, target_device: &Self::Device, ) -> Self::TensorHandle { tensor.change_backend(target_device) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_address_to_id() { let address1 = "ws://127.0.0.1:3000"; let address2 = "ws://127.0.0.1:3001"; let id1 = address_to_id(address1); let id2 = address_to_id(address2); assert_ne!(id1, id2); assert_eq!(address_to_id(address1), id1); assert_eq!(id_to_address(id1), Some(address1.to_string())); assert_eq!(address_to_id(address2), id2); assert_eq!(id_to_address(id2), Some(address2.to_string())); let unused_id = u32::MAX; assert_eq!(id_to_address(unused_id), None); } } ================================================ FILE: crates/burn-remote/src/client/worker.rs ================================================ use super::{RemoteClient, runner::RemoteDevice}; use crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent}; use burn_communication::{CommunicationChannel, Message, ProtocolClient}; use std::{collections::HashMap, marker::PhantomData, sync::Arc}; pub type CallbackSender = async_channel::Sender; #[derive(Debug)] pub enum ClientRequest { WithSyncCallback(Task, CallbackSender), WithoutCallback(Task), } pub(crate) struct ClientWorker { requests: HashMap, _p: PhantomData, } impl ClientWorker { async fn on_response(&mut self, response: TaskResponse) { match self.requests.remove(&response.id) { Some(request) => { request.send(response.content).await.unwrap(); } None => { panic!("Can't ignore message from the server."); } } } fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) { self.requests.insert(id, callback); } } impl ClientWorker { pub fn start(device: RemoteDevice) -> RemoteClient { let runtime = Arc::new( tokio::runtime::Builder::new_multi_thread() .enable_io() .build() .unwrap(), ); let (sender, rec) = async_channel::bounded(10); let session_id = SessionId::new(); let address = device.address.clone(); #[allow(deprecated)] runtime.spawn(async move { log::info!("Connecting to {} ...", address.clone()); let mut stream_request = C::connect(address.clone(), "request") .await .expect("Server to be accessible"); let mut stream_response = C::connect(address, "response") .await .expect("Server to be accessible"); let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::::default())); // Init the connection. let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id)) .expect("Can serialize tasks to bytes.") .into(); stream_request .send(Message::new(bytes.clone())) .await .expect("Can send the message over the comms channel."); stream_response .send(Message::new(bytes)) .await .expect("Can send the message on the websocket."); // Async worker loading callbacks from the server. let state_ws = state.clone(); tokio::spawn(async move { while let Ok(msg) = stream_response.recv().await { let msg = match msg { Some(msg) => msg, None => { log::warn!("Closed connection"); return; } }; let response: TaskResponse = rmp_serde::from_slice(&msg.data) .expect("Can deserialize messages from the websocket."); let mut state = state_ws.lock().await; state.on_response(response).await; } }); // Channel async worker sending operations to the server. tokio::spawn(async move { while let Ok(req) = rec.recv().await { let task = match req { ClientRequest::WithSyncCallback(task, callback) => { if let Task::Compute(_content, id) = &task { let mut state = state.lock().await; state.register_callback(*id, callback); } task } ClientRequest::WithoutCallback(task) => task, }; let bytes = rmp_serde::to_vec(&task) .expect("Can serialize tasks to bytes.") .into(); stream_request .send(Message::new(bytes)) .await .expect("Can send the message on the websocket."); } }); }); RemoteClient::new(device, sender, runtime, session_id) } } impl Default for ClientWorker { fn default() -> Self { Self { requests: Default::default(), _p: PhantomData, } } } ================================================ FILE: crates/burn-remote/src/lib.rs ================================================ #[macro_use] extern crate derive_new; #[cfg(feature = "client")] pub(crate) mod client; #[cfg(feature = "server")] pub mod server; pub(crate) mod shared; #[cfg(feature = "client")] mod __client { use super::*; use crate::{client::RemoteChannel, shared::RemoteProtocol}; use burn_communication::Protocol; use burn_router::BackendRouter; /// The remote backend allows you to run computation on a remote device. /// /// Make sure there is a running server before trying to connect to it. /// /// ```rust, ignore /// fn main() { /// let device = Default::default(); /// let port = 3000; /// /// // You need to activate the `server` feature flag to have access to this function. /// burn::server::start::(device, port); /// } ///``` pub type RemoteBackend = BackendRouter::Client>>; pub use client::RemoteDevice; } #[cfg(feature = "client")] pub use __client::*; #[cfg(all(test, feature = "client", feature = "server"))] mod tests { use crate::RemoteBackend; use burn_ndarray::NdArray; use burn_tensor::{Distribution, Tensor}; #[test] pub fn test_to_device_over_websocket() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_io() .build() .unwrap(); rt.spawn(crate::server::start_websocket_async::( Default::default(), 3000, )); rt.spawn(crate::server::start_websocket_async::( Default::default(), 3010, )); let remote_device_1 = super::RemoteDevice::new("ws://localhost:3000"); let remote_device_2 = super::RemoteDevice::new("ws://localhost:3010"); // Some random input let input_shape = [1, 28, 28]; let input = Tensor::::random( input_shape, Distribution::Default, &remote_device_1, ); let numbers_expected: Vec = input.to_data().to_vec().unwrap(); // Move tensor to device 2 let input = input.to_device(&remote_device_2); let numbers: Vec = input.to_data().to_vec().unwrap(); assert_eq!(numbers, numbers_expected); // Move tensor back to device 1 let input = input.to_device(&remote_device_1); let numbers: Vec = input.to_data().to_vec().unwrap(); assert_eq!(numbers, numbers_expected); rt.shutdown_background(); } } ================================================ FILE: crates/burn-remote/src/server/base.rs ================================================ use burn_communication::{ CommunicationChannel, Message, Protocol, ProtocolServer, data_service::{TensorDataServer, TensorDataService}, util::os_shutdown_signal, websocket::{WebSocket, WsServer}, }; use std::{marker::PhantomData, sync::Arc}; use tokio_util::sync::CancellationToken; use burn_backend::tensor::Device; use burn_ir::BackendIr; use crate::shared::{ComputeTask, Task}; use super::session::SessionManager; pub struct RemoteServer where B: BackendIr, P: Protocol, { _b: PhantomData, _n: PhantomData

, } impl RemoteServer where B: BackendIr, P: Protocol, { /// Start the server on the given address. pub async fn start(device: Device, server: P::Server) { let cancel_token = CancellationToken::new(); let data_service = Arc::new(TensorDataService::::new(cancel_token)); let session_manager = Arc::new(SessionManager::::new(device, data_service.clone())); let _server = server .route("/response", { let session_manager = session_manager.clone(); move |stream| Self::handle_socket_response(session_manager, stream) }) .route("/request", { let session_manager = session_manager.clone(); move |stream| Self::handle_socket_request(session_manager, stream) }) .route_tensor_data_service(data_service) .serve(os_shutdown_signal()) .await; } async fn handle_socket_response( session_manager: Arc>, mut socket: ::Channel, ) { log::info!("[Response Handler] On new connection."); let packet = socket.recv().await; let msg = match packet { Ok(Some(msg)) => msg, Ok(None) => { log::info!("Response stream closed"); return; } Err(e) => { log::info!("Response stream error on init: {e:?}"); return; } }; let id = match rmp_serde::from_slice::(&msg.data) { Ok(Task::Init(session_id)) => session_id, msg => { log::error!("Message is not a valid initialization task {msg:?}"); return; } }; let mut receiver = session_manager.register_responder(id).await; log::info!("Response handler connection active"); while let Some(mut callback) = receiver.recv().await { let response = callback.recv().await.unwrap(); let bytes = rmp_serde::to_vec(&response).unwrap(); socket.send(Message::new(bytes.into())).await.unwrap(); } } async fn handle_socket_request( session_manager: Arc>, mut socket: ::Channel, ) { log::info!("[Request Handler] On new connection."); let mut session_id = None; loop { let packet = socket.recv().await; let msg = match packet { Ok(Some(msg)) => msg, Ok(None) => { log::info!("Request stream closed"); break; } Err(e) => { log::info!("Request stream error: {e:?}, Closing."); break; } }; let task = match rmp_serde::from_slice::(&msg.data) { Ok(val) => val, Err(err) => { log::info!("Only bytes message in the json format are supported {err:?}"); break; } }; if let Task::Close(id) = task { session_id = Some(id); break; } let (stream, connection_id, task) = match session_manager.stream(&mut session_id, task).await { Some(val) => val, None => { log::info!("Ops session activated {session_id:?}"); continue; } }; match task { ComputeTask::RegisterOperation(op) => { stream.register_operation(op).await; } ComputeTask::RegisterTensor(id, data) => { stream.register_tensor(id, data).await; } ComputeTask::ReadTensor(tensor) => { stream.read_tensor(connection_id, tensor).await; } ComputeTask::SyncBackend => { stream.sync(connection_id).await; } ComputeTask::RegisterTensorRemote(tensor, new_id) => { stream.register_tensor_remote(tensor, new_id).await; } ComputeTask::ExposeTensorRemote { tensor, count, transfer_id, } => { stream .expose_tensor_remote(tensor, count, transfer_id) .await; } ComputeTask::Seed(seed) => { stream.seed(seed).await; } ComputeTask::SupportsDType(dtype) => { stream.supports_dtype(connection_id, dtype).await } } } log::info!("Closing session {session_id:?}"); session_manager.close(session_id).await; } } /// Start the server on the given port and [device](Device). pub async fn start_websocket_async(device: Device, port: u16) { let server = WsServer::new(port); RemoteServer::::start(device, server).await; } #[tokio::main] /// Start the server on the given port and [device](Device). pub async fn start_websocket(device: Device, port: u16) { start_websocket_async::(device, port).await; } ================================================ FILE: crates/burn-remote/src/server/mod.rs ================================================ pub(crate) mod processor; pub(crate) mod session; pub(crate) mod stream; mod base; pub use base::{start_websocket, start_websocket_async}; ================================================ FILE: crates/burn-remote/src/server/processor.rs ================================================ use burn_backend::TensorData; use burn_communication::{ Protocol, data_service::{TensorDataService, TensorTransferId}, }; use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr}; use burn_router::{Runner, RunnerClient}; use burn_std::DType; use core::marker::PhantomData; use std::sync::Arc; use tokio::sync::mpsc::Sender; use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent, TensorRemote}; /// The goal of the processor is to asynchronously process compute tasks on it own thread. pub struct Processor where B: BackendIr, P: Protocol, { p: PhantomData, n: PhantomData

, } pub type Callback = Sender; pub enum ProcessorTask { RegisterOperation(Box), RegisterTensor(TensorId, TensorData), RegisterTensorRemote(TensorRemote, TensorId), ExposeTensorRemote { tensor: TensorIr, transfer_id: TensorTransferId, count: u32, }, ReadTensor(ConnectionId, TensorIr, Callback), Sync(ConnectionId, Callback), Seed(u64), SupportsDType(ConnectionId, DType, Callback), Close, } impl Processor where B: BackendIr, P: Protocol, { pub async fn start( runner: Runner, data_service: Arc>, ) -> Sender { // channel for tasks to execute let (task_sender, mut task_rec) = tokio::sync::mpsc::channel(1); tokio::spawn(async move { while let Some(item) = task_rec.recv().await { match item { ProcessorTask::RegisterOperation(op) => { runner.register_op(*op); } ProcessorTask::Sync(id, callback) => { let result = runner.sync(); callback .send(TaskResponse { content: TaskResponseContent::SyncBackend(result), id, }) .await .unwrap(); } ProcessorTask::RegisterTensor(id, data) => { runner.register_tensor_data_id(id, data); } ProcessorTask::RegisterTensorRemote(remote_tensor, new_id) => { log::info!( "Registering remote tensor...(id: {:?})", remote_tensor.transfer_id ); let data = data_service .download_tensor(remote_tensor.address, remote_tensor.transfer_id) .await .expect("Can't download tensor: error"); // TODO all these panics should be server errors runner.register_tensor_data_id(new_id, data); } ProcessorTask::ExposeTensorRemote { tensor, transfer_id, count, } => { log::info!("Exposing tensor: (id: {transfer_id:?})"); let data = runner.read_tensor_async(tensor).await; data_service .expose_data(data.unwrap(), count, transfer_id) .await; } ProcessorTask::ReadTensor(id, tensor, callback) => { let tensor = runner.read_tensor_async(tensor).await; callback .send(TaskResponse { content: TaskResponseContent::ReadTensor(tensor), id, }) .await .unwrap(); } ProcessorTask::Close => { let device = runner.device(); runner.sync().unwrap(); core::mem::drop(runner); B::sync(&device).unwrap(); break; } ProcessorTask::Seed(seed) => runner.seed(seed), ProcessorTask::SupportsDType(id, dtype, callback) => { let _result = runner.dtype_usage(dtype); callback .send(TaskResponse { // content: TaskResponseContent::SupportsDType(result), // TODO: Update to result. content: TaskResponseContent::SupportsDType(()), id, }) .await .unwrap(); } } } }); task_sender } } ================================================ FILE: crates/burn-remote/src/server/session.rs ================================================ use burn_backend::tensor::Device; use burn_communication::{Protocol, data_service::TensorDataService}; use burn_ir::BackendIr; use burn_router::Runner; use burn_std::id::StreamId; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{ Mutex, mpsc::{Receiver, Sender}, }; use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse}; use super::stream::Stream; /// A session manager control the creation of sessions. /// /// Each session manages its own stream, spawning one thread per stream to mimic the same behavior /// a native backend would have. pub struct SessionManager where B: BackendIr, P: Protocol, { runner: Runner, sessions: Mutex>>, data_service: Arc>, } struct Session where B: BackendIr, P: Protocol, { runner: Runner, streams: HashMap>, sender: Sender>, receiver: Option>>, data_service: Arc>, } impl SessionManager where B: BackendIr, P: Protocol, { pub fn new(device: Device, data_service: Arc>) -> Self { Self { runner: Runner::new(device), sessions: Mutex::new(Default::default()), data_service, } } /// Register a new responder for the session. Only one responder can exist for a session for /// now. pub async fn register_responder( &self, session_id: SessionId, ) -> Receiver> { log::info!("Register responder for session {session_id}"); let mut sessions = self.sessions.lock().await; self.register_session(&mut sessions, session_id); let session = sessions.get_mut(&session_id).unwrap(); session.init_responder() } /// Get the stream for the current session and task. pub async fn stream( &self, session_id: &mut Option, task: Task, ) -> Option<(Stream, ConnectionId, ComputeTask)> { let mut sessions = self.sessions.lock().await; let session_id = match session_id { Some(id) => *id, None => match task { Task::Init(id) => { log::info!("Init requester for session {id}"); *session_id = Some(id); self.register_session(&mut sessions, id); return None; } _ => panic!("The first message should initialize the session"), }, }; match sessions.get_mut(&session_id) { Some(session) => { let (task, connection_id) = match task { Task::Compute(task, connection_id) => (task, connection_id), _ => panic!("Only support compute tasks."), }; let stream = session.select(connection_id.stream_id).await; Some((stream, connection_id, task)) } None => panic!("To be initialized"), } } /// Close the session with the given id. pub async fn close(&self, session_id: Option) { if let Some(id) = session_id { let mut sessions = self.sessions.lock().await; if let Some(session) = sessions.get_mut(&id) { session.close().await; } } } fn register_session(&self, sessions: &mut HashMap>, id: SessionId) { sessions.entry(id).or_insert_with(|| { log::info!("Creating a new session {id}"); Session::new(self.runner.clone(), self.data_service.clone()) }); } } impl Session where B: BackendIr, P: Protocol, { fn new(runner: Runner, data_service: Arc>) -> Self { let (sender, receiver) = tokio::sync::mpsc::channel(1); Self { runner, streams: Default::default(), sender, receiver: Some(receiver), data_service, } } fn init_responder(&mut self) -> Receiver> { let mut receiver = None; core::mem::swap(&mut receiver, &mut self.receiver); receiver.expect("Only one responder per session is possible.") } /// Select the current [stream](Stream) based on the given task. async fn select(&mut self, stream_id: StreamId) -> Stream { // We return the stream. match self.streams.get(&stream_id) { Some(stream) => stream.clone(), None => { let stream = Stream::::new( self.runner.clone(), self.sender.clone(), self.data_service.clone(), ) .await; self.streams.insert(stream_id, stream.clone()); stream } } } // Close all streams created in the session. async fn close(&mut self) { for (id, stream) in self.streams.drain() { log::info!("Closing stream {id}"); stream.close().await; } } } ================================================ FILE: crates/burn-remote/src/server/stream.rs ================================================ use core::marker::PhantomData; use std::sync::Arc; use crate::shared::{ConnectionId, TaskResponse, TensorRemote}; use super::processor::{Processor, ProcessorTask}; use burn_backend::TensorData; use burn_communication::{ Protocol, data_service::{TensorDataService, TensorTransferId}, }; use burn_ir::{BackendIr, OperationIr, TensorId, TensorIr}; use burn_router::Runner; use burn_std::DType; use tokio::sync::mpsc::{Receiver, Sender}; /// A stream makes sure all operations registered are executed in the order they were sent to the /// server, potentially waiting to reconstruct consistency. #[derive(Clone)] pub struct Stream where B: BackendIr, P: Protocol, { compute_sender: Sender, writer_sender: Sender>, _p: PhantomData, _n: PhantomData

, } impl Stream where B: BackendIr, P: Protocol, { pub async fn new( runner: Runner, writer_sender: Sender>, data_service: Arc>, ) -> Self { let sender = Processor::::start(runner, data_service).await; Self { compute_sender: sender, writer_sender, _p: PhantomData, _n: PhantomData, } } pub async fn register_operation(&self, op: Box) { self.compute_sender .send(ProcessorTask::RegisterOperation(op)) .await .unwrap(); } pub async fn register_tensor(&self, tensor_id: TensorId, data: TensorData) { self.compute_sender .send(ProcessorTask::RegisterTensor(tensor_id, data)) .await .unwrap(); } pub async fn register_tensor_remote(&self, tensor: TensorRemote, new_id: TensorId) { self.compute_sender .send(ProcessorTask::RegisterTensorRemote(tensor, new_id)) .await .unwrap(); } pub async fn expose_tensor_remote( &self, tensor: TensorIr, count: u32, transfer_id: TensorTransferId, ) { self.compute_sender .send(ProcessorTask::ExposeTensorRemote { tensor, count, transfer_id, }) .await .unwrap(); } pub async fn read_tensor(&self, id: ConnectionId, desc: TensorIr) { let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1); self.compute_sender .send(ProcessorTask::ReadTensor(id, desc, callback_sender)) .await .unwrap(); self.writer_sender.send(callback_rec).await.unwrap(); } pub async fn sync(&self, id: ConnectionId) { let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1); self.compute_sender .send(ProcessorTask::Sync(id, callback_sender)) .await .unwrap(); self.writer_sender.send(callback_rec).await.unwrap(); } pub async fn close(&self) { self.compute_sender .send(ProcessorTask::Close) .await .unwrap(); } pub async fn seed(&self, seed: u64) { self.compute_sender .send(ProcessorTask::Seed(seed)) .await .unwrap(); } pub async fn supports_dtype(&self, id: ConnectionId, dtype: DType) { let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1); self.compute_sender .send(ProcessorTask::SupportsDType(id, dtype, callback_sender)) .await .unwrap(); self.writer_sender.send(callback_rec).await.unwrap(); } } ================================================ FILE: crates/burn-remote/src/shared/mod.rs ================================================ mod task; #[allow(unused_imports)] pub(crate) use task::*; /// We define the communication protocol here pub(crate) type RemoteProtocol = burn_communication::websocket::WebSocket; ================================================ FILE: crates/burn-remote/src/shared/task.rs ================================================ use burn_backend::{ExecutionError, TensorData}; use burn_communication::{Address, data_service::TensorTransferId}; use burn_ir::{OperationIr, TensorId, TensorIr}; use burn_std::{ DType, id::{IdGenerator, StreamId}, }; use serde::{Deserialize, Serialize}; use std::fmt::Display; #[allow(missing_docs)] #[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] pub struct ConnectionId { pub position: u64, pub stream_id: StreamId, } /// Unique identifier that can represent a session. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct SessionId { id: u64, } impl Display for SessionId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "SessionId({})", self.id) } } impl SessionId { /// Create a new [session id](SessionId). #[allow(dead_code)] pub fn new() -> Self { Self { id: IdGenerator::generate(), } } } #[allow(missing_docs)] #[derive(Serialize, Deserialize, Debug)] pub enum Task { Compute(ComputeTask, ConnectionId), Init(SessionId), Close(SessionId), } #[allow(missing_docs)] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TensorRemote { pub transfer_id: TensorTransferId, pub address: Address, } #[allow(missing_docs)] #[derive(Serialize, Deserialize, Debug)] pub enum ComputeTask { Seed(u64), RegisterOperation(Box), RegisterTensor(TensorId, TensorData), RegisterTensorRemote(TensorRemote, TensorId), ExposeTensorRemote { tensor: TensorIr, count: u32, transfer_id: TensorTransferId, }, ReadTensor(TensorIr), SyncBackend, SupportsDType(DType), } #[allow(missing_docs)] #[derive(Serialize, Deserialize, Debug)] pub struct TaskResponse { pub content: TaskResponseContent, pub id: ConnectionId, } #[allow(missing_docs)] #[derive(Serialize, Deserialize, Debug)] pub enum TaskResponseContent { ReadTensor(Result), SyncBackend(Result<(), ExecutionError>), // SupportsDType(DTypeUsageSet), // TODO: Update to `DTypeUsageSet` when it implements `serde`. SupportsDType(()), } ================================================ FILE: crates/burn-rl/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "RL crate for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-rl" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-rl" documentation = "https://docs.rs/burn-rl" version.workspace = true [dependencies] burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [ "dataset", "std", ], default-features = false } burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [ "std", ], default-features = false } derive-new.workspace = true log = { workspace = true } rand.workspace = true [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } [lints] workspace = true ================================================ FILE: crates/burn-rl/README.md ================================================ # Burn RL ================================================ FILE: crates/burn-rl/src/environment/base.rs ================================================ /// The result of taking a step in an environment. pub struct StepResult { /// The updated state. pub next_state: S, /// The reward. pub reward: f64, /// If the environment reached a terminal state. pub done: bool, /// If the environment reached its max length. pub truncated: bool, } /// Trait to be implemented for a RL environment. pub trait Environment { /// The type of the state. type State; /// The type of actions. type Action; /// The maximum number of step for one episode. const MAX_STEPS: usize; /// Returns the current state. fn state(&self) -> Self::State; /// Take a step in the environment given an action. fn step(&mut self, action: Self::Action) -> StepResult; /// Reset the environment to an initial state. fn reset(&mut self); } /// Trait to define how to initialize an environment. /// By default, any function returning an environment implements it. pub trait EnvironmentInit: Clone { /// Initialize the environment. fn init(&self) -> E; } impl EnvironmentInit for F where F: Fn() -> E + Clone, E: Environment, { fn init(&self) -> E { (self)() } } ================================================ FILE: crates/burn-rl/src/environment/mod.rs ================================================ mod base; pub use base::*; ================================================ FILE: crates/burn-rl/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for training reinforcement learning agents. /// Module for implementing an environment. pub mod environment; /// Module for implementing a policy. pub mod policy; /// Transition buffer. pub mod transition_buffer; pub use environment::*; pub use policy::*; pub use transition_buffer::*; #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; #[cfg(test)] pub(crate) mod tests { use crate::{Batchable, Policy, PolicyState, TestBackend}; use burn_core::record::Record; use burn_core::{self as burn}; /// Mock policy for testing /// /// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution) /// containing a list of 0s of the same length as the observation. /// /// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockAction](MockAction) with a list of actions of the same length as the observation. /// The actions are all 1 if the call is requested as deterministic, or else 0. #[derive(Clone)] pub(crate) struct MockPolicy {} impl MockPolicy { pub fn new() -> Self { Self {} } } impl Policy for MockPolicy { type Observation = MockObservation; type ActionDistribution = MockActionDistribution; type Action = MockAction; type ActionContext = MockActionContext; type PolicyState = MockPolicyState; fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution { let mut dists = vec![]; for _ in obs.0 { dists.push(MockActionDistribution(vec![0.])); } MockActionDistribution::batch(dists) } fn action( &mut self, obs: Self::Observation, deterministic: bool, ) -> (Self::Action, Vec) { let mut actions = vec![]; let mut contexts = vec![]; for _ in obs.0 { if deterministic { actions.push(MockAction(vec![1])); } else { actions.push(MockAction(vec![0])); } contexts.push(MockActionContext); } (MockAction::batch(actions), contexts) } fn update(&mut self, _update: Self::PolicyState) {} fn state(&self) -> Self::PolicyState { MockPolicyState } fn load_record( self, _record: >::Record, ) -> Self { self } } /// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockObservation(pub Vec); /// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockAction(pub Vec); /// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockActionDistribution(Vec); #[derive(Clone)] pub(crate) struct MockActionContext; /// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy. #[derive(Clone)] pub(crate) struct MockPolicyState; #[derive(Clone, Record)] pub(crate) struct MockRecord { item: usize, } impl PolicyState for MockPolicyState { type Record = MockRecord; fn into_record(self) -> Self::Record { MockRecord { item: 0 } } fn load_record(&self, _record: Self::Record) -> Self { self.clone() } } impl Batchable for MockObservation { fn batch(items: Vec) -> Self { MockObservation(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { vec![MockObservation(self.0)] } } impl Batchable for MockAction { fn batch(items: Vec) -> Self { MockAction(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { let mut actions = vec![]; for a in self.0 { actions.push(MockAction(vec![a])); } actions } } impl Batchable for MockActionDistribution { fn batch(items: Vec) -> Self { MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { let mut dists = vec![]; for _ in self.0 { dists.push(MockActionDistribution(vec![0.])); } dists } } } ================================================ FILE: crates/burn-rl/src/policy/async_policy.rs ================================================ use std::{ sync::{ Arc, atomic::{AtomicUsize, Ordering}, mpsc::{self, Sender}, }, thread::spawn, }; use burn_core::prelude::Backend; use crate::{ActionContext, Batchable, Policy, PolicyState}; #[derive(Clone)] struct PolicyInferenceServer> { // `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size. num_agents: Arc, max_autobatch_size: usize, inner_policy: P, batch_action: Vec>, batch_logits: Vec>, } impl PolicyInferenceServer where B: Backend, P: Policy, P::Observation: Clone + Batchable, P::ActionDistribution: Clone + Batchable, P::Action: Clone + Batchable, P::ActionContext: Clone, { pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self { Self { num_agents: Arc::new(AtomicUsize::new(0)), max_autobatch_size, inner_policy, batch_action: vec![], batch_logits: vec![], } } pub fn push_action(&mut self, item: ActionItem) { self.batch_action.push(item); if self.len_actions() >= self .num_agents .load(Ordering::Relaxed) .min(self.max_autobatch_size) { self.flush_actions(); } } pub fn push_logits(&mut self, item: ForwardItem) { self.batch_logits.push(item); if self.len_logits() >= self .num_agents .load(Ordering::Relaxed) .min(self.max_autobatch_size) { self.flush_logits(); } } pub fn len_actions(&self) -> usize { self.batch_action.len() } pub fn len_logits(&self) -> usize { self.batch_logits.len() } pub fn flush_actions(&mut self) { if self.len_actions() == 0 { return; } let input: Vec<_> = self .batch_action .iter() .map(|m| m.inference_state.clone()) .collect(); // Only deterministic if all actions are requested as deterministic. let deterministic = self.batch_action.iter().all(|item| item.deterministic); let (actions, context) = self .inner_policy .action(P::Observation::batch(input), deterministic); let actions: Vec<_> = actions.unbatch(); for (i, item) in self.batch_action.iter().enumerate() { item.sender .send(ActionContext { context: vec![context[i].clone()], action: actions[i].clone(), }) .expect("Autobatcher should be able to send resulting actions."); } self.batch_action.clear(); } pub fn flush_logits(&mut self) { if self.len_logits() == 0 { return; } let input: Vec<_> = self .batch_logits .iter() .map(|m| m.inference_state.clone()) .collect(); let output = self.inner_policy.forward(P::Observation::batch(input)); let logits: Vec<_> = output.unbatch(); for (i, item) in self.batch_logits.iter().enumerate() { item.sender .send(logits[i].clone()) .expect("Autobatcher should be able to send resulting probabilities."); } self.batch_logits.clear(); } pub fn update_policy(&mut self, policy_update: P::PolicyState) { if self.len_actions() > 0 { self.flush_actions(); } if self.len_logits() > 0 { self.flush_logits(); } self.inner_policy.update(policy_update); } pub fn state(&self) -> P::PolicyState { self.inner_policy.state() } pub fn increment_agents(&mut self, num: usize) { self.num_agents.fetch_add(num, Ordering::Relaxed); } pub fn decrement_agents(&mut self, num: usize) { self.num_agents.fetch_sub(num, Ordering::Relaxed); if self.len_actions() >= self .num_agents .load(Ordering::Relaxed) .min(self.max_autobatch_size) { self.flush_actions(); } if self.len_logits() >= self .num_agents .load(Ordering::Relaxed) .min(self.max_autobatch_size) { self.flush_logits(); } } } enum InferenceMessage> { ActionMessage(ActionItem), ForwardMessage(ForwardItem), PolicyUpdate(P::PolicyState), PolicyRequest(Sender), IncrementAgents(usize), DecrementAgents(usize), } #[derive(Clone)] struct ActionItem { sender: Sender>>, inference_state: S, deterministic: bool, } #[derive(Clone)] struct ForwardItem { sender: Sender, inference_state: S, } /// An asynchronous policy using an inference server with autobatching. #[derive(Clone)] pub struct AsyncPolicy> { inference_state_sender: Sender>, } impl AsyncPolicy where B: Backend, P: Policy + Clone + Send + 'static, P::ActionContext: Clone + Send, P::PolicyState: Send, P::Observation: Clone + Send + Batchable, P::ActionDistribution: Clone + Send + Batchable, P::Action: Clone + Send + Batchable, { /// Create the policy. /// /// # Arguments /// /// * `autobatch_size` - Number of observations to accumulate before running a pass of inference. /// * `inner_policy` - The policy used to take actions. pub fn new(autobatch_size: usize, inner_policy: P) -> Self { let (sender, receiver) = std::sync::mpsc::channel(); let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone()); spawn(move || { loop { match receiver.recv() { Ok(msg) => match msg { InferenceMessage::ActionMessage(item) => autobatcher.push_action(item), InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item), InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update), InferenceMessage::PolicyRequest(sender) => sender .send(autobatcher.state()) .expect("Autobatcher should be able to send current policy state."), InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num), InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num), }, Err(err) => { log::error!("Error in AsyncPolicy : {}", err); break; } } } }); Self { inference_state_sender: sender, } } /// Increment the number of agents using the inference server. pub fn increment_agents(&self, num: usize) { self.inference_state_sender .send(InferenceMessage::IncrementAgents(num)) .expect("Can send message to autobatcher.") } /// Decrement the number of agents using the inference server. pub fn decrement_agents(&self, num: usize) { self.inference_state_sender .send(InferenceMessage::DecrementAgents(num)) .expect("Can send message to autobatcher.") } } impl Policy for AsyncPolicy where B: Backend, P: Policy + Send + 'static, { type ActionContext = P::ActionContext; type PolicyState = P::PolicyState; type Observation = P::Observation; type ActionDistribution = P::ActionDistribution; type Action = P::Action; fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution { let (action_sender, action_receiver) = std::sync::mpsc::channel(); let item = ForwardItem { sender: action_sender, inference_state: states, }; self.inference_state_sender .send(InferenceMessage::ForwardMessage(item)) .expect("Should be able to send message to inference_server"); action_receiver .recv() .expect("AsyncPolicy should receive queued probabilities.") } fn action( &mut self, states: Self::Observation, deterministic: bool, ) -> (Self::Action, Vec) { let (action_sender, action_receiver) = std::sync::mpsc::channel(); let item = ActionItem { sender: action_sender, inference_state: states, deterministic, }; self.inference_state_sender .send(InferenceMessage::ActionMessage(item)) .expect("should be able to send message to inference_server."); let action = action_receiver .recv() .expect("AsyncPolicy should receive queued actions."); (action.action, action.context) } fn update(&mut self, update: Self::PolicyState) { self.inference_state_sender .send(InferenceMessage::PolicyUpdate(update)) .expect("AsyncPolicy should be able to send policy state.") } fn state(&self) -> Self::PolicyState { let (sender, receiver) = mpsc::channel(); self.inference_state_sender .send(InferenceMessage::PolicyRequest(sender)) .expect("should be able to send message to inference_server."); receiver .recv() .expect("AsyncPolicy should be able to receive policy state.") } fn load_record(self, _record: >::Record) -> Self { // Not needed for now todo!() } } #[cfg(test)] #[allow(clippy::needless_range_loop)] mod tests { use std::thread::JoinHandle; use std::time::Duration; use crate::TestBackend; use crate::tests::{MockAction, MockObservation, MockPolicy}; use super::*; #[test] fn test_multiple_actions_before_flush() { fn launch_thread( policy: &AsyncPolicy, handles: &mut Vec>, ) { let mut thread_policy = policy.clone(); let handle = spawn(move || { thread_policy.action(MockObservation(vec![0.]), false); }); handles.push(handle); } let policy = AsyncPolicy::new(8, MockPolicy::new()); policy.increment_agents(1000); let mut handles = vec![]; launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); for _ in 0..6 { launch_thread(&policy, &mut handles); } std::thread::sleep(Duration::from_millis(10)); for i in 0..7 { assert!(!handles[i].is_finished()); } launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); for i in 0..8 { assert!(handles[i].is_finished()); } let mut handles = vec![]; launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); } #[test] fn test_multiple_forward_before_flush() { fn launch_thread( policy: &AsyncPolicy, handles: &mut Vec>, ) { let mut thread_policy = policy.clone(); let handle = spawn(move || { thread_policy.forward(MockObservation(vec![0.])); }); handles.push(handle); } let policy = AsyncPolicy::new(8, MockPolicy::new()); policy.increment_agents(1000); let mut handles = vec![]; launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); for _ in 0..6 { launch_thread(&policy, &mut handles); } std::thread::sleep(Duration::from_millis(10)); for i in 0..7 { assert!(!handles[i].is_finished()); } launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); for i in 0..8 { assert!(handles[i].is_finished()); } let mut handles = vec![]; launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); } #[test] fn test_async_policy_deterministic_behaviour() { fn launch_thread( policy: &AsyncPolicy, handles: &mut Vec>, deterministic: bool, ) { let mut thread_policy = policy.clone(); let handle = spawn(move || { let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic); action }); handles.push(handle); } let policy = AsyncPolicy::new(2, MockPolicy::new()); policy.increment_agents(1000); let mut handles = vec![]; launch_thread(&policy, &mut handles, true); launch_thread(&policy, &mut handles, false); for _ in 0..2 { let action = handles.pop().unwrap().join().unwrap(); assert_eq!(action.0, vec![0]); } let mut handles = vec![]; launch_thread(&policy, &mut handles, true); launch_thread(&policy, &mut handles, true); for _ in 0..2 { let action = handles.pop().unwrap().join().unwrap(); assert_eq!(action.0, vec![1]); } } #[test] fn flush_when_running_agents_smaller_than_autobatch_size() { fn launch_thread( policy: &AsyncPolicy, handles: &mut Vec>, ) { let mut thread_policy = policy.clone(); let handle = spawn(move || { thread_policy.action(MockObservation(vec![0.]), false); }); handles.push(handle); } let policy = AsyncPolicy::new(8, MockPolicy::new()); policy.increment_agents(3); let mut handles = vec![]; launch_thread(&policy, &mut handles); launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); assert!(!handles[1].is_finished()); launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); for i in 0..3 { assert!(handles[i].is_finished()); } let mut handles = vec![]; launch_thread(&policy, &mut handles); launch_thread(&policy, &mut handles); std::thread::sleep(Duration::from_millis(10)); assert!(!handles[0].is_finished()); assert!(!handles[1].is_finished()); policy.decrement_agents(1); std::thread::sleep(Duration::from_millis(10)); assert!(handles[0].is_finished()); assert!(handles[1].is_finished()); } } ================================================ FILE: crates/burn-rl/src/policy/base.rs ================================================ use derive_new::new; use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend}; use crate::TransitionBatch; /// An action along with additional context about the decision. #[derive(Clone, new)] pub struct ActionContext { /// The context. pub context: C, /// The action. pub action: A, } /// The state of a policy. pub trait PolicyState { /// The type of the record. type Record: Record; /// Convert the state to a record. fn into_record(self) -> Self::Record; /// Load the state from a record. fn load_record(&self, record: Self::Record) -> Self; } /// Trait for a RL policy. pub trait Policy: Clone { /// The observation given as input to the policy. type Observation; /// The action distribution parameters defining how the action will be sampled. type ActionDistribution; /// The action. type Action; /// Additional context on the policy's decision. type ActionContext; /// The current parameterization of the policy. type PolicyState: PolicyState; /// Produces the action distribution from a batch of observations. fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution; /// Gives the action from a batch of observations. fn action( &mut self, obs: Self::Observation, deterministic: bool, ) -> (Self::Action, Vec); /// Update the policy's parameters. fn update(&mut self, update: Self::PolicyState); /// Returns the current parameterization. fn state(&self) -> Self::PolicyState; /// Loads the policy parameters from a record. fn load_record(self, record: >::Record) -> Self; } /// Trait for a type that can be batched and unbatched (split). pub trait Batchable: Sized { /// Create a batch from a list of items. fn batch(value: Vec) -> Self; /// Create a list from batched items. fn unbatch(self) -> Vec; } /// A training output. pub struct RLTrainOutput { /// The policy. pub policy: P, /// The item. pub item: TO, } /// Batched transitions for a PolicyLearner. pub type LearnerTransitionBatch = TransitionBatch>::Observation,

>::Action>; /// Learner for a policy. pub trait PolicyLearner where B: AutodiffBackend, >::Observation: Clone + Batchable, >::ActionDistribution: Clone + Batchable, >::Action: Clone + Batchable, { /// Additional context of a training step. type TrainContext; /// The policy to train. type InnerPolicy: Policy; /// The record of the learner. type Record: Record; /// Execute a training step on the policy. fn train( &mut self, input: LearnerTransitionBatch, ) -> RLTrainOutput>::PolicyState>; /// Returns the learner's current policy. fn policy(&self) -> Self::InnerPolicy; /// Update the learner's policy. fn update_policy(&mut self, update: Self::InnerPolicy); /// Convert the learner's state into a record. fn record(&self) -> Self::Record; /// Load the learner's state from a record. fn load_record(self, record: Self::Record) -> Self; } ================================================ FILE: crates/burn-rl/src/policy/mod.rs ================================================ mod async_policy; mod base; pub use async_policy::*; pub use base::*; ================================================ FILE: crates/burn-rl/src/transition_buffer/base.rs ================================================ use burn_core::{Tensor, prelude::Backend, tensor::Distribution}; use derive_new::new; use super::SliceAccess; /// A state transition in an environment. #[derive(Clone, new)] pub struct Transition { /// The initial state. pub state: S, /// The state after the step was taken. pub next_state: S, /// The action taken in the step. pub action: A, /// The reward. pub reward: Tensor, /// If the environment has reached a terminal state. pub done: Tensor, } /// A batch of transitions. pub struct TransitionBatch { /// Batched initial states. pub states: SB, /// Batched resulting states. pub next_states: SB, /// Batched actions. pub actions: AB, /// Batched rewards. pub rewards: Tensor, /// Batched flags for terminal states. pub dones: Tensor, } /// A tensor-backed circular buffer for transitions. /// /// Uses [`SliceAccess`] to store state and action batches in contiguous /// tensor storage, enabling efficient random sampling via `select`. /// The buffer lazily initializes its storage on the first `push` call. pub struct TransitionBuffer, AB: SliceAccess> { states: Option, next_states: Option, actions: Option, rewards: Option>, dones: Option>, capacity: usize, write_head: usize, len: usize, device: B::Device, } impl, AB: SliceAccess> TransitionBuffer { /// Creates a new buffer. Storage is lazily allocated on the first `push`. pub fn new(capacity: usize, device: &B::Device) -> Self { Self { states: None, next_states: None, actions: None, rewards: None, dones: None, capacity, write_head: 0, len: 0, device: device.clone(), } } fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) { if self.states.is_none() { self.states = Some(SB::zeros_like(state, self.capacity, &self.device)); self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device)); self.actions = Some(AB::zeros_like(action, self.capacity, &self.device)); self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device)); self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device)); } } /// Add a transition, overwriting the oldest if full. pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) { self.ensure_init(&state, &next_state, &action); let idx = self.write_head % self.capacity; self.states .as_mut() .unwrap() .slice_assign_inplace(idx, state); self.next_states .as_mut() .unwrap() .slice_assign_inplace(idx, next_state); self.actions .as_mut() .unwrap() .slice_assign_inplace(idx, action); let reward = Tensor::from_data([[reward]], &self.device); self.rewards .as_mut() .unwrap() .inplace(|r| r.slice_assign(idx..idx + 1, reward)); let done_val = if done { 1.0f32 } else { 0.0 }; let done = Tensor::from_data([[done_val]], &self.device); self.dones .as_mut() .unwrap() .inplace(|d| d.slice_assign(idx..idx + 1, done)); self.write_head += 1; if self.len < self.capacity { self.len += 1; } } /// Sample a random batch of transitions. pub fn sample(&self, batch_size: usize) -> TransitionBatch { assert!(batch_size <= self.len, "batch_size exceeds buffer length"); let indices = Tensor::::random( [batch_size], Distribution::Uniform(0.0, self.len as f64), &self.device, ) .int(); TransitionBatch { states: self .states .as_ref() .unwrap() .clone() .select(0, indices.clone()), next_states: self .next_states .as_ref() .unwrap() .clone() .select(0, indices.clone()), actions: self .actions .as_ref() .unwrap() .clone() .select(0, indices.clone()), rewards: self .rewards .as_ref() .unwrap() .clone() .select(0, indices.clone()), dones: self.dones.as_ref().unwrap().clone().select(0, indices), } } /// Current number of stored transitions. pub fn len(&self) -> usize { self.len } /// Whether the buffer is empty. pub fn is_empty(&self) -> bool { self.len == 0 } /// Buffer capacity. pub fn capacity(&self) -> usize { self.capacity } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; type TB = Tensor; fn push_transition( buffer: &mut TransitionBuffer, device: &::Device, val: f32, ) { let state = Tensor::::from_data([[val, val]], device); let next_state = Tensor::::from_data([[val + 1.0, val + 1.0]], device); let action = Tensor::::from_data([[val]], device); buffer.push(state, next_state, action, val, false); } #[test] fn push_increment_len() { let device = Default::default(); let mut buffer = TransitionBuffer::::new(5, &device); assert_eq!(buffer.len(), 0); assert!(buffer.is_empty()); push_transition(&mut buffer, &device, 1.0); assert_eq!(buffer.len(), 1); push_transition(&mut buffer, &device, 2.0); assert_eq!(buffer.len(), 2); } #[test] fn push_overwrites_when_full() { let device = Default::default(); let mut buffer = TransitionBuffer::::new(3, &device); for i in 0..5 { push_transition(&mut buffer, &device, i as f32); } assert_eq!(buffer.len(), 3); assert_eq!(buffer.capacity(), 3); } #[test] fn sample_returns_correct_shapes() { let device = Default::default(); let mut buffer = TransitionBuffer::::new(10, &device); for i in 0..5 { push_transition(&mut buffer, &device, i as f32); } let batch = buffer.sample(3); assert_eq!(batch.states.dims(), [3, 2]); assert_eq!(batch.next_states.dims(), [3, 2]); assert_eq!(batch.actions.dims(), [3, 1]); assert_eq!(batch.rewards.dims(), [3, 1]); assert_eq!(batch.dones.dims(), [3, 1]); } #[test] #[should_panic(expected = "batch_size exceeds buffer length")] fn sample_panics_when_batch_too_large() { let device = Default::default(); let mut buffer = TransitionBuffer::::new(5, &device); push_transition(&mut buffer, &device, 1.0); buffer.sample(5); } } ================================================ FILE: crates/burn-rl/src/transition_buffer/mod.rs ================================================ mod base; mod slice_access; pub use base::*; pub use slice_access::*; ================================================ FILE: crates/burn-rl/src/transition_buffer/slice_access.rs ================================================ use burn_core::prelude::*; /// Trait for types that support tensor-like slice operations, /// enabling storage in a [`TransitionBuffer`](super::TransitionBuffer). /// /// Implement this trait for any type that wraps tensors and can be stored /// in a replay buffer. The buffer uses these operations for: /// - Pre-allocating storage (`zeros_like`) /// - Writing transitions (`slice_assign_inplace`) /// - Sampling batches (`select`) pub trait SliceAccess: Clone + Sized { /// Create zeroed storage matching the shape of `sample` but with `capacity` rows /// along the first dimension. fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self; /// Select rows at the given indices along the specified dimension. fn select(self, dim: usize, indices: Tensor) -> Self; /// Assign `value` at row `index` along the first dimension, in place. fn slice_assign_inplace(&mut self, index: usize, value: Self); } impl SliceAccess for Tensor { fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self { let feature_dim = sample.dims()[1]; Tensor::zeros([capacity, feature_dim], device) } fn select(self, dim: usize, indices: Tensor) -> Self { Tensor::select(self, dim, indices) } fn slice_assign_inplace(&mut self, index: usize, value: Self) { self.inplace(|t| t.slice_assign(index..index + 1, value)); } } ================================================ FILE: crates/burn-rocm/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "ROCm HIP backend for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu", "rocm", "hip"] license.workspace = true name = "burn-rocm" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-rocm" documentation = "https://docs.rs/burn-rocm" version.workspace = true [lints] workspace = true [features] default = ["fusion", "burn-cubecl/default", "cubecl/default"] tracing = [ "cubecl/tracing", "burn-cubecl/tracing", "burn-backend/tracing", "burn-fusion?/tracing", ] fusion = ["burn-fusion", "burn-cubecl/fusion"] autotune = ["burn-cubecl/autotune"] autotune-checks = ["burn-cubecl/autotune-checks"] doc = ["burn-cubecl/doc"] std = ["burn-cubecl/std", "cubecl/std"] [dependencies] cubecl = { workspace = true, features = ["hip"] } burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", default-features = true } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", features = [ "cubecl-hip", ] } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-rocm/README.md ================================================ # burn-rocm Backend using ROCm HIP runtime. To execute the tests for this backend set an environment variable called `ROCM_PATH` or `CUBECL_ROCM_PATH` to the installation path of ROCm. It is often `/opt/rocm`. For now this backend requires the version `6.2.2` of ROCm or a compatible version. ================================================ FILE: crates/burn-rocm/src/lib.rs ================================================ #![cfg_attr(docsrs, feature(doc_cfg))] extern crate alloc; use burn_cubecl::CubeBackend; pub use cubecl::hip::AmdDevice as RocmDevice; use cubecl::hip::HipRuntime; #[cfg(not(feature = "fusion"))] pub type Rocm = CubeBackend; #[cfg(feature = "fusion")] pub type Rocm = burn_fusion::Fusion>; ================================================ FILE: crates/burn-router/Cargo.toml ================================================ [package] authors = [ "laggui ", "nathanielsimard ", ] categories = ["science"] description = "Multi-backend router decorator for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-router" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-router" documentation = "https://docs.rs/burn-router" version.workspace = true [lints] workspace = true [features] default = ["std"] std = ["burn-backend/std", "burn-std/std", "burn-ir/std"] doc = ["default"] tracing = [ "burn-backend/tracing", "burn-ir/tracing", "burn-std/tracing", ] [dependencies] burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } hashbrown = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false, features = [ "std", ] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-router/README.md ================================================ # Burn Router A multi-backend extension that forwards the tensor operations to the appropriate backend. ================================================ FILE: crates/burn-router/src/backend.rs ================================================ use super::{RouterTensor, RunnerChannel, RunnerClient, get_client}; use alloc::{format, string::String}; use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme}; use core::marker::PhantomData; /// A backend that forwards the tensor operations to the appropriate backend (given multiple backends). pub struct BackendRouter { r: PhantomData, } impl core::fmt::Debug for BackendRouter { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!("router")) } } impl Clone for BackendRouter { fn clone(&self) -> Self { Self { r: PhantomData } } } impl Default for BackendRouter { fn default() -> Self { Self { r: PhantomData } } } impl QTensorPrimitive for RouterTensor { fn scheme(&self) -> &QuantScheme { if let DType::QFloat(scheme) = &self.dtype { scheme } else { // TODO: maybe `tensor.scheme()` should return an option panic!("Expected quantized float dtype, got {:?}", self.dtype) } } } impl Backend for BackendRouter { type Device = R::Device; type FloatTensorPrimitive = RouterTensor; type FloatElem = R::FloatElem; type IntTensorPrimitive = RouterTensor; type IntElem = R::IntElem; type BoolTensorPrimitive = RouterTensor; type BoolElem = R::BoolElem; type QuantizedTensorPrimitive = RouterTensor; fn name(device: &Self::Device) -> String { format!("router<{}>", R::name(device)) } fn seed(device: &Self::Device, seed: u64) { let client = get_client::(device); client.seed(seed); } fn sync(device: &Self::Device) -> Result<(), ExecutionError> { let client = get_client::(device); client.sync() } fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { let client = get_client::(device); client.dtype_usage(dtype) } } ================================================ FILE: crates/burn-router/src/bridge/base.rs ================================================ use burn_backend::{Shape, backend::DeviceOps}; /// Allows tensors to be transferred between multiple backends. pub trait MultiBackendBridge: Send + Sync + 'static { /// The type that can be used to point to a tensor of any kind. type TensorHandle; /// Device type used by the backends. type Device: DeviceOps; /// Change the backend of the given float tensor. fn change_backend_float( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle; /// Change the backend of the given int tensor. fn change_backend_int( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle; /// Change the backend of the given bool tensor. fn change_backend_bool( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle; // TODO: change_backend_quantized } ================================================ FILE: crates/burn-router/src/bridge/byte.rs ================================================ use core::marker::PhantomData; /// Simply transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData). pub struct ByteBridge { backends: PhantomData, } ================================================ FILE: crates/burn-router/src/bridge/mod.rs ================================================ mod base; mod byte; pub use base::*; pub use byte::*; ================================================ FILE: crates/burn-router/src/channel/base.rs ================================================ use alloc::string::String; use burn_backend::{DType, Element, Shape, backend::DeviceOps}; use burn_ir::TensorIr; use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client}; /// Type alias for `
::TensorHandle`. pub type TensorHandle
=
::TensorHandle; /// Defines the connection channel and operations for a setup with multiple backend runner clients. pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized { /// Device type. type Device: DeviceOps; /// A bridge that can transfer tensors between multiple backends. type Bridge: MultiBackendBridge; /// Client type. type Client: RunnerClient; /// Float element type. type FloatElem: Element; /// Int element type. type IntElem: Element; /// Bool element type. type BoolElem: Element; /// Name of the channel. fn name(device: &Self::Device) -> String; /// Initialize a new client for the given device. fn init_client(device: &Self::Device) -> Self::Client; /// Get the tensor handle corresponding to the [tensor representation](TensorIr). fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle; /// Create a tensor with the given handle and shape. fn register_tensor( client: &Self::Client, handle: TensorHandle, shape: Shape, dtype: DType, ) -> RouterTensor; /// Change the tensor to a different client backend. fn change_client_backend( tensor: RouterTensor, device: &Self::Device, // target device ) -> RouterTensor { // Get tensor handle from current client let original_client = tensor.client.clone(); let desc = tensor.into_ir(); let mut handle = Self::get_tensor_handle(&desc, &original_client); if desc.dtype.is_float() { handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device); } else if desc.dtype.is_int() { handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device); } else if desc.dtype.is_bool() { handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device); } else { unimplemented!() } // Register tensor handle on target client let target_client = get_client::(device); Self::register_tensor(&target_client, handle, desc.shape, desc.dtype) } } ================================================ FILE: crates/burn-router/src/channel/direct.rs ================================================ use core::marker::PhantomData; /// A local channel with direct connection to the backend runner clients. pub struct DirectChannel { backends: PhantomData, bridge: PhantomData, } impl Clone for DirectChannel { fn clone(&self) -> Self { Self { backends: self.backends, bridge: self.bridge, } } } ================================================ FILE: crates/burn-router/src/channel/mod.rs ================================================ mod base; mod direct; pub use base::*; pub use direct::*; ================================================ FILE: crates/burn-router/src/client/base.rs ================================================ use crate::{RouterTensor, RunnerChannel}; use alloc::boxed::Box; use alloc::vec::Vec; use burn_backend::{ DType, TensorData, backend::{DeviceId, DeviceOps, ExecutionError}, }; use burn_ir::{OperationIr, TensorId, TensorIr}; use burn_std::future::DynFut; use core::ops::DerefMut; use hashbrown::HashMap; use spin::Mutex; /// Type alias for `::Client`. pub type Client = ::Client; pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new(); type Key = (core::any::TypeId, DeviceId); /// Define how to interact with the runner. pub trait RunnerClient: Clone + Send + Sync + Sized { /// Device type. type Device: DeviceOps; /// Register a new tensor operation to be executed by the (runner) server. fn register_op(&self, op: OperationIr); /// Register a new tensor operation to be executed by the (runner) server. /// /// Returns the new (uninitialized) output tensor(s) generated by the registered operation. fn register(&self, op: OperationIr) -> Vec> { let out = op .outputs() .map(|output| { RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone()) }) .collect(); self.register_op(op); out } /// Read the values contained by a tensor. fn read_tensor_async(&self, tensor: TensorIr) -> DynFut>; /// Sync the runner, ensure that all computations are finished. fn sync(&self) -> Result<(), ExecutionError>; /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId). fn create_empty_handle(&self) -> TensorId; /// Create a new [RouterTensor] from the tensor data. fn register_tensor_data(&self, data: TensorData) -> RouterTensor; /// Get the current device used by all operations handled by this client. fn device(&self) -> Self::Device; /// Seed the runner. fn seed(&self, seed: u64); /// Returns the supported data type usage set fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet; } pub(crate) struct RunnerClientLocator { clients: Mutex>>>, } /// Get the client for the given device pub fn get_client(device: &R::Device) -> Client { CLIENTS.client::(device) } /// Initialize a new client for the given device. /// /// If a (global) seed was previously set, the client seed is set. fn new_client(device: &R::Device) -> Client { R::init_client(device) } impl RunnerClientLocator { /// Create a new client locator. pub const fn new() -> Self { Self { clients: Mutex::new(None), } } /// Get the runner client for the given device. /// /// If a client isn't already initialized, it is created. pub fn client(&self, device: &R::Device) -> Client { let device_id = device.id(); let client_id = (core::any::TypeId::of::(), device_id); let mut clients = self.clients.lock(); if clients.is_none() { let client = new_client::(device); Self::register_inner::(client_id, client, &mut clients); } match clients.deref_mut() { Some(clients) => match clients.get(&client_id) { Some(client) => { let client: &Client = client.downcast_ref().unwrap(); client.clone() } None => { let client = new_client::(device); let any = Box::new(client.clone()); clients.insert(client_id, any); client } }, _ => unreachable!(), } } fn register_inner( key: Key, client: Client, clients: &mut Option>>, ) { if clients.is_none() { *clients = Some(HashMap::new()); } if let Some(clients) = clients { if clients.contains_key(&key) { panic!("Client already created for device {key:?}"); } clients.insert(key, Box::new(client)); } } } ================================================ FILE: crates/burn-router/src/client/mod.rs ================================================ mod base; pub use base::*; ================================================ FILE: crates/burn-router/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![recursion_limit = "138"] //! Burn multi-backend router. mod backend; mod bridge; mod channel; mod client; mod ops; mod runner; mod tensor; mod types; pub use backend::*; pub use bridge::*; pub use channel::*; pub use client::*; pub use runner::*; pub use tensor::*; pub use types::*; /// A local channel with a simple byte bridge between backends. /// It transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData). pub type DirectByteChannel = DirectChannel>; /// Router backend. /// /// # Example /// /// ```ignore /// type MyBackend = Router<(NdArray, Wgpu)>; /// ``` pub type Router = BackendRouter>; extern crate alloc; #[cfg(test)] #[allow(unused)] mod tests { use crate::BackendRouter; use crate::DirectByteChannel; pub type TestBackend1 = burn_ndarray::NdArray; pub type TestBackend2 = burn_wgpu::Wgpu; pub type TestBackend = BackendRouter>; } ================================================ FILE: crates/burn-router/src/ops/activation.rs ================================================ use crate::{BackendRouter, RunnerChannel}; use burn_backend::ops::ActivationOps; impl ActivationOps for BackendRouter {} ================================================ FILE: crates/burn-router/src/ops/binary.rs ================================================ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let rhs = $handles.get_float_tensor::(&$desc.rhs); let output = $ops(lhs, rhs); $handles.register_float_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_float_cmp_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let rhs = $handles.get_float_tensor::(&$desc.rhs); let output = $ops(lhs, rhs); $handles.register_bool_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.lhs); let rhs = $handles.get_int_tensor::(&$desc.rhs); let output = $ops(lhs, rhs); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_int_cmp_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.lhs); let rhs = $handles.get_int_tensor::(&$desc.rhs); let output = $ops(lhs, rhs); $handles.register_bool_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! binary_bool_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_bool_tensor::(&$desc.lhs); let rhs = $handles.get_bool_tensor::(&$desc.rhs); let output = $ops(lhs, rhs); $handles.register_bool_tensor::(&$desc.out.id, output); }}; } ================================================ FILE: crates/burn-router/src/ops/bool_tensor.rs ================================================ use alloc::vec::Vec; use burn_backend::backend::ExecutionError; use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client}; use burn_backend::ops::BoolTensorOps; use burn_backend::tensor::{ BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor, }; use burn_backend::{Element, Scalar, Shape, Slice, TensorData}; use burn_ir::{ BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr, }; impl BoolTensorOps for BackendRouter { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc))) .output() } fn bool_zeros(shape: Shape, device: &Device) -> BoolTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc))) .output() } fn bool_ones(shape: Shape, device: &Device) -> BoolTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc))) .output() } async fn bool_into_data(tensor: BoolTensor) -> Result { tensor.into_data().await } fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); let out = client.register_tensor_data(data); let desc = InitOperationIr { out: out.to_ir_out(), }; // Call register op when output is already initialized client.register_op(OperationIr::Init(desc)); out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), IntElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc))) .output() } fn bool_into_float(tensor: BoolTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), FloatElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc))) .output() } fn bool_device(tensor: &BoolTensor) -> Device { tensor.client.device() } fn bool_to_device(tensor: BoolTensor, device: &Device) -> BoolTensor { if &tensor.client.device() == device { return tensor; } R::change_client_backend(tensor, device) } fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc))) .output() } fn bool_slice(tensor: BoolTensor, slices: &[Slice]) -> BoolTensor { let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc))) .output() } fn bool_slice_assign( tensor: BoolTensor, slices: &[burn_backend::Slice], value: BoolTensor, ) -> BoolTensor { let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc))) .output() } fn bool_equal(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc))) .output() } fn bool_not(tensor: BoolTensor) -> BoolTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Bool(BoolOperationIr::Not(desc))) .output() } fn bool_and(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Bool(BoolOperationIr::And(desc))) .output() } fn bool_or(lhs: BoolTensor, rhs: BoolTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Bool(BoolOperationIr::Or(desc))) .output() } fn bool_swap_dims(tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc))) .output() } fn bool_permute(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc))) .output() } fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc))) .output() } fn bool_expand(tensor: BoolTensor, shape: Shape) -> BoolTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc))) .output() } fn bool_cat(tensors: Vec>, dim: usize) -> BoolTensor { let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc))) .output() } fn bool_repeat_dim(tensor: BoolTensor, dim: usize, times: usize) -> BoolTensor { let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc))) .output() } fn bool_unfold( tensor: BoolTensor, dim: usize, size: usize, step: usize, ) -> BoolTensor { let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc))) .output() } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc))) .output() } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc))) .output() } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc))) .output() } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc))) .output() } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc))) .output() } } ================================================ FILE: crates/burn-router/src/ops/int_tensor.rs ================================================ use alloc::vec::Vec; use burn_backend::backend::{Backend, ExecutionError}; use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client}; use burn_backend::tensor::{ BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor, }; use burn_backend::{ Distribution, Element, IntDType, Scalar, Shape, Slice, TensorData, ops::IntTensorOps, }; use burn_ir::{ BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr, GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr, }; impl IntTensorOps for BackendRouter { fn int_empty(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc))) .output() } async fn int_into_data(tensor: IntTensor) -> Result { Ok(tensor .into_data() .await? // Since underlying backends can have different data types, we convert to the current elem .convert::<::IntElem>()) } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); let out = client.register_tensor_data(data); let desc = InitOperationIr { out: out.to_ir_out(), }; // Call register op when output is already initialized client.register_op(OperationIr::Init(desc)); out } fn int_device(tensor: &IntTensor) -> Device { tensor.client.device() } fn int_to_device(tensor: IntTensor, device: &Device) -> IntTensor { if &tensor.client.device() == device { return tensor; } R::change_client_backend(tensor, device) } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc))) .output() } fn int_slice(tensor: IntTensor, slices: &[Slice]) -> IntTensor { let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc))) .output() } fn int_slice_assign( tensor: IntTensor, slices: &[burn_backend::Slice], value: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc))) .output() } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::Matmul(desc))) .output() } fn int_mask_where( tensor: IntTensor, mask: BoolTensor, value: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc))) .output() } fn int_mask_fill( tensor: IntTensor, mask: BoolTensor, value: Scalar, ) -> IntTensor { let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc))) .output() } fn int_gather( dim: usize, tensor: IntTensor, indices: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc))) .output() } fn int_scatter_add( dim: usize, tensor: IntTensor, indices: IntTensor, value: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc))) .output() } fn int_select( tensor: IntTensor, dim: usize, indices: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Select(desc))) .output() } fn int_select_add( tensor: IntTensor, dim: usize, indices: IntTensor, value: IntTensor, ) -> IntTensor { let client = tensor.client.clone(); let desc = SelectAssignOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc))) .output() } fn int_cat(tensors: Vec>, dim: usize) -> IntTensor { let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc))) .output() } fn int_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc))) .output() } fn int_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc))) .output() } fn int_greater(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::Greater(desc), )) .output() } fn int_greater_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterElem(desc), )) .output() } fn int_greater_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterEqual(desc), )) .output() } fn int_greater_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::GreaterEqualElem(desc), )) .output() } fn int_lower(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::Lower(desc), )) .output() } fn int_lower_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerElem(desc), )) .output() } fn int_lower_equal(lhs: IntTensor, rhs: IntTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerEqual(desc), )) .output() } fn int_lower_equal_elem(lhs: IntTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.lhs.dtype, NumericOperationIr::LowerEqualElem(desc), )) .output() } fn int_add(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Add(desc), )) .output() } fn int_add_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::AddScalar(desc), )) .output() } fn int_sub(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Sub(desc), )) .output() } fn int_sub_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::SubScalar(desc), )) .output() } fn int_mul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Mul(desc), )) .output() } fn int_mul_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MulScalar(desc), )) .output() } fn int_div(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Div(desc), )) .output() } fn int_div_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::DivScalar(desc), )) .output() } fn int_remainder(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Rem(desc), )) .output() } fn int_remainder_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::RemScalar(desc), )) .output() } fn int_zeros(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc))) .output() } fn int_ones(shape: Shape, device: &Device, dtype: IntDType) -> IntTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc))) .output() } fn int_sum(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Sum(desc), )) .output() } fn int_sum_dim(tensor: IntTensor, axis: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::SumDim(desc), )) .output() } fn int_prod(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Prod(desc), )) .output() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::ProdDim(desc), )) .output() } fn int_mean(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Mean(desc), )) .output() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MeanDim(desc), )) .output() } fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::CumSum(desc), )) .output() } fn int_cumprod(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::CumProd(desc), )) .output() } fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::CumMin(desc), )) .output() } fn int_cummax(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::CumMax(desc), )) .output() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::ArgMax(desc), )) .output() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::ArgMin(desc), )) .output() } fn int_clamp(tensor: IntTensor, min: Scalar, max: Scalar) -> IntTensor { let client = tensor.client.clone(); let min = min.into(); let max = max.into(); let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Clamp(desc), )) .output() } fn int_abs(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Abs(desc), )) .output() } fn int_into_float(tensor: IntTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), FloatElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::IntoFloat(desc))) .output() } fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc))) .output() } fn int_max(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Max(desc), )) .output() } fn int_max_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MaxDim(desc), )) .output() } fn int_max_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create( tensor.into_ir(), dim, IntElem::::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.tensor.dtype, NumericOperationIr::MaxDimWithIndices(desc), )) .outputs() .into() } fn int_max_abs(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MaxAbs(desc), )) .output() } fn int_max_abs_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MaxAbsDim(desc), )) .output() } fn int_min(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::Min(desc), )) .output() } fn int_min_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MinDim(desc), )) .output() } fn int_min_dim_with_indices( tensor: IntTensor, dim: usize, ) -> (IntTensor, IntTensor) { let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create( tensor.into_ir(), dim, IntElem::::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericInt( desc.out.dtype, NumericOperationIr::MinDimWithIndices(desc), )) .outputs() .into() } fn int_random( shape: Shape, distribution: Distribution, device: &Device, ) -> IntTensor { let client = get_client::(device); let dtype = IntElem::::dtype(); let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle()); client .register(OperationIr::NumericInt( dtype, NumericOperationIr::IntRandom(desc), )) .output() } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc))) .output() } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc))) .output() } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc))) .output() } fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc))) .output() } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc))) .output() } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc))) .output() } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc))) .output() } fn bitwise_not(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc))) .output() } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc))) .output() } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc))) .output() } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc))) .output() } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc))) .output() } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar( desc, ))) .output() } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc))) .output() } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar( desc, ))) .output() } fn int_cast(tensor: IntTensor, dtype: burn_backend::IntDType) -> IntTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc))) .output() } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc))) .output() } } ================================================ FILE: crates/burn-router/src/ops/mod.rs ================================================ mod activation; mod binary; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; mod unary; ================================================ FILE: crates/burn-router/src/ops/module.rs ================================================ use alloc::boxed::Box; use burn_backend::Element; use burn_backend::ops::{ AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; use burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor}; use burn_ir::*; use crate::{BackendRouter, RunnerChannel, RunnerClient}; impl ModuleOps for BackendRouter { fn conv1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<1>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv1dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc))) .output() } fn conv1d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv1dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward( desc, ))) .output() } fn conv1d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv1dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module( ModuleOperationIr::Conv1dWeightBackward(desc), )) .output() } fn conv1d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv1dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward( desc, ))) .output() } fn conv2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv2dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc))) .output() } fn conv2d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv2dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward( desc, ))) .output() } fn conv2d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv2dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module( ModuleOperationIr::Conv2dWeightBackward(desc), )) .output() } fn conv2d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv2dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward( desc, ))) .output() } fn conv3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvOptions<3>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv3dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc))) .output() } fn conv3d_x_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv3dXBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward( desc, ))) .output() } fn conv3d_weight_backward( x: FloatTensor, weight: FloatTensor, output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv3dWeightBackwardOpIr::create( x.into_ir(), weight.into_ir(), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module( ModuleOperationIr::Conv3dWeightBackward(desc), )) .output() } fn conv3d_bias_backward( x: FloatTensor, bias: FloatTensor, output_grad: FloatTensor, ) -> FloatTensor { let client = x.client.clone(); let desc = Conv3dBiasBackwardOpIr::create( x.into_ir(), bias.into_ir(), output_grad.into_ir(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward( desc, ))) .output() } fn conv_transpose1d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> FloatTensor { let client = x.client.clone(); let desc = ConvTranspose1dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d( desc, ))) .output() } fn conv_transpose2d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { let client = x.client.clone(); let desc = ConvTranspose2dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d( desc, ))) .output() } fn conv_transpose3d( x: FloatTensor, weight: FloatTensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> FloatTensor { let client = x.client.clone(); let desc = ConvTranspose3dOpIr::create( x.into_ir(), weight.into_ir(), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d( desc, ))) .output() } fn avg_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = AvgPool1dOpIr::create( x.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc))) .output() } fn avg_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = AvgPool2dOpIr::create( x.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc))) .output() } fn avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = AvgPool1dBackwardOpIr::create( x.into_ir(), grad.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward( desc, ))) .output() } fn avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = AvgPool2dBackwardOpIr::create( x.into_ir(), grad.into_ir(), kernel_size, stride, padding, count_include_pad, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward( desc, ))) .output() } fn max_pool1d( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = MaxPool1dOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc))) .output() } fn max_pool2d( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> FloatTensor { let client = x.client.clone(); let desc = MaxPool2dOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc))) .output() } fn max_pool1d_with_indices( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { let client = x.client.clone(); let desc = MaxPool1dWithIndicesOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, IntElem::::dtype(), || client.create_empty_handle(), ); let [out, out_indices] = client .register(OperationIr::Module( ModuleOperationIr::MaxPool1dWithIndices(desc), )) .outputs(); MaxPool1dWithIndices::new(out, out_indices) } fn max_pool2d_with_indices( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { let client = x.client.clone(); let desc = MaxPool2dWithIndicesOpIr::create( x.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, IntElem::::dtype(), || client.create_empty_handle(), ); let [out, out_indices] = client .register(OperationIr::Module( ModuleOperationIr::MaxPool2dWithIndices(desc), )) .outputs(); MaxPool2dWithIndices::new(out, out_indices) } fn max_pool1d_with_indices_backward( x: FloatTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool1dBackward { let client = x.client.clone(); let desc = MaxPool1dWithIndicesBackwardOpIr::create( x.into_ir(), output_grad.into_ir(), indices.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); let out = client .register(OperationIr::Module( ModuleOperationIr::MaxPool1dWithIndicesBackward(desc), )) .output(); MaxPool1dBackward::new(out) } fn max_pool2d_with_indices_backward( x: FloatTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: FloatTensor, indices: IntTensor, ) -> MaxPool2dBackward { let client = x.client.clone(); let desc = MaxPool2dWithIndicesBackwardOpIr::create( x.into_ir(), output_grad.into_ir(), indices.into_ir(), kernel_size, stride, padding, dilation, ceil_mode, || client.create_empty_handle(), ); let out = client .register(OperationIr::Module( ModuleOperationIr::MaxPool2dWithIndicesBackward(desc), )) .output(); MaxPool2dBackward::new(out) } fn adaptive_avg_pool1d(x: FloatTensor, output_size: usize) -> FloatTensor { let client = x.client.clone(); let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || { client.create_empty_handle() }); client .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d( desc, ))) .output() } fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { let client = x.client.clone(); let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || { client.create_empty_handle() }); client .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d( desc, ))) .output() } fn adaptive_avg_pool1d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { let client = x.client.clone(); let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Module( ModuleOperationIr::AdaptiveAvgPool1dBackward(desc), )) .output() } fn adaptive_avg_pool2d_backward( x: FloatTensor, grad: FloatTensor, ) -> FloatTensor { let client = x.client.clone(); let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Module( ModuleOperationIr::AdaptiveAvgPool2dBackward(desc), )) .output() } fn interpolate( x: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { let client = x.client.clone(); let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || { client.create_empty_handle() }); client .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc))) .output() } fn interpolate_backward( x: FloatTensor, grad: FloatTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> FloatTensor { let client = x.client.clone(); let desc = InterpolateBackwardOpIr::create( x.into_ir(), grad.into_ir(), output_size, options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward( desc, ))) .output() } fn deform_conv2d( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { let client = x.client.clone(); let desc = DeformConv2dOpIr::create( x.into_ir(), offset.into_ir(), weight.into_ir(), mask.map(|mask| mask.into_ir()), bias.map(|bias| bias.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d( Box::new(desc), ))) .output() } fn deform_conv2d_backward( x: FloatTensor, offset: FloatTensor, weight: FloatTensor, mask: Option>, bias: Option>, output_grad: FloatTensor, options: DeformConvOptions<2>, ) -> DeformConv2dBackward { let client = x.client.clone(); let has_bias = bias.is_some(); let has_mask = mask.is_some(); let desc = DeformConv2dBackwardOpIr::create( x.into_ir(), offset.into_ir(), weight.into_ir(), mask.map(|mask| mask.into_ir()), bias.map(|bias| bias.into_ir()), output_grad.into_ir(), options.into(), || client.create_empty_handle(), ); let mut outputs = client .register(OperationIr::Module( ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)), )) .into_iter(); // When the number of outputs is variable, the order is important let input_grad = outputs.next().unwrap(); let offset_grad = outputs.next().unwrap(); let weight_grad = outputs.next().unwrap(); let mask_grad = has_mask.then(|| outputs.next().unwrap()); let bias_grad = has_bias.then(|| outputs.next().unwrap()); DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad) } fn attention( query: FloatTensor, key: FloatTensor, value: FloatTensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> FloatTensor { let client = query.client.clone(); let desc = AttentionOpIr::create( query.into_ir(), key.into_ir(), value.into_ir(), mask.map(|m: BoolTensor| m.into_ir()), attn_bias.map(|ab| ab.into_ir()), options.into(), || client.create_empty_handle(), ); client .register(OperationIr::Module(ModuleOperationIr::Attention(desc))) .output() } } ================================================ FILE: crates/burn-router/src/ops/qtensor.rs ================================================ use burn_backend::{ ExecutionError, Shape, Slice, TensorData, ops::QTensorOps, quantization::{QuantScheme, QuantizationParametersPrimitive}, tensor::{Device, FloatTensor, IntTensor, QuantizedTensor}, }; use crate::{BackendRouter, RunnerChannel}; impl QTensorOps for BackendRouter { fn q_from_data(_data: TensorData, _device: &Device) -> QuantizedTensor { unimplemented!() } fn quantize( _tensor: FloatTensor, _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { unimplemented!() } fn quantize_dynamic( _tensor: FloatTensor, _scheme: &QuantScheme, ) -> QuantizedTensor { unimplemented!() } fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { unimplemented!() } fn q_device(_tensor: &QuantizedTensor) -> Device { unimplemented!() } fn q_to_device( _tensor: QuantizedTensor, _device: &Device, ) -> QuantizedTensor { unimplemented!() } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } async fn q_into_data(_tensor: QuantizedTensor) -> Result { unimplemented!() } fn q_swap_dims( _tensor: QuantizedTensor, _dim1: usize, _dim2: usize, ) -> QuantizedTensor { unimplemented!() } fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_gather( _dim: usize, _tensor: QuantizedTensor, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_select( _tensor: QuantizedTensor, _dim: usize, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_slice(_tensor: QuantizedTensor, _slices: &[Slice]) -> QuantizedTensor { unimplemented!() } fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } } ================================================ FILE: crates/burn-router/src/ops/tensor.rs ================================================ use alloc::vec::Vec; use burn_backend::Scalar; use burn_backend::backend::{Backend, ExecutionError}; use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client}; use burn_backend::tensor::{ BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor, }; use burn_backend::{ Distribution, Element, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps, }; use burn_ir::{ BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr, FlipOpIr, FloatOperationIr, FullOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr, }; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); let out = client.register_tensor_data(data); let desc = InitOperationIr { out: out.to_ir_out(), }; // Call register op when output is already initialized client.register_op(OperationIr::Init(desc)); out } fn float_random( shape: Shape, distribution: Distribution, device: &Device, ) -> FloatTensor { let client = get_client::(device); let dtype = FloatElem::::dtype(); let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle()); client .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc))) .output() } fn float_zeros(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc))) .output() } fn float_ones(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc))) .output() } fn float_full( shape: Shape, fill_value: Scalar, device: &Device, dtype: FloatDType, ) -> FloatTensor { let client = get_client::(device); let dtype = dtype.into(); let value = fill_value.into(); let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Full(desc), )) .output() } async fn float_into_data(tensor: FloatTensor) -> Result { Ok(tensor .into_data() .await? // Since underlying backends can have different data types, we convert to the current elem .convert::<::FloatElem>()) } fn float_device(tensor: &FloatTensor) -> Device { tensor.client.device() } fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { if &tensor.client.device() == device { return tensor; } R::change_client_backend(tensor, device) } fn float_into_int(tensor: FloatTensor) -> IntTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), IntElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::Float( desc.input.dtype, FloatOperationIr::IntoInt(desc), )) .output() } fn float_empty(shape: Shape, device: &Device, dtype: FloatDType) -> FloatTensor { let client = get_client::(device); let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc))) .output() } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Add(desc), )) .output() } fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::AddScalar(desc), )) .output() } fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { let client = tensor.client.clone(); let min = min.into(); let max = max.into(); let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Clamp(desc), )) .output() } fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Sub(desc), )) .output() } fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::SubScalar(desc), )) .output() } fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Mul(desc), )) .output() } fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MulScalar(desc), )) .output() } fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Div(desc), )) .output() } fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::DivScalar(desc), )) .output() } fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Rem(desc), )) .output() } fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::RemScalar(desc), )) .output() } fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Matmul(desc), )) .output() } fn float_cross( lhs: FloatTensor, rhs: FloatTensor, dim: usize, ) -> FloatTensor { let client = lhs.client.clone(); let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || { client.create_empty_handle() }); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Cross(desc), )) .output() } fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc))) .output() } fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc))) .output() } fn float_gather( dim: usize, tensor: FloatTensor, indices: IntTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc))) .output() } fn float_scatter_add( dim: usize, tensor: FloatTensor, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = ScatterOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc))) .output() } fn float_select( tensor: FloatTensor, dim: usize, indices: IntTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc))) .output() } fn float_select_add( tensor: FloatTensor, dim: usize, indices: IntTensor, value: FloatTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = SelectAssignOpIr::create( tensor.into_ir(), dim, indices.into_ir(), value.into_ir(), IndexingUpdateOp::Add, || client.create_empty_handle(), ); client .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc))) .output() } fn float_slice(tensor: FloatTensor, slices: &[Slice]) -> FloatTensor { let client = tensor.client.clone(); let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc))) .output() } fn float_slice_assign( tensor: FloatTensor, slices: &[burn_backend::Slice], value: FloatTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc))) .output() } fn float_mask_where( tensor: FloatTensor, mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { let client = tensor.client.clone(); let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc))) .output() } fn float_mask_fill( tensor: FloatTensor, mask: BoolTensor, value: Scalar, ) -> FloatTensor { let client = tensor.client.clone(); let value = value.into(); let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc))) .output() } fn float_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc))) .output() } fn float_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc))) .output() } fn float_greater(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::Greater(desc), )) .output() } fn float_greater_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterElem(desc), )) .output() } fn float_greater_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterEqual(desc), )) .output() } fn float_greater_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::GreaterEqualElem(desc), )) .output() } fn float_lower(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::Lower(desc), )) .output() } fn float_lower_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerElem(desc), )) .output() } fn float_lower_equal(lhs: FloatTensor, rhs: FloatTensor) -> BoolTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create_comparison( lhs.into_ir(), rhs.into_ir(), R::BoolElem::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerEqual(desc), )) .output() } fn float_lower_equal_elem(lhs: FloatTensor, rhs: Scalar) -> BoolTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.lhs.dtype, NumericOperationIr::LowerEqualElem(desc), )) .output() } fn float_sum(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Sum(desc), )) .output() } fn float_sum_dim(tensor: FloatTensor, axis: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::SumDim(desc), )) .output() } fn float_prod(tensor: IntTensor) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Prod(desc), )) .output() } fn float_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::ProdDim(desc), )) .output() } fn float_mean(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Mean(desc), )) .output() } fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MeanDim(desc), )) .output() } fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::CumSum(desc), )) .output() } fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::CumProd(desc), )) .output() } fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::CumMin(desc), )) .output() } fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::CumMax(desc), )) .output() } fn float_exp(lhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Exp(desc), )) .output() } fn float_log(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Log(desc), )) .output() } fn float_log1p(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Log1p(desc), )) .output() } fn float_powf_scalar_impl(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { let client = lhs.client.clone(); let rhs = rhs.into(); let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::PowfScalar(desc), )) .output() } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Sqrt(desc), )) .output() } fn float_abs(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Abs(desc), )) .output() } fn float_cos(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Cos(desc), )) .output() } fn float_cosh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Cosh(desc), )) .output() } fn float_sin(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Sin(desc), )) .output() } fn float_sinh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Sinh(desc), )) .output() } fn float_tan(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Tan(desc), )) .output() } fn float_tanh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Tanh(desc), )) .output() } fn float_acos(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcCos(desc), )) .output() } fn float_acosh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcCosh(desc), )) .output() } fn float_asin(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcSin(desc), )) .output() } fn float_asinh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcSinh(desc), )) .output() } fn float_atan(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcTan(desc), )) .output() } fn float_atanh(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcTanh(desc), )) .output() } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::ArcTan2(desc), )) .output() } fn float_round(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Round(desc), )) .output() } fn float_floor(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Floor(desc), )) .output() } fn float_ceil(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Ceil(desc), )) .output() } fn float_trunc(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Trunc(desc), )) .output() } fn float_recip(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Recip(desc), )) .output() } fn float_erf(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Erf(desc), )) .output() } fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { let client = tensors.first().unwrap().client.clone(); let tensors = tensors.into_iter().map(|t| t.into_ir()).collect(); let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc))) .output() } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.input.dtype, NumericOperationIr::ArgMax(desc), )) .output() } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc))) .output() } fn float_argmin(tensor: FloatTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::::dtype(), || { client.create_empty_handle() }); client .register(OperationIr::NumericFloat( desc.input.dtype, NumericOperationIr::ArgMin(desc), )) .output() } fn float_max(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Max(desc), )) .output() } fn float_max_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MaxDim(desc), )) .output() } fn float_max_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create( tensor.into_ir(), dim, IntElem::::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.tensor.dtype, NumericOperationIr::MaxDimWithIndices(desc), )) .outputs() .into() } fn float_min(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::Min(desc), )) .output() } fn float_min_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { let client = tensor.client.clone(); let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle()); client .register(OperationIr::NumericFloat( desc.out.dtype, NumericOperationIr::MinDim(desc), )) .output() } fn float_min_dim_with_indices( tensor: FloatTensor, dim: usize, ) -> (FloatTensor, IntTensor) { let client = tensor.client.clone(); let desc = ReduceDimWithIndicesOpIr::create( tensor.into_ir(), dim, IntElem::::dtype(), || client.create_empty_handle(), ); client .register(OperationIr::NumericFloat( desc.tensor.dtype, NumericOperationIr::MinDimWithIndices(desc), )) .outputs() .into() } fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || { client.create_empty_handle() }); client .register(OperationIr::Float( desc.out.dtype, FloatOperationIr::Powf(desc), )) .output() } fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { let client = tensor.client.clone(); let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc))) .output() } fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { let client = tensor.client.clone(); let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle()); client .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc))) .output() } fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { let client = tensor.client.clone(); let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc))) .output() } fn float_cast(tensor: FloatTensor, dtype: burn_backend::FloatDType) -> FloatTensor { let client = tensor.client.clone(); let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc))) .output() } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { let client = tensor.client.clone(); let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || { client.create_empty_handle() }); client .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc))) .output() } } ================================================ FILE: crates/burn-router/src/ops/transaction.rs ================================================ use burn_backend::ops::TransactionOps; use crate::{BackendRouter, RunnerChannel}; impl TransactionOps for BackendRouter {} ================================================ FILE: crates/burn-router/src/ops/unary.rs ================================================ #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs.into()); $handles.register_float_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_dim_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs); $handles.register_float_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_float_dim_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let input = $handles.get_float_tensor::(&$desc.input); let output = $ops(input, $desc.axis); $handles.register_float_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_float2int_dim_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let input = $handles.get_float_tensor::(&$desc.input); let output = $ops(input, $desc.axis); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! reduce_int_dim_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let input = $handles.get_int_tensor::(&$desc.input); let output = $ops(input, $desc.axis); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float2int_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_float_cmp_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs.into()); $handles.register_bool_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_float_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_float_tensor::(&$desc.input); let output = $ops(lhs); $handles.register_float_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs.into()); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_dim_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs); $handles.register_int_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! scalar_int_cmp_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.lhs); let output = $ops(lhs, $desc.rhs.into()); $handles.register_bool_tensor::(&$desc.out.id, output); }}; } #[allow(missing_docs)] #[macro_export(local_inner_macros)] macro_rules! unary_int_ops { ( $handles:expr, $desc:expr, $ops:expr ) => {{ let lhs = $handles.get_int_tensor::(&$desc.input); let output = $ops(lhs); $handles.register_int_tensor::(&$desc.out.id, output); }}; } ================================================ FILE: crates/burn-router/src/runner.rs ================================================ use core::sync::atomic::{AtomicU64, Ordering}; use super::{RouterTensor, RunnerClient}; use crate::{ binary_bool_ops, binary_float_cmp_ops, binary_float_ops, binary_int_cmp_ops, binary_int_ops, reduce_float_dim_ops, reduce_float2int_dim_ops, reduce_int_dim_ops, scalar_float_cmp_ops, scalar_float_ops, scalar_int_cmp_ops, scalar_int_ops, unary_float_ops, unary_int_ops, }; use alloc::boxed::Box; use alloc::sync::Arc; use burn_backend::{Backend, DType, ExecutionError, Shape, TensorData, tensor::IndexingUpdateOp}; use burn_ir::{ BackendIr, BaseOperationIr, BoolOperationIr, FloatOperationIr, HandleContainer, IntOperationIr, ModuleOperationIr, NumericOperationIr, OperationIr, TensorId, TensorIr, TensorStatus, }; use burn_std::{future::DynFut, stub::Mutex}; /// A runner's context contains a [handle container](HandleContainer) to manage /// (i.e., fetch and update) existing tensors. pub struct RunnerContext { /// Handle container to retrieve tensors based on their intermediate representation. handles: HandleContainer, } static COUNTER: AtomicU64 = AtomicU64::new(0); impl RunnerContext { /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId). fn create_empty_handle(&mut self) -> TensorId { let value = COUNTER.fetch_add(1, Ordering::Relaxed); TensorId::new(value) } } /// A runner is responsible for executing tensor operations for a given [intermediate backend](BackendIr). #[derive(Clone)] pub struct Runner { // Mutex for the mutable handles context: Arc>>, device: B::Device, } impl core::fmt::Debug for Runner { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Runner") .field("device", &self.device) .finish() } } impl Runner { /// Create a new runner. pub fn new(device: B::Device) -> Self { Self { context: Arc::new(Mutex::new(RunnerContext { handles: HandleContainer::new(), })), device, } } /// Get the tensor handle for the given [tensor representation](TensorIr). pub fn get_tensor_handle(&self, tensor: &TensorIr) -> B::Handle { let handles = &mut self.context.lock().unwrap().handles; handles.get_tensor_handle(tensor).handle } /// Create a tensor with the given handle and shape. pub fn register_tensor( &self, handle: B::Handle, shape: Shape, dtype: DType, client: C, ) -> RouterTensor { let mut ctx = self.context.lock().unwrap(); let id = ctx.create_empty_handle(); ctx.handles.register_handle(id, handle); core::mem::drop(ctx); RouterTensor::new(id, shape, dtype, client) } /// Register a tensor from its data and id. pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) { let mut ctx = self.context.lock().unwrap(); let dtype = data.dtype; if dtype.is_float() { let tensor = B::float_from_data(data, &self.device); ctx.handles.register_float_tensor::(&id, tensor) } else if dtype.is_int() { let tensor = B::int_from_data(data, &self.device); ctx.handles.register_int_tensor::(&id, tensor) } else if dtype.is_bool() { let tensor = B::bool_from_data(data, &self.device); ctx.handles.register_bool_tensor::(&id, tensor) } else if let DType::QFloat(_) = dtype { todo!(); } core::mem::drop(ctx); } /// Register a tensor and returns its intermediate representation. pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorIr { let mut ctx = self.context.lock().unwrap(); let id = ctx.create_empty_handle(); let shape = data.shape.clone(); let dtype = data.dtype; if dtype.is_float() { let tensor = B::float_from_data(data, &self.device); ctx.handles.register_float_tensor::(&id, tensor) } else if dtype.is_int() { let tensor = B::int_from_data(data, &self.device); ctx.handles.register_int_tensor::(&id, tensor) } else if dtype.is_bool() { let tensor = B::bool_from_data(data, &self.device); ctx.handles.register_bool_tensor::(&id, tensor) } else if let DType::QFloat(_) = dtype { todo!(); } core::mem::drop(ctx); TensorIr { id, shape, status: TensorStatus::ReadWrite, dtype, } } } // This is a Remote Runner impl RunnerClient for Runner { type Device = B::Device; /// Execute a tensor operation. fn register_op(&self, op: OperationIr) { // Remove unused tensor handles let mut ctx = self.context.lock().unwrap(); let handles = &mut ctx.handles; match &op { // For every op: get the input(s), execute the operation and register the output(s) OperationIr::BaseFloat(op) => match op { BaseOperationIr::Reshape(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_reshape(tensor, desc.out.shape.clone()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::SwapDims(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_swap_dims(tensor, desc.dim1, desc.dim2); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Permute(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_permute(tensor, &desc.axes); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Flip(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_flip(tensor, &desc.axes); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Expand(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_expand(tensor, desc.out.shape.clone()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Unfold(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_unfold(tensor, desc.dim, desc.size, desc.step); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Slice(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let output = B::float_slice(tensor, &desc.ranges); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::SliceAssign(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let value = handles.get_float_tensor::(&desc.value); let output = B::float_slice_assign(tensor, &desc.ranges, value); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Gather(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::float_gather(desc.dim, tensor, indices); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Scatter(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_float_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::float_scatter_add(desc.dim, tensor, indices, value) } }; handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Select(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::float_select(tensor, desc.dim, indices); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::SelectAssign(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_float_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::float_select_add(tensor, desc.dim, indices, value) } }; handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::MaskWhere(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let value = handles.get_float_tensor::(&desc.value); let output = B::float_mask_where(tensor, mask, value); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::MaskFill(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let output = B::float_mask_fill(tensor, mask, desc.value.into()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Equal(desc) => { binary_float_cmp_ops!(handles, desc, B::float_equal) } BaseOperationIr::EqualElem(desc) => { scalar_float_cmp_ops!(handles, desc, B::float_equal_elem) } BaseOperationIr::RepeatDim(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let output = B::float_repeat_dim(tensor, desc.dim, desc.times); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Cat(desc) => { let tensors = desc .tensors .iter() .map(|tensor| handles.get_float_tensor::(tensor)) .collect(); let output = B::float_cat(tensors, desc.dim); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Cast(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_cast(tensor, desc.out.dtype.into()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Empty(desc) => { let shape = desc.out.shape.clone(); let output = B::float_empty(shape, &self.device, desc.out.dtype.into()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Ones(desc) => { let shape = desc.out.shape.clone(); let output = B::float_ones(shape, &self.device, desc.out.dtype.into()); handles.register_float_tensor::(&desc.out.id, output); } BaseOperationIr::Zeros(desc) => { let shape = desc.out.shape.clone(); let output = B::float_zeros(shape, &self.device, desc.out.dtype.into()); handles.register_float_tensor::(&desc.out.id, output); } }, OperationIr::BaseInt(op) => match op { BaseOperationIr::Reshape(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_reshape(tensor, desc.out.shape.clone()); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::SwapDims(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_swap_dims(tensor, desc.dim1, desc.dim2); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Permute(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_permute(tensor, &desc.axes); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Flip(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_flip(tensor, &desc.axes); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Expand(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_expand(tensor, desc.out.shape.clone()); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Unfold(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_unfold(tensor, desc.dim, desc.size, desc.step); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Slice(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let output = B::int_slice(tensor, &desc.ranges); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::SliceAssign(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let value = handles.get_int_tensor::(&desc.value); let output = B::int_slice_assign(tensor, &desc.ranges, value); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Gather(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::int_gather(desc.dim, tensor, indices); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Scatter(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_int_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::int_scatter_add(desc.dim, tensor, indices, value) } }; handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Select(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::int_select(tensor, desc.dim, indices); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::SelectAssign(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_int_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::int_select_add(tensor, desc.dim, indices, value) } }; handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::MaskWhere(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let value = handles.get_int_tensor::(&desc.value); let output = B::int_mask_where(tensor, mask, value); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::MaskFill(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let output = B::int_mask_fill(tensor, mask, desc.value.into()); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Equal(desc) => { binary_int_cmp_ops!(handles, desc, B::int_equal) } BaseOperationIr::EqualElem(desc) => { scalar_int_cmp_ops!(handles, desc, B::int_equal_elem) } BaseOperationIr::RepeatDim(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let output = B::int_repeat_dim(tensor, desc.dim, desc.times); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Cat(desc) => { let tensors = desc .tensors .iter() .map(|tensor| handles.get_int_tensor::(tensor)) .collect(); let output = B::int_cat(tensors, desc.dim); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Cast(_) => unreachable!(), BaseOperationIr::Empty(desc) => { let shape = desc.out.shape.clone(); let output = B::int_empty(shape, &self.device, desc.out.dtype.into()); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Ones(desc) => { let shape = desc.out.shape.clone(); let output = B::int_ones(shape, &self.device, desc.out.dtype.into()); handles.register_int_tensor::(&desc.out.id, output); } BaseOperationIr::Zeros(desc) => { let shape = desc.out.shape.clone(); let output = B::int_zeros(shape, &self.device, desc.out.dtype.into()); handles.register_int_tensor::(&desc.out.id, output); } }, OperationIr::BaseBool(op) => match op { BaseOperationIr::Reshape(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_reshape(tensor, desc.out.shape.clone()); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::SwapDims(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_swap_dims(tensor, desc.dim1, desc.dim2); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Permute(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_permute(tensor, &desc.axes); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Flip(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_flip(tensor, &desc.axes); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Expand(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_expand(tensor, desc.out.shape.clone()); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Unfold(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_unfold(tensor, desc.dim, desc.size, desc.step); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Slice(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let output = B::bool_slice(tensor, &desc.ranges); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::SliceAssign(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let value = handles.get_bool_tensor::(&desc.value); let output = B::bool_slice_assign(tensor, &desc.ranges, value); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Gather(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::bool_gather(desc.dim, tensor, indices); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Scatter(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_bool_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::bool_scatter_or(desc.dim, tensor, indices, value) } }; handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Select(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let output = B::bool_select(tensor, desc.dim, indices); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::SelectAssign(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let indices = handles.get_int_tensor::(&desc.indices); let value = handles.get_bool_tensor::(&desc.value); let output = match desc.update { IndexingUpdateOp::Add => { B::bool_select_or(tensor, desc.dim, indices, value) } }; handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::MaskWhere(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let value = handles.get_bool_tensor::(&desc.value); let output = B::bool_mask_where(tensor, mask, value); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::MaskFill(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let mask = handles.get_bool_tensor::(&desc.mask); let output = B::bool_mask_fill(tensor, mask, desc.value.into()); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Equal(desc) => { let lhs = handles.get_bool_tensor::(&desc.lhs); let rhs = handles.get_bool_tensor::(&desc.rhs); let output = B::bool_equal(lhs, rhs); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::EqualElem(desc) => { let lhs = handles.get_bool_tensor::(&desc.lhs); let output = B::bool_equal_elem(lhs, desc.rhs.into()); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::RepeatDim(desc) => { let tensor = handles.get_bool_tensor::(&desc.tensor); let output = B::bool_repeat_dim(tensor, desc.dim, desc.times); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Cat(desc) => { let tensors = desc .tensors .iter() .map(|tensor| handles.get_bool_tensor::(tensor)) .collect(); let output = B::bool_cat(tensors, desc.dim); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Cast(_) => unreachable!(), BaseOperationIr::Empty(desc) => { let shape = desc.out.shape.clone(); let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Zeros(desc) => { let shape = desc.out.shape.clone(); let output = B::bool_zeros(shape, &self.device); handles.register_bool_tensor::(&desc.out.id, output); } BaseOperationIr::Ones(desc) => { let shape = desc.out.shape.clone(); let output = B::bool_ones(shape, &self.device); handles.register_bool_tensor::(&desc.out.id, output); } }, OperationIr::NumericFloat(_dtype, op) => match op { NumericOperationIr::Add(desc) => { binary_float_ops!(handles, desc, B::float_add) } NumericOperationIr::AddScalar(desc) => { scalar_float_ops!(handles, desc, B::float_add_scalar) } NumericOperationIr::Sub(desc) => { binary_float_ops!(handles, desc, B::float_sub) } NumericOperationIr::SubScalar(desc) => { scalar_float_ops!(handles, desc, B::float_sub_scalar) } NumericOperationIr::Div(desc) => { binary_float_ops!(handles, desc, B::float_div) } NumericOperationIr::DivScalar(desc) => { scalar_float_ops!(handles, desc, B::float_div_scalar) } NumericOperationIr::Rem(desc) => { binary_float_ops!(handles, desc, B::float_remainder) } NumericOperationIr::RemScalar(desc) => { scalar_float_ops!(handles, desc, B::float_remainder_scalar) } NumericOperationIr::Mul(desc) => { binary_float_ops!(handles, desc, B::float_mul) } NumericOperationIr::MulScalar(desc) => { scalar_float_ops!(handles, desc, B::float_mul_scalar) } NumericOperationIr::Abs(desc) => { unary_float_ops!(handles, desc, B::float_abs) } NumericOperationIr::Full(desc) => { let shape = desc.out.shape.clone(); let output = B::float_full( shape, desc.value.into(), &self.device, desc.out.dtype.into(), ); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::MeanDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_mean_dim) } NumericOperationIr::Mean(desc) => { unary_float_ops!(handles, desc, B::float_mean) } NumericOperationIr::Sum(desc) => { unary_float_ops!(handles, desc, B::float_sum) } NumericOperationIr::SumDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_sum_dim) } NumericOperationIr::Prod(desc) => { unary_float_ops!(handles, desc, B::float_prod) } NumericOperationIr::ProdDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_prod_dim) } NumericOperationIr::Greater(desc) => { binary_float_cmp_ops!(handles, desc, B::float_greater) } NumericOperationIr::GreaterElem(desc) => { scalar_float_cmp_ops!(handles, desc, B::float_greater_elem) } NumericOperationIr::GreaterEqual(desc) => { binary_float_cmp_ops!(handles, desc, B::float_greater_equal) } NumericOperationIr::GreaterEqualElem(desc) => { scalar_float_cmp_ops!(handles, desc, B::float_greater_equal_elem) } NumericOperationIr::Lower(desc) => { binary_float_cmp_ops!(handles, desc, B::float_lower) } NumericOperationIr::LowerElem(desc) => { scalar_float_cmp_ops!(handles, desc, B::float_lower_elem) } NumericOperationIr::LowerEqual(desc) => { binary_float_cmp_ops!(handles, desc, B::float_lower_equal) } NumericOperationIr::LowerEqualElem(desc) => { scalar_float_cmp_ops!(handles, desc, B::float_lower_equal_elem) } NumericOperationIr::ArgMax(desc) => { reduce_float2int_dim_ops!(handles, desc, B::float_argmax) } NumericOperationIr::ArgMin(desc) => { reduce_float2int_dim_ops!(handles, desc, B::float_argmin) } NumericOperationIr::Max(desc) => { unary_float_ops!(handles, desc, B::float_max) } NumericOperationIr::MaxDimWithIndices(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let (output, output_idx) = B::float_max_dim_with_indices(tensor, desc.dim); handles.register_float_tensor::(&desc.out.id, output); handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::MinDimWithIndices(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let (output, output_idx) = B::float_min_dim_with_indices(tensor, desc.dim); handles.register_float_tensor::(&desc.out.id, output); handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::Min(desc) => { unary_float_ops!(handles, desc, B::float_min) } NumericOperationIr::MaxDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_max_dim) } NumericOperationIr::MinDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_min_dim) } NumericOperationIr::MaxAbs(desc) => { unary_float_ops!(handles, desc, B::float_max_abs) } NumericOperationIr::MaxAbsDim(desc) => { reduce_float_dim_ops!(handles, desc, B::float_max_abs_dim) } NumericOperationIr::Clamp(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let output = B::float_clamp(tensor, desc.min.into(), desc.max.into()); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::IntRandom(_) => unreachable!(), NumericOperationIr::Powi(desc) => { let lhs = handles.get_float_tensor::(&desc.lhs); let rhs = handles.get_int_tensor::(&desc.rhs); let output = (B::float_powi)(lhs, rhs); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::CumSum(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_cumsum(tensor, desc.axis); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::CumProd(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_cumprod(tensor, desc.axis); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::CumMin(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_cummin(tensor, desc.axis); handles.register_float_tensor::(&desc.out.id, output); } NumericOperationIr::CumMax(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_cummax(tensor, desc.axis); handles.register_float_tensor::(&desc.out.id, output); } }, OperationIr::NumericInt(_dtype, op) => match op { NumericOperationIr::Add(desc) => { binary_int_ops!(handles, desc, B::int_add) } NumericOperationIr::AddScalar(desc) => { scalar_int_ops!(handles, desc, B::int_add_scalar) } NumericOperationIr::Sub(desc) => { binary_int_ops!(handles, desc, B::int_sub) } NumericOperationIr::SubScalar(desc) => { scalar_int_ops!(handles, desc, B::int_sub_scalar) } NumericOperationIr::Div(desc) => { binary_int_ops!(handles, desc, B::int_div) } NumericOperationIr::DivScalar(desc) => { scalar_int_ops!(handles, desc, B::int_div_scalar) } NumericOperationIr::Rem(desc) => { binary_int_ops!(handles, desc, B::int_remainder) } NumericOperationIr::RemScalar(desc) => { scalar_int_ops!(handles, desc, B::int_remainder_scalar) } NumericOperationIr::Mul(desc) => { binary_int_ops!(handles, desc, B::int_mul) } NumericOperationIr::MulScalar(desc) => { scalar_int_ops!(handles, desc, B::int_mul_scalar) } NumericOperationIr::Abs(desc) => { unary_int_ops!(handles, desc, B::int_abs) } NumericOperationIr::Full(desc) => { let shape = desc.out.shape.clone(); let output = B::int_full( shape, desc.value.into(), &self.device, desc.out.dtype.into(), ); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::MeanDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_mean_dim) } NumericOperationIr::Mean(desc) => { unary_int_ops!(handles, desc, B::int_mean) } NumericOperationIr::Sum(desc) => { unary_int_ops!(handles, desc, B::int_sum) } NumericOperationIr::SumDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_sum_dim) } NumericOperationIr::Prod(desc) => { unary_int_ops!(handles, desc, B::int_prod) } NumericOperationIr::ProdDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_prod_dim) } NumericOperationIr::Greater(desc) => { binary_int_cmp_ops!(handles, desc, B::int_greater) } NumericOperationIr::GreaterElem(desc) => { scalar_int_cmp_ops!(handles, desc, B::int_greater_elem) } NumericOperationIr::GreaterEqual(desc) => { binary_int_cmp_ops!(handles, desc, B::int_greater_equal) } NumericOperationIr::GreaterEqualElem(desc) => { scalar_int_cmp_ops!(handles, desc, B::int_greater_equal_elem) } NumericOperationIr::Lower(desc) => { binary_int_cmp_ops!(handles, desc, B::int_lower) } NumericOperationIr::LowerElem(desc) => { scalar_int_cmp_ops!(handles, desc, B::int_lower_elem) } NumericOperationIr::LowerEqual(desc) => { binary_int_cmp_ops!(handles, desc, B::int_lower_equal) } NumericOperationIr::LowerEqualElem(desc) => { scalar_int_cmp_ops!(handles, desc, B::int_lower_equal_elem) } NumericOperationIr::ArgMax(desc) => { reduce_int_dim_ops!(handles, desc, B::int_argmax) } NumericOperationIr::ArgMin(desc) => { reduce_int_dim_ops!(handles, desc, B::int_argmin) } NumericOperationIr::Max(desc) => { unary_int_ops!(handles, desc, B::int_max) } NumericOperationIr::MaxDimWithIndices(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let (output, output_idx) = B::int_max_dim_with_indices(tensor, desc.dim); handles.register_int_tensor::(&desc.out.id, output); handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::MinDimWithIndices(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let (output, output_idx) = B::int_min_dim_with_indices(tensor, desc.dim); handles.register_int_tensor::(&desc.out.id, output); handles.register_int_tensor::(&desc.out_indices.id, output_idx); } NumericOperationIr::Min(desc) => { unary_int_ops!(handles, desc, B::int_min) } NumericOperationIr::MaxDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_max_dim) } NumericOperationIr::MinDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_min_dim) } NumericOperationIr::MaxAbs(desc) => { unary_int_ops!(handles, desc, B::int_max_abs) } NumericOperationIr::MaxAbsDim(desc) => { reduce_int_dim_ops!(handles, desc, B::int_max_abs_dim) } NumericOperationIr::Clamp(desc) => { let tensor = handles.get_int_tensor::(&desc.tensor); let output = B::int_clamp(tensor, desc.min.into(), desc.max.into()); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::IntRandom(desc) => { let shape = desc.out.shape.clone(); let output = B::int_random(shape, desc.distribution, &self.device); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::Powi(desc) => { let lhs = handles.get_int_tensor::(&desc.lhs); let rhs = handles.get_int_tensor::(&desc.rhs); let output = B::int_powi(lhs, rhs); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::CumSum(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_cumsum(tensor, desc.axis); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::CumProd(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_cumprod(tensor, desc.axis); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::CumMin(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_cummin(tensor, desc.axis); handles.register_int_tensor::(&desc.out.id, output); } NumericOperationIr::CumMax(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_cummax(tensor, desc.axis); handles.register_int_tensor::(&desc.out.id, output); } }, OperationIr::Bool(op) => match op { BoolOperationIr::IntoFloat(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } BoolOperationIr::IntoInt(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_into_int(tensor); handles.register_int_tensor::(&desc.out.id, output); } BoolOperationIr::Not(desc) => { let tensor = handles.get_bool_tensor::(&desc.input); let output = B::bool_not(tensor); handles.register_bool_tensor::(&desc.out.id, output); } BoolOperationIr::And(desc) => { binary_bool_ops!(handles, desc, B::bool_and) } BoolOperationIr::Or(desc) => { binary_bool_ops!(handles, desc, B::bool_or) } }, OperationIr::Int(op) => match op { IntOperationIr::IntoFloat(desc) => { let tensor = handles.get_int_tensor::(&desc.input); let output = B::int_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } IntOperationIr::Matmul(desc) => { binary_int_ops!(handles, desc, B::int_matmul) } IntOperationIr::BitwiseAnd(desc) => { binary_int_ops!(handles, desc, B::bitwise_and) } IntOperationIr::BitwiseAndScalar(desc) => { scalar_int_ops!(handles, desc, B::bitwise_and_scalar) } IntOperationIr::BitwiseOr(desc) => { binary_int_ops!(handles, desc, B::bitwise_or) } IntOperationIr::BitwiseOrScalar(desc) => { scalar_int_ops!(handles, desc, B::bitwise_or_scalar) } IntOperationIr::BitwiseXor(desc) => { binary_int_ops!(handles, desc, B::bitwise_xor) } IntOperationIr::BitwiseXorScalar(desc) => { scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) } IntOperationIr::BitwiseNot(desc) => { unary_int_ops!(handles, desc, B::bitwise_not) } IntOperationIr::BitwiseLeftShift(desc) => { binary_int_ops!(handles, desc, B::bitwise_left_shift) } IntOperationIr::BitwiseRightShift(desc) => { binary_int_ops!(handles, desc, B::bitwise_right_shift) } IntOperationIr::BitwiseLeftShiftScalar(desc) => { scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) } IntOperationIr::BitwiseRightShiftScalar(desc) => { scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) } }, OperationIr::Float(_dtype, op) => match op { FloatOperationIr::Exp(desc) => { unary_float_ops!(handles, desc, B::float_exp) } FloatOperationIr::Powf(desc) => { binary_float_ops!(handles, desc, B::float_powf) } FloatOperationIr::Log(desc) => { unary_float_ops!(handles, desc, B::float_log) } FloatOperationIr::Log1p(desc) => { unary_float_ops!(handles, desc, B::float_log1p) } FloatOperationIr::Erf(desc) => { unary_float_ops!(handles, desc, B::float_erf) } FloatOperationIr::PowfScalar(desc) => { scalar_float_ops!(handles, desc, B::float_powf_scalar) } FloatOperationIr::Sqrt(desc) => { unary_float_ops!(handles, desc, B::float_sqrt) } FloatOperationIr::Cos(desc) => { unary_float_ops!(handles, desc, B::float_cos) } FloatOperationIr::Sin(desc) => { unary_float_ops!(handles, desc, B::float_sin) } FloatOperationIr::Tanh(desc) => { unary_float_ops!(handles, desc, B::float_tanh) } FloatOperationIr::Tan(desc) => unary_float_ops!(handles, desc, B::float_tan), FloatOperationIr::Cosh(desc) => unary_float_ops!(handles, desc, B::float_cosh), FloatOperationIr::Sinh(desc) => unary_float_ops!(handles, desc, B::float_sinh), FloatOperationIr::ArcCos(desc) => unary_float_ops!(handles, desc, B::float_acos), FloatOperationIr::ArcCosh(desc) => unary_float_ops!(handles, desc, B::float_acosh), FloatOperationIr::ArcSin(desc) => unary_float_ops!(handles, desc, B::float_asin), FloatOperationIr::ArcSinh(desc) => unary_float_ops!(handles, desc, B::float_asinh), FloatOperationIr::ArcTan(desc) => unary_float_ops!(handles, desc, B::float_atan), FloatOperationIr::ArcTanh(desc) => unary_float_ops!(handles, desc, B::float_atanh), FloatOperationIr::ArcTan2(desc) => binary_float_ops!(handles, desc, B::float_atan2), FloatOperationIr::Round(desc) => { unary_float_ops!(handles, desc, B::float_round) } FloatOperationIr::Floor(desc) => { unary_float_ops!(handles, desc, B::float_floor) } FloatOperationIr::Ceil(desc) => { unary_float_ops!(handles, desc, B::float_ceil) } FloatOperationIr::Trunc(desc) => { unary_float_ops!(handles, desc, B::float_trunc) } FloatOperationIr::IntoInt(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_into_int(tensor); handles.register_int_tensor::(&desc.out.id, output); } FloatOperationIr::Matmul(desc) => { binary_float_ops!(handles, desc, B::float_matmul) } FloatOperationIr::Cross(desc) => { let lhs = handles.get_float_tensor::(&desc.lhs); let rhs = handles.get_float_tensor::(&desc.rhs); let output = B::float_cross(lhs, rhs, desc.dim); handles.register_float_tensor::(&desc.out.id, output); } FloatOperationIr::Random(desc) => { let shape = desc.out.shape.clone(); let output = B::float_random(shape, desc.distribution, &self.device); handles.register_float_tensor::(&desc.out.id, output); } FloatOperationIr::Recip(desc) => { unary_float_ops!(handles, desc, B::float_recip) } FloatOperationIr::Quantize(_) => todo!(), FloatOperationIr::Dequantize(_) => todo!(), FloatOperationIr::IsNan(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_is_nan(tensor); handles.register_bool_tensor::(&desc.out.id, output); } FloatOperationIr::IsInf(desc) => { let tensor = handles.get_float_tensor::(&desc.input); let output = B::float_is_inf(tensor); handles.register_bool_tensor::(&desc.out.id, output); } FloatOperationIr::GridSample2d(desc) => { let tensor = handles.get_float_tensor::(&desc.tensor); let grid = handles.get_float_tensor::(&desc.grid); let output = B::float_grid_sample_2d(tensor, grid, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } }, OperationIr::Module(op) => match op { ModuleOperationIr::Embedding(desc) => { let weights = handles.get_float_tensor::(&desc.weights); let indices = handles.get_int_tensor::(&desc.indices); let output = B::embedding(weights, indices); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::EmbeddingBackward(desc) => { let weights = handles.get_float_tensor::(&desc.weights); let indices = handles.get_int_tensor::(&desc.indices); let output_grad = handles.get_float_tensor::(&desc.out_grad); let output = B::embedding_backward(weights, output_grad, indices); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv1d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv1d(x, weight, bias, desc.clone().options.into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv1dXBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_x_backward(x, weight, output_grad, desc.clone().options.into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv1dWeightBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_weight_backward( x, weight, output_grad, desc.clone().options.into(), ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv1dBiasBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv1d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv2d(x, weight, bias, desc.clone().options.into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv2dXBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_x_backward(x, weight, output_grad, desc.clone().options.into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv2dWeightBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_weight_backward( x, weight, output_grad, desc.clone().options.into(), ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv2dBiasBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv2d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv3d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv3d(x, weight, bias, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv3dXBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_x_backward(x, weight, output_grad, desc.clone().options.into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv3dWeightBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_weight_backward( x, weight, output_grad, desc.clone().options.into(), ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Conv3dBiasBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let bias = handles.get_float_tensor::(&desc.bias); let output_grad = handles.get_float_tensor::(&desc.output_grad); let output = B::conv3d_bias_backward(x, bias, output_grad); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::DeformableConv2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let offset = handles.get_float_tensor::(&desc.offset); let mask = desc .mask .as_ref() .map(|mask| handles.get_float_tensor::(mask)); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::deform_conv2d( x, offset, weight, mask, bias, desc.options.clone().into(), ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::DeformableConv2dBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let offset = handles.get_float_tensor::(&desc.offset); let mask = desc .mask .as_ref() .map(|mask| handles.get_float_tensor::(mask)); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output_grad = handles.get_float_tensor::(&desc.out_grad); let output = B::deform_conv2d_backward( x, offset, weight, mask, bias, output_grad, desc.options.clone().into(), ); handles.register_float_tensor::(&desc.input_grad.id, output.x_grad); handles.register_float_tensor::(&desc.offset_grad.id, output.offset_grad); handles.register_float_tensor::(&desc.weight_grad.id, output.weight_grad); if let Some((mask_grad, field)) = output.mask_grad.zip(desc.mask_grad.as_ref()) { handles.register_float_tensor::(&field.id, mask_grad); } if let Some((bias_grad, field)) = output.bias_grad.zip(desc.bias_grad.as_ref()) { handles.register_float_tensor::(&field.id, bias_grad); } } ModuleOperationIr::ConvTranspose1d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose1d(x, weight, bias, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::ConvTranspose2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose2d(x, weight, bias, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::ConvTranspose3d(desc) => { let x = handles.get_float_tensor::(&desc.x); let weight = handles.get_float_tensor::(&desc.weight); let bias = desc .bias .as_ref() .map(|bias| handles.get_float_tensor::(bias)); let output = B::conv_transpose3d(x, weight, bias, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AvgPool1d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::avg_pool1d( x, desc.kernel_size, desc.stride, desc.padding, desc.count_include_pad, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AvgPool2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::avg_pool2d( x, desc.kernel_size, desc.stride, desc.padding, desc.count_include_pad, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AvgPool1dBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let grad = handles.get_float_tensor::(&desc.grad); let output = B::avg_pool1d_backward( x, grad, desc.kernel_size, desc.stride, desc.padding, desc.count_include_pad, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AvgPool2dBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let grad = handles.get_float_tensor::(&desc.grad); let output = B::avg_pool2d_backward( x, grad, desc.kernel_size, desc.stride, desc.padding, desc.count_include_pad, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AdaptiveAvgPool1d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::adaptive_avg_pool1d(x, desc.output_size); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AdaptiveAvgPool2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::adaptive_avg_pool2d(x, desc.output_size); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AdaptiveAvgPool1dBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let grad = handles.get_float_tensor::(&desc.grad); let output = B::adaptive_avg_pool1d_backward(x, grad); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::AdaptiveAvgPool2dBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let grad = handles.get_float_tensor::(&desc.grad); let output = B::adaptive_avg_pool2d_backward(x, grad); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::MaxPool1d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::max_pool1d( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::MaxPool1dWithIndices(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::max_pool1d_with_indices( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output.output); handles.register_int_tensor::(&desc.out_indices.id, output.indices); } ModuleOperationIr::MaxPool1dWithIndicesBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let output_grad = handles.get_float_tensor::(&desc.grad); let indices = handles.get_int_tensor::(&desc.indices); let output = B::max_pool1d_with_indices_backward( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, output_grad, indices, ); handles.register_float_tensor::(&desc.out.id, output.x_grad); } ModuleOperationIr::MaxPool2d(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::max_pool2d( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::MaxPool2dWithIndices(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::max_pool2d_with_indices( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, ); handles.register_float_tensor::(&desc.out.id, output.output); handles.register_int_tensor::(&desc.out_indices.id, output.indices); } ModuleOperationIr::MaxPool2dWithIndicesBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let output_grad = handles.get_float_tensor::(&desc.grad); let indices = handles.get_int_tensor::(&desc.indices); let output = B::max_pool2d_with_indices_backward( x, desc.kernel_size, desc.stride, desc.padding, desc.dilation, desc.ceil_mode, output_grad, indices, ); handles.register_float_tensor::(&desc.out.id, output.x_grad); } ModuleOperationIr::Interpolate(desc) => { let x = handles.get_float_tensor::(&desc.x); let output = B::interpolate(x, desc.output_size, desc.options.clone().into()); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::InterpolateBackward(desc) => { let x = handles.get_float_tensor::(&desc.x); let grad = handles.get_float_tensor::(&desc.grad); let output = B::interpolate_backward( x, grad, desc.output_size, desc.options.clone().into(), ); handles.register_float_tensor::(&desc.out.id, output); } ModuleOperationIr::Attention(desc) => { let query = handles.get_float_tensor::(&desc.query); let key = handles.get_float_tensor::(&desc.key); let value = handles.get_float_tensor::(&desc.value); let mask = desc.mask.as_ref().map(|m| handles.get_bool_tensor::(m)); let attn_bias = desc .attn_bias .as_ref() .map(|ab| handles.get_float_tensor::(ab)); let output = B::attention( query, key, value, mask, attn_bias, desc.options.clone().into(), ); handles.register_float_tensor::(&desc.out.id, output); } }, OperationIr::Custom(_) => { panic!("Can't execute custom operation here") } OperationIr::Init(_) => { // Nothing to do. } OperationIr::Drop(repr) => { handles.remove_handle(repr.id); } } } fn read_tensor_async(&self, tensor: TensorIr) -> DynFut> { let mut ctx = self.context.lock().unwrap(); enum Output { Float(B::FloatTensorPrimitive), Int(B::IntTensorPrimitive), Bool(B::BoolTensorPrimitive), } let tensor = if tensor.dtype.is_float() { let tensor = ctx.handles.get_float_tensor::(&tensor); Output::::Float(tensor) } else if tensor.dtype.is_int() { let tensor = ctx.handles.get_int_tensor::(&tensor); Output::Int(tensor) } else if tensor.dtype.is_bool() { let tensor = ctx.handles.get_bool_tensor::(&tensor); Output::Bool(tensor) } else if let DType::QFloat(_) = tensor.dtype { todo!() } else { unimplemented!() }; match tensor { Output::Float(val) => Box::pin(B::float_into_data(val)), Output::Int(val) => Box::pin(B::int_into_data(val)), Output::Bool(val) => Box::pin(B::bool_into_data(val)), } } fn register_tensor_data(&self, data: TensorData) -> RouterTensor { let desc = self.register_tensor_data_desc(data); RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone()) } fn device(&self) -> Self::Device { self.device.clone() } fn sync(&self) -> Result<(), ExecutionError> { B::sync(&self.device) } fn seed(&self, seed: u64) { B::seed(&self.device, seed) } fn create_empty_handle(&self) -> TensorId { let mut ctx = self.context.lock().unwrap(); ctx.create_empty_handle() } fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet { B::dtype_usage(&self.device, dtype) } } ================================================ FILE: crates/burn-router/src/tensor.rs ================================================ use core::sync::atomic::{AtomicU32, Ordering}; use alloc::format; use alloc::{sync::Arc, vec::Vec}; use super::RunnerClient; use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError}; use burn_ir::{TensorId, TensorIr, TensorStatus}; /// Tensor primitive for the [router backend](crate::BackendRouter). pub struct RouterTensor { pub(crate) id: TensorId, pub(crate) shape: Shape, pub(crate) dtype: DType, /// The client that has this tensor pub client: C, pub(crate) count: Arc, } impl TensorMetadata for RouterTensor { fn dtype(&self) -> DType { self.dtype } fn shape(&self) -> Shape { self.shape.clone() } fn rank(&self) -> usize { self.shape.num_dims() } } impl RouterTensor { /// Create a new router tensor. pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self { Self { id, shape, dtype, client, count: Arc::new(AtomicU32::new(1)), } } pub(crate) async fn into_data(self) -> Result { self.client.clone().read_tensor_async(self.into_ir()).await } /// Get the ir for this tensor pub fn into_ir(mut self) -> TensorIr { let count = self.count.load(Ordering::Relaxed); let status = self.status(count); let mut shape_out = Shape::from(Vec::::new()); core::mem::swap(&mut self.shape, &mut shape_out); if let TensorStatus::ReadWrite = status { // Avoids an unwanted drop on the same thread. // // Since `drop` is called after `into_ir`, we must not register a drop if the tensor // was consumed with a `ReadWrite` status. self.count.fetch_add(1, Ordering::Relaxed); } TensorIr { status, shape: shape_out, id: self.id, dtype: self.dtype, } } pub(crate) fn to_ir_out(&self) -> TensorIr { TensorIr { status: TensorStatus::NotInit, shape: self.shape.clone(), id: self.id, dtype: self.dtype, } } pub(crate) fn status(&self, count: u32) -> TensorStatus { if count <= 1 { TensorStatus::ReadWrite } else { TensorStatus::ReadOnly } } } impl core::fmt::Debug for RouterTensor { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str( format!( "{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}", self.id, self.shape, self.dtype, self.client.device().clone(), ) .as_str(), ) } } impl Clone for RouterTensor { fn clone(&self) -> Self { self.count.fetch_add(1, Ordering::Relaxed); Self { id: self.id, shape: self.shape.clone(), client: self.client.clone(), dtype: self.dtype, count: self.count.clone(), } } } impl Drop for RouterTensor { fn drop(&mut self) { let count = self.count.fetch_sub(1, Ordering::Relaxed); match self.status(count) { TensorStatus::ReadWrite => { let id = self.id; let mut shape = Shape::from(Vec::::new()); core::mem::swap(&mut shape, &mut self.shape); let ir = TensorIr { id, shape, status: TensorStatus::ReadWrite, dtype: self.dtype, }; self.client.register_op(burn_ir::OperationIr::Drop(ir)); } TensorStatus::ReadOnly => {} TensorStatus::NotInit => {} } } } ================================================ FILE: crates/burn-router/src/types.rs ================================================ use alloc::format; use alloc::string::String; use burn_backend::{ DType, Shape, TensorData, backend::{Backend, DeviceId, DeviceOps, ExecutionError}, try_read_sync, }; use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr}; use burn_std::future::DynFut; use crate::{ ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel, RunnerClient, }; /// Implement multi backend types, with enums having one variant per backend. macro_rules! impl_multi_backend_types { // Match the default backend and at least one other backend, with rest being optional ($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => { /// Module containing the essential types for multi-backend operations. /// /// - `Handle`: the type used to point to a tensor (defined for all backends). /// - `MultiRunnerClient`: a client for multiple runners (each responsible to execute tensor operations on a given backend). /// - `DirectChannel`: a local channel with direct connection to the backend runner clients. /// - `ByteBridge`: a simple multi-backend bridge that transfers tensors via the underlying [tensor data](burn_backend::TensorData). /// /// Each enum type is defined with backend identifiers as variant names (e.g., `B1` and `B2` for dual backends). pub mod $module_name { use super::*; /// The type that can be used to point to a tensor of any kind. /// Each backend has its own variant. pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> { #[allow(missing_docs)] $DefaultBackend($DefaultBackend::Handle), $( #[allow(missing_docs)] $OtherBackend($OtherBackend::Handle), )+ } /// The device type used by a backend. /// Each backend has its own variant. #[derive(Clone, Debug)] pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> { #[allow(missing_docs)] $DefaultBackend($DefaultBackend::Device), $( #[allow(missing_docs)] $OtherBackend($OtherBackend::Device), )+ } impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs, $( (Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs, )+ _ => false, } } } // Default implementation always returns the first backend's device impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> { fn default() -> Self { Self::$DefaultBackend($DefaultBackend::Device::default()) } } impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> { fn from_id(_device_id: DeviceId) -> Self { // TODO: Should be fix with the new router backend. Default::default() } fn to_id(&self) -> DeviceId { match self { Self::$DefaultBackend(device) => device.id(), $( Self::$OtherBackend(device) => device.id(), )+ } } fn device_count(_type_id: u16) -> usize { 1 } } impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {} /// A local client with multiple runners (each responsible to execute tensor operations on a given backend). #[derive(Clone)] pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> { #[allow(missing_docs)] $DefaultBackend(Runner<$DefaultBackend>), $( #[allow(missing_docs)] $OtherBackend(Runner<$OtherBackend>), )+ } impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+> { type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>; fn register_op(&self, op: OperationIr) { match self { Self::$DefaultBackend(runner) => runner.register_op(op), $( Self::$OtherBackend(runner) => runner.register_op(op), )+ } } fn read_tensor_async(&self, tensor: TensorIr) -> DynFut> { match self { Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor), $( Self::$OtherBackend(runner) => runner.read_tensor_async(tensor), )+ } } fn register_tensor_data(&self, data: TensorData) -> RouterTensor { match self { Self::$DefaultBackend(runner) => { let desc = runner.register_tensor_data_desc(data); RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone()) } $( Self::$OtherBackend(runner) => { let desc = runner.register_tensor_data_desc(data); RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone()) } )+ } } fn device(&self) -> Self::Device { match self { Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()), $( Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()), )+ } } fn sync(&self) -> Result<(), ExecutionError> { match self { Self::$DefaultBackend(runner) => runner.sync(), $( Self::$OtherBackend(runner) => runner.sync(), )+ } } fn seed(&self, seed: u64) { match self { Self::$DefaultBackend(runner) => runner.seed(seed), $( Self::$OtherBackend(runner) => runner.seed(seed), )+ } } fn create_empty_handle(&self) -> TensorId { match self { Self::$DefaultBackend(runner) => runner.create_empty_handle(), $( Self::$OtherBackend(runner) => runner.create_empty_handle(), )+ } } fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet { match self { Self::$DefaultBackend(runner) => runner.dtype_usage(dtype), $( Self::$OtherBackend(runner) => runner.dtype_usage(dtype), )+ } } } impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br> where Br: MultiBackendBridge, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>, { type Device = Br::Device; type Bridge = Br; type FloatElem = $DefaultBackend::FloatElem; type IntElem = $DefaultBackend::IntElem; type BoolElem = $DefaultBackend::BoolElem; type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>; fn init_client(device: &Self::Device) -> Self::Client { match device { MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())), $( MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())), )+ } } fn get_tensor_handle( tensor: &TensorIr, client: &Self::Client, ) -> ::TensorHandle { match client { MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)), $( MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)), )+ } } fn register_tensor( client: &Self::Client, handle: ::TensorHandle, shape: Shape, dtype: DType, ) -> RouterTensor { match client { MultiRunnerClient::$DefaultBackend(runner) => match handle { Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()), _ => unreachable!("Can't register tensor handle for another backend."), }, $( MultiRunnerClient::$OtherBackend(runner) => match handle { Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()), _ => unreachable!("Can't register tensor handle for another backend."), }, )+ } } fn name(_device: &Self::Device) -> String { let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default())); $( name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default()))); )+ format!("direct<({})>", name) } } impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> { type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>; type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>; fn change_backend_float( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle { multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+) } fn change_backend_int( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle { multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+) } fn change_backend_bool( tensor: Self::TensorHandle, shape: Shape, target_device: &Self::Device, ) -> Self::TensorHandle { multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+) } } } }; } macro_rules! bridge { ($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{ // Bridge for the same backend let tensor = $Backend::float_tensor(TensorHandle { handle: $handle, shape: $shape, }); let tensor = $Backend::float_to_device(tensor, $device); let handle = $Backend::float_tensor_handle(tensor); Handle::$Backend(handle) }}; ($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{ // Byte bridge between two backends let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape }); let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect( "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM." ); let tensor = $BackendB::float_from_data(data, $device); let handle = $BackendB::float_tensor_handle(tensor); Handle::$BackendB(handle) }}; } macro_rules! multi_backend_match { ($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => { multi_backend_match! ( @step $shape, ($handle, $device); { (Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape), $( (Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape), (Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape), (Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape), )+ }; $($OtherBackend),+ ) }; (@step $shape:expr, $pats:tt; { $($arms:tt)* }; $BackendA:ident, $($OtherBackend:ident),+ ) => { multi_backend_match! ( @step $shape, $pats; { $($arms)* $( (Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape), (Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape), )* }; $($OtherBackend),* ) }; (@step $shape:expr, ($handle:expr, $device:expr); { $($arms:tt)* }; $($BackendA:ident)? ) => { match ($handle, $device) { $($arms)* } }; } // Implement multi-backend types and byte bridge for up to 4 backends impl_multi_backend_types!(duo, B1, B2); impl_multi_backend_types!(trio, B1, B2, B3); impl_multi_backend_types!(quad, B1, B2, B3, B4); #[cfg(not(target_os = "windows"))] // cannot find a wgpu adapter on windows CI #[cfg(test)] mod tests { use burn_tensor::{Tensor, backend::Backend}; use super::*; use crate::tests::{TestBackend, TestBackend1, TestBackend2}; #[test] fn should_support_dual_byte_bridge() { let device1 = duo::MultiDevice::B1(::Device::default()); let device2 = duo::MultiDevice::B2(::Device::default()); let tensor1 = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device1); let tensor2 = Tensor::::from_floats([5.0, 6.0, 7.0, 8.0], &device2); let tensor1_2 = tensor1.clone().to_device(&device2); tensor1.into_data().assert_eq(&tensor1_2.into_data(), true); let tensor2_1 = tensor2.clone().to_device(&device1); tensor2.into_data().assert_eq(&tensor2_1.into_data(), true); } } ================================================ FILE: crates/burn-std/Cargo.toml ================================================ [package] authors = ["Dilshod Tadjibaev (@antimora)"] categories = [] description = "Core types and utilities shared across the Burn ecosystem." documentation = "https://docs.rs/burn-std" edition.workspace = true keywords = [] license.workspace = true name = "burn-std" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-std" version.workspace = true [lints] workspace = true [features] cubecl = ["dep:cubecl"] default = ["std", "cubecl-common/default"] doc = ["default"] std = ["cubecl-common/std", "num-traits/std"] tracing = ["cubecl?/tracing", "cubecl-common/tracing"] network = ["dep:indicatif", "dep:reqwest", "dep:tokio"] [dependencies] bytemuck = { workspace = true, features = ["extern_crate_alloc"] } half = { workspace = true, features = ["bytemuck"] } num-traits = { workspace = true } serde = { workspace = true } smallvec = { workspace = true, features = ["serde"] } cubecl = { workspace = true, optional = true, default-features = false } cubecl-common = { workspace = true, default-features = false, features = [ "serde", "shared-bytes", ] } cubecl-zspace = { workspace = true, default-features = false } # Enable extra-platforms for portable-atomic support on targets without native atomics (e.g., thumbv6m) # This is needed because cubecl-common's shared-bytes feature pulls in bytes bytes = { workspace = true } # Network downloader indicatif = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } tokio = { workspace = true, optional = true } [dev-dependencies] dashmap = { workspace = true } # Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi) [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] bytes = { workspace = true, features = ["extra-platforms"] } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-std/README.md ================================================ # Burn Standard Library `burn-std` provides the core types and utilities shared across the Burn ecosystem. It includes foundational definitions for shapes, indexing, and data types. This crate supports both `std` and `no_std` environments and must compile with `cargo build --no-default-features` as well. ================================================ FILE: crates/burn-std/src/id.rs ================================================ //! # Unique Identifiers use crate::rand::gen_random; /// Simple ID generator. pub struct IdGenerator {} impl IdGenerator { /// Generates a new ID. pub fn generate() -> u64 { // Generate a random u64 (18,446,744,073,709,551,615 combinations) let random_bytes: [u8; 8] = gen_random(); u64::from_le_bytes(random_bytes) } } pub use cubecl_common::stream_id::StreamId; #[cfg(test)] mod tests { use super::*; use alloc::collections::BTreeSet; #[cfg(feature = "std")] use dashmap::DashSet; //Concurrent HashMap #[cfg(feature = "std")] use std::{sync::Arc, thread}; #[test] fn uniqueness_test() { const IDS_CNT: usize = 10_000; let mut set: BTreeSet = BTreeSet::new(); for _i in 0..IDS_CNT { assert!(set.insert(IdGenerator::generate())); } assert_eq!(set.len(), IDS_CNT); } #[cfg(feature = "std")] #[test] fn thread_safety_test() { const NUM_THREADS: usize = 10; const NUM_REPEATS: usize = 1_000; const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS; let set: Arc> = Arc::new(DashSet::new()); let mut handles = vec![]; for _ in 0..NUM_THREADS { let set = set.clone(); let handle = thread::spawn(move || { for _i in 0..NUM_REPEATS { assert!(set.insert(IdGenerator::generate())); } }); handles.push(handle); } for handle in handles { handle.join().unwrap(); } assert_eq!(set.len(), EXPECTED_TOTAL_IDS); } } ================================================ FILE: crates/burn-std/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! # Burn Standard Library //! //! This library contains core types and utilities shared across Burn, including shapes, indexing, //! and data types. extern crate alloc; /// Id module contains types for unique identifiers. pub mod id; /// Tensor utilities. pub mod tensor; pub use tensor::*; /// Common Errors. pub use cubecl_zspace::errors::{self, *}; /// Network utilities. #[cfg(feature = "network")] pub mod network; // Re-exported types pub use cubecl_common::bytes::*; pub use cubecl_common::device_handle::DeviceHandle; pub use cubecl_common::*; pub use half::{bf16, f16}; #[cfg(feature = "cubecl")] pub use cubecl::flex32; #[cfg(feature = "cubecl")] mod cube { use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind}; use cubecl_common::quant::scheme::QuantScheme; use crate::tensor::DType; use crate::tensor::quantization::{QuantStore, QuantValue}; impl From for cubecl::ir::ElemType { fn from(dtype: DType) -> Self { match dtype { DType::F64 => ElemType::Float(FloatKind::F64), DType::F32 => ElemType::Float(FloatKind::F32), DType::Flex32 => ElemType::Float(FloatKind::Flex32), DType::F16 => ElemType::Float(FloatKind::F16), DType::BF16 => ElemType::Float(FloatKind::BF16), DType::I64 => ElemType::Int(IntKind::I64), DType::I32 => ElemType::Int(IntKind::I32), DType::I16 => ElemType::Int(IntKind::I16), DType::I8 => ElemType::Int(IntKind::I8), DType::U64 => ElemType::UInt(UIntKind::U64), DType::U32 => ElemType::UInt(UIntKind::U32), DType::U16 => ElemType::UInt(UIntKind::U16), DType::U8 => ElemType::UInt(UIntKind::U8), DType::Bool(store) => match store { crate::BoolStore::Native => ElemType::Bool, crate::BoolStore::U8 => ElemType::UInt(UIntKind::U8), crate::BoolStore::U32 => ElemType::UInt(UIntKind::U32), }, DType::QFloat(scheme) => match scheme.store { QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8), QuantValue::E4M3 => Self::Float(FloatKind::E4M3), QuantValue::E5M2 => Self::Float(FloatKind::E5M2), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E2M1 => { panic!("Can't store native sub-byte values") } }, QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32), QuantStore::PackedNative(_) => match scheme.value { QuantValue::E2M1 => panic!("Can't store native sub-byte values"), other => panic!("{other:?} doesn't support native packing"), }, }, } } } impl From for cubecl::ir::StorageType { fn from(dtype: DType) -> cubecl::ir::StorageType { match dtype { DType::QFloat(QuantScheme { store: QuantStore::PackedNative(_), value: QuantValue::E2M1, .. }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2), _ => { let elem: ElemType = dtype.into(); elem.into() } } } } } ================================================ FILE: crates/burn-std/src/network.rs ================================================ //! # Common Network Utilities /// Network download utilities. pub mod downloader { use indicatif::{ProgressBar, ProgressState, ProgressStyle}; use reqwest::Client; use std::io::Write; /// Download the file at the specified url. /// File download progress is reported with the help of a [progress bar](indicatif). /// /// # Arguments /// /// * `url` - The file URL to download. /// * `message` - The message to display on the progress bar during download. /// /// # Returns /// /// A vector of bytes containing the downloaded file data. #[tokio::main(flavor = "current_thread")] pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec { // Get file from web let mut response = Client::new().get(url).send().await.unwrap(); let total_size = response.content_length().unwrap(); // Pretty progress bar let pb = ProgressBar::new(total_size); let msg = message.to_owned(); pb.set_style( ProgressStyle::with_template( "{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})", ) .unwrap() .with_key( "eta", |state: &ProgressState, w: &mut dyn std::fmt::Write| { write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() }, ) .progress_chars("▬ "), ); pb.set_message(msg.clone()); // Read stream into bytes let mut downloaded: u64 = 0; let mut bytes: Vec = Vec::with_capacity(total_size as usize); while let Some(chunk) = response.chunk().await.unwrap() { let num_bytes = bytes.write(&chunk).unwrap(); let new = std::cmp::min(downloaded + (num_bytes as u64), total_size); downloaded = new; pb.set_position(new); } pb.finish_with_message(msg); bytes } } ================================================ FILE: crates/burn-std/src/tensor/dtype.rs ================================================ //! Tensor data type. use serde::{Deserialize, Serialize}; use crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue}; use crate::{bf16, f16}; #[allow(missing_docs)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum DType { F64, F32, Flex32, F16, BF16, I64, I32, I16, I8, U64, U32, U16, U8, Bool(BoolStore), QFloat(QuantScheme), } #[cfg(feature = "cubecl")] impl From for DType { fn from(value: cubecl::ir::ElemType) -> Self { match value { cubecl::ir::ElemType::Float(float_kind) => match float_kind { cubecl::ir::FloatKind::F16 => DType::F16, cubecl::ir::FloatKind::BF16 => DType::BF16, cubecl::ir::FloatKind::Flex32 => DType::Flex32, cubecl::ir::FloatKind::F32 => DType::F32, cubecl::ir::FloatKind::F64 => DType::F64, cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."), cubecl::ir::FloatKind::E2M1 | cubecl::ir::FloatKind::E2M3 | cubecl::ir::FloatKind::E3M2 | cubecl::ir::FloatKind::E4M3 | cubecl::ir::FloatKind::E5M2 | cubecl::ir::FloatKind::UE8M0 => { unimplemented!("Not yet supported, will be used for quantization") } }, cubecl::ir::ElemType::Int(int_kind) => match int_kind { cubecl::ir::IntKind::I8 => DType::I8, cubecl::ir::IntKind::I16 => DType::I16, cubecl::ir::IntKind::I32 => DType::I32, cubecl::ir::IntKind::I64 => DType::I64, }, cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind { cubecl::ir::UIntKind::U8 => DType::U8, cubecl::ir::UIntKind::U16 => DType::U16, cubecl::ir::UIntKind::U32 => DType::U32, cubecl::ir::UIntKind::U64 => DType::U64, }, _ => panic!("Not a valid DType for tensors."), } } } impl DType { /// Returns the size of a type in bytes. pub const fn size(&self) -> usize { match self { DType::F64 => core::mem::size_of::(), DType::F32 => core::mem::size_of::(), DType::Flex32 => core::mem::size_of::(), DType::F16 => core::mem::size_of::(), DType::BF16 => core::mem::size_of::(), DType::I64 => core::mem::size_of::(), DType::I32 => core::mem::size_of::(), DType::I16 => core::mem::size_of::(), DType::I8 => core::mem::size_of::(), DType::U64 => core::mem::size_of::(), DType::U32 => core::mem::size_of::(), DType::U16 => core::mem::size_of::(), DType::U8 => core::mem::size_of::(), DType::Bool(store) => match store { BoolStore::Native => core::mem::size_of::(), BoolStore::U8 => core::mem::size_of::(), BoolStore::U32 => core::mem::size_of::(), }, DType::QFloat(scheme) => match scheme.store { QuantStore::Native => match scheme.value { QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::(), // e2m1 native is automatically packed by the kernels, so the actual storage is // 8 bits wide. QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => { core::mem::size_of::() } QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { // Sub-byte values have fractional size 0 } }, QuantStore::PackedU32(_) => core::mem::size_of::(), QuantStore::PackedNative(_) => match scheme.value { QuantValue::E2M1 => core::mem::size_of::(), _ => 0, }, }, } } /// Returns true if the data type is a floating point type. pub fn is_float(&self) -> bool { matches!( self, DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16 ) } /// Returns true if the data type is a signed integer type. pub fn is_int(&self) -> bool { matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8) } /// Returns true if the data type is an unsigned integer type. pub fn is_uint(&self) -> bool { matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8) } /// Returns true if the data type is a boolean type pub fn is_bool(&self) -> bool { matches!(self, DType::Bool(_)) } /// Returns the data type name. pub fn name(&self) -> &'static str { match self { DType::F64 => "f64", DType::F32 => "f32", DType::Flex32 => "flex32", DType::F16 => "f16", DType::BF16 => "bf16", DType::I64 => "i64", DType::I32 => "i32", DType::I16 => "i16", DType::I8 => "i8", DType::U64 => "u64", DType::U32 => "u32", DType::U16 => "u16", DType::U8 => "u8", DType::Bool(store) => match store { BoolStore::Native => "bool", BoolStore::U8 => "bool(u8)", BoolStore::U32 => "bool(u32)", }, DType::QFloat(_) => "qfloat", } } } #[allow(missing_docs)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum FloatDType { F64, F32, Flex32, F16, BF16, } impl From for FloatDType { fn from(value: DType) -> Self { match value { DType::F64 => FloatDType::F64, DType::F32 => FloatDType::F32, DType::Flex32 => FloatDType::Flex32, DType::F16 => FloatDType::F16, DType::BF16 => FloatDType::BF16, _ => panic!("Expected float data type, got {value:?}"), } } } impl From for DType { fn from(value: FloatDType) -> Self { match value { FloatDType::F64 => DType::F64, FloatDType::F32 => DType::F32, FloatDType::Flex32 => DType::Flex32, FloatDType::F16 => DType::F16, FloatDType::BF16 => DType::BF16, } } } #[allow(missing_docs)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum IntDType { I64, I32, I16, I8, U64, U32, U16, U8, } impl From for IntDType { fn from(value: DType) -> Self { match value { DType::I64 => IntDType::I64, DType::I32 => IntDType::I32, DType::I16 => IntDType::I16, DType::I8 => IntDType::I8, DType::U64 => IntDType::U64, DType::U32 => IntDType::U32, DType::U16 => IntDType::U16, DType::U8 => IntDType::U8, _ => panic!("Expected int data type, got {value:?}"), } } } impl From for DType { fn from(value: IntDType) -> Self { match value { IntDType::I64 => DType::I64, IntDType::I32 => DType::I32, IntDType::I16 => DType::I16, IntDType::I8 => DType::I8, IntDType::U64 => DType::U64, IntDType::U32 => DType::U32, IntDType::U16 => DType::U16, IntDType::U8 => DType::U8, } } } #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Data type used to store boolean values. pub enum BoolStore { /// Stored as native boolean type (e.g. `bool`). Native, /// Stored as 8-bit unsigned integer. U8, /// Stored as 32-bit unsigned integer. U32, } /// Boolean dtype. /// /// This is currently an alias to [`BoolStore`], since it only varies by the storage representation. pub type BoolDType = BoolStore; #[allow(deprecated)] impl From for BoolDType { fn from(value: DType) -> Self { match value { DType::Bool(store) => match store { BoolStore::Native => BoolDType::Native, BoolStore::U8 => BoolDType::U8, BoolStore::U32 => BoolDType::U32, }, _ => panic!("Expected bool data type, got {value:?}"), } } } impl From for DType { fn from(value: BoolDType) -> Self { match value { BoolDType::Native => DType::Bool(BoolStore::Native), BoolDType::U8 => DType::Bool(BoolStore::U8), BoolDType::U32 => DType::Bool(BoolStore::U32), } } } ================================================ FILE: crates/burn-std/src/tensor/mod.rs ================================================ pub mod dtype; pub mod quantization; pub mod shape; pub mod slice; pub use dtype::*; pub use quantization::*; pub use shape::*; pub use slice::*; pub use cubecl_zspace::indexing::{self, *}; pub use cubecl_zspace::{Strides, metadata::Metadata, strides}; /// Check if the current tensor is contiguous. /// /// A tensor is considered contiguous if its elements are stored in memory /// such that the stride at position `k` is equal to the product of the shapes /// of all dimensions greater than `k`. /// /// This means that strides increase as you move from the rightmost to the leftmost dimension. pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { if shape.is_empty() { return true; } for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) { if expected != stride { return false; } } true } /// Computes the strides for a contiguous tensor with the given shape. /// /// In a contiguous row-major tensor, the stride for each dimension /// equals the product of all dimension sizes to its right. pub fn contiguous_strides(shape: &[usize]) -> Strides { let mut strides = strides![0; shape.len()]; let mut current = 1; for (i, &dim) in shape.iter().enumerate().rev() { strides[i] = current; current *= dim; } strides } /// The action to take for a reshape operation. #[derive(Debug)] pub enum ReshapeAction { /// Updating the strides is sufficient to handle the reshape. UpdateStrides { /// The new strides. strides: Strides, }, /// The strides are not compatible, we should recompute the buffer. Recompute, /// The strides are already correct. NoChange, } /// The reshape kind. #[derive(Debug)] pub enum ReshapeAnalysis { /// Original tensor is contiguous, can update the strides. IsContiguous, /// Original tensor is highly permutated, can't update the strides. HighlyPermuted, /// Only batch dimensions are added, can update the strides. Broadcasted, /// Dimensions are only split, can update the strides. Split, /// Original tensor is bigger than output shape. SmallerRank, /// New shape is the same. NoChange, } impl ReshapeAnalysis { /// Returns the proper action to take for the current analysis. fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction { match self { ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides { strides: contiguous_strides(shape_new), }, ReshapeAnalysis::NoChange => ReshapeAction::NoChange, ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => { ReshapeAction::Recompute } ReshapeAnalysis::Broadcasted => { let shape_rank = shape.len(); let shape_new_rank = shape_new.len(); let n_new_batch = shape_new_rank - shape_rank; let num_elems = shape.iter().product::(); let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides); ReshapeAction::UpdateStrides { strides: strides_new, } } ReshapeAnalysis::Split => { let strides_new = split_strides(shape, strides, shape_new); ReshapeAction::UpdateStrides { strides: strides_new, } } } } } /// Returns the proper action to take when reshaping a tensor. pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction { reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new) } /// Calculate the new strides given added batch dimensions. pub fn broadcast_strides( n_new_batch: usize, rank_prev: usize, num_elems: usize, strides: &[usize], ) -> Strides { let mut strides_new = strides![num_elems; rank_prev + n_new_batch]; for (i, s) in strides.iter().enumerate() { strides_new[i + n_new_batch] = *s; } strides_new } /// Calculate the new strides given added split dimensions. pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides { let mut strides_new = strides![1; shape_new.len()]; let mut old_idx = shape.len() - 1; let mut current_stride = strides[old_idx]; let mut dim_prod = 1; for (i, dim) in shape_new.iter().enumerate().rev() { dim_prod *= *dim; strides_new[i] = current_stride; if *dim == 1 { continue; } else if dim_prod == shape[old_idx] { old_idx = old_idx.saturating_sub(1); current_stride = strides[old_idx]; dim_prod = 1; } else { current_stride *= *dim; } } strides_new } /// Returns the analysis of a reshape operation. pub fn reshape_analysis( shape: &[usize], strides: Option<&[usize]>, shape_new: &[usize], ) -> ReshapeAnalysis { let shape_rank = shape.len(); let shape_new_rank = shape_new.len(); let is_contiguous = match strides { Some(strides) => is_contiguous(shape, strides), None => false, }; if is_contiguous { return ReshapeAnalysis::IsContiguous; } if shape_new_rank < shape_rank { return ReshapeAnalysis::SmallerRank; } let n_new_batch = shape_new_rank - shape_rank; match n_new_batch > 0 { true => { if shape == &shape_new[n_new_batch..shape_new_rank] && shape_new[0..n_new_batch].iter().all(|it| *it == 1) { return ReshapeAnalysis::Broadcasted; } else { let mut dim_prod = 1; let mut old_idx = 0; for dim in shape_new { dim_prod *= *dim; // We need to ignore unit dims because they don't affect analysis and break // things because they match the default `dim_prod`. If we don't do this, // reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access. if *dim == 1 { continue; } else if dim_prod == shape[old_idx] { dim_prod = 1; old_idx += 1; } else if dim_prod > shape[old_idx] { return ReshapeAnalysis::HighlyPermuted; } } return ReshapeAnalysis::Split; } } false => { if shape == shape_new { return ReshapeAnalysis::NoChange; } } }; ReshapeAnalysis::HighlyPermuted } ================================================ FILE: crates/burn-std/src/tensor/quantization.rs ================================================ //! Quantization data representation. // Re-exported types pub use cubecl_common::quant::scheme::{ BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue, }; /// Alignment (in bytes) for quantization parameters in serialized tensor data. /// /// NOTE: This is currently f32-based since scales were originally always f32. /// With `QuantParam` now supporting different precisions (F16, BF16, etc.), /// this alignment may need to be revisited in the future. pub const QPARAM_ALIGN: usize = core::mem::align_of::(); use alloc::vec::Vec; use core::any::TypeId; use num_traits::PrimInt; use serde::{Deserialize, Serialize}; use crate::{DType, Metadata, Shape, bytes::Bytes}; #[derive( Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, )] /// The precision of accumulating elements. pub enum QuantAcc { /// Full precision. #[default] F32, /// Half precision. F16, /// bfloat16 precision. BF16, } /// Specify if the output of an operation is quantized using the scheme of the input /// or returned unquantized. #[derive( Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, )] pub enum QuantPropagation { /// The output is quantized using the scheme of the input. Propagate, /// The output is not quantized. #[default] Inhibit, } /// The quantization tensor data parameters. #[derive(Clone, Debug)] pub struct QParams { /// The scaling factor. pub scales: S, } /// A quantization parameter tensor descriptor. #[derive(Debug, Clone, PartialEq, Eq)] pub struct QParamTensor { /// Start of the tensor in the buffer pub offset_start: usize, /// Offset of tensor end from the end of the buffer pub offset_end: usize, /// Metadata of the tensor pub metadata: Metadata, /// Data type of the tensor pub dtype: DType, } /// Calculate the shape of the quantization parameters for a given tensor and level pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape { match level { QuantLevel::Tensor => Shape::new([1]), QuantLevel::Block(block_size) => { let mut params_shape = data_shape.clone(); let block_size = block_size.to_dim_vec(data_shape.num_dims()); for (shape, block_size) in params_shape.iter_mut().zip(block_size) { *shape = (*shape).div_ceil(block_size as usize); } params_shape } } } /// Quantized data bytes representation. /// /// # Notes /// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8 /// quantized values pack 4 grouped values into a single `u32`. When unpacking these values, /// we make sure to retrieve only the meaningful values (and ignore the alignment padding). /// 2) Quantization parameters are appended to the tensor data. /// As such, the last bytes always correspond to the scale parameter. /// If the quantization scheme includes an offset (zero-point) parameter, it is next to last. pub struct QuantizedBytes { /// The quantized values and quantization parameters represented as bytes. pub bytes: Bytes, /// The quantization scheme. pub scheme: QuantScheme, /// The number of quantized elements. pub num_elements: usize, } impl QuantizedBytes { /// Creates a new quantized bytes representation. pub fn new( value: Vec, scheme: QuantScheme, scales: &[f32], ) -> Self { let num_elements = value.len(); // Only used for 8-bit quantization data comparison in tests if TypeId::of::() != TypeId::of::() { panic!("Invalid quantized type"); } // Re-interpret `Vec` as `Vec` with `Vec::from_raw_parts` let i8s: Vec = bytemuck::allocation::cast_vec(value); let mut bytes = Bytes::from_elems(i8s); match scheme.level { QuantLevel::Tensor => { let scale_bytes = bytemuck::bytes_of(&scales[0]); bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN); } QuantLevel::Block(_block_size) => { let mut scale_bytes = Vec::with_capacity(size_of_val(scales)); for scale in scales { scale_bytes.extend_from_slice(bytemuck::bytes_of(scale)); } bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN); } } Self { bytes, scheme, num_elements, } } /// Returns the int8 quantized values with the quantization parameters. pub fn into_vec_i8(self) -> (Vec, QParams>) { let (values, (qparams, num_params)) = self.split_values_off(); // Quantization parameters are added at the end of the tensor data. // As such, the last bytes always correspond to the scale parameter(s). // For example, per-block quantization can have multiple parameters for a single tensor: // [scale, scale, scale, ...] let scale_size = core::mem::size_of::(); // scale is stored as f32 let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams); let total_bytes = qparams_bytes.len(); let scales_size = scale_size * num_params; let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec(); (values, QParams { scales }) } fn split_i8_values(self, num_params: usize) -> (Vec, Vec) { let mut values = read_bytes_to_i8(self.bytes); let scale_size = num_params * size_of::(); let values_end = values.len() - scale_size; let qparams = values.split_off(values_end); let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) { let mut qparams = core::mem::ManuallyDrop::new(qparams); unsafe { Vec::::from_raw_parts( qparams.as_mut_ptr() as _, qparams.len() / 4, qparams.capacity() / 4, ) } } else { #[cfg(target_endian = "little")] { // SAFETY: quantized bytes representation is created from packed u32 values in little endian bytemuck::cast_vec(qparams) } #[cfg(target_endian = "big")] { crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams)) } }; (values, qparams) } /// Splits the quantized values of the tensor from the quantization parameters. /// /// Returns the values in i8 and a newly allocated vector containing the quantization parameters. fn split_values_off(self) -> (Vec, (Vec, usize)) { let num_params = match self.scheme.level { QuantLevel::Tensor => 1, QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(), }; if let QuantStore::PackedU32(packed_dim) = self.scheme.store { assert_eq!( packed_dim, 0, "Packing must be on innermost dimension for splitting off values" ); } let (values, qparams) = match self.scheme.store { QuantStore::Native => self.split_i8_values(num_params), QuantStore::PackedU32(_) => match self.scheme.value { QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params), QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { let mut values = self.bytes.try_into_vec::().unwrap(); let scale_size = num_params; // size of f32 same as u32 let values_end = values.len() - scale_size; let qparams = values.split_off(values_end); // Sub-byte values are unpacked as i8s for value equality tests let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value); (values, qparams) } QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => { unimplemented!("Not yet supported") } }, QuantStore::PackedNative(_) => unimplemented!("Not yet supported"), }; (values, (qparams, num_params)) } } fn read_bytes_to_i8(bytes: Bytes) -> Vec { match bytes.try_into_vec::() { Ok(val) => val, // Safety, // // `Vec` can be Re-interpreted as `Vec` since they share the same alignment. Err(bytes) => unsafe { core::mem::transmute::, Vec>(bytes.to_vec()) }, } } /// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers. pub fn pack_i8s_to_u32s(values: Vec) -> Vec { // Shift and combine groups of four 8-bit values into a u32. // Same as doing this: // let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF); #[cfg(target_endian = "big")] { values .chunks(4) .map(|x| { x.iter() .enumerate() .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8)) }) .collect() } // The order of bytes in little endian matches the above description, we just need to // handle padding when the number of values is not a factor of 4 #[cfg(target_endian = "little")] { let mut values = values; let remainder = values.len() % 4; if remainder != 0 { // Pad with zeros values.extend(core::iter::repeat_n(0, 4 - remainder)); } let len = values.len() / 4; let capacity = values.capacity() / 4; // Pre-forget the old vec and re-interpret as u32 let mut values = core::mem::ManuallyDrop::new(values); let ptr = values.as_mut_ptr() as *mut u32; unsafe { Vec::from_raw_parts(ptr, len, capacity) } } } /// Unpack integer values into a sequence of signed 8-bit integers. pub(crate) fn unpack_q_to_i8s( values: &[Q], numel: usize, value: &QuantValue, ) -> Vec { let size_store = size_of::() * 8; let size_quant = value.size_bits(); let num_quants = size_store / size_quant; let mask = Q::from((1 << size_quant) - 1).unwrap(); let sign_shift = 8 - size_quant; // sign extension for sub-byte values values .iter() .enumerate() .flat_map(|(i, &packed)| { // A single u32 could contain less than four 8-bit values... let n = core::cmp::min(num_quants, numel - i * num_quants); // Extract each 8-bit segment from u32 and cast back to i8 // Same as doing this (when 4 values are fully packed): // let a = (packed & 0xFF) as i8; // let b = ((packed >> 8) & 0xFF) as i8; // let c = ((packed >> 16) & 0xFF) as i8; // let d = ((packed >> 24) & 0xFF) as i8; (0..n).map(move |i| { let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap(); ((raw << sign_shift) as i8) >> sign_shift }) }) .collect() } #[cfg(test)] mod tests { use super::*; use alloc::vec; #[test] fn should_pack_i8s_to_u32() { let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]); assert_eq!(packed, vec![2147287680]); } #[test] fn should_pack_i8s_to_u32_padded() { let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]); let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]); assert_eq!(packed, vec![2147287680, 55]); assert_eq!(packed, packed_padded); } #[test] fn should_unpack_u32s_to_i8s() { let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S); assert_eq!(unpacked, vec![-128, 2, -3, 127]); } #[test] fn should_unpack_u32s_to_i8s_padded() { let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S); assert_eq!(unpacked, vec![55]); } #[test] fn should_unpack_u32s_to_i8s_arange() { let unpacked = unpack_q_to_i8s( &[ 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459, 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590, 2004318071, ], 128, &QuantValue::Q4S, ); assert_eq!( unpacked, vec![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ] ); } #[test] fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() { // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] let scale = 0.03937008; let values = vec![0i8, 25, 51, 76, 102, 127]; let q_bytes = QuantizedBytes::new( values.clone(), QuantScheme::default() .with_value(QuantValue::Q8S) .with_store(QuantStore::Native), &[scale], ); let (q_values, qparams) = q_bytes.into_vec_i8(); assert_eq!(qparams.scales, vec![scale]); assert_eq!(q_values, values); } } ================================================ FILE: crates/burn-std/src/tensor/shape.rs ================================================ //! Tensor shape definition. use super::{Slice, SliceArg}; use alloc::vec::Vec; use core::ops::Range; pub use crate::errors::ExpressionError; pub use cubecl_zspace::{MetadataError, Shape, SmallVec, calculate_matmul_output, shape}; /// Slice-related ops on [`Shape`] pub trait SliceOps: Sized { /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. fn into_ranges(self) -> Vec>; /// Converts slice arguments into an array of slice specifications for the shape. /// /// This method returns an array of `Slice` objects that can be used for slicing operations. /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but /// allows custom slice specifications instead of full ranges. /// For creating complex slice specifications, use the [`s!`] macro. /// /// # Arguments /// /// * `slices` - An array of slice specifications, where each element can be: /// - A range (e.g., `2..5`) /// - An index /// - A `Slice` object /// - The output of the [`s!`] macro for advanced slicing /// /// # Behavior /// /// - Supports partial and full slicing in any number of dimensions. /// - Missing ranges are treated as full slices if D > D2. /// - Handles negative indices by wrapping around from the end of the dimension. /// - Clamps ranges to the shape's dimensions if they exceed the bounds. /// /// # Returns /// /// An array of `Slice` objects corresponding to the provided slice specifications, /// clamped to the shape's actual dimensions. /// /// # Examples /// /// ```rust /// use burn_std::{Shape, Slice, s, SliceOps}; /// /// fn example() { /// // 1D slicing /// let slices = Shape::new([4]).into_slices(1..4); /// assert_eq!(slices[0].to_range(4), 1..3); /// /// // 2D slicing /// let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]); /// assert_eq!(slices[0].to_range(3), 1..3); /// assert_eq!(slices[1].to_range(4), 0..2); /// /// // Using negative indices /// let slices = Shape::new([3]).into_slices(..-2); /// assert_eq!(slices[0].to_range(3), 0..1); /// /// // Using the slice macro to select different ranges /// let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]); /// assert_eq!(slices[0].to_range(2), 0..2); /// assert_eq!(slices[1].to_range(3), 1..2); /// } /// ``` /// /// # See Also /// /// - [`s!`] - The recommended macro for creating slice specifications /// - [`Shape::into_ranges`] - Convert to full covering ranges /// /// [`s!`]: crate::s! fn into_slices(self, slices: S) -> Vec where S: SliceArg; /// Compute the output shape from the given slices. fn slice(self, slices: &[Slice]) -> Result; } impl SliceOps for Shape { fn into_ranges(self) -> Vec> { self.iter().map(|&d| 0..d).collect() } fn into_slices(self, slices: S) -> Vec where S: SliceArg, { slices.into_slices(&self) } fn slice(mut self, slices: &[Slice]) -> Result { if slices.len() > self.rank() { return Err(MetadataError::RankMismatch { left: self.rank(), right: slices.len(), }); } slices .iter() .zip(self.iter_mut()) .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size)); Ok(self) } } #[cfg(test)] #[allow(clippy::identity_op, reason = "useful for clarity")] mod tests { use super::*; use crate::s; use alloc::vec; #[test] fn test_into_ranges() { let dims = [2, 3, 4, 5]; let shape = Shape::new(dims); assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]); } #[allow(clippy::single_range_in_vec_init)] #[test] fn test_into_slices() { let slices = Shape::new([3]).into_slices(1..4); assert_eq!(slices[0].to_range(3), 1..3); let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]); assert_eq!(slices[0].to_range(3), 1..3); assert_eq!(slices[1].to_range(4), 0..2); let slices = Shape::new([3]).into_slices(..-2); assert_eq!(slices[0].to_range(3), 0..1); let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]); assert_eq!(slices[0].to_range(2), 0..2); assert_eq!(slices[1].to_range(3), 1..2); let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]); assert_eq!(slices[0].to_range(2), 0..2); assert_eq!(slices[1].to_range(3), 2..3); } #[test] fn test_shape_as_slice() { let dims = [2, 3, 4, 5]; let shape = Shape::new(dims); assert_eq!(shape.as_slice(), dims.as_slice()); // Deref coercion let shape_slice: &[usize] = &shape; assert_eq!(shape_slice, *&[2, 3, 4, 5]); } #[test] fn test_shape_as_mut_slice() { let mut dims = [2, 3, 4, 5]; let mut shape = Shape::new(dims); let shape_mut = shape.as_mut_slice(); assert_eq!(shape_mut, dims.as_mut_slice()); shape_mut[1] = 6; assert_eq!(shape_mut, &[2, 6, 4, 5]); let mut shape = Shape::new(dims); let shape = &mut shape[..]; shape[1] = 6; assert_eq!(shape, shape_mut) } #[test] fn test_shape_slice_output_shape_basic() { // Test basic slicing with step=1 let slices = [ Slice::new(0, Some(5), 1), // 5 elements Slice::new(2, Some(8), 1), // 6 elements ]; let original_shape = Shape::new([10, 10, 10]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([5, 6, 10])); } #[test] fn test_shape_slice_output_shape_with_positive_steps() { // Test slicing with various positive steps let slices = [ Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements Slice::new(1, Some(9), 3), // [1,4,7] -> 3 elements Slice::new(0, Some(7), 4), // [0,4] -> 2 elements ]; let original_shape = Shape::new([20, 20, 20, 30]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([5, 3, 2, 30])); } #[test] fn test_shape_slice_output_shape_with_negative_steps() { // Test slicing with negative steps (backward iteration) let slices = [ Slice::new(0, Some(10), -1), // 10 elements traversed backward Slice::new(2, Some(8), -2), // [7,5,3] -> 3 elements ]; let original_shape = Shape::new([20, 20, 20]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([10, 3, 20])); } #[test] fn test_shape_slice_output_shape_mixed_steps() { // Test with a mix of positive, negative, and unit steps let slices = [ Slice::from_range_stepped(1..6, 1), // 5 elements Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements Slice::from_range_stepped(2..14, 4), // [2,6,10] -> 3 elements ]; let original_shape = Shape::new([20, 20, 20]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([5, 4, 3])); } #[test] fn test_shape_slice_output_shape_partial_dims() { // Test when slices has fewer dimensions than original shape let slices = [ Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements ]; let original_shape = Shape::new([10, 20, 30, 40]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([3, 20, 30, 40])); } #[test] fn test_shape_slice_output_shape_edge_cases() { // Test edge cases with small ranges and large steps let slices = [ Slice::from_range_stepped(0..1, 1), // Single element Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element Slice::from_range_stepped(5..5, 1), // Empty range -> 0 elements ]; let original_shape = Shape::new([10, 20, 30]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([1, 1, 0])); } #[test] fn test_shape_slice_output_shape_empty() { // Test with no slice infos (should return original shape) let slices = []; let original_shape = Shape::new([10, 20, 30]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([10, 20, 30])); } #[test] fn test_shape_slice_output_shape_uneven_division() { // Test cases where range size doesn't divide evenly by step let slices = [ Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6] Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8] Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6] ]; let original_shape = Shape::new([20, 20, 20]); let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([3, 3, 2])); } } ================================================ FILE: crates/burn-std/src/tensor/slice.rs ================================================ //! Tensor slice utilities. use crate::Shape; use crate::indexing::AsIndex; use alloc::format; use alloc::vec::Vec; use core::fmt::{Display, Formatter}; use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use core::str::FromStr; /// Trait for slice arguments that can be converted into an array of slices. /// This allows the `slice` method to accept both single slices (from `s![..]`) /// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`). pub trait SliceArg { /// Convert to an vec of slices with clamping to shape dimensions. /// /// Returns a [Slice] for each dimension in `shape`. fn into_slices(self, shape: &Shape) -> Vec; } impl + Clone> SliceArg for &[S] { fn into_slices(self, shape: &Shape) -> Vec { assert!( self.len() <= shape.num_dims(), "Too many slices provided for shape, got {} but expected at most {}", self.len(), shape.num_dims() ); shape .iter() .enumerate() .map(|(i, dim_size)| { let slice = if i >= self.len() { Slice::full() } else { self[i].clone().into() }; // Apply shape clamping by converting to range and back let clamped_range = slice.to_range(*dim_size); Slice::new( clamped_range.start as isize, Some(clamped_range.end as isize), slice.step(), ) }) .collect::>() } } impl SliceArg for &Vec { fn into_slices(self, shape: &Shape) -> Vec { self.as_slice().into_slices(shape) } } impl SliceArg for [T; R] where T: Into + Clone, { fn into_slices(self, shape: &Shape) -> Vec { self.as_slice().into_slices(shape) } } impl SliceArg for T where T: Into, { fn into_slices(self, shape: &Shape) -> Vec { let slice: Slice = self.into(); [slice].as_slice().into_slices(shape) } } /// Slice argument constructor for tensor indexing. /// /// The `s![]` macro is used to create multi-dimensional slice specifications for tensors. /// It converts various range syntax forms into a `&[Slice]` that can be used with /// `tensor.slice()` and `tensor.slice_assign()` operations. /// /// # Syntax Overview /// /// ## Basic Forms /// /// * **`s![index]`** - Index a single element (produces a subview with that axis removed) /// * **`s![range]`** - Slice a range of elements /// * **`s![range;step]`** - Slice a range with a custom step /// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms /// /// ## Range Types /// /// All standard Rust range types are supported: /// * **`a..b`** - From `a` (inclusive) to `b` (exclusive) /// * **`a..=b`** - From `a` to `b` (both inclusive) /// * **`a..`** - From `a` to the end /// * **`..b`** - From the beginning to `b` (exclusive) /// * **`..=b`** - From the beginning to `b` (inclusive) /// * **`..`** - The full range (all elements) /// /// ## Negative Indices /// /// Negative indices count from the end of the axis: /// * **`-1`** refers to the last element /// * **`-2`** refers to the second-to-last element /// * And so on... /// /// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]` /// /// ## Step Syntax /// /// Steps control the stride between selected elements: /// * **`;step`** after a range specifies the step /// * **Positive steps** select every nth element going forward /// * **Negative steps** select every nth element going backward /// * Default step is `1` when not specified /// * Step cannot be `0` /// /// ### Negative Step Behavior /// /// With negative steps, the range bounds still specify *which* elements to include, /// but the traversal order is reversed: /// /// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`) /// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2) /// * `s![..;-1]` reverses the entire axis /// /// This matches the semantics of NumPy and the ndarray crate. /// /// # Examples /// /// ## Basic Slicing /// /// ```rust,ignore /// use burn_tensor::{Tensor, s}; /// /// # fn example(tensor: Tensor) { /// // Select rows 0-5 (exclusive) /// let subset = tensor.slice(s![0..5, .., ..]); /// /// // Select the last row /// let last_row = tensor.slice(s![-1, .., ..]); /// /// // Select columns 2, 3, 4 /// let cols = tensor.slice(s![.., 2..5, ..]); /// /// // Select a single element at position [1, 2, 3] /// let element = tensor.slice(s![1, 2, 3]); /// # } /// ``` /// /// ## Slicing with Steps /// /// ```rust,ignore /// use burn_tensor::{Tensor, s}; /// /// # fn example(tensor: Tensor) { /// // Select every 2nd row /// let even_rows = tensor.slice(s![0..10;2, ..]); /// /// // Select every 3rd column /// let cols = tensor.slice(s![.., 0..9;3]); /// /// // Select every 2nd element in reverse order /// let reversed_even = tensor.slice(s![10..0;-2, ..]); /// # } /// ``` /// /// ## Reversing Dimensions /// /// ```rust,ignore /// use burn_tensor::{Tensor, s}; /// /// # fn example(tensor: Tensor) { /// // Reverse the first dimension /// let reversed = tensor.slice(s![..;-1, ..]); /// /// // Reverse both dimensions /// let fully_reversed = tensor.slice(s![..;-1, ..;-1]); /// /// // Reverse a specific range /// let range_reversed = tensor.slice(s![2..8;-1, ..]); /// # } /// ``` /// /// ## Complex Multi-dimensional Slicing /// /// ```rust,ignore /// use burn_tensor::{Tensor, s}; /// /// # fn example(tensor: Tensor) { /// // Mix of different slice types /// let complex = tensor.slice(s![ /// 0..10;2, // Every 2nd element from 0 to 10 /// .., // All elements in dimension 1 /// 5..15;-3, // Every 3rd element from 14 down to 5 /// -1 // Last element in dimension 3 /// ]); /// /// // Using inclusive ranges /// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]); /// /// // Negative indices with steps /// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]); /// # } /// ``` /// /// ## Slice Assignment /// /// ```rust,ignore /// use burn_tensor::{Tensor, s}; /// /// # fn example(tensor: Tensor, values: Tensor) { /// // Assign to every 2nd row /// let tensor = tensor.slice_assign(s![0..10;2, ..], values); /// /// // Assign to a reversed slice /// let tensor = tensor.slice_assign(s![..;-1, 0..5], values); /// # } /// ``` #[macro_export] macro_rules! s { // Empty - should not happen [] => { compile_error!("Empty slice specification") }; // Single expression with step [$range:expr; $step:expr] => { { #[allow(clippy::reversed_empty_ranges)] { $crate::tensor::Slice::from_range_stepped($range, $step) } } }; // Single expression without step (no comma after) [$range:expr] => { { #[allow(clippy::reversed_empty_ranges)] { $crate::tensor::Slice::from($range) } } }; // Two or more expressions with first having step [$range:expr; $step:expr, $($rest:tt)*] => { { #[allow(clippy::reversed_empty_ranges)] { $crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*) } } }; // Two or more expressions with first not having step [$range:expr, $($rest:tt)*] => { { #[allow(clippy::reversed_empty_ranges)] { $crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*) } } }; // Internal: finished parsing (@internal [$($acc:expr),*]) => { [$($acc),*] }; // Internal: parse range with step followed by comma (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => { $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*) }; // Internal: parse range with step at end (@internal [$($acc:expr),*] $range:expr; $step:expr) => { $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)]) }; // Internal: parse range without step followed by comma (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => { $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*) }; // Internal: parse range without step at end (@internal [$($acc:expr),*] $range:expr) => { $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)]) }; } /// A slice specification for a single tensor dimension. /// /// This struct represents a range with an optional step, used for advanced indexing /// operations on tensors. It is typically created using the [`s!`] macro rather than /// constructed directly. /// /// # Fields /// /// * `start` - The starting index (inclusive). Negative values count from the end. /// * `end` - The ending index (exclusive). `None` means to the end of the dimension. /// * `step` - The stride between elements. Must be non-zero. /// /// # Index Interpretation /// /// - **Positive indices**: Count from the beginning (0-based) /// - **Negative indices**: Count from the end (-1 is the last element) /// - **Bounds checking**: Indices are clamped to valid ranges /// /// # Step Behavior /// /// - **Positive step**: Traverse forward through the range /// - **Negative step**: Traverse backward through the range /// - **Step size**: Determines how many elements to skip /// /// # Examples /// /// While you typically use the [`s!`] macro, you can also construct slices directly: /// /// ```rust,ignore /// use burn_tensor::Slice; /// /// // Equivalent to s![2..8] /// let slice1 = Slice::new(2, Some(8), 1); /// /// // Equivalent to s![0..10;2] /// let slice2 = Slice::new(0, Some(10), 2); /// /// // Equivalent to s![..;-1] (reverse) /// let slice3 = Slice::new(0, None, -1); /// ``` /// /// See also the [`s!`] macro for the preferred way to create slices. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Slice { /// Slice start index. pub start: isize, /// Slice end index (exclusive). pub end: Option, /// Step between elements (default: 1). pub step: isize, } /// Defines an [`Iterator`] over a [`Slice`]. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct SliceIter { slice: Slice, current: isize, } impl Iterator for SliceIter { type Item = isize; fn next(&mut self) -> Option { let next = self.current; self.current += self.slice.step; if let Some(end) = self.slice.end { if self.slice.is_reversed() { if next <= end { return None; } } else if next >= end { return None; } } Some(next) } } /// Note: Unbounded [`Slice`]s produce infinite iterators. impl IntoIterator for Slice { type Item = isize; type IntoIter = SliceIter; fn into_iter(self) -> Self::IntoIter { SliceIter { slice: self, current: self.start, } } } impl Default for Slice { fn default() -> Self { Self::full() } } impl Slice { /// Creates a new slice with start, end, and step pub const fn new(start: isize, end: Option, step: isize) -> Self { assert!(step != 0, "Step cannot be zero"); Self { start, end, step } } /// Creates a slice that represents the full range. pub const fn full() -> Self { Self::new(0, None, 1) } /// Creates a slice that represents a single index pub fn index(idx: isize) -> Self { Self { start: idx, end: handle_signed_inclusive_end(idx), step: 1, } } /// Converts the slice to a vector. pub fn into_vec(self) -> Vec { assert!( self.end.is_some(), "Slice must have an end to convert to a vector: {self:?}" ); self.into_iter().collect() } /// Clips the slice to a maximum size. /// /// # Example /// /// ```rust,ignore /// assert_eq!( /// Slice::new(0, None, 1).bound_to(10), /// Slice::new(0, Some(10), 1)); /// assert_eq!( /// Slice::new(0, Some(5), 1).bound_to(10), /// Slice::new(0, Some(5), 1)); /// assert_eq!( /// Slice::new(0, None, -1).bound_to(10), /// Slice::new(0, Some(-11), -1)); /// assert_eq!( /// Slice::new(0, Some(-5), -1).bound_to(10), /// Slice::new(0, Some(-5), -1)); /// ``` pub fn bound_to(self, size: usize) -> Self { let mut bounds = size as isize; if let Some(end) = self.end { if end > 0 { bounds = end.min(bounds); } else { bounds = end.max(-(bounds + 1)); } } else if self.is_reversed() { bounds = -(bounds + 1); } Self { end: Some(bounds), ..self } } /// Creates a slice with a custom step pub fn with_step(start: isize, end: Option, step: isize) -> Self { assert!(step != 0, "Step cannot be zero"); Self { start, end, step } } /// Creates a slice from a range with a specified step pub fn from_range_stepped>(range: R, step: isize) -> Self { assert!(step != 0, "Step cannot be zero"); let mut slice = range.into(); slice.step = step; slice } /// Returns the step of the slice pub fn step(&self) -> isize { self.step } /// Returns the range for this slice given a dimension size pub fn range(&self, size: usize) -> Range { self.to_range(size) } /// Convert this slice to a range for a dimension of the given size. /// /// # Arguments /// /// * `size` - The size of the dimension to slice. /// /// # Returns /// /// A `Range` representing the slice bounds. pub fn to_range(&self, size: usize) -> Range { // Always return a valid range with start <= end // The step information will be handled separately let start = convert_signed_index(self.start, size); let end = match self.end { Some(end) => convert_signed_index(end, size), None => size, }; start..end } /// Converts the slice into a range and step tuple pub fn to_range_and_step(&self, size: usize) -> (Range, isize) { let range = self.to_range(size); (range, self.step) } /// Returns true if the step is negative pub fn is_reversed(&self) -> bool { self.step < 0 } /// Calculates the output size for this slice operation pub fn output_size(&self, dim_size: usize) -> usize { let range = self.to_range(dim_size); // Handle empty slices (start >= end) if range.start >= range.end { return 0; } let len = range.end - range.start; if self.step.unsigned_abs() == 1 { len } else { len.div_ceil(self.step.unsigned_abs()) } } } fn convert_signed_index(index: isize, size: usize) -> usize { if index < 0 { (size as isize + index).max(0) as usize } else { (index as usize).min(size) } } fn handle_signed_inclusive_end(end: isize) -> Option { match end { -1 => None, end => Some(end + 1), } } impl From> for Slice { fn from(r: Range) -> Self { Self { start: r.start.as_index(), end: Some(r.end.as_index()), step: 1, } } } impl From> for Slice { fn from(r: RangeInclusive) -> Self { Self { start: r.start().as_index(), end: handle_signed_inclusive_end(r.end().as_index()), step: 1, } } } impl From> for Slice { fn from(r: RangeFrom) -> Self { Self { start: r.start.as_index(), end: None, step: 1, } } } impl From> for Slice { fn from(r: RangeTo) -> Self { Self { start: 0, end: Some(r.end.as_index()), step: 1, } } } impl From> for Slice { fn from(r: RangeToInclusive) -> Self { Self { start: 0, end: handle_signed_inclusive_end(r.end.as_index()), step: 1, } } } impl From for Slice { fn from(_: RangeFull) -> Self { Self { start: 0, end: None, step: 1, } } } impl From for Slice { fn from(i: usize) -> Self { Slice::index(i as isize) } } impl From for Slice { fn from(i: isize) -> Self { Slice::index(i) } } impl From for Slice { fn from(i: i32) -> Self { Slice::index(i as isize) } } impl From for Slice { fn from(i: i64) -> Self { Slice::index(i as isize) } } impl Display for Slice { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { if self.step == 1 && let Some(end) = self.end && self.start == end - 1 { f.write_fmt(format_args!("{}", self.start)) } else { if self.start != 0 { f.write_fmt(format_args!("{}", self.start))?; } f.write_str("..")?; if let Some(end) = self.end { f.write_fmt(format_args!("{}", end))?; } if self.step != 1 { f.write_fmt(format_args!(";{}", self.step))?; } Ok(()) } } } impl FromStr for Slice { type Err = crate::ExpressionError; fn from_str(source: &str) -> Result { let mut s = source.trim(); let parse_int = |v: &str| -> Result { v.parse::().map_err(|e| { crate::ExpressionError::parse_error( format!("Invalid integer: '{v}': {}", e), source, ) }) }; let mut start: isize = 0; let mut end: Option = None; let mut step: isize = 1; if let Some((head, tail)) = s.split_once(";") { step = parse_int(tail)?; s = head; } if s.is_empty() { return Err(crate::ExpressionError::parse_error( "Empty expression", source, )); } if let Some((start_s, end_s)) = s.split_once("..") { if !start_s.is_empty() { start = parse_int(start_s)?; } if !end_s.is_empty() { if let Some(end_s) = end_s.strip_prefix('=') { end = Some(parse_int(end_s)? + 1); } else { end = Some(parse_int(end_s)?); } } } else { start = parse_int(s)?; end = Some(start + 1); } if step == 0 { return Err(crate::ExpressionError::invalid_expression( "Step cannot be zero", source, )); } Ok(Slice::new(start, end, step)) } } #[cfg(test)] mod tests { use super::*; use alloc::string::ToString; use alloc::vec; #[test] fn test_slice_to_str() { assert_eq!(Slice::new(0, None, 1).to_string(), ".."); assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0"); assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10"); assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10"); assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2"); } #[test] fn test_slice_from_str() { assert_eq!("1".parse::(), Ok(Slice::new(1, Some(2), 1))); assert_eq!("..".parse::(), Ok(Slice::new(0, None, 1))); assert_eq!("..3".parse::(), Ok(Slice::new(0, Some(3), 1))); assert_eq!("..=3".parse::(), Ok(Slice::new(0, Some(4), 1))); assert_eq!("-12..3".parse::(), Ok(Slice::new(-12, Some(3), 1))); assert_eq!("..;-1".parse::(), Ok(Slice::new(0, None, -1))); assert_eq!("..=3;-2".parse::(), Ok(Slice::new(0, Some(4), -2))); assert_eq!( "..;0".parse::(), Err(crate::ExpressionError::invalid_expression( "Step cannot be zero", "..;0" )) ); assert_eq!( "".parse::(), Err(crate::ExpressionError::parse_error("Empty expression", "")) ); assert_eq!( "a".parse::(), Err(crate::ExpressionError::parse_error( "Invalid integer: 'a': invalid digit found in string", "a" )) ); assert_eq!( "..a".parse::(), Err(crate::ExpressionError::parse_error( "Invalid integer: 'a': invalid digit found in string", "..a" )) ); assert_eq!( "a:b:c".parse::(), Err(crate::ExpressionError::parse_error( "Invalid integer: 'a:b:c': invalid digit found in string", "a:b:c" )) ); } #[test] fn test_slice_output_size() { // Test the output_size method directly assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10); assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5); assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3) assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10); assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5); assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3) assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range } #[test] fn test_bound_to() { assert_eq!( Slice::new(0, None, 1).bound_to(10), Slice::new(0, Some(10), 1) ); assert_eq!( Slice::new(0, Some(5), 1).bound_to(10), Slice::new(0, Some(5), 1) ); assert_eq!( Slice::new(0, None, -1).bound_to(10), Slice::new(0, Some(-11), -1) ); assert_eq!( Slice::new(0, Some(-5), -1).bound_to(10), Slice::new(0, Some(-5), -1) ); } #[test] fn test_slice_iter() { assert_eq!( Slice::new(2, Some(3), 1).into_iter().collect::>(), vec![2] ); assert_eq!( Slice::new(3, Some(-1), -1).into_iter().collect::>(), vec![3, 2, 1, 0] ); assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]); assert_eq!( Slice::new(3, None, 2) .into_iter() .take(3) .collect::>(), vec![3, 5, 7] ); assert_eq!( Slice::new(3, None, 2) .bound_to(8) .into_iter() .collect::>(), vec![3, 5, 7] ); } #[test] #[should_panic( expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }" )] fn test_unbound_slice_into_vec() { Slice::new(0, None, 1).into_vec(); } #[test] fn into_slices_should_return_for_all_shape_dims() { let slice = s![1]; let shape = Shape::new([2, 3, 1]); let slices = slice.into_slices(&shape); assert_eq!(slices.len(), shape.len()); assert_eq!(slices[0], Slice::new(1, Some(2), 1)); assert_eq!(slices[1], Slice::new(0, Some(3), 1)); assert_eq!(slices[2], Slice::new(0, Some(1), 1)); let slice = s![1, 0..2]; let slices = slice.into_slices(&shape); assert_eq!(slices.len(), shape.len()); assert_eq!(slices[0], Slice::new(1, Some(2), 1)); assert_eq!(slices[1], Slice::new(0, Some(2), 1)); assert_eq!(slices[2], Slice::new(0, Some(1), 1)); let slice = s![..]; let slices = slice.into_slices(&shape); assert_eq!(slices.len(), shape.len()); assert_eq!(slices[0], Slice::new(0, Some(2), 1)); assert_eq!(slices[1], Slice::new(0, Some(3), 1)); assert_eq!(slices[2], Slice::new(0, Some(1), 1)); } #[test] fn into_slices_all_dimensions() { let slice = s![1, ..2, ..]; let shape = Shape::new([2, 3, 1]); let slices = slice.into_slices(&shape); assert_eq!(slices.len(), shape.len()); assert_eq!(slices[0], Slice::new(1, Some(2), 1)); assert_eq!(slices[1], Slice::new(0, Some(2), 1)); assert_eq!(slices[2], Slice::new(0, Some(1), 1)); } #[test] fn into_slices_supports_empty_dimensions() { let slice = s![.., 1, ..]; let shape = Shape::new([0, 3, 1]); let slices = slice.into_slices(&shape); assert_eq!(slices.len(), shape.len()); assert_eq!(slices[0], Slice::new(0, Some(0), 1)); assert_eq!(slices[1], Slice::new(1, Some(2), 1)); assert_eq!(slices[2], Slice::new(0, Some(1), 1)); } #[test] #[should_panic = "Too many slices provided for shape"] fn into_slices_should_match_shape_rank() { let slice = s![.., 1, ..]; let shape = Shape::new([3, 1]); let _ = slice.into_slices(&shape); } #[test] fn should_support_const_and_full() { static SLICES: [Slice; 2] = [Slice::full(), Slice::new(2, None, 1)]; assert_eq!(SLICES[0], Slice::new(0, None, 1)); assert_eq!(SLICES[1], Slice::new(2, None, 1)); } #[test] fn should_support_default() { assert_eq!(Slice::default(), Slice::new(0, None, 1)); } #[test] fn should_support_copy() { let mut slice = Slice::new(1, Some(3), 2); let slice_copy = slice; slice.end = Some(4); assert_eq!(slice, Slice::new(1, Some(4), 2)); assert_eq!(slice_copy, Slice::new(1, Some(3), 2)); } } ================================================ FILE: crates/burn-store/Cargo.toml ================================================ [package] authors = ["Dilshod Tadjibaev (@antimora)"] categories = ["science", "no-std", "embedded", "wasm"] description = "Storage and serialization infrastructure for Burn" documentation = "https://docs.rs/burn-store" edition.workspace = true keywords = [ "deep-learning", "machine-learning", "tensor", "storage", "serialization", ] license.workspace = true name = "burn-store" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-store" version.workspace = true [lints] workspace = true [features] default = ["std", "pytorch", "safetensors", "burnpack", "memmap"] memmap = ["std", "dep:memmap2"] std = [ "dep:memmap2", "safetensors/std", "burn-core/std", "burn-tensor/std", "dep:regex", "byteorder/std", ] tracing = [ "burn-core/tracing", "burn-cuda?/tracing", "burn-nn/tracing", "burn-tch?/tracing", "burn-tensor/tracing", "burn-wgpu?/tracing", ] burnpack = ["serde", "ciborium"] cuda = ["burn-cuda"] metal = ["wgpu", "burn-wgpu/metal"] tch = ["burn-tch"] wgpu = ["burn-wgpu"] safetensors = ["dep:safetensors"] pytorch = ["burn-core/record-item-custom-serde", "zip", "serde", "tar"] [dependencies] burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false } burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2", default-features = false } # External dependencies byteorder = { workspace = true, default-features = false } bytes = { workspace = true } ciborium = { workspace = true, optional = true } half = { workspace = true } hashbrown = { workspace = true, features = ["serde"] } memmap2 = { workspace = true, optional = true } regex = { workspace = true, optional = true } serde = { workspace = true, optional = true } textdistance = { workspace = true } zip = { workspace = true, optional = true } tar = { workspace = true, optional = true } # Workaround to force broken minor version to update lzma-rust2 = { workspace = true, optional = true } safetensors = { workspace = true, optional = true } # Optional backend dependencies for benchmarks burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true } [dev-dependencies] # burn-import = { path = "../burn-import", version = "=0.21.0-pre.2" } # disabled (circular dep in publish, only for bench) burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", default-features = false } divan = "0.1" tempfile = { workspace = true } [[bench]] harness = false name = "resnet18_loading" [[bench]] harness = false name = "unified_loading" [[bench]] harness = false name = "unified_saving" [[bench]] harness = false name = "zero_copy_loading" # Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi) [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] bytes = { workspace = true, features = ["extra-platforms"] } ================================================ FILE: crates/burn-store/MIGRATION.md ================================================ # Migration Guide: burn-import to burn-store This guide helps you migrate from the deprecated `burn-import` recorders (`PyTorchFileRecorder`, `SafetensorsFileRecorder`) to the new `burn-store` API (`PytorchStore`, `SafetensorsStore`). ## Overview The new `burn-store` API provides: - **Simpler API**: Load directly into models instead of records - **Fluent builder pattern**: Chain configuration methods - **Better error handling**: Detailed load results with applied/missing/errors info - **Bidirectional support**: Both load and save operations - **More features**: Filtering, partial loading, metadata, zero-copy loading ## Quick Migration ### PyTorch Files (.pt/.pth) **Before (burn-import):** ```rust use burn::record::{FullPrecisionSettings, Recorder}; use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; // Load into a record, then create model from record let record: ModelRecord = PyTorchFileRecorder::::default() .load("model.pt".into(), &device) .expect("Failed to load"); let model = Model::init(&device).load_record(record); ``` **After (burn-store):** ```rust use burn_store::{ModuleSnapshot, PytorchStore}; // Initialize model, then load weights directly let mut model = Model::init(&device); let mut store = PytorchStore::from_file("model.pt"); model.load_from(&mut store).expect("Failed to load"); ``` ### SafeTensors Files (.safetensors) **Before (burn-import):** ```rust use burn::record::{FullPrecisionSettings, Recorder}; use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder}; let record: ModelRecord = SafetensorsFileRecorder::::default() .load("model.safetensors".into(), &device) .expect("Failed to load"); let model = Model::init(&device).load_record(record); ``` **After (burn-store):** ```rust use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; let mut model = Model::init(&device); // For SafeTensors exported from PyTorch, use the adapter let mut store = SafetensorsStore::from_file("model.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store).expect("Failed to load"); // For native Burn SafeTensors, no adapter needed let mut store = SafetensorsStore::from_file("model.safetensors"); model.load_from(&mut store).expect("Failed to load"); ``` ## API Mapping ### PyTorchFileRecorder Options | burn-import | burn-store | | ---------------------------------------------- | ------------------------------------------- | | `LoadArgs::new(path)` | `PytorchStore::from_file(path)` | | `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` | | `.with_top_level_key(key)` | `.with_top_level_key(key)` | | `.with_debug_print()` | _(use tracing/logging instead)_ | | `PyTorchFileRecorder::` | _(precision handled automatically)_ | ### SafetensorsFileRecorder Options | burn-import | burn-store | | -------------------------------------------------- | ------------------------------------------- | | `LoadArgs::new(path)` | `SafetensorsStore::from_file(path)` | | `.with_key_remap(pattern, replacement)` | `.with_key_remapping(pattern, replacement)` | | `.with_adapter_type(AdapterType::PyTorch)` | `.with_from_adapter(PyTorchToBurnAdapter)` | | `.with_adapter_type(AdapterType::NoAdapter)` | _(default, no adapter)_ | | `.with_debug_print()` | _(use tracing/logging instead)_ | | `SafetensorsFileRecorder::` | _(precision handled automatically)_ | ## Detailed Examples ### Key Remapping **Before:** ```rust let args = LoadArgs::new("model.pt".into()) .with_key_remap("conv\\.(.*)", "$1") .with_key_remap("^old_prefix\\.", "new_prefix."); let record: ModelRecord = PyTorchFileRecorder::::default() .load(args, &device)?; ``` **After:** ```rust let mut store = PytorchStore::from_file("model.pt") .with_key_remapping("conv\\.(.*)", "$1") .with_key_remapping("^old_prefix\\.", "new_prefix."); model.load_from(&mut store)?; ``` ### Top-Level Key Access **Before:** ```rust let args = LoadArgs::new("checkpoint.pt".into()) .with_top_level_key("state_dict"); let record: ModelRecord = PyTorchFileRecorder::::default() .load(args, &device)?; ``` **After:** ```rust let mut store = PytorchStore::from_file("checkpoint.pt") .with_top_level_key("state_dict"); model.load_from(&mut store)?; ``` ### PyTorch Adapter for SafeTensors **Before:** ```rust use burn_import::safetensors::{AdapterType, LoadArgs}; let args = LoadArgs::new("pytorch_model.safetensors".into()) .with_adapter_type(AdapterType::PyTorch); let record: ModelRecord = SafetensorsFileRecorder::::default() .load(args, &device)?; ``` **After:** ```rust use burn_store::{PyTorchToBurnAdapter, SafetensorsStore}; let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store)?; ``` ## New Features in burn-store ### Partial Loading Handle missing tensors gracefully: ```rust let mut store = PytorchStore::from_file("model.pt") .allow_partial(true); let result = model.load_from(&mut store)?; println!("Loaded: {:?}", result.applied); println!("Missing: {:?}", result.missing); ``` ### Filtering Load only specific tensors: ```rust let mut store = SafetensorsStore::from_file("model.safetensors") .with_regex(r"^encoder\..*") // Only encoder layers .allow_partial(true); model.load_from(&mut store)?; ``` ### Saving Models Save models (not supported by old recorders): ```rust // Save to SafeTensors let mut store = SafetensorsStore::from_file("output.safetensors") .metadata("version", "1.0"); model.save_into(&mut store)?; // Save to Burnpack (native format) let mut store = BurnpackStore::from_file("output.bpk"); model.save_into(&mut store)?; ``` ### Load Results Get detailed information about loading: ```rust let result = model.load_from(&mut store)?; // Print the full result for debugging - shows applied, skipped, missing, and errors println!("{}", result); // Or access individual fields println!("Applied: {} tensors", result.applied.len()); println!("Skipped: {} tensors", result.skipped.len()); println!("Missing: {:?}", result.missing); println!("Errors: {:?}", result.errors); // Check if fully successful if result.is_success() { println!("All tensors loaded successfully"); } ``` The `LoadResult` implements `Display`, so printing it shows a formatted summary with suggestions for common issues (e.g., using `allow_partial(true)` for missing tensors). ## Updating Cargo.toml **Before:** ```toml [dependencies] burn-import = { version = "0.x", features = ["pytorch", "safetensors"] } ``` **After:** ```toml [dependencies] burn-store = { version = "0.x", features = ["pytorch", "safetensors"] } ``` ## Common Migration Issues ### 1. Model vs Record The new API loads directly into models, not records. Update your model initialization: ```rust // Before: Create record, then model from record let record = recorder.load(...)?; let model = Model::init(&device).load_record(record); // After: Create model, then load into it let mut model = Model::init(&device); model.load_from(&mut store)?; ``` ### 2. Inference Functions If you had functions that took `ModelRecord`, update them to take `Model`: ```rust // Before fn infer(record: ModelRecord) { let model = Model::init(&device).load_record(record); // ... } // After fn infer(model: Model) { // Model already has weights loaded // ... } ``` ### 3. Precision Settings The old API required explicit precision settings. The new API handles this automatically: ```rust // Before: Had to specify FullPrecisionSettings or HalfPrecisionSettings PyTorchFileRecorder::::default() // After: Precision handled automatically based on tensor dtype PytorchStore::from_file("model.pt") ``` ### 4. Error Handling The new API provides richer error information: ```rust // Before: Simple Result let record = recorder.load(args, &device)?; // After: LoadResult with detailed info let result = model.load_from(&mut store)?; // Print the result to see a helpful summary with suggestions println!("{}", result); // Or handle specific issues programmatically if !result.errors.is_empty() { for (path, error) in &result.errors { eprintln!("Error loading {}: {}", path, error); } } ``` ## See Also - [burn-store README](README.md) - Full documentation - [import-model-weights example](../../examples/import-model-weights/) - Working example ================================================ FILE: crates/burn-store/README.md ================================================ # Burn Store > Advanced model storage and serialization for the Burn deep learning framework [![Current Crates.io Version](https://img.shields.io/crates/v/burn-store.svg)](https://crates.io/crates/burn-store) [![Documentation](https://docs.rs/burn-store/badge.svg)](https://docs.rs/burn-store) A comprehensive storage library for Burn that enables efficient model serialization, cross-framework interoperability, and advanced tensor management. > **Migrating from burn-import?** See the [Migration Guide](MIGRATION.md) for help moving from > `PyTorchFileRecorder`/`SafetensorsFileRecorder` to the new Store API. ## Features - **Burnpack Format** - Native Burn format with CBOR metadata, memory-mapped loading, ParamId persistence for stateful training, and no-std support - **SafeTensors Format** - Industry-standard format for secure and efficient tensor serialization - **PyTorch Support** - Direct loading of PyTorch .pth/.pt files with automatic weight transformation - **Zero-Copy Loading** - Memory-mapped files and lazy tensor materialization for optimal performance - **Flexible Filtering** - Load/save specific model subsets with regex, exact paths, or custom predicates - **Tensor Remapping** - Rename tensors during load/save for framework compatibility - **Half-Precision Storage** - Automatic F32/F16 conversion with smart defaults for reduced model file size - **No-std Support** - Burnpack and SafeTensors formats available in embedded and WASM environments ## Quick Start ```rust use burn_store::{ModuleSnapshot, PytorchStore, SafetensorsStore, BurnpackStore, HalfPrecisionAdapter}; // Load from PyTorch let mut store = PytorchStore::from_file("model.pt"); model.load_from(&mut store)?; // Load from SafeTensors (with PyTorch adapter) let mut store = SafetensorsStore::from_file("model.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store)?; // Save to Burnpack let mut store = BurnpackStore::from_file("model.bpk"); model.save_into(&mut store)?; // Save with half-precision (F32 -> F16, ~50% smaller files) let adapter = HalfPrecisionAdapter::new(); let mut store = BurnpackStore::from_file("model_f16.bpk") .with_to_adapter(adapter.clone()); model.save_into(&mut store)?; // Load half-precision back (F16 -> F32, same adapter) let mut store = BurnpackStore::from_file("model_f16.bpk") .with_from_adapter(adapter); model.load_from(&mut store)?; ``` ## Documentation For comprehensive documentation including: - Exporting weights from PyTorch - Loading weights into Burn models - Saving models to various formats - Advanced features (filtering, remapping, partial loading, zero-copy) - API reference and troubleshooting See the **[Burn Book - Saving and Loading](../../burn-book/src/saving-and-loading.md)** chapter. ## Running Benchmarks ```bash # Generate model files (one-time setup) uv run benches/generate_unified_models.py # Run loading benchmarks cargo bench --bench unified_loading # Run saving benchmarks cargo bench --bench unified_saving # With specific backend cargo bench --bench unified_loading --features metal ``` ## License This project is dual-licensed under MIT and Apache-2.0. ================================================ FILE: crates/burn-store/benches/download_resnet18.py ================================================ #!/usr/bin/env python3 # /// script # requires-python = ">=3.8" # dependencies = [ # "torch", # "torchvision", # ] # /// """ Download ResNet18 PyTorch model for benchmarking. This script downloads a pre-trained ResNet18 model from PyTorch Hub and saves it in a format suitable for benchmarking. """ import os import sys import tempfile from pathlib import Path import torch import torchvision.models as models def download_resnet18(): """Download ResNet18 model and save to temp directory.""" # Create a temporary directory for the model temp_dir = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark" temp_dir.mkdir(parents=True, exist_ok=True) output_path = temp_dir / "resnet18.pth" # Check if already downloaded if output_path.exists(): file_size_mb = output_path.stat().st_size / (1024 * 1024) print(f"✅ ResNet18 already exists at: {output_path}") print(f" Size: {file_size_mb:.1f} MB") return str(output_path) print("📥 Downloading ResNet18 model...") try: # Download pre-trained ResNet18 model model = models.resnet18(pretrained=True) # Save the model state dict (this is what burn-store reads) # Using the legacy format for compatibility torch.save(model.state_dict(), output_path, _use_new_zipfile_serialization=False) file_size_mb = output_path.stat().st_size / (1024 * 1024) print(f"✅ Successfully downloaded ResNet18 to: {output_path}") print(f" Size: {file_size_mb:.1f} MB") print(f" Format: PyTorch legacy format") # Verify it's readable state_dict = torch.load(output_path, map_location='cpu') print(f" Tensors: {len(state_dict)} tensors") # Print a few tensor names and shapes for verification print("\n Sample tensors:") for i, (name, tensor) in enumerate(state_dict.items()): if i < 3: print(f" - {name}: {list(tensor.shape)}") return str(output_path) except Exception as e: print(f"❌ Failed to download ResNet18: {e}") sys.exit(1) def main(): """Main entry point.""" path = download_resnet18() # Write the path to a file that the benchmark can read bench_config = Path(tempfile.gettempdir()) / "burn_resnet18_benchmark" / "path.txt" bench_config.write_text(path) print(f"\n💡 Model ready for benchmarking") print(f" Run: cargo bench --bench resnet18_loading") if __name__ == "__main__": main() ================================================ FILE: crates/burn-store/benches/generate_unified_models.py ================================================ #!/usr/bin/env python3 # /// script # requires-python = ">=3.8" # dependencies = [ # "torch", # "safetensors", # "packaging", # "numpy", # ] # /// """ Generate a large model (~312MB) in both PyTorch and SafeTensors formats for unified benchmarking. Usage: uv run benches/generate_unified_models.py The script will create model files in /tmp/simple_bench_models/ directory. """ import torch import torch.nn as nn import os from pathlib import Path import tempfile from safetensors.torch import save_file def get_temp_dir(): """Get the appropriate temp directory.""" temp_dir = Path(tempfile.gettempdir()) / "simple_bench_models" temp_dir.mkdir(parents=True, exist_ok=True) return temp_dir class LargeModel(nn.Module): """Large model with 20 layers to match Rust benchmark.""" def __init__(self): super().__init__() self.layers = nn.ModuleList() # Create a model with 20 layers matching the Rust LargeModel for i in range(20): in_size = 1024 if i == 0 else 2048 out_size = 2048 self.layers.append(nn.Linear(in_size, out_size)) print(f"Created model with {len(self.layers)} layers") def forward(self, x): for layer in self.layers: x = layer(x) return x def calculate_model_size(model): """Calculate the size of the model in MB.""" total_params = sum(p.numel() for p in model.parameters()) size_mb = (total_params * 4) / (1024 * 1024) # 4 bytes per float32 return total_params, size_mb def initialize_weights(model): """Initialize model weights with random values.""" for param in model.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) def save_pytorch_format(model, output_dir): """Save model in PyTorch format.""" pt_path = output_dir / "large_model.pt" # Save as checkpoint with model_state_dict (common format) checkpoint = { 'model_state_dict': model.state_dict(), 'metadata': { 'model_type': 'large_benchmark_model', 'num_layers': len(model.layers), } } torch.save(checkpoint, pt_path) return pt_path def save_safetensors_format(model, output_dir): """Save model in SafeTensors format.""" st_path = output_dir / "large_model.safetensors" # Convert state dict to safetensors format state_dict = model.state_dict() # Ensure all tensors are contiguous and on CPU state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()} # Save with metadata metadata = { 'model_type': 'large_benchmark_model', 'num_layers': str(len(model.layers)), } save_file(state_dict, st_path, metadata=metadata) return st_path def verify_files(pt_path, st_path): """Verify the saved files can be loaded.""" # Verify PyTorch file checkpoint = torch.load(pt_path, map_location='cpu') pt_keys = set(checkpoint['model_state_dict'].keys()) print(f" PyTorch file: {len(pt_keys)} tensors") # Verify SafeTensors file from safetensors import safe_open with safe_open(st_path, framework="pt", device="cpu") as f: st_keys = set(f.keys()) print(f" SafeTensors file: {len(st_keys)} tensors") # Check keys match if pt_keys != st_keys: print(" ⚠️ Warning: Keys don't match between formats!") else: print(" ✓ Keys match between formats") def main(): print("🔧 Generating unified benchmark model files...") print("") output_dir = get_temp_dir() print(f"📁 Output directory: {output_dir}") print("") # Set random seed for reproducibility torch.manual_seed(42) # Create the large model print("📝 Creating large model...") model = LargeModel() # Calculate and display model size total_params, size_mb = calculate_model_size(model) print(f" Total parameters: {total_params:,}") print(f" Model size: {size_mb:.2f} MB") print("") # Initialize weights print("🎲 Initializing weights...") initialize_weights(model) # Save in PyTorch format print("💾 Saving PyTorch format...") pt_path = save_pytorch_format(model, output_dir) pt_size_mb = pt_path.stat().st_size / (1024 * 1024) print(f" Saved: {pt_path}") print(f" File size: {pt_size_mb:.2f} MB") print("") # Save in SafeTensors format print("💾 Saving SafeTensors format...") st_path = save_safetensors_format(model, output_dir) st_size_mb = st_path.stat().st_size / (1024 * 1024) print(f" Saved: {st_path}") print(f" File size: {st_size_mb:.2f} MB") print("") # Verify files print("🔍 Verifying saved files...") verify_files(pt_path, st_path) print("") print(f"✅ Model files generated successfully!") print("") print("📊 Summary:") print(f" PyTorch file: {pt_path.name} ({pt_size_mb:.2f} MB)") print(f" SafeTensors file: {st_path.name} ({st_size_mb:.2f} MB)") print("") print("💡 To run the unified benchmark:") print(" cargo bench --bench unified_loading") if __name__ == "__main__": main() ================================================ FILE: crates/burn-store/benches/resnet18_loading.rs ================================================ //! Benchmark for ResNet18 loading to verify lazy loading memory usage. //! //! resnet18.pth is pytorch's legacy file format. //! //! This benchmark loads a ResNet18 model and materializes all tensors //! to ensure memory usage stays reasonable with lazy loading. //! //! Run the benchmark: //! ```bash //! cargo bench --bench resnet18_loading //! ``` use burn_store::pytorch::PytorchReader; use divan::{AllocProfiler, Bencher}; use std::path::PathBuf; #[global_allocator] static ALLOC: AllocProfiler = AllocProfiler::system(); #[allow(clippy::manual_range_contains)] fn main() { // Check if ResNet18 file exists let path = resnet18_path(); if !path.exists() { eprintln!("❌ ResNet18 model not found!"); eprintln!(); eprintln!("Please download it first by running:"); eprintln!(" python benches/download_resnet18.py"); eprintln!(); eprintln!("Or if you don't have Python/PyTorch installed:"); eprintln!(" uv run benches/download_resnet18.py"); eprintln!(); eprintln!("Expected location: {}", path.display()); std::process::exit(1); } // Verify file size is reasonable let metadata = std::fs::metadata(&path).expect("Failed to read file metadata"); let size_mb = metadata.len() as f64 / 1_048_576.0; if size_mb < 40.0 || size_mb > 50.0 { eprintln!( "⚠️ Warning: ResNet18 file size ({:.1} MB) seems unusual", size_mb ); eprintln!("Expected size is around 45 MB"); } println!("✅ Found ResNet18 model at: {}", path.display()); println!("📦 File size: {:.1} MB", size_mb); println!("📊 Running ResNet18 loading benchmarks...\n"); // Run divan benchmarks divan::main(); } /// Get the path to ResNet18 model file fn resnet18_path() -> PathBuf { // First try to read from the path file created by download script let temp_dir = std::env::temp_dir(); let config_file = temp_dir.join("burn_resnet18_benchmark").join("path.txt"); if config_file.exists() && let Ok(path_str) = std::fs::read_to_string(&config_file) { let path = PathBuf::from(path_str.trim()); if path.exists() { return path; } } // Fallback to default location temp_dir .join("burn_resnet18_benchmark") .join("resnet18.pth") } #[divan::bench(sample_count = 10)] fn load_resnet18_metadata(bencher: Bencher) { let path = resnet18_path(); bencher.bench_local(|| { let reader = PytorchReader::new(&path).expect("Failed to load ResNet18"); let metadata = reader.metadata(); // Just access metadata without materializing tensors assert_eq!(metadata.tensor_count, 122); }); } #[divan::bench(sample_count = 5)] fn load_resnet18_materialize_all(bencher: Bencher) { let path = resnet18_path(); bencher.bench_local(|| { let reader = PytorchReader::new(&path).expect("Failed to load ResNet18"); let keys = reader.keys(); let mut total_bytes = 0usize; // Materialize all tensors one by one for key in &keys { let tensor = reader.get(key).expect("Failed to get tensor"); // Materialize the tensor data let _data = tensor.to_data().expect("Failed to materialize tensor data"); total_bytes += tensor.data_len(); } // Verify we processed all the data assert!(total_bytes > 40_000_000); // Should be ~45MB }); } #[divan::bench(sample_count = 5)] fn load_resnet18_materialize_sequential(bencher: Bencher) { let path = resnet18_path(); bencher.bench_local(|| { let reader = PytorchReader::new(&path).expect("Failed to load ResNet18"); let keys = reader.keys(); // Materialize tensors one at a time, letting previous ones be dropped // This simulates processing tensors sequentially without keeping all in memory for key in &keys { let tensor = reader.get(key).expect("Failed to get tensor"); let data = tensor.to_data().expect("Failed to materialize tensor data"); // Do minimal work with the data to prevent optimization let sum = match data.dtype { burn_tensor::DType::F32 => data .as_slice::() .map(|s| s.iter().sum::()) .unwrap_or(0.0) as f64, burn_tensor::DType::F64 => data .as_slice::() .map(|s| s.iter().sum::()) .unwrap_or(0.0), _ => 0.0, }; // Use the sum to prevent dead code elimination std::hint::black_box(sum); } }); } #[divan::bench(sample_count = 10)] fn load_resnet18_largest_tensor(bencher: Bencher) { let path = resnet18_path(); bencher.bench_local(|| { let reader = PytorchReader::new(&path).expect("Failed to load ResNet18"); // Find and materialize only the largest tensor // This tests peak memory for a single tensor operation let keys = reader.keys(); let mut largest_key = String::new(); let mut largest_size = 0usize; for key in &keys { let tensor = reader.get(key).expect("Failed to get tensor"); let size = tensor.data_len(); if size > largest_size { largest_size = size; largest_key = key.clone(); } } // Materialize the largest tensor let tensor = reader .get(&largest_key) .expect("Failed to get largest tensor"); let _data = tensor.to_data().expect("Failed to materialize tensor data"); assert!(largest_size > 9_000_000); // Should be ~9MB for layer4.0.conv2.weight }); } #[divan::bench(sample_count = 10)] fn load_resnet18_memory_profile(bencher: Bencher) { let path = resnet18_path(); bencher .with_inputs(|| path.clone()) .bench_local_values(|path| { let reader = PytorchReader::new(&path).expect("Failed to load ResNet18"); let keys = reader.keys(); let mut peak_single_tensor = 0usize; let mut total_data = 0usize; // Process each tensor and track memory for key in &keys { let tensor = reader.get(key).expect("Failed to get tensor"); let tensor_size = tensor.data_len(); // Track largest single tensor if tensor_size > peak_single_tensor { peak_single_tensor = tensor_size; } // Materialize the tensor let data = tensor.to_data().expect("Failed to materialize tensor data"); total_data += tensor_size; // Drop data immediately to test lazy loading memory efficiency drop(data); } // Return stats for verification (peak_single_tensor, total_data) }); } ================================================ FILE: crates/burn-store/benches/unified_loading.rs ================================================ #![recursion_limit = "256"] //! Unified benchmark comparing all loading methods: //! - BurnpackStore (new native format) //! - NamedMpkFileRecorder (old native format) //! - SafetensorsStore (new) //! - SafetensorsFileRecorder (old) //! - PytorchStore (new) //! - PyTorchFileRecorder (old) //! //! Before running this benchmark, generate the model files: //! ```bash //! cd crates/burn-store //! uv run benches/generate_unified_models.py //! ``` //! //! Then run the benchmark: //! ```bash //! cargo bench --bench unified_loading //! ``` use burn_core as burn; use burn_core::module::Module; use burn_core::prelude::*; use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}; // use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; // use burn_import::safetensors::SafetensorsFileRecorder; use burn_nn as nn; use burn_store::{ BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore, }; use divan::{AllocProfiler, Bencher}; use std::fs; use std::path::{Path, PathBuf}; #[global_allocator] static ALLOC: AllocProfiler = AllocProfiler::system(); // Backend type aliases type NdArrayBackend = burn_ndarray::NdArray; #[cfg(feature = "wgpu")] type WgpuBackend = burn_wgpu::Wgpu; #[cfg(feature = "cuda")] type CudaBackend = burn_cuda::Cuda; #[cfg(feature = "tch")] type TchBackend = burn_tch::LibTorch; #[cfg(feature = "metal")] type MetalBackend = burn_wgpu::Metal; // Use the same LargeModel as other benchmarks for fair comparison #[derive(Module, Debug)] struct LargeModel { layers: Vec>, } impl LargeModel { fn new(device: &B::Device) -> Self { let mut layers = Vec::new(); // Create a model with 20 layers - same as safetensor_loading benchmark for i in 0..20 { let in_size = if i == 0 { 1024 } else { 2048 }; layers.push(nn::LinearConfig::new(in_size, 2048).init(device)); } Self { layers } } } /// Get the path to the model files fn get_model_dir() -> PathBuf { std::env::temp_dir().join("simple_bench_models") } /// Generate Burnpack and NamedMpk files from existing SafeTensors file fn generate_burn_formats(st_path: &Path, bp_path: &Path, mpk_path: &Path) { type TestBackend = NdArrayBackend; let device = Default::default(); // Load the model from SafeTensors let mut model = LargeModel::::new(&device); let mut store = SafetensorsStore::from_file(st_path).with_from_adapter(PyTorchToBurnAdapter); model .load_from(&mut store) .expect("Failed to load from SafeTensors"); // Save as Burnpack if !bp_path.exists() { println!(" Creating Burnpack file..."); let mut burnpack_store = BurnpackStore::from_file(bp_path); model .save_into(&mut burnpack_store) .expect("Failed to save as Burnpack"); } // Save as NamedMpk if !mpk_path.exists() { println!(" Creating NamedMpk file..."); let recorder = NamedMpkFileRecorder::::default(); model .save_file(mpk_path, &recorder) .expect("Failed to save as NamedMpk"); } } /// Get paths to the model files fn get_model_paths() -> (PathBuf, PathBuf, PathBuf, PathBuf) { let dir = get_model_dir(); ( dir.join("large_model.bpk"), dir.join("large_model.mpk"), dir.join("large_model.safetensors"), dir.join("large_model.pt"), ) } /// Check if model files exist fn check_model_files() -> Result<(), String> { let (_, _, st_path, pt_path) = get_model_paths(); // For now, only check safetensors and pytorch files (will generate burnpack/mpk later) if !st_path.exists() || !pt_path.exists() { return Err(format!( "\n❌ Model files not found!\n\ \n\ Please generate the model files first by running:\n\ \n\ cd crates/burn-store\n\ uv run benches/generate_unified_models.py\n\ \n\ Expected files:\n\ - {}\n\ - {}\n", st_path.display(), pt_path.display() )); } Ok(()) } fn main() { // Check if model files exist before running benchmarks match check_model_files() { Ok(()) => { let (bp_path, mpk_path, st_path, pt_path) = get_model_paths(); // First, generate Burnpack and MPK files if they don't exist if !bp_path.exists() || !mpk_path.exists() { println!("⏳ Generating Burnpack and NamedMpk files from SafeTensors..."); generate_burn_formats(&st_path, &bp_path, &mpk_path); } let bp_size = fs::metadata(&bp_path) .ok() .map(|m| m.len() as f64 / 1_048_576.0); let mpk_size = fs::metadata(&mpk_path) .ok() .map(|m| m.len() as f64 / 1_048_576.0); let st_size = fs::metadata(&st_path).unwrap().len() as f64 / 1_048_576.0; let pt_size = fs::metadata(&pt_path).unwrap().len() as f64 / 1_048_576.0; println!("✅ Found model files:"); if let Some(size) = bp_size { println!(" Burnpack: {} ({:.1} MB)", bp_path.display(), size); } if let Some(size) = mpk_size { println!(" NamedMpk: {} ({:.1} MB)", mpk_path.display(), size); } println!(" SafeTensors: {} ({:.1} MB)", st_path.display(), st_size); println!(" PyTorch: {} ({:.1} MB)", pt_path.display(), pt_size); println!(); println!("🚀 Running unified loading benchmarks..."); println!(); println!("Comparing 6 loading methods:"); println!(" 1. BurnpackStore (new native format - lazy loading)"); println!(" 2. NamedMpkFileRecorder (old native format - loads all to memory)"); println!(" 3. SafetensorsStore (new)"); println!(" 4. SafetensorsFileRecorder (old)"); println!(" 5. PytorchStore (new)"); println!(" 6. PyTorchFileRecorder (old)"); println!(); println!("Available backends:"); println!(" - NdArray (CPU)"); #[cfg(feature = "wgpu")] println!(" - WGPU (GPU)"); #[cfg(feature = "cuda")] println!(" - CUDA (NVIDIA GPU)"); #[cfg(feature = "tch")] println!(" - LibTorch"); #[cfg(feature = "metal")] println!(" - Metal (Apple GPU)"); println!(); divan::main(); } Err(msg) => { eprintln!("{}", msg); std::process::exit(1); } } } // Macro to generate benchmarks for each backend macro_rules! bench_backend { ($backend:ty, $mod_name:ident, $backend_name:literal) => { #[divan::bench_group(name = $backend_name, sample_count = 10)] mod $mod_name { use super::*; type TestBackend = $backend; type TestDevice = ::Device; #[divan::bench] fn burnpack_store(bencher: Bencher) { let (bp_path, _, _, _) = get_model_paths(); let file_size = fs::metadata(&bp_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_file(bp_path.clone()); model.load_from(&mut store).expect("Failed to load"); }); } #[divan::bench] fn namedmpk_recorder(bencher: Bencher) { let (_, mpk_path, _, _) = get_model_paths(); let file_size = fs::metadata(&mpk_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let recorder = NamedMpkFileRecorder::::default(); let record = recorder .load(mpk_path.clone().into(), &device) .expect("Failed to load"); let _model = LargeModel::::new(&device).load_record(record); }); } #[divan::bench] fn safetensors_store(bencher: Bencher) { let (_, _, st_path, _) = get_model_paths(); let file_size = fs::metadata(&st_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); let mut store = SafetensorsStore::from_file(st_path.clone()) .with_from_adapter(PyTorchToBurnAdapter); model.load_from(&mut store).expect("Failed to load"); }); } // #[divan::bench] // fn safetensors_recorder(bencher: Bencher) { // let (_, _, st_path, _) = get_model_paths(); // let file_size = fs::metadata(&st_path).unwrap().len(); // bencher // .counter(divan::counter::BytesCount::new(file_size)) // .bench(|| { // let device: TestDevice = Default::default(); // let recorder = SafetensorsFileRecorder::::default(); // let record = recorder // .load(st_path.clone().into(), &device) // .expect("Failed to load"); // let _model = LargeModel::::new(&device).load_record(record); // }); // } #[divan::bench] fn pytorch_store(bencher: Bencher) { let (_, _, _, pt_path) = get_model_paths(); let file_size = fs::metadata(&pt_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); let mut store = PytorchStore::from_file(pt_path.clone()) .with_top_level_key("model_state_dict") .allow_partial(true); model.load_from(&mut store).expect("Failed to load"); }); } // #[divan::bench] // fn pytorch_recorder(bencher: Bencher) { // let (_, _, _, pt_path) = get_model_paths(); // let file_size = fs::metadata(&pt_path).unwrap().len(); // bencher // .counter(divan::counter::BytesCount::new(file_size)) // .bench(|| { // let device: TestDevice = Default::default(); // let recorder = PyTorchFileRecorder::::default(); // let load_args = // LoadArgs::new(pt_path.clone()).with_top_level_key("model_state_dict"); // let record = recorder.load(load_args, &device).expect("Failed to load"); // let _model = LargeModel::::new(&device).load_record(record); // }); // } } }; } // Generate benchmarks for each backend bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)"); #[cfg(feature = "wgpu")] bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)"); #[cfg(feature = "cuda")] bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)"); #[cfg(feature = "tch")] bench_backend!(TchBackend, tch_backend, "LibTorch Backend"); #[cfg(feature = "metal")] bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)"); ================================================ FILE: crates/burn-store/benches/unified_saving.rs ================================================ #![recursion_limit = "256"] //! Unified benchmark comparing all saving methods: //! - BurnpackStore (new native format) //! - NamedMpkFileRecorder (old native format) //! - SafetensorsStore (new) //! //! Before running this benchmark, ensure the directory exists: //! ```bash //! mkdir -p /tmp/simple_bench_models //! ``` //! //! Then run the benchmark: //! ```bash //! cargo bench --bench unified_saving //! ``` use burn_core as burn; use burn_core::module::Module; use burn_core::prelude::*; use burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder}; use burn_nn as nn; use burn_store::{BurnpackStore, ModuleSnapshot, SafetensorsStore}; use divan::{AllocProfiler, Bencher}; use std::fs; use std::path::PathBuf; #[global_allocator] static ALLOC: AllocProfiler = AllocProfiler::system(); // Backend type aliases type NdArrayBackend = burn_ndarray::NdArray; #[cfg(feature = "wgpu")] type WgpuBackend = burn_wgpu::Wgpu; #[cfg(feature = "cuda")] type CudaBackend = burn_cuda::Cuda; #[cfg(feature = "tch")] type TchBackend = burn_tch::LibTorch; #[cfg(feature = "metal")] type MetalBackend = burn_wgpu::Metal; // Use the same LargeModel as other benchmarks for fair comparison #[derive(Module, Debug)] struct LargeModel { layers: Vec>, } impl LargeModel { fn new(device: &B::Device) -> Self { let mut layers = Vec::new(); // Create a model with 20 layers - same as loading benchmarks for i in 0..20 { let in_size = if i == 0 { 1024 } else { 2048 }; layers.push(nn::LinearConfig::new(in_size, 2048).init(device)); } Self { layers } } } /// Get the path to the output directory fn get_output_dir() -> PathBuf { std::env::temp_dir().join("simple_bench_models_saving") } /// Ensure output directory exists fn ensure_output_dir() -> Result<(), String> { let dir = get_output_dir(); if !dir.exists() { fs::create_dir_all(&dir) .map_err(|e| format!("Failed to create output directory: {}", e))?; } Ok(()) } fn main() { match ensure_output_dir() { Ok(()) => { println!("✅ Output directory ready: {}", get_output_dir().display()); println!(); println!("🚀 Running unified saving benchmarks..."); println!(); println!("Comparing 3 saving methods:"); println!(" 1. BurnpackStore (new native format)"); println!(" 2. NamedMpkFileRecorder (old native format)"); println!(" 3. SafetensorsStore (new)"); println!(); println!("Available backends:"); println!(" - NdArray (CPU)"); #[cfg(feature = "wgpu")] println!(" - WGPU (GPU)"); #[cfg(feature = "cuda")] println!(" - CUDA (NVIDIA GPU)"); #[cfg(feature = "tch")] println!(" - LibTorch"); #[cfg(feature = "metal")] println!(" - Metal (Apple GPU)"); println!(); divan::main(); } Err(msg) => { eprintln!("❌ {}", msg); std::process::exit(1); } } } // Macro to generate benchmarks for each backend macro_rules! bench_backend { ($backend:ty, $mod_name:ident, $backend_name:literal) => { #[divan::bench_group(name = $backend_name, sample_count = 10)] mod $mod_name { use super::*; type TestBackend = $backend; type TestDevice = ::Device; #[divan::bench] fn burnpack_store(bencher: Bencher) { bencher.bench(|| { let device: TestDevice = Default::default(); let model = LargeModel::::new(&device); let output_path = get_output_dir().join("test_burnpack.bpk"); let mut store = BurnpackStore::from_file(output_path.clone()).overwrite(true); model .save_into(&mut store) .expect("Failed to save with BurnpackStore"); // Clean up let _ = fs::remove_file(output_path); }); } #[divan::bench] fn namedmpk_recorder(bencher: Bencher) { bencher.bench(|| { let device: TestDevice = Default::default(); let model = LargeModel::::new(&device); let output_path = get_output_dir().join("test_namedmpk.mpk"); let recorder = NamedMpkFileRecorder::::default(); model .save_file(output_path.clone(), &recorder) .expect("Failed to save with NamedMpkFileRecorder"); // Clean up let _ = fs::remove_file(output_path); }); } #[divan::bench] fn safetensors_store(bencher: Bencher) { bencher.bench(|| { let device: TestDevice = Default::default(); let model = LargeModel::::new(&device); let output_path = get_output_dir().join("test_safetensors_store.safetensors"); let mut store = SafetensorsStore::from_file(output_path.clone()); model .save_into(&mut store) .expect("Failed to save with SafetensorsStore"); // Clean up let _ = fs::remove_file(output_path); }); } } }; } // Generate benchmarks for each backend bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)"); #[cfg(feature = "wgpu")] bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)"); #[cfg(feature = "cuda")] bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)"); #[cfg(feature = "tch")] bench_backend!(TchBackend, tch_backend, "LibTorch Backend"); #[cfg(feature = "metal")] bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)"); ================================================ FILE: crates/burn-store/benches/zero_copy_loading.rs ================================================ #![recursion_limit = "256"] //! Benchmark comparing zero-copy vs copy loading modes for BurnpackStore. //! //! This benchmark measures the performance difference between: //! - `zero_copy(false)` - Default mode, copies tensor data into new allocations //! - `zero_copy(true)` - Zero-copy mode, slices tensor data without copying //! //! ## Understanding the Results //! //! **IMPORTANT**: For NdArray backend, you'll see similar allocation numbers because: //! - NdArray uses `ndarray::ArrayD` which MUST own data as `Vec` //! - Even with zero-copy, the backend eventually copies data into its own format //! //! The zero-copy benefit is: //! - **Without zero-copy**: File → Copy to heap (Bytes) → Copy to Vec (backend) //! - **With zero-copy**: File → Zero-copy slice → Copy to Vec (backend) //! //! So zero-copy saves ONE memory copy at the store level. The `store_only_*` benchmarks //! show the raw store performance without backend allocation overhead. //! //! GPU backends that can consume `Bytes` directly will show larger benefits. //! //! ## Running the benchmark //! //! Before running this benchmark, generate the model files: //! ```bash //! cd crates/burn-store //! uv run benches/generate_unified_models.py //! ``` //! //! Then run the benchmark: //! ```bash //! cargo bench --bench zero_copy_loading //! ``` use burn_core as burn; use burn_core::module::Module; use burn_core::prelude::*; use burn_nn as nn; use burn_store::{ BurnpackStore, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore, }; use burn_tensor::{AllocationProperty, Bytes}; use divan::{AllocProfiler, Bencher}; use std::fs; use std::path::PathBuf; use std::sync::OnceLock; #[global_allocator] static ALLOC: AllocProfiler = AllocProfiler::system(); // Static storage for embedded model bytes (simulating include_bytes!) static STATIC_MODEL_BYTES: OnceLock<&'static [u8]> = OnceLock::new(); // Backend type aliases type NdArrayBackend = burn_ndarray::NdArray; #[cfg(feature = "wgpu")] type WgpuBackend = burn_wgpu::Wgpu; #[cfg(feature = "cuda")] type CudaBackend = burn_cuda::Cuda; #[cfg(feature = "tch")] type TchBackend = burn_tch::LibTorch; #[cfg(feature = "metal")] type MetalBackend = burn_wgpu::Metal; // Use the same LargeModel as other benchmarks for fair comparison #[derive(Module, Debug)] struct LargeModel { layers: Vec>, } impl LargeModel { fn new(device: &B::Device) -> Self { let mut layers = Vec::new(); // Create a model with 20 layers - same as unified_loading benchmark for i in 0..20 { let in_size = if i == 0 { 1024 } else { 2048 }; layers.push(nn::LinearConfig::new(in_size, 2048).init(device)); } Self { layers } } } /// Get the path to the model files fn get_model_dir() -> PathBuf { std::env::temp_dir().join("simple_bench_models") } /// Get path to Burnpack model file fn get_burnpack_path() -> PathBuf { get_model_dir().join("large_model.bpk") } /// Generate Burnpack file from existing SafeTensors file if needed fn ensure_burnpack_file() { let bp_path = get_burnpack_path(); let st_path = get_model_dir().join("large_model.safetensors"); if bp_path.exists() { return; } if !st_path.exists() { panic!( "\n❌ SafeTensors model file not found!\n\ \n\ Please generate the model files first by running:\n\ \n\ cd crates/burn-store\n\ uv run benches/generate_unified_models.py\n\ \n\ Expected file: {}\n", st_path.display() ); } println!("⏳ Generating Burnpack file from SafeTensors..."); type TestBackend = NdArrayBackend; let device = Default::default(); // Load from SafeTensors let mut model = LargeModel::::new(&device); let mut store = SafetensorsStore::from_file(&st_path).with_from_adapter(PyTorchToBurnAdapter); model .load_from(&mut store) .expect("Failed to load from SafeTensors"); // Save as Burnpack let mut burnpack_store = BurnpackStore::from_file(&bp_path); model .save_into(&mut burnpack_store) .expect("Failed to save as Burnpack"); println!("✅ Created Burnpack file: {}", bp_path.display()); } /// Initialize static model bytes (simulating include_bytes! at runtime for benchmarks) fn get_static_model_bytes() -> &'static [u8] { STATIC_MODEL_BYTES.get_or_init(|| { let bp_path = get_burnpack_path(); let bytes = fs::read(&bp_path).expect("Failed to read Burnpack file"); // Leak the bytes to get a 'static lifetime (acceptable for benchmarks) Box::leak(bytes.into_boxed_slice()) }) } fn main() { // Ensure Burnpack file exists ensure_burnpack_file(); let bp_path = get_burnpack_path(); let file_size = fs::metadata(&bp_path).unwrap().len() as f64 / 1_048_576.0; println!("✅ Found Burnpack model file:"); println!(" Path: {}", bp_path.display()); println!(" Size: {:.1} MB", file_size); println!(); println!("🚀 Running zero-copy loading benchmarks..."); println!(); println!("Comparing loading modes:"); println!(" 1. file_copy - from_file().zero_copy(false) - copies tensor data"); println!(" 2. file_zero_copy - from_file().zero_copy(true) - zero-copy via mmap"); println!(" 3. static_copy - from_bytes() with Vec copy - copies from static"); println!(" 4. static_zero_copy - from_static() - zero-copy from static"); println!(); println!("Available backends:"); println!(" - NdArray (CPU)"); #[cfg(feature = "wgpu")] println!(" - WGPU (GPU)"); #[cfg(feature = "cuda")] println!(" - CUDA (NVIDIA GPU)"); #[cfg(feature = "tch")] println!(" - LibTorch"); #[cfg(feature = "metal")] println!(" - Metal (Apple GPU)"); println!(); // Pre-initialize static bytes before benchmarks let _ = get_static_model_bytes(); divan::main(); } // Macro to generate benchmarks for each backend macro_rules! bench_backend { ($backend:ty, $mod_name:ident, $backend_name:literal) => { #[divan::bench_group(name = $backend_name, sample_count = 10)] mod $mod_name { use super::*; type TestBackend = $backend; type TestDevice = ::Device; /// File-based loading with copy mode (default) #[divan::bench] fn file_copy(bencher: Bencher) { let bp_path = get_burnpack_path(); let file_size = fs::metadata(&bp_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false); model.load_from(&mut store).expect("Failed to load"); }); } /// File-based loading with zero-copy mode (mmap + bytes::Bytes) #[divan::bench] fn file_zero_copy(bencher: Bencher) { let bp_path = get_burnpack_path(); let file_size = fs::metadata(&bp_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true); model.load_from(&mut store).expect("Failed to load"); }); } /// Static bytes with copy mode (simulating old behavior) #[divan::bench] fn static_copy(bencher: Bencher) { let static_bytes = get_static_model_bytes(); let file_size = static_bytes.len() as u64; bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); // Simulate old behavior: copy static bytes to Vec, then load let bytes = Bytes::from_bytes_vec(static_bytes.to_vec()); let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false); model.load_from(&mut store).expect("Failed to load"); }); } /// Static bytes with zero-copy mode (new from_static) #[divan::bench] fn static_zero_copy(bencher: Bencher) { let static_bytes = get_static_model_bytes(); let file_size = static_bytes.len() as u64; bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); // Zero-copy: use from_static which keeps data in .rodata let mut store = BurnpackStore::from_static(static_bytes); model.load_from(&mut store).expect("Failed to load"); }); } /// In-memory shared bytes with zero-copy #[divan::bench] fn memory_shared_zero_copy(bencher: Bencher) { let static_bytes = get_static_model_bytes(); let file_size = static_bytes.len() as u64; // Pre-create shared bytes outside the benchmark loop let shared = bytes::Bytes::from_static(static_bytes); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let device: TestDevice = Default::default(); let mut model = LargeModel::::new(&device); // Create Bytes from shared (cheap clone of Arc) let bytes = Bytes::from_shared(shared.clone(), AllocationProperty::Other); let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(true); model.load_from(&mut store).expect("Failed to load"); }); } } }; } // ============================================================================= // Zero-copy verification (proves operations use static region data) // ============================================================================= /// Verify that zero-copy loading actually uses data from the static region. /// This runs once at startup to prove correctness before benchmarking. #[divan::bench_group(name = "Zero-Copy Verification", sample_count = 1)] mod verification { use super::*; use burn_ndarray::NdArray; type B = NdArray; /// Verify zero-copy: tensor storage is borrowed (not owned) #[divan::bench] fn verify_storage_is_borrowed() { let static_bytes = get_static_model_bytes(); // Load model with zero-copy from static bytes let device = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_static(static_bytes); model.load_from(&mut store).expect("Failed to load"); // Get the first layer's weight tensor and verify it uses borrowed storage let weight = model.layers[0].weight.val(); // .into_primitive() returns TensorPrimitive, .tensor() extracts B::FloatTensorPrimitive let ndarray_tensor = weight.into_primitive().tensor(); // Verify the storage is borrowed (zero-copy from static region) assert!( ndarray_tensor.is_borrowed(), "ZERO-COPY FAILURE: Tensor storage is NOT borrowed. \ Data was copied instead of being zero-copy!" ); println!("✅ Verified: Tensor storage is borrowed (zero-copy from static region)"); } /// Verify ALL layers use borrowed (zero-copy) storage. /// This is the key proof that loaded weights point to static memory. #[divan::bench] fn verify_all_layers_borrowed() { let static_bytes = get_static_model_bytes(); // Load model with zero-copy let device = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_static(static_bytes); model.load_from(&mut store).expect("Failed to load"); // Check ALL layers have borrowed storage let mut total_elements = 0usize; for (i, layer) in model.layers.iter().enumerate() { let weight = layer.weight.val(); total_elements += weight.shape().num_elements(); assert!( weight.into_primitive().tensor().is_borrowed(), "Layer {} weight should be borrowed (zero-copy)", i ); } let total_mb = (total_elements * 4) as f64 / 1_048_576.0; println!( "✅ Verified: All {} layers use borrowed storage", model.layers.len() ); println!( " - Model size: {:.2} MB - all pointing to static region", total_mb ); } /// Verify data is readable and correct using sum().into_scalar(). /// Note: sum() triggers COW copy, so this shows ops work correctly on zero-copy data. #[divan::bench] fn verify_ops_produce_correct_results() { let static_bytes = get_static_model_bytes(); let device = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_static(static_bytes); model.load_from(&mut store).expect("Failed to load"); // Compute sum of first layer weight - proves data is valid let weight = model.layers[0].weight.val(); let sum: f32 = weight.sum().into_scalar(); assert!(sum.is_finite(), "Sum should be finite"); println!("✅ Verified: Operations on zero-copy data produce valid results"); println!(" - First layer sum: {:.4}", sum); } /// Verify operations produce correct results on zero-copy data #[divan::bench] fn verify_operations_on_static_data() { let static_bytes = get_static_model_bytes(); // Load model with zero-copy let device = Default::default(); let mut model = LargeModel::::new(&device); let mut store = BurnpackStore::from_static(static_bytes); model.load_from(&mut store).expect("Failed to load"); // Perform operations on the loaded weights let weight = model.layers[0].weight.val(); let shape = weight.shape(); // Test 1: Sum should be finite (not NaN or Inf) let sum: f32 = weight.clone().sum().to_data().to_vec().unwrap()[0]; assert!( sum.is_finite(), "Operation failed: sum is not finite ({})", sum ); // Test 2: Matrix multiply with itself transposed (W @ W.T) let transposed = weight.clone().transpose(); let matmul_result = weight.clone().matmul(transposed); let matmul_sum: f32 = matmul_result.sum().to_data().to_vec().unwrap()[0]; assert!( matmul_sum.is_finite(), "Matmul failed: result sum is not finite ({})", matmul_sum ); // Test 3: Element-wise operations let doubled = weight.clone() * 2.0; let doubled_sum: f32 = doubled.sum().to_data().to_vec().unwrap()[0]; assert!( (doubled_sum - sum * 2.0).abs() < 1e-3, "Element-wise op failed: doubled_sum ({}) != sum*2 ({})", doubled_sum, sum * 2.0 ); println!("✅ Verified: Operations on zero-copy data produce correct results"); println!(" - Weight shape: {:?}", shape.as_slice()); println!(" - Sum: {:.4}", sum); println!(" - Matmul result sum: {:.4}", matmul_sum); } /// Compare zero-copy vs copy: verify both produce identical results #[divan::bench] fn verify_copy_vs_zero_copy_equality() { let static_bytes = get_static_model_bytes(); let device: ::Device = Default::default(); // Load with zero-copy let mut model_zc = LargeModel::::new(&device); let mut store_zc = BurnpackStore::from_static(static_bytes); model_zc .load_from(&mut store_zc) .expect("Failed to load zero-copy"); // Load with copy (simulate old behavior) let mut model_copy = LargeModel::::new(&device); let bytes = Bytes::from_bytes_vec(static_bytes.to_vec()); let mut store_copy = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false); model_copy .load_from(&mut store_copy) .expect("Failed to load copy"); // Compare weights from both models for (i, (layer_zc, layer_copy)) in model_zc .layers .iter() .zip(model_copy.layers.iter()) .enumerate() { let weight_zc = layer_zc.weight.val(); let weight_copy = layer_copy.weight.val(); // Check shapes match assert_eq!( weight_zc.shape(), weight_copy.shape(), "Layer {} weight shapes don't match", i ); // Check values match (using sum as a proxy) let sum_zc: f32 = weight_zc.clone().sum().to_data().to_vec().unwrap()[0]; let sum_copy: f32 = weight_copy.clone().sum().to_data().to_vec().unwrap()[0]; assert!( (sum_zc - sum_copy).abs() < 1e-6, "Layer {} weight sums don't match: zero-copy={}, copy={}", i, sum_zc, sum_copy ); } println!( "✅ Verified: Zero-copy and copy loading produce identical results for all {} layers", model_zc.layers.len() ); } } // ============================================================================= // Store-only benchmarks (no backend allocation overhead) // These show the TRUE zero-copy benefit at the store level // ============================================================================= #[divan::bench_group(name = "Store Only (no backend)", sample_count = 10)] mod store_only { use super::*; /// File-based store with copy mode - measures store overhead only #[divan::bench] fn file_copy(bencher: Bencher) { let bp_path = get_burnpack_path(); let file_size = fs::metadata(&bp_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false); // Just iterate through all tensor snapshots, calling to_data() on each // This forces the store to read and materialize all tensor data let snapshots = store.get_all_snapshots().expect("Failed to get snapshots"); for snapshot in snapshots.values() { let _data = snapshot.to_data().expect("Failed to get tensor data"); } }); } /// File-based store with zero-copy mode - measures store overhead only #[divan::bench] fn file_zero_copy(bencher: Bencher) { let bp_path = get_burnpack_path(); let file_size = fs::metadata(&bp_path).unwrap().len(); bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true); let snapshots = store.get_all_snapshots().expect("Failed to get snapshots"); for snapshot in snapshots.values() { let _data = snapshot.to_data().expect("Failed to get tensor data"); } }); } /// Static bytes with copy mode - measures store overhead only #[divan::bench] fn static_copy(bencher: Bencher) { let static_bytes = get_static_model_bytes(); let file_size = static_bytes.len() as u64; bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { // Simulate old behavior: copy static bytes to Vec let bytes = Bytes::from_bytes_vec(static_bytes.to_vec()); let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false); let snapshots = store.get_all_snapshots().expect("Failed to get snapshots"); for snapshot in snapshots.values() { let _data = snapshot.to_data().expect("Failed to get tensor data"); } }); } /// Static bytes with zero-copy mode - measures store overhead only #[divan::bench] fn static_zero_copy(bencher: Bencher) { let static_bytes = get_static_model_bytes(); let file_size = static_bytes.len() as u64; bencher .counter(divan::counter::BytesCount::new(file_size)) .bench(|| { let mut store = BurnpackStore::from_static(static_bytes); let snapshots = store.get_all_snapshots().expect("Failed to get snapshots"); for snapshot in snapshots.values() { let _data = snapshot.to_data().expect("Failed to get tensor data"); } }); } } // ============================================================================= // Full model loading benchmarks (includes backend allocation) // ============================================================================= // Generate benchmarks for each backend bench_backend!(NdArrayBackend, ndarray_backend, "NdArray Backend (CPU)"); #[cfg(feature = "wgpu")] bench_backend!(WgpuBackend, wgpu_backend, "WGPU Backend (GPU)"); #[cfg(feature = "cuda")] bench_backend!(CudaBackend, cuda_backend, "CUDA Backend (NVIDIA GPU)"); #[cfg(feature = "tch")] bench_backend!(TchBackend, tch_backend, "LibTorch Backend"); #[cfg(feature = "metal")] bench_backend!(MetalBackend, metal_backend, "Metal Backend (Apple GPU)"); ================================================ FILE: crates/burn-store/examples/burnpack_inspect.rs ================================================ //! Example: Generate a Burnpack file for inspection //! //! This example creates a simple Burnpack file that you can examine to understand the format. //! //! Usage: //! cargo run --example burnpack-inspect [output_path] //! //! Example: //! cargo run --example burnpack-inspect sample.bpk //! cargo run --example burnpack-inspect /tmp/test.bpk //! //! After generating the file, examine it with: //! hexdump -C sample.bpk | head -100 //! xxd sample.bpk | head -100 //! hexyl sample.bpk use burn_core as burn; use burn_core::module::Module; use burn_ndarray::NdArray; use burn_nn::{Linear, LinearConfig}; use burn_store::{BurnpackStore, ModuleSnapshot}; use burn_tensor::backend::Backend; use std::env; // Simple model with a few layers #[derive(Module, Debug)] struct SampleModel { linear1: Linear, linear2: Linear, linear3: Linear, } impl SampleModel { fn new(device: &B::Device) -> Self { Self { linear1: LinearConfig::new(128, 64).init(device), linear2: LinearConfig::new(64, 32).init(device), linear3: LinearConfig::new(32, 10).init(device), } } } fn main() { type Backend = NdArray; // Get output path from command line or use default let output_path = env::args() .nth(1) .unwrap_or_else(|| "sample.bpk".to_string()); println!("Creating sample Burnpack file: {}", output_path); println!(); // Create a simple model let device = Default::default(); let model = SampleModel::::new(&device); // Save to Burnpack format with metadata let mut store = BurnpackStore::from_file(&output_path) .overwrite(true) .metadata("format", "burnpack") .metadata("description", "Sample file for examining Burnpack format") .metadata("version", env!("CARGO_PKG_VERSION")) .metadata("author", "Burn Example"); model.save_into(&mut store).expect("Failed to save model"); println!("✅ Successfully created: {}", output_path); println!(); println!("📋 File Structure:"); println!(" ┌─────────────────────────────────────┐"); println!(" │ Header (10 bytes) │"); println!(" ├─────────────────────────────────────┤"); println!(" │ - Magic: 0x4E525542 (BURN in LE) │"); println!(" │ - Version: 0x0001 (2 bytes) │"); println!(" │ - Metadata size: (4 bytes, u32 LE) │"); println!(" ├─────────────────────────────────────┤"); println!(" │ Metadata (CBOR format) │"); println!(" ├─────────────────────────────────────┤"); println!(" │ - Tensor descriptors │"); println!(" │ * name, dtype, shape, offsets │"); println!(" │ - User metadata │"); println!(" ├─────────────────────────────────────┤"); println!(" │ Tensor Data (raw bytes, LE) │"); println!(" ├─────────────────────────────────────┤"); println!(" │ - linear1.weight [64, 128] │"); println!(" │ - linear1.bias [64] │"); println!(" │ - linear2.weight [32, 64] │"); println!(" │ - linear2.bias [32] │"); println!(" │ - linear3.weight [10, 32] │"); println!(" │ - linear3.bias [10] │"); println!(" └─────────────────────────────────────┘"); println!(); println!("📊 Model Contents:"); println!(" - linear1.weight: [64, 128] = 8,192 params → 32,768 bytes"); println!(" - linear1.bias: [64] = 64 params → 256 bytes"); println!(" - linear2.weight: [32, 64] = 2,048 params → 8,192 bytes"); println!(" - linear2.bias: [32] = 32 params → 128 bytes"); println!(" - linear3.weight: [10, 32] = 320 params → 1,280 bytes"); println!(" - linear3.bias: [10] = 10 params → 40 bytes"); println!(" ───────────────────────────────────────────────────────"); let total_params = 8192 + 64 + 2048 + 32 + 320 + 10; let total_bytes = total_params * 4; println!( " Total: {} parameters = {} KB", total_params, total_bytes / 1024 ); println!(); // Get actual file size if let Ok(metadata) = std::fs::metadata(&output_path) { let file_size = metadata.len(); println!( "📦 File size: {} bytes ({:.2} KB)", file_size, file_size as f64 / 1024.0 ); } println!(); println!("🔍 Inspection Commands:"); println!(); println!(" # View first 100 bytes in hex:"); println!(" hexdump -C {} | head -20", output_path); println!(); println!(" # View header only (10 bytes):"); println!(" head -c 10 {} | hexdump -C", output_path); println!(); println!(" # View with prettier hex viewer (if installed):"); println!(" hexyl {} | head -50", output_path); println!(); println!(" # View in binary format:"); println!(" xxd -b {} | head -20", output_path); println!(); println!(" # Extract and examine header:"); println!(" # Magic (bytes 0-3): Should be 42 55 52 4E (BURN)"); println!(" # Version (bytes 4-5): Should be 01 00"); println!(" # Metadata size (bytes 6-9): u32 little-endian"); println!(); println!(" # Load back the model:"); println!( " # let mut store = BurnpackStore::from_file(\"{}\");", output_path ); println!(" # model.load_from(&mut store)?;"); } ================================================ FILE: crates/burn-store/examples/half_precision.rs ================================================ //! Example: Save and load a model with half-precision (F32 <-> F16) //! //! Demonstrates using HalfPrecisionAdapter to automatically convert between //! F32 and F16 during saving/loading. The same adapter instance handles both //! directions. //! //! Usage: //! cargo run -p burn-store --example half_precision use burn_core as burn; use burn_core::module::Module; use burn_ndarray::NdArray; use burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig}; use burn_store::{BurnpackStore, HalfPrecisionAdapter, ModuleSnapshot}; use burn_tensor::backend::Backend; // A model with mixed layer types to show selective conversion #[derive(Module, Debug)] struct DemoModel { linear1: Linear, norm: LayerNorm, linear2: Linear, } impl DemoModel { fn new(device: &B::Device) -> Self { Self { linear1: LinearConfig::new(128, 64).init(device), norm: LayerNormConfig::new(64).init(device), linear2: LinearConfig::new(64, 10).init(device), } } } fn main() { type B = NdArray; let device = Default::default(); let model = DemoModel::::new(&device); // 1) Save at full F32 precision (baseline) let dir = tempfile::tempdir().expect("Failed to create temp dir"); let path_f32 = dir.path().join("model_f32"); let path_f16 = dir.path().join("model_f16"); let path_mixed = dir.path().join("model_mixed"); let mut store = BurnpackStore::from_file(path_f32.to_str().unwrap()).overwrite(true); model.save_into(&mut store).expect("Failed to save F32"); let size_f32 = std::fs::metadata(format!("{}.bpk", path_f32.display())) .map(|m| m.len()) .unwrap_or(0); // 2) Save with default half-precision (all default modules get F16) let adapter = HalfPrecisionAdapter::new(); let mut store = BurnpackStore::from_file(path_f16.to_str().unwrap()) .overwrite(true) .with_to_adapter(adapter.clone()); model.save_into(&mut store).expect("Failed to save F16"); let size_f16 = std::fs::metadata(format!("{}.bpk", path_f16.display())) .map(|m| m.len()) .unwrap_or(0); // 3) Save with without_module: keep LayerNorm at F32 let adapter_no_norm = HalfPrecisionAdapter::new().without_module("LayerNorm"); let mut store = BurnpackStore::from_file(path_mixed.to_str().unwrap()) .overwrite(true) .with_to_adapter(adapter_no_norm); model.save_into(&mut store).expect("Failed to save mixed"); let size_mixed = std::fs::metadata(format!("{}.bpk", path_mixed.display())) .map(|m| m.len()) .unwrap_or(0); println!("F32 (full precision): {} bytes", size_f32); println!("F16 (default modules): {} bytes", size_f16); println!("Mixed (norm stays F32): {} bytes", size_mixed); println!( "F16 savings: {:.1}%", (1.0 - size_f16 as f64 / size_f32 as f64) * 100.0 ); // 4) Round-trip: load the F16 file back to F32 with the same adapter let mut load_store = BurnpackStore::from_file(path_f16.to_str().unwrap()).with_from_adapter(adapter); let mut model2 = DemoModel::::new(&device); let result = model2.load_from(&mut load_store).expect("Failed to load"); println!( "\nRound-trip loaded {} tensors successfully", result.applied.len() ); } ================================================ FILE: crates/burn-store/pytorch-tests/Cargo.toml ================================================ [package] name = "pytorch-tests" version.workspace = true edition.workspace = true license.workspace = true [dev-dependencies] burn = { path = "../../burn" } burn-ndarray = { path = "../../burn-ndarray" } burn-autodiff = { path = "../../burn-autodiff" } burn-store = { path = "../", features = ["std", "pytorch"] } serde = { workspace = true } float-cmp = { workspace = true } ================================================ FILE: crates/burn-store/pytorch-tests/src/lib.rs ================================================ ================================================ FILE: crates/burn-store/pytorch-tests/tests/backend.rs ================================================ pub type TestBackend = burn_ndarray::NdArray; ================================================ FILE: crates/burn-store/pytorch-tests/tests/batch_norm/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.norm1 = nn.BatchNorm2d(5) def forward(self, x): x = self.norm1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) # Condition batch norm (each forward will affect the running stats) x1 = torch.ones(1, 5, 2, 2) - 0.5 _ = model(x1) model.eval() # Set to eval mode to freeze running stats # Save the model after the first forward torch.save(model.state_dict(), "batch_norm2d.pt") x2 = torch.ones(1, 5, 2, 2) - 0.3 print("Input shape: {}", x2.shape) output = model(x2) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/batch_norm/mod.rs ================================================ use burn::{ module::Module, nn::{BatchNorm, BatchNormConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { norm1: BatchNorm, } impl Net { pub fn new(device: &B::Device) -> Self { Self { norm1: BatchNormConfig::new(5).init(device), // Python model uses BatchNorm2d(5) } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.norm1.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; #[test] fn batch_norm2d() { let device = Default::default(); let mut model = Net::::new(&device); let mut store = PytorchStore::from_file("tests/batch_norm/batch_norm2d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([1, 5, 2, 2], &device) - 0.3; let output = model.forward(input); let expected = Tensor::::from_data( [[ [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/boolean/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() buffer = torch.tensor([True, False, True]) self.register_buffer("buffer", buffer, persistent=True) def forward(self, x): x = self.buffer return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "boolean.pt") input = torch.ones(3, 3) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/boolean/mod.rs ================================================ use burn::{ module::{Module, Param, ParamId}, tensor::{Bool, Tensor, TensorData, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { buffer: Param>, } impl Net { /// Create a new model with placeholder values. pub fn init(device: &B::Device) -> Self { Self { buffer: Param::initialized( ParamId::new(), Tensor::from_bool(TensorData::from([false, false, false]), device), ), } } /// Forward pass of the model. pub fn forward(&self, _x: Tensor) -> Tensor { self.buffer.val() } } #[cfg(test)] mod tests { use burn::tensor::TensorData; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; use crate::backend::TestBackend; #[test] fn boolean() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/boolean/boolean.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([3, 3], &device); let output = model.forward(input); let expected = Tensor::::from_bool( TensorData::from([true, false, true]), &device, ); assert_eq!(output.to_data(), expected.to_data()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/buffer/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() buffer = torch.ones(3, 3) self.register_buffer("buffer", buffer, persistent=True) def forward(self, x): x = self.buffer + x return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "buffer.pt") input = torch.ones(3, 3) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/buffer/mod.rs ================================================ use burn::{ module::{Module, Param}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { buffer: Param>, } impl Net { /// Create a new model with placeholder values. pub fn init(device: &B::Device) -> Self { Self { buffer: Param::from_tensor(Tensor::zeros([3, 3], device)), } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.buffer.val() + x } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; #[test] fn buffer() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/buffer/buffer.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([3, 3], &device); let output = model.forward(input); let expected = Tensor::::ones([3, 3], &device) * 2.0; output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/complex_nested/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size): super(ConvBlock, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size) self.norm = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) x = self.norm(x) return x class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv_blocks = nn.Sequential( ConvBlock(2, 4, (3, 2)), ConvBlock(4, 6, (3, 2)), ) self.norm1 = nn.BatchNorm2d(6) self.fc1 = nn.Linear(120, 12) self.fc2 = nn.Linear(12, 10) def forward(self, x): x = self.conv_blocks(x) x = self.norm1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.log_softmax(x, dim=1) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(2) model = Net().to(torch.device("cpu")) # Condition the model (batch norm requires a forward pass to compute the mean and variance) x1 = torch.ones(1, 2, 9, 6) - 0.1 x2 = torch.ones(1, 2, 9, 6) - 0.3 output = model(x1) output = model(x2) model.eval() # set to eval mode torch.save(model.state_dict(), "complex_nested.pt") # feed test data x = torch.ones(1, 2, 9, 6) - 0.5 output = model(x) print("Input shape: {}", x.shape) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/complex_nested/mod.rs ================================================ use burn::tensor::Tolerance; use burn::tensor::ops::FloatElem; use burn::{ module::Module, nn::{ BatchNorm, BatchNormConfig, Linear, LinearConfig, conv::{Conv2d, Conv2dConfig}, }, tensor::{ Tensor, activation::{log_softmax, relu}, backend::Backend, }, }; use burn_autodiff::Autodiff; use burn_store::{ModuleSnapshot, PytorchStore}; #[derive(Module, Debug)] pub struct ConvBlock { conv: Conv2d, norm: BatchNorm, } #[derive(Module, Debug)] pub struct Net { conv_blocks: Vec>, norm1: BatchNorm, fc1: Linear, fc2: Linear, } impl Net { pub fn init(device: &B::Device) -> Self { let conv_blocks = vec![ ConvBlock { conv: Conv2dConfig::new([2, 4], [3, 2]).init(device), norm: BatchNormConfig::new(4).init(device), // matches conv output channels }, ConvBlock { conv: Conv2dConfig::new([4, 6], [3, 2]).init(device), norm: BatchNormConfig::new(6).init(device), // matches conv output channels }, ]; let norm1 = BatchNormConfig::new(6).init(device); let fc1 = LinearConfig::new(120, 12).init(device); let fc2 = LinearConfig::new(12, 10).init(device); Self { conv_blocks, norm1, fc1, fc2, } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv_blocks[0].forward(x); let x = self.conv_blocks[1].forward(x); let x = self.norm1.forward(x); let x = x.reshape([0, -1]); let x = self.fc1.forward(x); let x = relu(x); let x = self.fc2.forward(x); log_softmax(x, 1) } } impl ConvBlock { pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv.forward(x); self.norm.forward(x) } } /// Partial model to test loading of partial records. #[derive(Module, Debug)] pub struct PartialNet { conv1: ConvBlock, } impl PartialNet { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = ConvBlock { conv: Conv2dConfig::new([2, 4], [3, 2]).init(device), norm: BatchNormConfig::new(4).init(device), // matches conv output channels }; Self { conv1 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.conv1.forward(x) } } /// Model with extra fields to test loading of records (e.g. from a different model). #[derive(Module, Debug)] pub struct PartialWithExtraNet { conv1: ConvBlock, extra_field: bool, // This field is not present in the pytorch model } impl PartialWithExtraNet { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = ConvBlock { conv: Conv2dConfig::new([2, 4], [3, 2]).init(device), norm: BatchNormConfig::new(4).init(device), // matches conv output channels }; Self { conv1, extra_field: true, } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.conv1.forward(x) } } type TestBackend = burn_ndarray::NdArray; fn model_test(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::ones([1, 2, 9, 6], &device) - 0.5; let output = model.forward(input); let expected = Tensor::::from_data( [[ -2.306_613, -2.058_945_4, -2.298_372_7, -2.358_294, -2.296_395_5, -2.416_090_5, -2.107_669, -2.428_420_8, -2.526_469, -2.319_918_6, ]], &device, ); output.to_data().assert_approx_eq::>( &expected.to_data(), Tolerance::absolute(precision), ); } #[test] fn full_record() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); model_test(model, 1e-8); } #[test] fn full_record_autodiff() { let device = Default::default(); let mut model = Net::>::init(&device); let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); } #[test] fn half_record() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); model_test(model, 1e-4); } #[test] fn partial_model_loading() { let device = Default::default(); let mut model = PartialNet::::init(&device); // Load the full model but rename "conv_blocks.0.*" to "conv1.*" let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt") .with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1") .allow_partial(true); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([1, 2, 9, 6], &device) - 0.5; let output = model.forward(input); // get the sum of all elements in the output tensor for quick check let sum = output.sum(); assert!((sum.into_scalar() - 4.871538).abs() < 0.000002); } #[test] fn extra_field_model_loading() { let device = Default::default(); let mut model = PartialWithExtraNet::::init(&device); // Load the full model but rename "conv_blocks.0.*" to "conv1.*" let mut store = PytorchStore::from_file("tests/complex_nested/complex_nested.pt") .with_key_remapping("conv_blocks\\.0\\.(.*)", "conv1.$1") .allow_partial(true); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([1, 2, 9, 6], &device) - 0.5; let output = model.forward(input); // get the sum of all elements in the output tensor for quick check let sum = output.sum(); assert!((sum.into_scalar() - 4.871538).abs() < 0.000002); assert!(model.extra_field); } ================================================ FILE: crates/burn-store/pytorch-tests/tests/config/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 4, bias=False) def forward(self, x): x = self.fc1(x) x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2 x = self.fc2(x) return x CONFIG = { "n_head": 2, "n_layer": 3, "d_model": 512, "some_float": 0.1, "some_int": 1, "some_bool": True, "some_str": "hello", "some_list_int": [1, 2, 3], "some_list_str": ["hello", "world"], "some_list_float": [0.1, 0.2, 0.3], "some_dict": { "some_key": "some_value" } } class ModelWithBias(nn.Module): def __init__(self): super(ModelWithBias, self).__init__() self.fc1 = nn.Linear(2, 3) def forward(self, x): x = self.fc1(x) return x def main(): model = Model().to(torch.device("cpu")) weights_with_config = { "my_model": model.state_dict(), "my_config": CONFIG } torch.save(weights_with_config, "weights_with_config.pt") if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/config/mod.rs ================================================ #![allow(clippy::too_many_arguments)] // To mute derive Config warning use std::collections::HashMap; use burn::config::Config; #[allow(clippy::too_many_arguments)] #[derive(Debug, PartialEq, Config)] struct NetConfig { n_head: usize, n_layer: usize, d_model: usize, some_float: f64, some_int: i32, some_bool: bool, some_str: String, some_list_int: Vec, some_list_str: Vec, some_list_float: Vec, some_dict: HashMap, } #[cfg(test)] mod tests { use burn_store::pytorch::PytorchReader; use super::*; #[test] fn test_net_config() { let config_expected = NetConfig { n_head: 2, n_layer: 3, d_model: 512, some_float: 0.1, some_int: 1, some_bool: true, some_str: "hello".to_string(), some_list_int: vec![1, 2, 3], some_list_str: vec!["hello".to_string(), "world".to_string()], some_list_float: vec![0.1, 0.2, 0.3], some_dict: { let mut map = HashMap::new(); map.insert("some_key".to_string(), "some_value".to_string()); map }, }; let path = "tests/config/weights_with_config.pt"; let top_level_key = Some("my_config"); let config: NetConfig = PytorchReader::load_config(path, top_level_key).unwrap(); assert_eq!(config, config_expected); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv1d/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv1d(2, 2, 2) self.conv2 = nn.Conv1d(2, 2, 2, bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "conv1d.pt") input = torch.rand(1, 2, 6) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv1d/mod.rs ================================================ use burn::{ module::Module, nn::conv::{Conv1d, Conv1dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: Conv1d, conv2: Conv1d, } impl Net { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = Conv1dConfig::new(2, 2, 2).init(device); let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); self.conv2.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; fn conv1d(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [ 0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965, ], [ 0.33977574, 0.523_940_8, 0.798_063_9, 0.77176833, 0.01122457, 0.80996025, ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [0.02987457, 0.03134188, 0.04234261, -0.02437721], [-0.03788019, -0.02972012, -0.00806090, -0.01981254], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn conv1d_full_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv1d(model, 1e-7); } #[test] fn conv1d_half_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv1d/conv1d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv1d(model, 1e-4); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv2d/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(2, 2, (2,2)) self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "conv2d.pt") input = torch.rand(1, 2, 5, 5) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv2d/mod.rs ================================================ use burn::{ module::Module, nn::conv::{Conv2d, Conv2dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: Conv2d, conv2: Conv2d, } impl Net { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device); let conv2 = Conv2dConfig::new([2, 2], [2, 2]) .with_bias(false) .init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); self.conv2.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn conv2d(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [ [ 0.024_595_8, 0.25883394, 0.93905586, 0.416_715_5, 0.713_979_7, ], [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8], [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4], [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136], [ 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845, 0.804_481_1, ], ], [ [ 0.65517855, 0.17679012, 0.824_772_3, 0.803_550_9, 0.943_447_5, ], [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086], [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497], [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397], [ 0.751_675_7, 0.148_438_4, 0.12274551, 0.530_407_2, 0.414_796_4, ], ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [-0.02502128, 0.00250649, 0.04841233], [0.04589614, -0.00296854, 0.01991477], [0.02920526, 0.059_497_3, 0.04326791], ], [ [-0.04825336, 0.080_190_9, -0.02375088], [0.02885434, 0.09638263, -0.07460806], [0.02004079, 0.06244051, 0.035_887_1], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn conv2d_full_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv2d(model, 1e-7); } #[test] fn conv2d_half_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv2d/conv2d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv2d(model, 1e-4); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv_transpose1d/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.ConvTranspose1d(2, 2, 2) self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "conv_transpose1d.pt") input = torch.rand(1, 2, 2) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv_transpose1d/mod.rs ================================================ use burn::{ module::Module, nn::conv::{ConvTranspose1d, ConvTranspose1dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: ConvTranspose1d, conv2: ConvTranspose1d, } impl Net { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device); let conv2 = ConvTranspose1dConfig::new([2, 2], 2) .with_bias(false) .init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); self.conv2.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn conv_transpose1d(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [0.02935525, 0.01119324, -0.01356167, -0.00682688], [0.01644749, -0.01429807, 0.00083987, 0.00279229], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn conv_transpose1d_full() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv_transpose1d(model, 1e-8); } #[test] fn conv_transpose1d_half() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv_transpose1d/conv_transpose1d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv_transpose1d(model, 1e-4); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv_transpose2d/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2)) self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "conv_transpose2d.pt") input = torch.rand(1, 2, 2, 2) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/conv_transpose2d/mod.rs ================================================ use burn::{ module::Module, nn::conv::{ConvTranspose2d, ConvTranspose2dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: ConvTranspose2d, conv2: ConvTranspose2d, } impl Net { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device); let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2]) .with_bias(false) .init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); self.conv2.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn conv_transpose2d(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], [[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [0.04547675, 0.01879685, -0.01636661, 0.00310803], [0.02090115, 0.01192738, -0.048_240_2, 0.02252235], [0.03249975, -0.00460748, 0.05003899, 0.04029131], [0.02185687, -0.10226749, -0.06508022, -0.01267705], ], [ [0.00277598, -0.00513832, -0.059_048_3, 0.00567626], [-0.03149522, -0.195_757_4, 0.03474613, 0.01997269], [-0.10096474, 0.00679589, 0.041_919_7, -0.02464108], [-0.03174751, 0.02963913, -0.02703723, -0.01860938], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn conv_transpose2d_full() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv_transpose2d(model, 1e-7); } #[test] fn conv_transpose2d_half() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/conv_transpose2d/conv_transpose2d.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); conv_transpose2d(model, 1e-4); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/embedding/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.embed = nn.Embedding(10, 3) def forward(self, x): x = self.embed(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "embedding.pt") input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/embedding/mod.rs ================================================ use burn::{ module::Module, nn::{Embedding, EmbeddingConfig}, tensor::{Int, Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { embed: Embedding, } impl Net { /// Create a new model. pub fn init(device: &B::Device) -> Self { let embed = EmbeddingConfig::new(10, 3).init(device); Self { embed } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.embed.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn embedding(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device); let output = model.forward(input); let expected = Tensor::::from_data( [ [ [-1.609_484_9, -0.10016718, -0.609_188_9], [-0.97977227, -1.609_096_3, -0.712_144_6], [-0.22227049, 1.687_113_4, -0.32062083], [-0.29934573, 1.879_345_7, -0.07213178], ], [ [-0.22227049, 1.687_113_4, -0.32062083], [0.303_722, -0.777_314_3, -0.25145486], [-0.97977227, -1.609_096_3, -0.712_144_6], [-0.02878714, 2.357_111, -1.037_338_7], ], ], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn embedding_full_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/embedding/embedding.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); embedding(model, 1e-3); } #[test] fn embedding_half_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/embedding/embedding.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); embedding(model, 1e-3); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/enum_module/export_weights.py ================================================ #!/usr/bin/env python3 import torch from torch import nn, Tensor class DwsConv(nn.Module): """Depthwise separable convolution.""" def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None: super().__init__() # Depthwise conv self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels) # Pointwise conv self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1) def forward(self, x: Tensor) -> Tensor: x = self.dconv(x) return self.pconv(x) class Model(nn.Module): def __init__(self, depthwise: bool = False) -> None: super().__init__() self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3) def forward(self, x: Tensor) -> Tensor: return self.conv(x) def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "enum_depthwise_false.pt") input = torch.rand(1, 2, 5, 5) print("Depthwise is False") print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) print("Depthwise is True") model = Model(depthwise=True).to(torch.device("cpu")) torch.save(model.state_dict(), "enum_depthwise_true.pt") print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/enum_module/mod.rs ================================================ use burn::{ module::Module, nn::conv::{Conv2d, Conv2dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] #[allow(clippy::large_enum_variant)] pub enum Conv { DwsConv(DwsConv), Conv(Conv2d), } #[derive(Module, Debug)] pub struct DwsConv { dconv: Conv2d, pconv: Conv2d, } #[derive(Module, Debug)] pub struct Net { conv: Conv, } impl Net { /// Create a new model with DwsConv variant. pub fn init_dws_conv(device: &B::Device) -> Self { let dconv = Conv2dConfig::new([2, 2], [3, 3]) .with_groups(2) .init(device); let pconv = Conv2dConfig::new([2, 2], [1, 1]) .with_groups(1) .init(device); Net { conv: Conv::DwsConv(DwsConv { dconv, pconv }), } } /// Create a new model with Conv variant. pub fn init_conv(device: &B::Device) -> Self { let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]); Net { conv: Conv::Conv(conv2d_config.init(device)), } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { match &self.conv { Conv::DwsConv(dws_conv) => { let x = dws_conv.dconv.forward(x); dws_conv.pconv.forward(x) } Conv::Conv(conv) => conv.forward(x), } } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; #[test] fn depthwise_false() { let device = Default::default(); let mut model = Net::::init_conv(&device); let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_false.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [ [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], [ 0.804_481_1, 0.65517855, 0.17679012, 0.824_772_3, 0.803_550_9, ], ], [ [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], [ 0.03694397, 0.751_675_7, 0.148_438_4, 0.12274551, 0.530_407_2, ], [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [0.35449377, -0.02832414, 0.490_976_1], [0.29709217, 0.332_586_3, 0.30594018], [0.18101373, 0.30932188, 0.30558896], ], [ [-0.17683622, -0.13244139, -0.05608707], [0.23467252, -0.07038684, 0.255_044_1], [-0.241_931_3, -0.20476191, -0.14468731], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } #[test] fn depthwise_true() { let device = Default::default(); let mut model = Net::::init_dws_conv(&device); let mut store = PytorchStore::from_file("tests/enum_module/enum_depthwise_true.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [ [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], [ 0.804_481_1, 0.65517855, 0.17679012, 0.824_772_3, 0.803_550_9, ], ], [ [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], [ 0.03694397, 0.751_675_7, 0.148_438_4, 0.12274551, 0.530_407_2, ], [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [0.77874625, 0.859_017_6, 0.834_283_5], [0.773_056_4, 0.73817325, 0.78292674], [0.710_775_2, 0.747_187_2, 0.733_264_4], ], [ [-0.44891885, -0.49027523, -0.394_170_7], [-0.43836114, -0.33961445, -0.387_311_5], [-0.581_134_3, -0.34197026, -0.535_035_7], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/group_norm/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.norm1 = nn.GroupNorm(2, 6) def forward(self, x): x = self.norm1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "group_norm.pt") x2 = torch.rand(1, 6, 2, 2) print("Input shape: {}", x2.shape) print("Input: {}", x2) output = model(x2) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/group_norm/mod.rs ================================================ use burn::{ module::Module, nn::{GroupNorm, GroupNormConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { norm1: GroupNorm, } impl Net { /// Create a new model from the given record. pub fn init(device: &B::Device) -> Self { let norm1 = GroupNormConfig::new(2, 6).init(device); Self { norm1 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.norm1.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn group_norm(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]], [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]], [[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]], [[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]], [[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]], [[0.31379688, 0.19801933], [0.41619217, 0.28432965]], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]], [[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]], [[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]], [[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]], [[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]], [[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn group_norm_full() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); group_norm(model, 1e-3); } #[test] fn group_norm_half() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/group_norm/group_norm.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); group_norm(model, 1e-3); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/integer/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() buffer = torch.tensor([1, 2, 3]) self.register_buffer("buffer", buffer, persistent=True) def forward(self, x): x = self.buffer return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "integer.pt") input = torch.ones(3, 3) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/integer/mod.rs ================================================ use burn::{ module::{Module, Param, ParamId}, tensor::{Int, Tensor, TensorData, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { buffer: Param>, } impl Net { /// Create a new model with placeholder values. pub fn init(device: &B::Device) -> Self { Self { buffer: Param::initialized( ParamId::new(), Tensor::::from_data(TensorData::from([0, 0, 0]), device), ), } } /// Forward pass of the model. pub fn forward(&self, _x: Tensor) -> Tensor { self.buffer.val() } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::TensorData; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; fn integer(model: Net) { let device = Default::default(); let input = Tensor::::ones([3, 3], &device); let output = model.forward(input); let expected = Tensor::::from_data(TensorData::from([1, 2, 3]), &device); assert_eq!(output.to_data(), expected.to_data()); } #[test] fn integer_full_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/integer/integer.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); integer(model); } #[test] fn integer_half_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/integer/integer.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); integer(model); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/key_remap/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class ConvModule(nn.Module): def __init__(self): super(ConvModule, self).__init__() self.conv1 = nn.Conv2d(2, 2, (2,2)) self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = ConvModule() def forward(self, x): x = self.conv(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "key_remap.pt") input = torch.rand(1, 2, 5, 5) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/key_remap/mod.rs ================================================ use burn::{ module::Module, nn::conv::{Conv2d, Conv2dConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: Conv2d, conv2: Conv2d, } impl Net { /// Create a new model. pub fn init(device: &B::Device) -> Self { let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device); let conv2 = Conv2dConfig::new([2, 2], [2, 2]) .with_bias(false) .init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); self.conv2.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; #[test] fn key_remap() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/key_remap/key_remap.pt") .with_key_remapping("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [ [ 0.024_595_8, 0.25883394, 0.93905586, 0.416_715_5, 0.713_979_7, ], [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8], [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4], [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136], [ 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845, 0.804_481_1, ], ], [ [ 0.65517855, 0.17679012, 0.824_772_3, 0.803_550_9, 0.943_447_5, ], [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086], [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497], [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397], [ 0.751_675_7, 0.148_438_4, 0.12274551, 0.530_407_2, 0.414_796_4, ], ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [-0.02502128, 0.00250649, 0.04841233], [0.04589614, -0.00296854, 0.01991477], [0.02920526, 0.059_497_3, 0.04326791], ], [ [-0.04825336, 0.080_190_9, -0.02375088], [0.02885434, 0.09638263, -0.07460806], [0.02004079, 0.06244051, 0.035_887_1], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/key_remap_chained/export_weights.py ================================================ #!/usr/bin/env python3 import torch from torch import nn, Tensor class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), ) def forward(self, x: Tensor) -> Tensor: return self.block(x) class Model(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 6, 3, bias=False) self.bn = nn.BatchNorm2d(6) self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6)) def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) x = self.layer(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(42) model = Model() input = torch.rand(1, 3, 4, 4) model(input) # condition batch norm model.eval() with torch.no_grad(): print(f"Input shape: {input.shape}") print("Input type: {}", input.dtype) print(f"Input: {input}") output = model(input) print(f"Output: {output}") print(f"Output Shape: {output.shape}") torch.save(model.state_dict(), "key_remap.pt") if __name__ == "__main__": main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/key_remap_chained/mod.rs ================================================ use std::marker::PhantomData; use burn::{ module::Module, nn::{ BatchNorm, BatchNormConfig, conv::{Conv2d, Conv2dConfig}, }, tensor::{Device, Tensor, backend::Backend}, }; /// Some module that implements a specific method so it can be used in a sequential block. pub trait ForwardModule { fn forward(&self, input: Tensor) -> Tensor; } /// Conv2d + BatchNorm block. #[derive(Module, Debug)] pub struct ConvBlock { conv: Conv2d, bn: BatchNorm, } impl ForwardModule for ConvBlock { fn forward(&self, input: Tensor) -> Tensor { let out = self.conv.forward(input); self.bn.forward(out) } } impl ConvBlock { pub fn new(in_channels: usize, out_channels: usize, device: &Device) -> Self { let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1]) .with_bias(false) .init(device); let bn = BatchNormConfig::new(out_channels).init(device); Self { conv, bn } } } /// Collection of sequential blocks. #[derive(Module, Debug)] pub struct ModuleBlock { blocks: Vec, _backend: PhantomData, } impl> ModuleBlock { pub fn forward(&self, input: Tensor) -> Tensor { let mut out = input; for block in &self.blocks { out = block.forward(out); } out } } impl ModuleBlock> { pub fn new(device: &Device) -> Self { let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)]; Self { blocks, _backend: PhantomData, } } } #[derive(Module, Debug)] pub struct Model { conv: Conv2d, bn: BatchNorm, layer: ModuleBlock, } impl Model> { pub fn new(device: &Device) -> Self { let conv = Conv2dConfig::new([3, 6], [3, 3]) .with_bias(false) .init(device); let bn = BatchNormConfig::new(6).init(device); let layer = ModuleBlock::new(device); Self { conv, bn, layer } } pub fn forward(&self, input: Tensor) -> Tensor { let out = self.conv.forward(input); let out = self.bn.forward(out); self.layer.forward(out) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; #[test] #[should_panic] fn key_remap_chained_missing_pattern() { // Loading record should fail due to missing pattern to map the layer.blocks let device = Default::default(); let mut model: Model = Model::new(&device); let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt") // Map *.block.0.* -> *.conv.* .with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") // Map *.block.1.* -> *.bn.* .with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2"); model .load_from(&mut store) .expect("Should decode state successfully"); } #[test] fn key_remap_chained() { let device = Default::default(); let mut model: Model = Model::new(&device); let mut store = PytorchStore::from_file("tests/key_remap_chained/key_remap.pt") // Map *.block.0.* -> *.conv.* .with_key_remapping("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") // Map *.block.1.* -> *.bn.* .with_key_remapping("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2") // Map layer.[i].* -> layer.blocks.[i].* .with_key_remapping("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [ [0.76193494, 0.626_546_1, 0.49510366, 0.11974698], [0.07161391, 0.03232569, 0.704_681, 0.254_516], [0.399_373_7, 0.21224737, 0.40888822, 0.14808255], [0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6], ], [ [0.33959562, 0.13321638, 0.41178054, 0.257_626_3], [0.347_029_2, 0.02400219, 0.77974546, 0.15189773], [0.75130886, 0.726_892_1, 0.85721636, 0.11647397], [0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734], ], [ [0.42948407, 0.49613327, 0.38488472, 0.08250773], [0.73995143, 0.00364107, 0.81039995, 0.87411255], [0.972_853_2, 0.38206023, 0.08917904, 0.61241513], [0.77621365, 0.00234562, 0.38650817, 0.20027226], ], ]], &device, ); let expected = Tensor::::from_data( [[ [[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]], [[0.17582723, 0.11344293], [0.05444185, 0.13307181]], [[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]], [[0.00230906, -0.02177845], [0.01129148, 0.00925517]], [[0.14751078, 0.14433631], [0.05498439, 0.29049855]], [[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]], ]], &device, ); let output = model.forward(input); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/layer_norm/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.norm1 = nn.LayerNorm(2) def forward(self, x): x = self.norm1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "layer_norm.pt") x2 = torch.rand(1, 2, 2, 2) print("Input shape: {}", x2.shape) print("Input: {}", x2) output = model(x2) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/layer_norm/mod.rs ================================================ use burn::{ module::Module, nn::{LayerNorm, LayerNormConfig}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { norm1: LayerNorm, } impl Net { /// Create a new model. pub fn init(device: &B::Device) -> Self { let norm1 = LayerNormConfig::new(2).init(device); Self { norm1 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.norm1.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; fn layer_norm(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]], [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]], [[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn layer_norm_full() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); layer_norm(model, 1e-3); } #[test] fn layer_norm_half() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/layer_norm/layer_norm.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); layer_norm(model, 1e-3); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/linear/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.fc1 = nn.Linear(2, 3) self.fc2 = nn.Linear(3, 4, bias=False) def forward(self, x): x = self.fc1(x) x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2 x = self.fc2(x) return x class ModelWithBias(nn.Module): def __init__(self): super(ModelWithBias, self).__init__() self.fc1 = nn.Linear(2, 3) def forward(self, x): x = self.fc1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) model_with_bias = ModelWithBias().to(torch.device("cpu")) torch.save(model.state_dict(), "linear.pt") torch.save(model_with_bias.state_dict(), "linear_with_bias.pt") input = torch.rand(1, 2, 2, 2) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) print("Model with bias") output = model_with_bias(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/linear/mod.rs ================================================ use burn::{ module::Module, nn::{Linear, LinearConfig, Relu}, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { fc1: Linear, fc2: Linear, relu: Relu, } impl Net { /// Create a new model. pub fn init(device: &B::Device) -> Self { let fc1 = LinearConfig::new(2, 3).init(device); let fc2 = LinearConfig::new(3, 4).with_bias(false).init(device); let relu = Relu; Self { fc1, fc2, relu } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.fc1.forward(x); let x = self.relu.forward(x); self.fc2.forward(x) } } #[derive(Module, Debug)] struct NetWithBias { fc1: Linear, } impl NetWithBias { /// Create a new model. pub fn init(device: &B::Device) -> Self { let fc1 = LinearConfig::new(2, 3).init(device); Self { fc1 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.fc1.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; fn linear_test(model: Net, precision: f32) { let device = Default::default(); let input = Tensor::::from_data( [[ [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]], [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [0.09778349, -0.13756673, 0.04962806, 0.08856435], [0.03163241, -0.02848549, 0.01437942, 0.11905234], ], [ [0.07628226, -0.10757702, 0.03656857, 0.03824598], [0.05443089, -0.06904714, 0.02744314, 0.09997337], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(precision)); } #[test] fn linear_full_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/linear/linear.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); linear_test(model, 1e-7); } #[test] fn linear_half_precision() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/linear/linear.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); linear_test(model, 1e-4); } #[test] fn linear_with_bias() { let device = Default::default(); let mut model = NetWithBias::::init(&device); let mut store = PytorchStore::from_file("tests/linear/linear_with_bias.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]], [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [-0.00432095, -1.107_101_2, 0.870_691_4], [0.024_595_5, -0.954_462_9, 0.48518157], ], [ [0.34315687, -0.757_384_2, 0.548_288], [-0.06608963, -1.072_072_7, 0.645_800_5], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/missing_module_field/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(2, 2, (2,2)) def forward(self, x): x = self.conv1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "missing_module_field.pt") if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/missing_module_field/mod.rs ================================================ use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; #[derive(Module, Debug)] #[allow(unused)] pub struct Net { do_not_exist_in_pt: Conv2d, } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::nn::conv::Conv2dConfig; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; impl Net { pub fn init(device: &B::Device) -> Self { Self { do_not_exist_in_pt: Conv2dConfig::new([2, 2], [2, 2]).init(device), } } } #[test] #[should_panic(expected = "do_not_exist_in_pt")] fn should_fail_if_struct_field_is_missing() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/missing_module_field/missing_module_field.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/non_contiguous_indexes/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() num_layers = 5 # Number of repeated convolutional layers # Create a list to store the layers layers = [] for _ in range(num_layers): layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True)) layers.append(nn.ReLU(inplace=True)) # Use nn.Sequential to create a single module from the layers self.fc = nn.Sequential(*layers) def forward(self, x): x = self.fc(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save(model.state_dict(), "non_contiguous_indexes.pt") input = torch.rand(1, 2, 5, 5) print("Input shape: {}", input.shape) print("Input: {}", input) output = model(input) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/non_contiguous_indexes/mod.rs ================================================ use burn::{ module::Module, nn::{ PaddingConfig2d, conv::{Conv2d, Conv2dConfig}, }, tensor::{Tensor, activation::relu, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { fc: Vec>, } impl Net { /// Create a new model with placeholder values. pub fn init(device: &B::Device) -> Self { let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same); // The PyTorch file has 5 Conv2d layers at non-contiguous indices (0, 2, 4, 6, 8) // in the Sequential (alternating with ReLU layers) let fc = vec![ conv2d_config.init(device), conv2d_config.init(device), conv2d_config.init(device), conv2d_config.init(device), conv2d_config.init(device), ]; Net { fc } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i))) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::{Tolerance, ops::FloatElem}; use burn_store::{ModuleSnapshot, PytorchStore}; type FT = FloatElem; use super::*; #[test] fn non_contiguous_indexes() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/non_contiguous_indexes/non_contiguous_indexes.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::from_data( [[ [ [ 0.67890584, 0.307_537_2, 0.265_156_2, 0.528_318_8, 0.86194897, ], [0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455], [0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637], [0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804], [0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932], ], [ [0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266], [0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439], [0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786], [0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244], [0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938], ], ]], &device, ); let output = model.forward(input); let expected = Tensor::::from_data( [[ [ [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3], ], [ [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], ], ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::absolute(1e-7)); } } ================================================ FILE: crates/burn-store/pytorch-tests/tests/test_mod.rs ================================================ mod backend; mod batch_norm; mod boolean; mod buffer; mod complex_nested; mod config; mod conv1d; mod conv2d; mod conv_transpose1d; mod conv_transpose2d; mod embedding; mod enum_module; mod group_norm; mod integer; mod key_remap; mod key_remap_chained; mod layer_norm; mod linear; mod missing_module_field; mod non_contiguous_indexes; mod top_level_key; ================================================ FILE: crates/burn-store/pytorch-tests/tests/top_level_key/export_weights.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(2, 2, (2,2)) def forward(self, x): x = self.conv1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) torch.save({"my_state_dict": model.state_dict()}, "top_level_key.pt") if __name__ == '__main__': main() ================================================ FILE: crates/burn-store/pytorch-tests/tests/top_level_key/mod.rs ================================================ use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; #[derive(Module, Debug)] #[allow(unused)] pub struct Net { conv1: Conv2d, } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::nn::conv::Conv2dConfig; use burn_store::{ModuleSnapshot, PytorchStore}; use super::*; impl Net { pub fn init(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([2, 2], [2, 2]).init(device), } } } #[test] #[should_panic] fn should_fail_if_not_found() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt"); model .load_from(&mut store) .expect("Should decode state successfully"); } #[test] fn should_load() { let device = Default::default(); let mut model = Net::::init(&device); let mut store = PytorchStore::from_file("tests/top_level_key/top_level_key.pt") .with_top_level_key("my_state_dict"); model .load_from(&mut store) .expect("Should decode state successfully"); } } ================================================ FILE: crates/burn-store/safetensors-tests/Cargo.toml ================================================ [package] name = "safetensors-tests" version.workspace = true edition.workspace = true license.workspace = true [dev-dependencies] burn = { path = "../../burn" } burn-ndarray = { path = "../../burn-ndarray" } burn-autodiff = { path = "../../burn-autodiff" } burn-store = { path = "../", features = ["std", "safetensors"] } serde = { workspace = true } float-cmp = { workspace = true } ================================================ FILE: crates/burn-store/safetensors-tests/src/lib.rs ================================================ ================================================ FILE: crates/burn-store/safetensors-tests/tests/backend.rs ================================================ pub type TestBackend = burn_ndarray::NdArray; ================================================ FILE: crates/burn-store/safetensors-tests/tests/multi_layer/mod.rs ================================================ use burn::{ module::Module, nn::{ BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu, conv::{Conv2d, Conv2dConfig}, }, tensor::{Tensor, backend::Backend}, }; #[derive(Module, Debug)] pub struct Net { conv1: Conv2d, norm1: BatchNorm, fc1: Linear, relu: Relu, } impl Net { pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 4], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), norm1: BatchNormConfig::new(4).init(device), fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), relu: Relu::new(), } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); let x = self.norm1.forward(x); let x = self.relu.forward(x); // Flatten all dimensions except the batch dimension let x = x.flatten(1, 3); self.fc1.forward(x) } } #[cfg(test)] mod tests { use crate::backend::TestBackend; use burn::tensor::Tolerance; use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; use super::*; #[test] fn multi_layer_model() { let device = Default::default(); let mut model = Net::::new(&device); let mut store = SafetensorsStore::from_file("tests/multi_layer/multi_layer.safetensors") .with_from_adapter(PyTorchToBurnAdapter); model .load_from(&mut store) .expect("Should decode state successfully"); let input = Tensor::::ones([1, 3, 8, 8], &device); let output = model.forward(input); // Note: Expected values should be updated based on the actual output from the PyTorch model let expected = Tensor::::from_data( [[ 0.04971555, -0.16849735, 0.05182848, -0.18032673, 0.23138367, 0.05041867, 0.13005908, -0.32202929, -0.07915690, -0.03232457, -0.19790289, -0.17476529, -0.19627589, -0.21757686, -0.31376451, 0.08377837, ]], &device, ); output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } } ================================================ FILE: crates/burn-store/safetensors-tests/tests/multi_layer/multi_layer.py ================================================ #!/usr/bin/env python3 import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import save_file class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(3, 4, kernel_size=3, padding=1) self.norm1 = nn.BatchNorm2d(4) self.flatten = nn.Flatten() self.fc1 = nn.Linear(4 * 8 * 8, 16) # Changed for smaller input size def forward(self, x): x = self.conv1(x) x = self.norm1(x) x = F.relu(x) x = self.flatten(x) x = self.fc1(x) return x def main(): torch.set_printoptions(precision=8) torch.manual_seed(1) model = Model().to(torch.device("cpu")) # Use a smaller input size # 1 batch, 3 channels (RGB), 8x8 image (small input) x1 = torch.ones(1, 3, 8, 8) _ = model(x1) model.eval() # Set to eval mode to freeze running stats # Save the model to safetensors after the first forward save_file(model.state_dict(), "multi_layer.safetensors") x2 = torch.ones(1, 3, 8, 8) print("Input shape: {}", x2.shape) output = model(x2) print("Output: {}", output) print("Output Shape: {}", output.shape) if __name__ == "__main__": main() ================================================ FILE: crates/burn-store/safetensors-tests/tests/test_mod.rs ================================================ mod backend; mod multi_layer; ================================================ FILE: crates/burn-store/src/adapter.rs ================================================ //! Module adapters for transforming tensor snapshots during save/load //! //! This module provides adapters for: //! - PyTorch/Burn format conversion (weight transposition, parameter renaming) //! - Mixed-precision storage (F32/F16 dtype casting via [`HalfPrecisionAdapter`]) //! - Adapter chaining for composing multiple transformations use crate::TensorSnapshot; use alloc::boxed::Box; use alloc::format; use alloc::rc::Rc; use alloc::string::String; use alloc::string::ToString; use alloc::vec; use burn_tensor::shape; use burn_tensor::{DType, TensorData}; use hashbrown::HashSet; // Module type names as they appear in the container_type field // These come from the Module derive macro which uses stringify! on the struct name // Format: "Struct:TypeName" for user-defined structs mod module_names { // The actual string constants that match what the Module derive macro produces pub const LINEAR: &str = "Struct:Linear"; pub const BATCH_NORM: &str = "Struct:BatchNorm"; pub const LAYER_NORM: &str = "Struct:LayerNorm"; pub const GROUP_NORM: &str = "Struct:GroupNorm"; pub const EMBEDDING: &str = "Struct:Embedding"; pub const CONV1D: &str = "Struct:Conv1d"; pub const CONV2D: &str = "Struct:Conv2d"; pub const CONV3D: &str = "Struct:Conv3d"; pub const CONV_TRANSPOSE1D: &str = "Struct:ConvTranspose1d"; pub const CONV_TRANSPOSE2D: &str = "Struct:ConvTranspose2d"; pub const CONV_TRANSPOSE3D: &str = "Struct:ConvTranspose3d"; pub const DEFORM_CONV2D: &str = "Struct:DeformConv2d"; pub const INSTANCE_NORM: &str = "Struct:InstanceNorm"; pub const RMS_NORM: &str = "Struct:RmsNorm"; pub const PRELU: &str = "Struct:PRelu"; } /// Trait for adapting tensor snapshots between different module formats pub trait ModuleAdapter: Send + Sync { /// Adapt a tensor snapshot based on its container type and parameter name fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot; /// Get alternative parameter name to try during matching /// /// When looking for a parameter in a module, this method provides an alternative /// name to try if the direct name doesn't match. This enables matching parameters /// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma"). /// /// # Arguments /// * `param_name` - The parameter name we're looking for /// * `container_type` - The type of container module (e.g., "BatchNorm") /// /// # Returns /// Alternative parameter name to try, or None if no alternative exists fn get_alternative_param_name( &self, _param_name: &str, _container_type: &str, ) -> Option { None } /// Clone the adapter into a boxed trait object fn clone_box(&self) -> Box; /// Chain adapters together, applying `self` first and then `next`. /// /// This is useful when multiple transformations are required when importing model weights /// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping). /// /// The semantics follow a simple pipeline: /// - `adapt`: `next.adapt(&self.adapt(snapshot))` /// - `get_alternative_param_name`: try `self` first; if it returns an alternative name, /// try `next` with that name, otherwise return the first alternative name. fn chain
(self, next: A) -> ChainAdapter where Self: Sized + 'static, A: ModuleAdapter + 'static, { ChainAdapter::new(self, next) } } impl Clone for Box { fn clone(&self) -> Self { self.clone_box() } } /// Adapter that applies two adapters in sequence. /// /// This allows composing smaller adapters instead of creating one large monolithic adapter. #[derive(Clone)] pub struct ChainAdapter { first: Box, second: Box, } impl ChainAdapter { /// Create a new adapter chain. pub fn new(first: A, second: B) -> Self where A: ModuleAdapter + 'static, B: ModuleAdapter + 'static, { Self { first: Box::new(first), second: Box::new(second), } } } impl ModuleAdapter for ChainAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { let snapshot = self.first.adapt(snapshot); self.second.adapt(&snapshot) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { if let Some(name) = self .first .get_alternative_param_name(param_name, container_type) { self.second .get_alternative_param_name(&name, container_type) .or(Some(name)) } else { self.second .get_alternative_param_name(param_name, container_type) } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Identity adapter that passes tensors through unchanged #[derive(Debug, Clone, Default)] pub struct IdentityAdapter; impl ModuleAdapter for IdentityAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { snapshot.clone() } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Returns the default set of module types that `HalfPrecisionAdapter` converts. /// /// Includes: Linear, Embedding, all Conv variants, LayerNorm, GroupNorm, /// InstanceNorm, RmsNorm, PRelu. /// /// Excludes BatchNorm by default because `running_var` underflows in F16. fn default_half_precision_modules() -> HashSet { let modules = [ module_names::LINEAR, module_names::EMBEDDING, module_names::CONV1D, module_names::CONV2D, module_names::CONV3D, module_names::CONV_TRANSPOSE1D, module_names::CONV_TRANSPOSE2D, module_names::CONV_TRANSPOSE3D, module_names::DEFORM_CONV2D, module_names::LAYER_NORM, module_names::GROUP_NORM, module_names::INSTANCE_NORM, module_names::RMS_NORM, module_names::PRELU, ]; modules.iter().map(|s| s.to_string()).collect() } /// Adapter for mixed-precision (F32/F16) model storage. /// /// Auto-detects conversion direction from the snapshot's dtype: /// - F32 source -> cast to F16 (typical for saving) /// - F16 source -> cast to F32 (typical for loading) /// - Other dtypes -> passed through unchanged /// /// The same instance works for both `with_to_adapter` (save) and `with_from_adapter` (load). /// /// By default, converts weights in: Linear, Embedding, Conv*, LayerNorm, GroupNorm, /// InstanceNorm, RmsNorm, PRelu. BatchNorm is excluded because `running_var` underflows in F16. /// /// # Examples /// /// Default usage (same adapter for save and load): /// ```rust /// # use burn_store::HalfPrecisionAdapter; /// let adapter = HalfPrecisionAdapter::new(); /// // store.with_to_adapter(adapter.clone()); // F32 -> F16 on save /// // store.with_from_adapter(adapter); // F16 -> F32 on load /// ``` /// /// Exclude a module type: /// ```rust /// # use burn_store::HalfPrecisionAdapter; /// let adapter = HalfPrecisionAdapter::new() /// .without_module("LayerNorm"); /// ``` /// /// Add a custom module type: /// ```rust /// # use burn_store::HalfPrecisionAdapter; /// let adapter = HalfPrecisionAdapter::new() /// .with_module("CustomLayer"); /// ``` #[derive(Debug, Clone)] pub struct HalfPrecisionAdapter { modules: HashSet, } impl HalfPrecisionAdapter { /// Create a new adapter with the default set of modules. pub fn new() -> Self { Self { modules: default_half_precision_modules(), } } /// Add a module type to convert. Accepts both short (`"MyLayer"`) and /// qualified (`"Struct:MyLayer"`) forms. /// /// Note: short names are mapped to `"Struct:Name"`. If you have an Enum-based /// module, use the qualified form `"Enum:MyModule"` explicitly. pub fn with_module(mut self, module_type: impl Into) -> Self { let name = module_type.into(); if name.contains(':') { self.modules.insert(name); } else { self.modules.insert(format!("Struct:{}", name)); } self } /// Remove a module type from conversion. Accepts both short and qualified forms. pub fn without_module(mut self, module_type: impl Into) -> Self { let name = module_type.into(); let key = if name.contains(':') { name } else { format!("Struct:{}", name) }; assert!( self.modules.contains(&key), "without_module called with '{}' which is not in the module set", key ); self.modules.remove(&key); self } /// Check whether the tensor belongs to a module that should be converted. fn should_convert(&self, snapshot: &TensorSnapshot) -> bool { snapshot .module_type() .is_some_and(|mt| self.modules.contains(&mt)) } } impl Default for HalfPrecisionAdapter { fn default() -> Self { Self::new() } } impl ModuleAdapter for HalfPrecisionAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { // Determine target dtype from source: F32 -> F16, F16 -> F32, anything else -> skip let target_dtype = match snapshot.dtype { DType::F32 => DType::F16, DType::F16 => DType::F32, _ => return snapshot.clone(), }; if !self.should_convert(snapshot) { return snapshot.clone(); } let original_data_fn = snapshot.clone_data_fn(); let cast_data_fn = Rc::new(move || { let data = original_data_fn()?; Ok(data.convert_dtype(target_dtype)) }); TensorSnapshot::from_closure( cast_data_fn, target_dtype, snapshot.shape.clone(), snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Adapter for converting from PyTorch format to Burn format /// /// Handles: /// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out]) /// - Normalization parameter renaming (weight → gamma, bias → beta) #[derive(Debug, Clone, Default)] pub struct PyTorchToBurnAdapter; impl ModuleAdapter for PyTorchToBurnAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { // For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias) if is_normalization_layer(container_type) { burn_norm_param_to_pytorch(param_name).map(|s| s.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Adapter for converting from Burn format to PyTorch format /// /// Handles: /// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in]) /// - Normalization parameter renaming (gamma → weight, beta → bias) #[derive(Debug, Clone, Default)] pub struct BurnToPyTorchAdapter; impl ModuleAdapter for BurnToPyTorchAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch) } fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option { // For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta) if is_normalization_layer(container_type) { pytorch_norm_param_to_burn(param_name).map(|s| s.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Direction of PyTorch conversion for parameter naming #[derive(Debug, Clone, Copy)] enum PyTorchConversionDirection { PyTorchToBurn, BurnToPyTorch, } /// Check if container type is a normalization layer fn is_normalization_layer(container_type: &str) -> bool { matches!( container_type, module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM ) } /// Map PyTorch normalization parameter name to Burn fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> { match param_name { "weight" => Some("gamma"), "bias" => Some("beta"), _ => None, } } /// Map Burn normalization parameter name to PyTorch fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> { match param_name { "gamma" => Some("weight"), "beta" => Some("bias"), _ => None, } } /// Core tensor adaptation logic for PyTorch format conversions fn adapt_pytorch_tensor( snapshot: &TensorSnapshot, direction: PyTorchConversionDirection, ) -> TensorSnapshot { // Extract path and parameter name let (path_stack, param_name) = match get_path_and_param(snapshot) { Some(result) => result, None => return snapshot.clone(), }; // Get module type for matching (ignores Vec/Array wrappers) let module_type = match snapshot.module_type() { Some(mt) => mt, None => return snapshot.clone(), // No user-defined module found }; // Linear: transpose weight (bidirectional - same operation both ways) if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 { return transpose_2d_tensor(snapshot); } // Normalization layers: rename parameters based on direction if is_normalization_layer(&module_type) { let new_name = match direction { PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name), PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name), }; if let Some(new_name) = new_name { return rename_parameter(snapshot, path_stack, new_name); } } snapshot.clone() } /// Extract path stack and parameter name from snapshot fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> { let path_stack = snapshot.path_stack.as_ref()?; let param_name = path_stack.last()?.as_str(); Some((path_stack.as_slice(), param_name)) } /// Rename a parameter in the snapshot fn rename_parameter( snapshot: &TensorSnapshot, path_stack: &[String], new_name: &str, ) -> TensorSnapshot { let mut new_path = path_stack.to_vec(); *new_path.last_mut().unwrap() = new_name.to_string(); TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), new_path, snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } /// Transpose a 2D tensor fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot { if snapshot.shape.len() != 2 { return snapshot.clone(); } let original_data_fn = snapshot.clone_data_fn(); let dtype = snapshot.dtype; let transposed_shape = shape![snapshot.shape[1], snapshot.shape[0]]; // Create a lazy closure that transposes when called let transposed_data_fn = Rc::new(move || { let data = original_data_fn()?; Ok(transpose_tensor_data(data)) }); TensorSnapshot::from_closure( transposed_data_fn, dtype, transposed_shape, snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } /// Transpose tensor data (assumes 2D shape is already validated) fn transpose_tensor_data(data: TensorData) -> TensorData { let shape = &data.shape; let rows = shape[0]; let cols = shape[1]; let transposed_shape = vec![cols, rows]; // Get the raw bytes and element size let bytes = data.as_bytes(); let element_size = data.dtype.size(); // Create a new buffer for transposed data let mut transposed_bytes = vec![0u8; bytes.len()]; // Transpose at the byte level - works for any data type for i in 0..rows { for j in 0..cols { let src_idx = (i * cols + j) * element_size; let dst_idx = (j * rows + i) * element_size; // Copy the bytes for this element transposed_bytes[dst_idx..dst_idx + element_size] .copy_from_slice(&bytes[src_idx..src_idx + element_size]); } } // Create new TensorData from transposed bytes TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype) } #[cfg(test)] mod tests { use super::*; use alloc::rc::Rc; use alloc::sync::Arc; use burn_tensor::{DType, Shape, TensorData}; use core::sync::atomic::{AtomicUsize, Ordering}; #[test] fn test_module_names_match_burn_nn() { // If these types are renamed or moved in `burn-nn`, this test will fail to compile. #[allow(unused_imports)] use burn_nn::{ BatchNorm, Embedding, GroupNorm, InstanceNorm, LayerNorm, Linear, PRelu, RmsNorm, conv::{ Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, DeformConv2d, }, }; assert_eq!(module_names::LINEAR, "Struct:Linear"); assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm"); assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm"); assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm"); assert_eq!(module_names::EMBEDDING, "Struct:Embedding"); assert_eq!(module_names::CONV1D, "Struct:Conv1d"); assert_eq!(module_names::CONV2D, "Struct:Conv2d"); assert_eq!(module_names::CONV3D, "Struct:Conv3d"); assert_eq!(module_names::CONV_TRANSPOSE1D, "Struct:ConvTranspose1d"); assert_eq!(module_names::CONV_TRANSPOSE2D, "Struct:ConvTranspose2d"); assert_eq!(module_names::CONV_TRANSPOSE3D, "Struct:ConvTranspose3d"); assert_eq!(module_names::DEFORM_CONV2D, "Struct:DeformConv2d"); assert_eq!(module_names::INSTANCE_NORM, "Struct:InstanceNorm"); assert_eq!(module_names::RMS_NORM, "Struct:RmsNorm"); assert_eq!(module_names::PRELU, "Struct:PRelu"); } fn create_test_snapshot(path: &str, shape: Shape, container_type: &str) -> TensorSnapshot { let path_parts: Vec = path.split('.').map(|s| s.to_string()).collect(); let values = vec![1.0f32; shape.iter().product()]; let data = TensorData::new(values, shape.clone()); TensorSnapshot::from_closure( Rc::new(move || Ok(data.clone())), DType::F32, shape, path_parts, vec![container_type.to_string()], burn_core::module::ParamId::new(), ) } #[test] fn test_pytorch_to_burn_linear_weight() { let adapter = PyTorchToBurnAdapter; // Linear layer weight should be transposed let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, shape![5, 10]); // Linear layer bias should not be transposed let snapshot = create_test_snapshot("fc.bias", shape![10], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, shape![10]); } #[test] fn test_pytorch_to_burn_norm_params() { let adapter = PyTorchToBurnAdapter; // BatchNorm weight -> gamma let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.gamma"); // BatchNorm bias -> beta let snapshot = create_test_snapshot("norm.bias", shape![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.beta"); } #[test] fn test_burn_to_pytorch_linear_weight() { let adapter = BurnToPyTorchAdapter; // Linear layer weight should be transposed let snapshot = create_test_snapshot("fc.weight", shape![5, 10], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, shape![10, 5]); } #[test] fn test_burn_to_pytorch_norm_params() { let adapter = BurnToPyTorchAdapter; // BatchNorm gamma -> weight let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.weight"); // BatchNorm beta -> bias let snapshot = create_test_snapshot("norm.beta", shape![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.full_path(), "norm.bias"); } #[test] fn test_transpose_different_dtypes() { // Test that transpose works for different data types // Test with F32 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]); let transposed = transpose_tensor_data(f32_data); assert_eq!(transposed.shape, shape![3, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); // Test with I32 let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], [2, 3]); let transposed = transpose_tensor_data(i32_data); assert_eq!(transposed.shape, shape![3, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1, 4, 2, 5, 3, 6]); // Test with F64 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]); let transposed = transpose_tensor_data(f64_data); assert_eq!(transposed.shape, shape![2, 2]); let values = transposed.to_vec::().unwrap(); assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]); } #[test] fn test_no_container_info() { let adapter = PyTorchToBurnAdapter; // Without container info, adapter returns unchanged for non-norm parameters let mut snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR); snapshot.container_stack = None; // Without container info, no transformation occurs for linear layers let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.shape, shape![10, 5]); // No transposition without container info // Test a non-linear, non-norm parameter - should pass through unchanged let mut snapshot2 = create_test_snapshot("other.weight", shape![10, 5], "Struct:Other"); snapshot2.container_stack = None; let adapted2 = adapter.adapt(&snapshot2); assert_eq!(adapted2.shape, shape![10, 5]); // No transposition } #[derive(Clone)] struct RenameParamAdapter { from: &'static str, to: &'static str, called: Arc, } impl ModuleAdapter for RenameParamAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { self.called.fetch_add(1, Ordering::Relaxed); let path_stack = match snapshot.path_stack.as_ref() { Some(stack) => stack, None => return snapshot.clone(), }; let param = match path_stack.last() { Some(p) => p.as_str(), None => return snapshot.clone(), }; if param != self.from { return snapshot.clone(); } let mut new_path = path_stack.to_vec(); *new_path.last_mut().unwrap() = self.to.to_string(); TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), new_path, snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } fn get_alternative_param_name( &self, _param_name: &str, _container_type: &str, ) -> Option { None } fn clone_box(&self) -> Box { Box::new(self.clone()) } } #[derive(Clone)] struct AltNameAdapter { from: &'static str, to: &'static str, called: Arc, } impl ModuleAdapter for AltNameAdapter { fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot { TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ) } fn get_alternative_param_name( &self, param_name: &str, _container_type: &str, ) -> Option { self.called.fetch_add(1, Ordering::Relaxed); if param_name == self.from { Some(self.to.to_string()) } else { None } } fn clone_box(&self) -> Box { Box::new(self.clone()) } } #[test] fn test_chain_adapter_pipes_adapt() { let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = RenameParamAdapter { from: "weight", to: "a", called: called1.clone(), }; let b = RenameParamAdapter { from: "a", to: "b", called: called2.clone(), }; let chain = a.chain(b); let snapshot = create_test_snapshot("fc.weight", shape![2, 2], module_names::LINEAR); let adapted = chain.adapt(&snapshot); assert_eq!(adapted.full_path(), "fc.b"); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); } #[test] fn test_chain_adapter_alternative_name_pipes_and_fallbacks() { let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "gamma", to: "weight", called: called1.clone(), }; let b = AltNameAdapter { from: "weight", to: "scale", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("scale")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // If the second adapter doesn't have a mapping for the first alternative, // fall back to the first alternative name. let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "gamma", to: "weight", called: called1.clone(), }; let b = AltNameAdapter { from: "something-else", to: "unused", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // If the first adapter doesn't provide an alternative, try the second with the original name. let called1 = Arc::new(AtomicUsize::new(0)); let called2 = Arc::new(AtomicUsize::new(0)); let a = AltNameAdapter { from: "something-else", to: "unused", called: called1.clone(), }; let b = AltNameAdapter { from: "gamma", to: "weight", called: called2.clone(), }; let chain = a.chain(b); let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); assert_eq!(called1.load(Ordering::Relaxed), 1); assert_eq!(called2.load(Ordering::Relaxed), 1); // clone_box must preserve behavior. let boxed = chain.clone_box(); let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM); assert_eq!(alt.as_deref(), Some("weight")); } #[test] fn test_half_precision_f32_to_f16() { let adapter = HalfPrecisionAdapter::new(); let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.dtype, DType::F16); assert_eq!(adapted.shape, shape![2, 3]); let data = adapted.to_data().unwrap(); assert_eq!(data.dtype, DType::F16); } #[test] fn test_half_precision_f16_to_f32() { let adapter = HalfPrecisionAdapter::new(); // Create an F16 snapshot let values = vec![1.0f32; 6]; let data = TensorData::new(values, shape![2, 3]).convert_dtype(DType::F16); let path_parts = vec!["fc".to_string(), "weight".to_string()]; let snapshot = TensorSnapshot::from_closure( Rc::new(move || Ok(data.clone())), DType::F16, shape![2, 3], path_parts, vec![module_names::LINEAR.to_string()], burn_core::module::ParamId::new(), ); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.dtype, DType::F32); } #[test] fn test_half_precision_skips_batch_norm() { let adapter = HalfPrecisionAdapter::new(); // BatchNorm is excluded by default let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM); let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.dtype, DType::F32); // unchanged } #[test] fn test_half_precision_converts_default_modules() { let adapter = HalfPrecisionAdapter::new(); // Linear let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); // Embedding let snapshot = create_test_snapshot("emb.weight", shape![100, 64], module_names::EMBEDDING); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); // Conv2d let snapshot = create_test_snapshot("conv.weight", shape![3, 3, 3, 3], module_names::CONV2D); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); // LayerNorm (included by default) let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); // GroupNorm let snapshot = create_test_snapshot("gn.gamma", shape![10], module_names::GROUP_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); // RmsNorm let snapshot = create_test_snapshot("rms.weight", shape![10], module_names::RMS_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); } #[test] fn test_half_precision_without_module() { let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm"); // LayerNorm removed from conversion set let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32); // Linear still converted let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); } #[test] fn test_half_precision_with_module() { let adapter = HalfPrecisionAdapter::new().with_module("CustomLayer"); // Custom module should now be converted let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer"); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); } #[test] fn test_half_precision_with_qualified_name() { let adapter = HalfPrecisionAdapter::new().with_module("Struct:CustomLayer"); let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer"); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); } #[test] fn test_half_precision_chain() { let adapter = PyTorchToBurnAdapter.chain(HalfPrecisionAdapter::new()); let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR); let adapted = adapter.adapt(&snapshot); // Should be both transposed and cast assert_eq!(adapted.shape, shape![5, 10]); assert_eq!(adapted.dtype, DType::F16); } #[test] fn test_half_precision_skips_no_container() { let adapter = HalfPrecisionAdapter::new(); let mut snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR); snapshot.container_stack = None; // No module type info: skip let adapted = adapter.adapt(&snapshot); assert_eq!(adapted.dtype, DType::F32); } #[test] fn test_half_precision_skips_non_float() { use burn_tensor::quantization::QuantScheme; let adapter = HalfPrecisionAdapter::new(); // QFloat source: skip let qfloat_dtype = DType::QFloat(QuantScheme::default()); let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR); let qfloat_snapshot = TensorSnapshot::from_closure( snapshot.clone_data_fn(), qfloat_dtype, snapshot.shape.clone(), snapshot.path_stack.clone().unwrap_or_default(), snapshot.container_stack.clone().unwrap_or_default(), snapshot.tensor_id.unwrap_or_default(), ); let adapted = adapter.adapt(&qfloat_snapshot); assert_eq!(adapted.dtype, qfloat_dtype); } #[test] fn test_half_precision_default_module_count() { let adapter = HalfPrecisionAdapter::new(); // 14 modules: Linear, Embedding, Conv1d-3d, ConvTranspose1d-3d, // DeformConv2d, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, PRelu assert_eq!(adapter.modules.len(), 14); } #[test] fn test_half_precision_without_module_qualified() { let adapter = HalfPrecisionAdapter::new().without_module("Struct:LayerNorm"); let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32); } #[test] fn test_half_precision_with_module_batch_norm_opt_in() { let adapter = HalfPrecisionAdapter::new().with_module("BatchNorm"); let snapshot = create_test_snapshot("bn.weight", shape![10], module_names::BATCH_NORM); assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16); } } ================================================ FILE: crates/burn-store/src/applier.rs ================================================ //! Applier that correctly applies tensor snapshots with adapter support use alloc::boxed::Box; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec::Vec; use hashbrown::{HashMap, HashSet}; use burn_core::module::{ModuleMapper, Param}; use burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend}; use crate::apply_result::{ApplyError, ApplyResult}; use crate::{ModuleAdapter, PathFilter, TensorSnapshot}; /// Applier that applies tensor snapshots to module parameters /// with proper adapter support using container type information pub struct Applier { /// Map of tensor paths to their snapshots snapshots: HashMap, /// Current path in the module hierarchy path_stack: Vec, /// Current container type stack in the module hierarchy container_stack: Vec, /// Optional filter for selective application filter: Option, /// Optional adapter to transform tensors based on container types adapter: Option>, /// Successfully applied tensor paths applied: Vec, /// Skipped tensor paths skipped: HashSet, /// Errors encountered during application errors: Vec, /// Track visited paths with their container stacks (in dot notation) to find missing tensors visited_paths: HashMap, /// Skip enum variant names when matching paths /// When true, "feature.BaseConv.weight" will also try to match "feature.weight" skip_enum_variants: bool, /// Phantom data for backend type _backend: core::marker::PhantomData, } impl Applier { /// Create a new applier with snapshots, optional filter, and optional adapter /// /// # Arguments /// /// * `views` - A vector of TensorSnapshot objects to apply /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply. /// When `None`, all available tensors are applied. /// * `adapter` - Optional adapter to transform tensors based on container types /// * `skip_enum_variants` - Skip enum variant names when matching paths pub fn new( views: Vec, filter: Option, adapter: Option>, skip_enum_variants: bool, ) -> Self { let views_map: HashMap = views .into_iter() .map(|view| (view.full_path(), view)) .collect(); Self { snapshots: views_map, path_stack: Vec::new(), container_stack: Vec::new(), filter, adapter, applied: Vec::new(), skipped: HashSet::new(), errors: Vec::new(), visited_paths: HashMap::new(), skip_enum_variants, _backend: core::marker::PhantomData, } } /// Get the current path in the module hierarchy fn current_path(&self) -> String { self.path_stack.join(".") } /// Get the current module type (last Struct/Enum in container stack) fn current_module_type(&self) -> Option<&str> { self.container_stack .iter() .rev() .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:")) .map(|s| s.as_str()) } /// Check if a tensor should be applied based on filter fn should_apply(&self) -> bool { match &self.filter { None => true, Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack), } } /// Convert the applier into a result pub fn into_result(self) -> ApplyResult { let mut unused: Vec = self .snapshots .keys() .filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path)) .cloned() .collect(); // Sort for stable output order unused.sort(); // Create a set of successfully applied paths for efficient lookup let applied_set: HashSet = self.applied.iter().cloned().collect(); // Extract paths that have errors - these are not "missing", they were found but had issues let errored_paths: HashSet = self .errors .iter() .map(|e| match e { ApplyError::ShapeMismatch { path, .. } => path.clone(), ApplyError::DTypeMismatch { path, .. } => path.clone(), ApplyError::AdapterError { path, .. } => path.clone(), ApplyError::LoadError { path, .. } => path.clone(), }) .collect(); // A path is missing if it was visited but not successfully applied, not skipped, and didn't have an error // Store both the path and its container stack (in dot notation) let mut missing: Vec<(String, String)> = self .visited_paths .into_iter() .filter(|(p, _)| { !applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p) }) .collect(); // Sort for stable output order (by path) missing.sort_by(|a, b| a.0.cmp(&b.0)); // Convert skipped HashSet to sorted Vec for stable output let mut skipped: Vec = self.skipped.into_iter().collect(); skipped.sort(); ApplyResult { applied: self.applied, skipped, missing, unused, errors: self.errors, } } /// Apply a tensor snapshot with shape validation and optional adapter transformation /// Returns None if snapshot not found, filtered, or validation fails fn apply_tensor( &mut self, target_device: &B::Device, target_shape: Shape, ) -> Option> where K: burn_tensor::TensorKind, K: burn_tensor::BasicOps, { let path = self.current_path(); let container_stack_str = self.container_stack.join("."); self.visited_paths.insert(path.clone(), container_stack_str); // Try to get snapshot with original path first let mut snapshot = self.snapshots.get(&path).cloned(); // If not found and we have an adapter, try alternative parameter names if snapshot.is_none() && let Some(ref adapter) = self.adapter && let Some(module_type) = self.current_module_type() { // Get alternative name based on current module type (user-defined module only) let param_name = self.path_stack.last()?; if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) { // Build alternative path with parameter name substitution let mut alt_path_stack = self.path_stack.clone(); *alt_path_stack.last_mut().unwrap() = alt_name.clone(); let alt_path = alt_path_stack.join("."); // Try to get snapshot with alternative name snapshot = self.snapshots.get(&alt_path).cloned(); // Don't mark the alternative path as visited - only the original Burn path // should be tracked. The alternative path is just for lookup. } } let mut snapshot = snapshot?; // Apply adapter transformation using current container_stack context (for data transformation like transpose) if let Some(ref adapter) = self.adapter { // Create a temporary snapshot with current context for adaptation let snapshot_with_context = TensorSnapshot::from_closure( snapshot.clone_data_fn(), snapshot.dtype, snapshot.shape.clone(), self.path_stack.clone(), self.container_stack.clone(), snapshot.tensor_id.unwrap_or_default(), ); // Transform using adapter (handles transpose) snapshot = adapter.adapt(&snapshot_with_context); } // Check if we should apply based on filter if !self.should_apply() { self.skipped.insert(path.clone()); return None; } // Load tensor data let data = match snapshot.to_data() { Ok(data) => data, Err(e) => { self.errors.push(ApplyError::LoadError { path: path.clone(), message: format!("Failed to load tensor data: {:?}", e), }); return None; // Signal caller to fall back to initialization } }; // Validate shape if data.shape != target_shape { self.errors.push(ApplyError::ShapeMismatch { path: path.clone(), expected: target_shape, found: data.shape, }); return None; // Signal caller to fall back to initialization } self.applied.push(path); Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype)) } } impl ModuleMapper for Applier { fn enter_module(&mut self, name: &str, container_type: &str) { // Always track the container type for proper module type detection self.container_stack.push(container_type.to_string()); // Only add to path if it's not an enum variant (when skip_enum_variants is enabled) // This ensures paths are built without enum variant names from the start if !self.skip_enum_variants || !container_type.starts_with("Enum:") { self.path_stack.push(name.to_string()); } } fn exit_module(&mut self, _name: &str, container_type: &str) { self.container_stack.pop(); // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled) if !self.skip_enum_variants || !container_type.starts_with("Enum:") { self.path_stack.pop(); } } fn map_float(&mut self, param: Param>) -> Param> { let param_id = param.id; let target_device = param.lazy_device(); let target_shape = param.lazy_shape(); // Try to apply snapshot with shape validation match self.apply_tensor(&target_device, target_shape) { Some(tensor) => { // We have a tensor to apply - load it param.transform_for_load(tensor, param_id) } None => { // No snapshot, filtered, or validation failed - return param unchanged param } } } fn map_int( &mut self, param: Param>, ) -> Param> { let param_id = param.id; let target_device = param.lazy_device(); let target_shape = param.lazy_shape(); // Try to apply snapshot with shape validation match self.apply_tensor(&target_device, target_shape) { Some(tensor) => { // We have a tensor to apply - load it param.transform_for_load(tensor, param_id) } None => { // No snapshot, filtered, or validation failed - return param unchanged param } } } fn map_bool( &mut self, param: Param>, ) -> Param> { let param_id = param.id; let target_device = param.lazy_device(); let target_shape = param.lazy_shape(); // Try to apply snapshot with shape validation match self.apply_tensor(&target_device, target_shape) { Some(tensor) => { // We have a tensor to apply - load it param.transform_for_load(tensor, param_id) } None => { // No snapshot, filtered, or validation failed - return param unchanged param } } } } #[cfg(all(test, feature = "std", target_has_atomic = "ptr"))] mod tests { use super::*; use burn_core::module::{ModuleMapper, Param, ParamId}; use burn_tensor::{DType, Tensor, TensorData}; type TestBackend = burn_ndarray::NdArray; #[test] fn root_level_parameters() { let device = Default::default(); // Create root-level parameters (not inside any module) let weight = Param::>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let bias = Param::>::from_data([5.0, 6.0], &device); // Create snapshots with root-level paths (single-element path, no nested modules) let weight_snapshot = crate::TensorSnapshot::from_data( weight.val().to_data(), vec!["weight".to_string()], // root-level parameter name vec![], // no container ParamId::new(), ); let bias_snapshot = crate::TensorSnapshot::from_data( bias.val().to_data(), vec!["bias".to_string()], // root-level parameter name vec![], // no container ParamId::new(), ); // Create applier with root-level snapshots let mut applier = Applier::::new(vec![weight_snapshot, bias_snapshot], None, None, false); // Create new params to load into let weight_target = Param::initialized( ParamId::new(), Tensor::::zeros([2, 2], &device), ); let bias_target = Param::initialized( ParamId::new(), Tensor::::zeros([2], &device), ); // Apply using the ModuleMapper interface - simulate module traversal // Enter "weight" path (as if we're visiting a field named "weight") applier.enter_module("weight", ""); let weight_loaded = applier.map_float(weight_target); applier.exit_module("weight", ""); // Enter "bias" path (as if we're visiting a field named "bias") applier.enter_module("bias", ""); let bias_loaded = applier.map_float(bias_target); applier.exit_module("bias", ""); // Verify values were loaded let weight_data = weight_loaded.val().to_data().to_vec::().unwrap(); let bias_data = bias_loaded.val().to_data().to_vec::().unwrap(); assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); assert_eq!(bias_data, vec![5.0, 6.0]); // Verify applier result let result = applier.into_result(); assert_eq!(result.applied.len(), 2); assert_eq!(result.errors.len(), 0); } /// Test that the applier preserves dtype when loading tensor data. /// This is a regression test for the bug where F16 tensors were being /// loaded as F32 because `Tensor::from_data` was used instead of /// `Tensor::from_data_dtype`. #[test] fn dtype_preservation_f64() { // Use NdArray backend to properly test F64 dtype preservation type TestBackendF64 = burn_ndarray::NdArray; let device = Default::default(); // Create TensorData with F64 dtype explicitly let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]); assert_eq!(f64_data.dtype, DType::F64, "Test setup: data should be F64"); // Create a snapshot with F64 data let snapshot = crate::TensorSnapshot::from_data( f64_data.clone(), vec!["weight".to_string()], vec![], ParamId::new(), ); assert_eq!( snapshot.dtype, DType::F64, "Snapshot should preserve F64 dtype" ); // Create applier with the F64 snapshot let mut applier = Applier::::new(vec![snapshot], None, None, false); // Create target parameter let target = Param::initialized( ParamId::new(), Tensor::::zeros([2, 2], &device), ); // Apply the snapshot applier.enter_module("weight", ""); let loaded = applier.map_float(target); applier.exit_module("weight", ""); // Verify dtype is preserved - this would fail before the fix // because the data would be converted to the backend's default FloatElem assert_eq!( loaded.val().dtype(), DType::F64, "Loaded tensor should have F64 dtype" ); // Verify data values are correct let loaded_data = loaded.val().to_data().to_vec::().unwrap(); assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]); // Verify applier result let result = applier.into_result(); assert_eq!(result.applied.len(), 1); assert_eq!(result.errors.len(), 0); } /// Test that F32 dtype is preserved when loading (verifies we didn't break F32 handling) #[test] fn dtype_preservation_f32() { let device = Default::default(); // Create TensorData with F32 dtype let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]); assert_eq!(f32_data.dtype, DType::F32); // Create a snapshot with F32 data let snapshot = crate::TensorSnapshot::from_data( f32_data.clone(), vec!["weight".to_string()], vec![], ParamId::new(), ); assert_eq!(snapshot.dtype, DType::F32); // Create applier with the F32 snapshot let mut applier = Applier::::new(vec![snapshot], None, None, false); // Create target parameter let target = Param::initialized( ParamId::new(), Tensor::::zeros([2, 2], &device), ); // Apply the snapshot applier.enter_module("weight", ""); let loaded = applier.map_float(target); applier.exit_module("weight", ""); // Verify dtype is F32 assert_eq!(loaded.val().dtype(), DType::F32); // Verify data values let loaded_data = loaded.val().to_data().to_vec::().unwrap(); assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]); } /// Test that F16 dtype is correctly preserved in TensorSnapshot. /// /// Note: Full F16 tensor loading requires a backend that supports F16 /// (e.g., CUDA, WebGPU). The NdArray backend does not support F16. /// This test verifies that the snapshot correctly preserves F16 dtype, /// which is the key part of the dtype preservation fix. #[test] fn dtype_preservation_f16_snapshot() { use half::f16; // Create TensorData with F16 dtype using the half crate let f16_values: Vec = vec![ f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0), f16::from_f32(4.0), ]; let f16_data = TensorData::new(f16_values.clone(), [2, 2]); assert_eq!( f16_data.dtype, DType::F16, "TensorData should have F16 dtype" ); // Create a snapshot with F16 data let snapshot = crate::TensorSnapshot::from_data( f16_data.clone(), vec!["weight".to_string()], vec![], ParamId::new(), ); // Verify snapshot preserves F16 dtype assert_eq!( snapshot.dtype, DType::F16, "TensorSnapshot should preserve F16 dtype" ); // Verify the data can be retrieved with correct dtype let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data"); assert_eq!( retrieved_data.dtype, DType::F16, "Retrieved data should have F16 dtype" ); // Verify the actual values are preserved let retrieved_values: Vec = retrieved_data .to_vec() .expect("Should be able to convert to f16 vec"); assert_eq!( retrieved_values, f16_values, "F16 values should be preserved" ); // Note: To fully test F16 tensor creation, you would need a backend // that supports F16 (like CUDA or WebGPU). The applier fix ensures // that `Tensor::from_data_dtype(data, device, snapshot.dtype)` is // called with DType::F16, which will correctly create an F16 tensor // on backends that support it. } /// Test that BF16 dtype is correctly preserved in TensorSnapshot. #[test] fn dtype_preservation_bf16_snapshot() { use half::bf16; // Create TensorData with BF16 dtype let bf16_values: Vec = vec![ bf16::from_f32(1.0), bf16::from_f32(2.0), bf16::from_f32(3.0), bf16::from_f32(4.0), ]; let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]); assert_eq!( bf16_data.dtype, DType::BF16, "TensorData should have BF16 dtype" ); // Create a snapshot with BF16 data let snapshot = crate::TensorSnapshot::from_data( bf16_data.clone(), vec!["weight".to_string()], vec![], ParamId::new(), ); // Verify snapshot preserves BF16 dtype assert_eq!( snapshot.dtype, DType::BF16, "TensorSnapshot should preserve BF16 dtype" ); // Verify the data can be retrieved with correct dtype let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data"); assert_eq!( retrieved_data.dtype, DType::BF16, "Retrieved data should have BF16 dtype" ); // Verify the actual values are preserved let retrieved_values: Vec = retrieved_data .to_vec() .expect("Should be able to convert to bf16 vec"); assert_eq!( retrieved_values, bf16_values, "BF16 values should be preserved" ); } } ================================================ FILE: crates/burn-store/src/apply_result.rs ================================================ //! Result types and diagnostics for tensor application operations use alloc::string::String; use alloc::vec; use alloc::vec::Vec; use burn_tensor::{DType, Shape}; /// Error types that can occur during tensor application #[derive(Debug, Clone)] pub enum ApplyError { /// Shape mismatch between expected and actual tensor ShapeMismatch { /// Path of the tensor path: String, /// Expected shape expected: Shape, /// Found shape found: Shape, }, /// Data type mismatch between expected and actual tensor DTypeMismatch { /// Path of the tensor path: String, /// Expected data type expected: DType, /// Found data type found: DType, }, /// Error from adapter transformation AdapterError { /// Path of the tensor path: String, /// Error message message: String, }, /// Error loading tensor data LoadError { /// Path of the tensor path: String, /// Error message message: String, }, } impl core::fmt::Display for ApplyError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::ShapeMismatch { path, expected, found, } => { write!( f, "Shape mismatch for '{}': expected {:?}, found {:?}", path, expected, found ) } Self::DTypeMismatch { path, expected, found, } => { write!( f, "DType mismatch for '{}': expected {:?}, found {:?}", path, expected, found ) } Self::AdapterError { path, message } => { write!(f, "Adapter error for '{}': {}", path, message) } Self::LoadError { path, message } => { write!(f, "Load error for '{}': {}", path, message) } } } } impl core::error::Error for ApplyError {} /// Result of applying tensor snapshots to a module #[derive(Clone)] pub struct ApplyResult { /// Successfully applied tensor paths pub applied: Vec, /// Skipped tensor paths (due to filter) pub skipped: Vec, /// Missing tensor paths with their container stacks in dot notation (path, container_stack) /// Container stack shows the hierarchy: "Struct:Model.Struct:Linear" or "Struct:Model.Enum:ConvType.Struct:Linear" pub missing: Vec<(String, String)>, /// Unused tensor paths (in snapshots but not in module) pub unused: Vec, /// Errors encountered during application pub errors: Vec, } impl ApplyResult { /// Try to strip enum variant from a path /// e.g., "field.BaseConv.weight" -> "field.weight" fn strip_enum_variant(path: &str) -> Option { let segments: Vec<&str> = path.split('.').collect(); // Find segments that look like enum variants (CamelCase in middle of path) let variant_indices: Vec = segments .iter() .enumerate() .filter(|(i, segment)| { *i > 0 && *i < segments.len() - 1 // Not first or last && !segment.is_empty() && segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false) && segment.len() > 1 && segment.chars().skip(1).any(|c| c.is_lowercase()) }) .map(|(i, _)| i) .collect(); if variant_indices.is_empty() { return None; } // Remove the first found variant and return the modified path let mut result_segments = segments.clone(); result_segments.remove(variant_indices[0]); Some(result_segments.join(".")) } /// Find similar paths for a given missing path (for "Did you mean?" suggestions) fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec { // First, try exact match with enum variant stripped if let Some(stripped) = Self::strip_enum_variant(missing_path) && self.unused.contains(&stripped) { return vec![stripped]; } // Fall back to Jaro similarity (used by Elixir for "did you mean?" suggestions) // Jaro gives higher weight to matching prefixes, ideal for hierarchical tensor paths let mut similarities: Vec<(String, f64)> = self .unused .iter() .map(|available| { let similarity = textdistance::nstr::jaro(missing_path, available); (available.clone(), similarity) }) .collect(); // Sort by similarity (higher = more similar) similarities .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal)); // Only suggest paths with >= 70% similarity const SIMILARITY_THRESHOLD: f64 = 0.7; similarities .into_iter() .filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD) .take(max_suggestions) .map(|(path, _)| path) .collect() } } impl ApplyResult { /// Check if the apply operation was successful (no errors) /// Note: Missing tensors are not considered errors when allow_partial is true pub fn is_success(&self) -> bool { self.errors.is_empty() } } impl core::fmt::Debug for ApplyResult { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { // Delegate to Display for comprehensive output core::fmt::Display::fmt(self, f) } } impl core::fmt::Display for ApplyResult { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { writeln!(f, "┌─ Tensor Loading Summary ─────────────────────────")?; writeln!(f, "│")?; writeln!( f, "│ ✓ Successfully applied: {} tensors", self.applied.len() )?; writeln!(f, "│ ⊘ Skipped (filtered): {} tensors", self.skipped.len())?; writeln!( f, "│ ✗ Missing in source: {} tensors", self.missing.len() )?; writeln!(f, "│ ? Unused in target: {} tensors", self.unused.len())?; writeln!(f, "│ ! Errors: {} errors", self.errors.len())?; if !self.missing.is_empty() { writeln!(f, "│")?; writeln!( f, "├─ Missing Tensors (requested by model but not found in source)" )?; writeln!(f, "│")?; // Use actual container stack data to detect enum variants // Count how many missing paths have "Enum:" in their container stack let enum_variant_missing: Vec<_> = self .missing .iter() .filter(|(_, stack)| stack.contains("Enum:")) .collect(); if !enum_variant_missing.is_empty() { writeln!( f, "│ ⚠️ {} paths contain enum variants (detected from container stack)", enum_variant_missing.len() )?; writeln!( f, "│ Burn includes enum variant names in paths, but PyTorch doesn't." )?; writeln!( f, "│ Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'" )?; writeln!(f, "│")?; writeln!( f, "│ 💡 Solution 1: Enable skip_enum_variants flag (simplest):" )?; writeln!(f, "│")?; writeln!( f, "│ let mut store = PytorchStore::from_file(\"model.pth\")" )?; writeln!(f, "│ .skip_enum_variants(true); // ← Add this")?; writeln!(f, "│")?; writeln!( f, "│ 💡 Solution 2: Remap enum keys in source (most precise):" )?; writeln!(f, "│")?; writeln!( f, "│ let mut store = SafetensorsStore::from_file(\"model.safetensors\")" )?; writeln!( f, "│ .with_key_remapping(r\"field\\.(\\w+)\", \"field.BaseConv.$1\");" )?; writeln!(f, "│")?; } writeln!(f, "│ First 10 missing tensors:")?; for (path, _) in self.missing.iter().take(10) { writeln!(f, "│ • {}", path)?; // Show "Did you mean?" suggestions for this path let suggestions = self.find_similar_paths(path, 1); if !suggestions.is_empty() { writeln!(f, "│ Did you mean: '{}'?", suggestions[0])?; } } if self.missing.len() > 10 { writeln!(f, "│ ... and {} more", self.missing.len() - 10)?; } } if !self.unused.is_empty() { writeln!(f, "│")?; writeln!(f, "├─ Unused Tensors (in source but not used by model)")?; writeln!(f, "│")?; writeln!(f, "│ First 10 unused tensors:")?; for path in self.unused.iter().take(10) { writeln!(f, "│ • {}", path)?; } if self.unused.len() > 10 { writeln!(f, "│ ... and {} more", self.unused.len() - 10)?; } } if !self.errors.is_empty() { writeln!(f, "│")?; writeln!(f, "├─ Errors")?; writeln!(f, "│")?; for error in self.errors.iter().take(10) { writeln!(f, "│ ⚠️ {}", error)?; } if self.errors.len() > 10 { writeln!(f, "│ ... and {} more", self.errors.len() - 10)?; } } writeln!(f, "│")?; write!(f, "└───────────────────────────────────────────────────")?; Ok(()) } } ================================================ FILE: crates/burn-store/src/burnpack/base.rs ================================================ //! Core types and constants for the Burnpack file format. //! //! See the [parent module](crate::burnpack) for the complete file format specification. use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use burn_tensor::DType; use byteorder::{ByteOrder, LittleEndian}; use serde::{Deserialize, Serialize}; /// Magic number identifying a Burnpack file: "BURN" in ASCII (0x4255524E) /// When written to file in little-endian format, appears as "NRUB" bytes pub const MAGIC_NUMBER: u32 = 0x4255524E; /// Current format version pub const FORMAT_VERSION: u16 = 0x0001; /// Size of the magic number in bytes pub const MAGIC_SIZE: usize = 4; /// Size of the format version in bytes pub const VERSION_SIZE: usize = 2; /// Size of the metadata size field in bytes pub const METADATA_SIZE_FIELD_SIZE: usize = 4; /// Total header size (computed from components) pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE; /// Alignment for tensor data in bytes. /// /// All tensor data is aligned to 256-byte boundaries to enable efficient /// memory-mapped (mmap) zero-copy loading. This alignment ensures: /// - Proper pointer alignment for all tensor element types (f64 requires 8-byte alignment) /// - Cache-line friendly access (most CPUs use 64-byte cache lines) /// - GPU memory alignment (CUDA prefers 256-byte for coalesced access) /// - Future-proofing for wider SIMD (AVX-512 = 64 bytes, future AVX-1024 = 128 bytes) /// /// Industry alignment choices: /// - 256-byte: GGUF, MLX, ncnn, MNN, TNN, vLLM-AWQ, Marlin (15+ formats) /// - 64-byte: SafeTensors (minimum for AVX-512) /// - 4096-byte: Core ML /// /// 256-byte alignment has negligible overhead for typical tensor sizes while /// providing maximum compatibility with current and future hardware. pub const TENSOR_ALIGNMENT: u64 = 256; /// Calculate the byte offset where the tensor data section starts. /// /// The data section is padded to start at a 256-byte aligned position /// so that all tensor offsets (which are relative to data section) result /// in properly aligned absolute file positions for mmap zero-copy access. /// /// This function must be used consistently by both writer and reader. #[inline] pub fn aligned_data_section_start(metadata_size: usize) -> usize { let unaligned_start = (HEADER_SIZE + metadata_size) as u64; // Keep multiplication in u64 space to avoid overflow on 32-bit systems (unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize } // Security limits to prevent DoS attacks via resource exhaustion // These can be adjusted based on your use case /// Maximum allowed metadata size (100 MB) /// Prevents memory exhaustion attacks via oversized metadata claims pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024; /// Maximum allowed tensor size per tensor /// Prevents memory exhaustion attacks via oversized tensor claims /// 32-bit platforms: 2 GB limit (to fit within usize range) /// 64-bit platforms: 10 GB limit #[cfg(target_pointer_width = "32")] pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024; #[cfg(not(target_pointer_width = "32"))] pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024; /// Maximum allowed number of tensors (100,000) /// Prevents resource exhaustion via excessive tensor counts pub const MAX_TENSOR_COUNT: usize = 100_000; /// Maximum CBOR deserialization recursion depth (128 levels) /// Prevents stack overflow attacks via deeply nested CBOR structures pub const MAX_CBOR_RECURSION_DEPTH: usize = 128; /// Maximum allowed file size (100 GB) /// Prevents resource exhaustion from extremely large files /// This limit applies to file-based loading (mmap and buffered) #[cfg(feature = "std")] pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024; /// Byte range for magic number in header pub const fn magic_range() -> core::ops::Range { let start = 0; let end = start + MAGIC_SIZE; start..end } /// Byte range for format version in header pub const fn version_range() -> core::ops::Range { let start = MAGIC_SIZE; let end = start + VERSION_SIZE; start..end } /// Byte range for metadata size field in header pub const fn metadata_size_range() -> core::ops::Range { let start = MAGIC_SIZE + VERSION_SIZE; let end = start + METADATA_SIZE_FIELD_SIZE; start..end } // Compile-time validation that ranges are correct const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE); /// Header structure for Burnpack files #[derive(Debug, Clone, Copy)] pub struct BurnpackHeader { /// Magic number (4 bytes): 0x4255524E ("BURN") pub magic: u32, /// Format version (2 bytes) pub version: u16, /// Size of CBOR metadata in bytes (4 bytes) pub metadata_size: u32, } impl BurnpackHeader { /// Create a new header with the given metadata size #[allow(dead_code)] pub fn new(metadata_size: u32) -> Self { Self { magic: MAGIC_NUMBER, version: FORMAT_VERSION, metadata_size, } } /// Serialize header into bytes pub fn into_bytes(self) -> [u8; HEADER_SIZE] { let mut bytes = [0u8; HEADER_SIZE]; LittleEndian::write_u32(&mut bytes[magic_range()], self.magic); LittleEndian::write_u16(&mut bytes[version_range()], self.version); LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size); bytes } /// Deserialize header from bytes pub fn from_bytes(bytes: &[u8]) -> Result { if bytes.len() < HEADER_SIZE { return Err(BurnpackError::InvalidHeader); } let magic = LittleEndian::read_u32(&bytes[magic_range()]); if magic != MAGIC_NUMBER { return Err(BurnpackError::InvalidMagicNumber); } let version = LittleEndian::read_u16(&bytes[version_range()]); let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]); Ok(Self { magic, version, metadata_size, }) } } /// Metadata structure serialized with CBOR #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BurnpackMetadata { /// Tensor descriptors mapped by name for efficient lookup pub tensors: BTreeMap, /// Optional additional metadata #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] pub metadata: BTreeMap, } /// Individual tensor descriptor #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TensorDescriptor { /// Data type of the tensor pub dtype: DType, /// Tensor shape dimensions pub shape: Vec, /// Byte offsets in data section (start, end) pub data_offsets: (u64, u64), /// Parameter ID for training state persistence matching. /// Generated automatically if not present during loading. #[serde(default, skip_serializing_if = "Option::is_none")] pub param_id: Option, } /// Error types for Burnpack operations #[derive(Debug)] pub enum BurnpackError { InvalidHeader, InvalidMagicNumber, InvalidVersion, MetadataSerializationError(String), MetadataDeserializationError(String), IoError(String), TensorNotFound(String), TensorBytesSizeMismatch(String), ValidationError(String), } impl core::fmt::Display for BurnpackError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"), BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"), BurnpackError::InvalidVersion => write!(f, "Unsupported version"), BurnpackError::MetadataSerializationError(e) => { write!(f, "Metadata serialization error: {}", e) } BurnpackError::MetadataDeserializationError(e) => { write!(f, "Metadata deserialization error: {}", e) } BurnpackError::IoError(e) => write!(f, "I/O error: {}", e), BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name), BurnpackError::TensorBytesSizeMismatch(e) => { write!(f, "Tensor bytes size mismatch: {}", e) } BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e), } } } impl core::error::Error for BurnpackError {} ================================================ FILE: crates/burn-store/src/burnpack/mod.rs ================================================ //! # Burnpack - Native Burn Model Storage Format //! //! Burnpack is the native binary storage format for Burn models, designed for efficient //! serialization, fast loading, and cross-platform compatibility. //! //! ## Key Features //! //! - **CBOR Metadata**: Structured metadata with efficient binary encoding //! - **Memory-Mapped Loading**: Zero-copy loading for optimal performance //! - **256-byte Tensor Alignment**: Enables efficient mmap zero-copy access //! - **No-std Support**: Works in embedded and WASM environments //! - **ParamId Persistence**: Preserves parameter identities for stateful training //! - **Lazy Tensor Loading**: Deferred data materialization for efficient memory usage //! //! ## File Format Structure //! //! ```text //! ┌──────────────────────────────────┐ //! │ Header (10 bytes) │ //! ├──────────────────────────────────┤ //! │ - Magic number (4 bytes) │ 0x4E525542 ("NRUB" in LE) //! │ - Version (2 bytes) │ Format version (0x0001) //! │ - Metadata size (4 bytes) │ Size of CBOR metadata (u32) //! ├──────────────────────────────────┤ //! │ Metadata (CBOR) │ //! ├──────────────────────────────────┤ //! │ - Tensor descriptors (BTreeMap) │ Order-preserving map of tensor metadata //! │ Key: tensor name (string) │ e.g., "model.layer1.weight" //! │ Value: TensorDescriptor │ //! │ - dtype: DType │ Data type (F32, F64, I32, etc.) //! │ - shape: Vec │ Tensor dimensions //! │ - data_offsets: (u64, u64) │ (start, end) byte offsets (256-byte aligned) //! │ - param_id: Option │ Parameter ID (for training state) //! │ - Additional metadata(BTreeMap) │ User-defined key-value pairs //! ├──────────────────────────────────┤ //! │ Tensor Data Section │ //! ├──────────────────────────────────┤ //! │ [padding][tensor1][padding]... │ Each tensor aligned to 256-byte boundary //! │ Raw tensor bytes (little-endian)│ Enables mmap zero-copy loading //! └──────────────────────────────────┘ //! ``` //! //! ## Tensor Alignment //! //! All tensor data is aligned to 256-byte boundaries to enable efficient memory-mapped //! (mmap) zero-copy loading. This alignment ensures: //! //! - Proper pointer alignment for all tensor element types (f64 requires 8 bytes) //! - Cache-line friendly access (most CPUs use 64-byte cache lines) //! - GPU memory alignment (CUDA prefers 256-byte for coalesced access) //! - Future-proofing for wider SIMD instructions (AVX-512, future AVX-1024) //! //! The 256-byte alignment matches industry standards used by GGUF, MLX, ncnn, MNN, //! and other major model formats. pub mod base; pub mod reader; pub mod store; pub mod writer; #[cfg(test)] mod tests; ================================================ FILE: crates/burn-store/src/burnpack/reader.rs ================================================ #[cfg(feature = "std")] use super::base::MAX_FILE_SIZE; use super::base::{ BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE, aligned_data_section_start, }; use crate::TensorSnapshot; use alloc::format; use alloc::rc::Rc; use alloc::string::ToString; use alloc::vec; use alloc::vec::Vec; use burn_core::module::ParamId; use burn_tensor::{Bytes, Shape, TensorData}; #[cfg(feature = "std")] use std::cell::RefCell; #[cfg(feature = "std")] use std::fs::File; #[cfg(feature = "std")] use std::io::{Read, Seek}; #[cfg(feature = "std")] use std::path::Path; /// Storage backend for BurnpackReader pub(crate) enum StorageBackend { /// Memory-based storage (also used for memory-mapped files converted to bytes::Bytes) Memory(Rc), /// File-based storage with buffered reading #[cfg(feature = "std")] #[allow(dead_code)] FileBuffered { file: Rc> }, } impl StorageBackend { /// Read data from storage into the provided buffer at the given offset. /// /// # Arguments /// * `bytes` - The buffer to read into (caller-allocated) /// * `offset` - Absolute file/data position to start reading from /// /// # Errors /// /// Returns an error if: /// - The requested data range is out of bounds /// - Less data is available than requested (indicates corruption or incorrect offset) /// - File I/O fails /// /// # Notes /// /// The caller allocates the buffer, which allows for buffer reuse and future optimizations /// like memory pools and pinned memory. /// /// This method ensures all backends have consistent behavior: if the exact number of /// requested bytes cannot be read, an error is returned to prevent data corruption. pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> { match self { StorageBackend::Memory(data) => { let data_bytes = data.as_ref(); let end = offset.checked_add(bytes.len()).ok_or_else(|| { BurnpackError::IoError(format!( "Offset overflow: offset {} + length {} exceeds maximum", offset, bytes.len() )) })?; if end > data_bytes.len() { return Err(BurnpackError::IoError(format!( "Read out of bounds: requested {}..{} but data length is {}", offset, end, data_bytes.len() ))); } bytes.copy_from_slice(&data_bytes[offset..end]); Ok(()) } #[cfg(feature = "std")] StorageBackend::FileBuffered { file } => { use std::io::SeekFrom; let mut file = file.borrow_mut(); file.seek(SeekFrom::Start(offset as u64)).map_err(|e| { BurnpackError::IoError(format!("Failed to seek in file: {}", e)) })?; file.read_exact(bytes).map_err(|e| { BurnpackError::IoError(format!("Failed to read from file: {}", e)) })?; Ok(()) } } } /// Get full data reference for raw access #[allow(dead_code)] pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> { match self { StorageBackend::Memory(data) => Ok(data.as_ref()), #[cfg(feature = "std")] StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError( "Cannot get full bytes reference for FileBuffered backend".into(), )), } } /// Attempt to slice bytes without copying (zero-copy). /// /// This uses `Bytes::clone()` + `split()` which is zero-copy when the underlying /// `Bytes` was created via `Bytes::from_shared()` (backed by `bytes::Bytes`). /// /// # Returns /// - `Ok(bytes)` - Successfully created a zero-copy slice /// - `Err(_)` - Backend doesn't support zero-copy or split failed pub(crate) fn slice_bytes(&self, start: usize, end: usize) -> Result { if end < start { return Err(BurnpackError::IoError(format!( "Invalid slice range: end ({}) < start ({})", end, start ))); } match self { StorageBackend::Memory(data) => { // Clone the Bytes - cheap if backed by SharedBytesAllocationController let cloned = (**data).clone(); // Split at start offset to get (_, right) let (_, right) = cloned.split(start).map_err(|(_, e)| { BurnpackError::IoError(format!("Failed to split at start {}: {:?}", start, e)) })?; // Split right at (end - start) to get (middle, _) let slice_len = end - start; let (middle, _) = right.split(slice_len).map_err(|(_, e)| { BurnpackError::IoError(format!( "Failed to split at length {}: {:?}", slice_len, e )) })?; Ok(middle) } #[cfg(feature = "std")] StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError( "Zero-copy not supported for buffered file reading. Use from_file() with memmap feature for zero-copy loading.".into(), )), } } } /// Reader for loading Burnpack files pub struct BurnpackReader { /// Parsed metadata pub(crate) metadata: BurnpackMetadata, /// Storage backend pub(crate) storage: StorageBackend, /// Offset to the start of tensor data pub(crate) data_offset: usize, } impl BurnpackReader { /// Load from bytes pub fn from_bytes(bytes: Bytes) -> Result { // Validate minimum size if bytes.len() < HEADER_SIZE { return Err(BurnpackError::InvalidHeader); } // Parse header let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?; // Verify magic number if header.magic != MAGIC_NUMBER { return Err(BurnpackError::InvalidMagicNumber); } // Verify version compatibility if header.version > FORMAT_VERSION { return Err(BurnpackError::InvalidVersion); } // Validate metadata size against security limit if header.metadata_size > MAX_METADATA_SIZE { return Err(BurnpackError::ValidationError(format!( "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", header.metadata_size, MAX_METADATA_SIZE ))); } // Parse metadata let metadata_start = HEADER_SIZE; let metadata_end = metadata_start .checked_add(header.metadata_size as usize) .ok_or_else(|| { BurnpackError::IoError(format!( "Metadata size overflow: {} + {}", metadata_start, header.metadata_size )) })?; if bytes.len() < metadata_end { return Err(BurnpackError::InvalidHeader); } let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( &bytes[metadata_start..metadata_end], MAX_CBOR_RECURSION_DEPTH, ) .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; // Validate tensor count against security limit if metadata.tensors.len() > MAX_TENSOR_COUNT { return Err(BurnpackError::ValidationError(format!( "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", metadata.tensors.len(), MAX_TENSOR_COUNT ))); } // Validate total file size - ensure file is large enough for all claimed tensor data if !metadata.tensors.is_empty() { let max_data_offset = metadata .tensors .values() .map(|t| t.data_offsets.1) .max() .unwrap_or(0); let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Data offset {} exceeds platform maximum", max_data_offset )) })?; let min_file_size = metadata_end .checked_add(max_data_offset_usize) .ok_or_else(|| { BurnpackError::ValidationError("File size calculation overflow".into()) })?; if bytes.len() < min_file_size { return Err(BurnpackError::ValidationError(format!( "File truncated: expected at least {} bytes, got {} bytes", min_file_size, bytes.len() ))); } } Ok(Self { metadata, storage: StorageBackend::Memory(Rc::new(bytes)), data_offset: aligned_data_section_start(header.metadata_size as usize), }) } /// Load from file with memory mapping (most efficient for large files) #[cfg(all(feature = "std", feature = "memmap"))] pub(crate) fn from_file_mmap>(path: P) -> Result { let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?; // Validate maximum file size to prevent resource exhaustion let file_size = file .metadata() .map_err(|e| BurnpackError::IoError(e.to_string()))? .len(); if file_size > MAX_FILE_SIZE { return Err(BurnpackError::ValidationError(format!( "File size {} bytes exceeds maximum allowed size of {} bytes", file_size, MAX_FILE_SIZE ))); } // Memory map the file let mmap = unsafe { memmap2::MmapOptions::new() .map(&file) .map_err(|e| BurnpackError::IoError(e.to_string()))? }; // Parse header if mmap.len() < HEADER_SIZE { return Err(BurnpackError::InvalidHeader); } let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?; // Verify magic number and version if header.magic != MAGIC_NUMBER { return Err(BurnpackError::InvalidMagicNumber); } if header.version > FORMAT_VERSION { return Err(BurnpackError::InvalidVersion); } // Validate metadata size against security limit if header.metadata_size > MAX_METADATA_SIZE { return Err(BurnpackError::ValidationError(format!( "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", header.metadata_size, MAX_METADATA_SIZE ))); } // Parse metadata let metadata_start = HEADER_SIZE; let metadata_end = metadata_start .checked_add(header.metadata_size as usize) .ok_or_else(|| { BurnpackError::IoError(format!( "Metadata size overflow: {} + {}", metadata_start, header.metadata_size )) })?; if mmap.len() < metadata_end { return Err(BurnpackError::InvalidHeader); } let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( &mmap[metadata_start..metadata_end], MAX_CBOR_RECURSION_DEPTH, ) .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; // Validate tensor count against security limit if metadata.tensors.len() > MAX_TENSOR_COUNT { return Err(BurnpackError::ValidationError(format!( "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", metadata.tensors.len(), MAX_TENSOR_COUNT ))); } // Validate total file size - ensure file is large enough for all claimed tensor data if !metadata.tensors.is_empty() { let max_data_offset = metadata .tensors .values() .map(|t| t.data_offsets.1) .max() .unwrap_or(0); let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Data offset {} exceeds platform maximum", max_data_offset )) })?; let min_file_size = metadata_end .checked_add(max_data_offset_usize) .ok_or_else(|| { BurnpackError::ValidationError("File size calculation overflow".into()) })?; if mmap.len() < min_file_size { return Err(BurnpackError::ValidationError(format!( "File truncated: expected at least {} bytes, got {} bytes", min_file_size, mmap.len() ))); } } // Convert mmap to bytes::Bytes for zero-copy slicing support // bytes::Bytes::from_owner takes ownership and enables efficient slicing let shared_bytes = bytes::Bytes::from_owner(mmap); let bytes = Bytes::from_shared(shared_bytes, burn_tensor::AllocationProperty::File); Ok(Self { metadata, storage: StorageBackend::Memory(Rc::new(bytes)), data_offset: aligned_data_section_start(header.metadata_size as usize), }) } /// Load from file - automatically uses memory mapping if available, otherwise uses buffered reading #[cfg(feature = "std")] pub fn from_file>(path: P) -> Result { #[cfg(feature = "memmap")] { // Use memory mapping for efficient access Self::from_file_mmap(path) } #[cfg(not(feature = "memmap"))] { // Fall back to buffered reading for memory efficiency Self::from_file_buffered(path) } } /// Load from file with buffered reading (memory efficient but slower) /// This is less efficient than memory mapping but works everywhere #[cfg(feature = "std")] #[allow(dead_code)] pub(crate) fn from_file_buffered>(path: P) -> Result { let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?; // Validate maximum file size to prevent resource exhaustion let file_size = file .metadata() .map_err(|e| BurnpackError::IoError(e.to_string()))? .len(); if file_size > MAX_FILE_SIZE { return Err(BurnpackError::ValidationError(format!( "File size {} bytes exceeds maximum allowed size of {} bytes", file_size, MAX_FILE_SIZE ))); } // Read header let mut header_bytes = [0u8; HEADER_SIZE]; file.read_exact(&mut header_bytes) .map_err(|e| BurnpackError::IoError(e.to_string()))?; let header = BurnpackHeader::from_bytes(&header_bytes)?; // Verify version if header.version > FORMAT_VERSION { return Err(BurnpackError::InvalidVersion); } // Validate metadata size against security limit if header.metadata_size > MAX_METADATA_SIZE { return Err(BurnpackError::ValidationError(format!( "Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", header.metadata_size, MAX_METADATA_SIZE ))); } // Read metadata let mut metadata_bytes = vec![0u8; header.metadata_size as usize]; file.read_exact(&mut metadata_bytes) .map_err(|e| BurnpackError::IoError(e.to_string()))?; let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit( metadata_bytes.as_slice(), MAX_CBOR_RECURSION_DEPTH, ) .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?; // Validate tensor count against security limit if metadata.tensors.len() > MAX_TENSOR_COUNT { return Err(BurnpackError::ValidationError(format!( "File contains {} tensors, exceeding maximum of {} (potential DoS attack)", metadata.tensors.len(), MAX_TENSOR_COUNT ))); } // Calculate metadata end offset let metadata_end = HEADER_SIZE .checked_add(header.metadata_size as usize) .ok_or_else(|| { BurnpackError::IoError(format!( "Metadata size overflow: {} + {}", HEADER_SIZE, header.metadata_size )) })?; // Validate total file size - ensure file is large enough for all claimed tensor data if !metadata.tensors.is_empty() { let max_data_offset = metadata .tensors .values() .map(|t| t.data_offsets.1) .max() .unwrap_or(0); let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Data offset {} exceeds platform maximum", max_data_offset )) })?; let min_file_size = metadata_end .checked_add(max_data_offset_usize) .ok_or_else(|| { BurnpackError::ValidationError("File size calculation overflow".into()) })?; // Get actual file size let file_size = file .metadata() .map_err(|e| BurnpackError::IoError(e.to_string()))? .len() as usize; if file_size < min_file_size { return Err(BurnpackError::ValidationError(format!( "File truncated: expected at least {} bytes, got {} bytes", min_file_size, file_size ))); } } Ok(Self { metadata, storage: StorageBackend::FileBuffered { file: Rc::new(RefCell::new(file)), }, data_offset: aligned_data_section_start(header.metadata_size as usize), }) } /// Get all tensor snapshots at once for efficient loading (always copies data) pub fn get_snapshots(&self) -> Result, BurnpackError> { self.get_snapshots_internal(false) } /// Get all tensor snapshots with optional zero-copy loading. /// /// When `zero_copy` is true and the backend supports it (Memory backend with /// `Bytes::from_shared()`), tensor data is sliced without copying. This keeps /// the original data alive as long as any tensor holds a reference. /// /// When `zero_copy` is false or the backend doesn't support it, data is copied /// into newly allocated buffers (default behavior). pub fn get_snapshots_zero_copy( &self, zero_copy: bool, ) -> Result, BurnpackError> { self.get_snapshots_internal(zero_copy) } /// Internal implementation with optional zero-copy support fn get_snapshots_internal( &self, zero_copy: bool, ) -> Result, BurnpackError> { let mut snapshots = Vec::new(); for (name, descriptor) in &self.metadata.tensors { // Clone metadata for use in closure // Convert shape dimensions with overflow checking let shape: Shape = Shape::from(descriptor .shape .iter() .map(|&s| { s.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum", name, s )) }) }) .collect::, BurnpackError>>()?); let dtype = descriptor.dtype; // Clone storage reference for the closure let storage = match &self.storage { StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()), #[cfg(feature = "std")] StorageBackend::FileBuffered { file } => { StorageBackend::FileBuffered { file: file.clone() } } }; // Always use absolute positions for all backends // Convert offsets with overflow checking let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum", name, descriptor.data_offsets.0 )) })?; let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| { BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum", name, descriptor.data_offsets.1 )) })?; let start = self.data_offset.checked_add(offset_start).ok_or_else(|| { BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted offset data: start offset overflow {} + {}", name, self.data_offset, offset_start )) })?; let end = self.data_offset.checked_add(offset_end).ok_or_else(|| { BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted offset data: end offset overflow {} + {}", name, self.data_offset, offset_end )) })?; // Clone shape for the closure (TensorSnapshot::from_closure will also need it) let shape_for_closure = shape.clone(); // Validate offset range if end < start { return Err(BurnpackError::ValidationError(format!( "Tensor '{}' has corrupted offset data: end offset {} < start offset {}", name, end, start ))); } // Validate tensor size against security limit let tensor_size = end - start; if tensor_size > MAX_TENSOR_SIZE { return Err(BurnpackError::ValidationError(format!( "Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)", name, tensor_size, MAX_TENSOR_SIZE ))); } // Restore param_id if it was saved, otherwise generate let tensor_id = descriptor .param_id .map(ParamId::from) .unwrap_or_else(ParamId::new); // Create the data-loading closure based on zero_copy flag let data_fn: Rc Result> = if zero_copy { // Zero-copy closure: slice without copying, error if not supported Rc::new(move || { let bytes = storage.slice_bytes(start, end).map_err(|e| { crate::TensorSnapshotError::IoError(format!( "Zero-copy slice failed: {}", e )) })?; Ok(TensorData::from_bytes( bytes, shape_for_closure.clone(), dtype, )) }) } else { // Copying closure: always allocate and copy Rc::new(move || { let len = end - start; // TODO Should be allocated by the backend in the future // See https://github.com/tracel-ai/burn/pull/3792#discussion_r2416812091 let mut data_bytes = vec![0u8; len]; storage.read_into(&mut data_bytes, start).map_err(|e| { crate::TensorSnapshotError::IoError(format!( "Failed to read tensor data: {}", e )) })?; Ok(TensorData::from_bytes_vec( data_bytes, shape_for_closure.clone(), dtype, )) }) }; // Create lazy TensorSnapshot let snapshot = TensorSnapshot::from_closure( data_fn, dtype, shape, name.split('.').map(|s| s.to_string()).collect(), vec![], // empty container_stack tensor_id, // restored or newly generated param id ); snapshots.push(snapshot); } Ok(snapshots) } // Legacy methods for test compatibility - will be removed /// Get tensor as TensorSnapshot with lazy loading #[allow(dead_code)] pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result { let snapshots = self.get_snapshots()?; snapshots .into_iter() .find(|s| s.full_path() == name) .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string())) } /// Get list of tensor names #[allow(dead_code)] pub(crate) fn tensor_names(&self) -> Vec<&str> { self.metadata .tensors .keys() .map(|name| name.as_str()) .collect() } /// Get metadata #[allow(dead_code)] pub(crate) fn metadata(&self) -> &BurnpackMetadata { &self.metadata } /// Get tensor data as raw bytes #[allow(dead_code)] pub(crate) fn get_tensor_data(&self, name: &str) -> Result, BurnpackError> { let descriptor = self .metadata .tensors .get(name) .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?; // Always use absolute positions for all backends // Convert offsets with overflow checking let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| { BurnpackError::IoError(format!( "Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum", name, descriptor.data_offsets.0 )) })?; let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| { BurnpackError::IoError(format!( "Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum", name, descriptor.data_offsets.1 )) })?; let start = self.data_offset.checked_add(offset_start).ok_or_else(|| { BurnpackError::IoError(format!( "Tensor '{}' has corrupted offset data: start offset overflow {} + {}", name, self.data_offset, offset_start )) })?; let end = self.data_offset.checked_add(offset_end).ok_or_else(|| { BurnpackError::IoError(format!( "Tensor '{}' has corrupted offset data: end offset overflow {} + {}", name, self.data_offset, offset_end )) })?; // Validate offset range if end < start { return Err(BurnpackError::IoError(format!( "Tensor '{}' has corrupted offset data: end offset {} < start offset {}", name, end, start ))); } let len = end - start; let mut buffer = vec![0u8; len]; self.storage.read_into(&mut buffer, start)?; Ok(buffer) } } ================================================ FILE: crates/burn-store/src/burnpack/store.rs ================================================ #[cfg(feature = "std")] use std::path::PathBuf; use super::reader::BurnpackReader; use super::writer::BurnpackWriter; #[cfg(feature = "std")] use crate::KeyRemapper; use crate::burnpack::base::BurnpackError; use crate::{ IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot, }; use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::String; use alloc::vec::Vec; use burn_core::prelude::Backend; use burn_tensor::Bytes; /// Store mode for BurnpackStore enum StoreMode { #[cfg(feature = "std")] File(PathBuf), Bytes(Option), } /// BurnpackStore - A Burn-specific file format store using CBOR for metadata pub struct BurnpackStore { /// Store mode - either file path or bytes mode: StoreMode, /// Optional filter for selective loading/saving filter: Option, /// Additional metadata metadata: BTreeMap, /// Allow partial loading (ignore missing tensors) allow_partial: bool, /// Validate tensors during loading (check shapes and dtypes) validate: bool, /// Allow overwriting existing files (default: false) overwrite: bool, /// Enable zero-copy tensor loading (default: false) /// /// When enabled and the backend supports it, tensor data is sliced from /// the source without copying. This requires keeping the source data alive. zero_copy: bool, /// Automatically append .bpk extension if not present (default: true) #[cfg(feature = "std")] auto_extension: bool, /// Key remapper for tensor name transformations #[cfg(feature = "std")] remapper: KeyRemapper, /// Adapter applied when loading (source -> Burn) from_adapter: Box, /// Adapter applied when saving (Burn -> target) to_adapter: Box, /// Writer for saving writer: Option, /// Reader for loading reader: Option, /// Cached tensor snapshots (parsed once, reused) snapshots_cache: Option>, } impl BurnpackStore { /// Get the default metadata that includes Burn framework information. /// /// This includes: /// - `format`: "burnpack" /// - `producer`: "burn" /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION) /// /// These metadata fields are automatically added to all saved models. pub fn default_metadata() -> BTreeMap { let mut metadata = BTreeMap::new(); metadata.insert("format".into(), "burnpack".into()); metadata.insert("producer".into(), "burn".into()); metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into()); metadata } /// Create a new store from a file path /// /// By default, automatically appends `.bpk` extension if the path doesn't have one. /// Use `.auto_extension(false)` to disable this behavior. /// /// # Examples /// /// ```no_run /// # use burn_store::BurnpackStore; /// // Automatically appends .bpk /// let store = BurnpackStore::from_file("model"); // creates "model.bpk" /// /// // Already has extension, no append /// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk" /// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext" /// /// // Disable auto-extension /// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model" /// ``` #[cfg(feature = "std")] pub fn from_file>(path: P) -> Self { Self { mode: StoreMode::File(path.as_ref().to_path_buf()), filter: None, metadata: Self::default_metadata(), allow_partial: false, validate: true, overwrite: false, zero_copy: false, #[cfg(feature = "std")] auto_extension: true, #[cfg(feature = "std")] remapper: KeyRemapper::new(), from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), writer: None, reader: None, snapshots_cache: None, } } /// Create a new store from bytes (for reading) or empty (for writing) pub fn from_bytes(bytes: Option) -> Self { Self { mode: StoreMode::Bytes(bytes), filter: None, metadata: Self::default_metadata(), allow_partial: false, validate: true, overwrite: false, zero_copy: false, #[cfg(feature = "std")] auto_extension: false, // Not used for bytes mode #[cfg(feature = "std")] remapper: KeyRemapper::new(), from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), writer: None, reader: None, snapshots_cache: None, } } /// Create a new store from static bytes with zero-copy loading enabled. /// /// This is optimized for embedded model weights where the data lives in the /// binary's `.rodata` section. Tensor data is sliced without copying, keeping /// the static reference alive. /// /// # Example /// /// ```ignore /// static MODEL_DATA: &[u8] = include_bytes!("model.bpk"); /// let store = BurnpackStore::from_static(MODEL_DATA); /// ``` pub fn from_static(data: &'static [u8]) -> Self { use burn_tensor::AllocationProperty; // Create bytes::Bytes from static data (zero-copy, stays in .rodata) let shared = bytes::Bytes::from_static(data); // Wrap in cubecl Bytes with shared-bytes allocation controller let bytes = Bytes::from_shared(shared, AllocationProperty::Other); Self { mode: StoreMode::Bytes(Some(bytes)), filter: None, metadata: Self::default_metadata(), allow_partial: false, validate: true, overwrite: false, zero_copy: true, // Enable zero-copy by default for static data #[cfg(feature = "std")] auto_extension: false, #[cfg(feature = "std")] remapper: KeyRemapper::new(), from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), writer: None, reader: None, snapshots_cache: None, } } /// Add metadata key-value pair pub fn metadata(mut self, key: impl Into, value: impl Into) -> Self { self.metadata.insert(key.into(), value.into()); self } /// Clear all metadata (including defaults) /// /// This removes all metadata including the default format, producer, and version fields. /// Use with caution as some tools may expect these fields to be present. pub fn clear_metadata(mut self) -> Self { self.metadata.clear(); self } /// Allow partial loading (ignore missing tensors) /// /// When set to `true`, the store will not fail if some tensors are missing /// during loading. This is useful when loading a subset of a model's parameters. /// /// Default: `false` pub fn allow_partial(mut self, allow: bool) -> Self { self.allow_partial = allow; self } /// Enable or disable validation during loading /// /// When validation is enabled, the store will check that loaded tensors /// match the expected shapes and data types. Disabling validation can /// improve performance but may lead to runtime errors if data is corrupted. /// /// Default: `true` pub fn validate(mut self, validate: bool) -> Self { self.validate = validate; self } /// Allow overwriting existing files when saving /// /// When set to `false`, attempting to save to an existing file will result in an error. /// When set to `true`, existing files will be overwritten without warning. /// /// Default: `false` pub fn overwrite(mut self, overwrite: bool) -> Self { self.overwrite = overwrite; self } /// Enable or disable zero-copy tensor loading. /// /// When enabled and the backend supports it (memory-backed with shared bytes), /// tensor data is sliced from the source without copying. This keeps the source /// data alive as long as any tensor holds a reference. /// /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static). /// Use this method to enable it for other memory-backed stores created with /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`. /// /// Default: `false` (except for `from_static` which defaults to `true`) pub fn zero_copy(mut self, enable: bool) -> Self { self.zero_copy = enable; self } /// Enable or disable automatic .bpk extension appending /// /// When enabled (default), automatically appends `.bpk` to the file path /// if no extension is detected. If an extension is already present, it is preserved. /// /// When disabled, uses the exact path provided without modification. /// /// Default: `true` /// /// # Examples /// /// ```no_run /// # use burn_store::BurnpackStore; /// // With auto_extension enabled (default) /// let store = BurnpackStore::from_file("model"); // -> "model.bpk" /// /// // With auto_extension disabled /// let store = BurnpackStore::from_file("model") /// .auto_extension(false); // -> "model" /// ``` #[cfg(feature = "std")] pub fn auto_extension(mut self, enable: bool) -> Self { self.auto_extension = enable; self } /// Set the adapter for loading tensors (converting from source format to Burn). pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self { self.from_adapter = Box::new(adapter); self } /// Set the adapter for saving tensors (converting from Burn to target format). pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self { self.to_adapter = Box::new(adapter); self } /// Set path filter for selective loading/saving pub fn with_filter(mut self, filter: PathFilter) -> Self { self.filter = Some(filter); self } /// Add regex pattern to filter #[cfg(feature = "std")] pub fn with_regex(mut self, pattern: &str) -> Self { let filter = self.filter.unwrap_or_default(); self.filter = Some(filter.with_regex(pattern)); self } /// Add exact path to filter pub fn with_full_path(mut self, path: impl Into) -> Self { let filter = self.filter.unwrap_or_default(); self.filter = Some(filter.with_full_path(path)); self } /// Match all tensors (no filtering) pub fn match_all(mut self) -> Self { self.filter = Some(PathFilter::new().match_all()); self } /// Set key remapper for tensor name transformations during loading #[cfg(feature = "std")] pub fn remap(mut self, remapper: KeyRemapper) -> Self { self.remapper = remapper; self } /// Add a single regex pattern for key remapping #[cfg(feature = "std")] pub fn with_remap_pattern(mut self, from: S1, to: S2) -> Self where S1: AsRef, S2: Into, { self.remapper = self .remapper .add_pattern(from.as_ref(), to.into()) .expect("Invalid regex pattern"); self } /// Set the path filter pub fn filter(mut self, filter: PathFilter) -> Self { self.filter = Some(filter); self } /// Get the bytes after writing (only valid for bytes mode after collecting) pub fn get_bytes(&self) -> Result { if let Some(writer) = &self.writer { return writer.to_bytes(); } match &self.mode { StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()), _ => Err(BurnpackError::IoError("No bytes available".into())), } } /// Process the file path with auto-extension logic #[cfg(feature = "std")] fn process_path(&self, path: &std::path::Path) -> PathBuf { if !self.auto_extension { return path.to_path_buf(); } // Check if path already has an extension if path.extension().is_some() { // Has extension, use as-is return path.to_path_buf(); } // No extension, append .bpk let mut new_path = path.to_path_buf(); new_path.set_extension("bpk"); new_path } /// Ensure the reader is initialized, loading from storage if needed fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> { if self.reader.is_none() { let reader = match &self.mode { #[cfg(feature = "std")] StoreMode::File(path) => { let final_path = self.process_path(path); BurnpackReader::from_file(&final_path)? } StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?, StoreMode::Bytes(None) => { return Err(BurnpackError::IoError("No bytes to read from".into())); } }; self.reader = Some(reader); } self.reader .as_ref() .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into())) } } impl ModuleStore for BurnpackStore { type Error = BurnpackError; fn collect_from>( &mut self, module: &M, ) -> Result<(), Self::Error> { // Invalidate cache since we're writing new data self.snapshots_cache = None; self.reader = None; // Collect snapshots from module with adapter let snapshots = module.collect(self.filter.clone(), Some(self.to_adapter.clone()), false); // Initialize writer with snapshots let mut writer = BurnpackWriter::new(snapshots); // Add metadata using builder pattern for (key, value) in &self.metadata { writer = writer.with_metadata(key.as_str(), value.as_str()); } // Store the writer for finalization self.writer = Some(writer); // Write to storage based on mode if let Some(writer) = &self.writer { match &self.mode { #[cfg(feature = "std")] StoreMode::File(path) => { // Process path with auto-extension logic let final_path = self.process_path(path); // Check if file exists and overwrite is disabled if final_path.exists() && !self.overwrite { return Err(BurnpackError::IoError(format!( "File already exists: {}. Use .overwrite(true) to overwrite.", final_path.display() ))); } writer.write_to_file(&final_path)?; } StoreMode::Bytes(_) => { // Generate and store the bytes let bytes_data = writer.to_bytes()?; // Update mode with bytes - this pattern is irrefutable in no-std mode #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))] let StoreMode::Bytes(bytes_ref) = &mut self.mode else { unreachable!("We just matched Bytes variant"); }; *bytes_ref = Some(bytes_data); } } } Ok(()) } fn apply_to>( &mut self, module: &mut M, ) -> Result { // Get all snapshots using the cached method let snapshots: Vec = self.get_all_snapshots()?.values().cloned().collect(); // Apply all snapshots at once to the module // Burnpack is Burn's native format, so no enum variant skipping needed // Filter is applied here during apply, not during cache population let result = module.apply( snapshots, self.filter.clone(), Some(self.from_adapter.clone()), false, ); // Validate if needed if self.validate && !result.errors.is_empty() { return Err(BurnpackError::ValidationError(format!( "Import errors: {:?}", result.errors ))); } // Check for missing tensors if partial loading is not allowed if !self.allow_partial && !result.missing.is_empty() { return Err(BurnpackError::ValidationError(format!( "Missing tensors: {:?}", result.missing ))); } Ok(result) } fn get_snapshot(&mut self, name: &str) -> Result, Self::Error> { // Ensure cache is populated self.ensure_snapshots_cache()?; Ok(self.snapshots_cache.as_ref().unwrap().get(name)) } fn get_all_snapshots(&mut self) -> Result<&BTreeMap, Self::Error> { // Ensure cache is populated self.ensure_snapshots_cache()?; Ok(self.snapshots_cache.as_ref().unwrap()) } fn keys(&mut self) -> Result, Self::Error> { // Always use the cache to ensure remapping is applied consistently Ok(self.get_all_snapshots()?.keys().cloned().collect()) } } impl BurnpackStore { /// Ensure the snapshots cache is populated fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> { if self.snapshots_cache.is_some() { return Ok(()); } // Ensure reader is loaded self.ensure_reader()?; // Get snapshots from reader with zero-copy if enabled let reader = self.reader.as_ref().unwrap(); let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?; // Apply remapping if configured (but NOT filtering - that's done at apply time) #[cfg(feature = "std")] let snapshots = if !self.remapper.patterns.is_empty() { let (remapped, _remapped_names) = self.remapper.remap(snapshots); remapped } else { snapshots }; // Build the cache as BTreeMap let cache: BTreeMap = snapshots.into_iter().map(|s| (s.full_path(), s)).collect(); self.snapshots_cache = Some(cache); Ok(()) } } ================================================ FILE: crates/burn-store/src/burnpack/tests/alignment.rs ================================================ //! Tests for tensor data alignment in burnpack format. //! //! These tests verify that tensor data is properly aligned for mmap zero-copy access. use crate::TensorSnapshot; use crate::burnpack::{ base::{ BurnpackHeader, BurnpackMetadata, HEADER_SIZE, TENSOR_ALIGNMENT, aligned_data_section_start, }, reader::BurnpackReader, writer::BurnpackWriter, }; use burn_core::module::ParamId; use burn_tensor::{DType, TensorData}; /// Verify that aligned_data_section_start always returns 256-byte aligned values #[test] fn test_aligned_data_section_start_is_always_aligned() { // Test various metadata sizes for metadata_size in 0..1024 { let result = aligned_data_section_start(metadata_size); assert_eq!( result % TENSOR_ALIGNMENT as usize, 0, "aligned_data_section_start({}) = {} is not aligned to {}", metadata_size, result, TENSOR_ALIGNMENT ); } } /// Verify data section starts at correct aligned position #[test] fn test_data_section_alignment() { // Create a tensor let data = [1.0f32, 2.0, 3.0, 4.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![4], DType::F32), vec!["tensor".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); // Parse header to get metadata size let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); // Verify data section starts at 256-byte aligned position assert_eq!( data_section_start % TENSOR_ALIGNMENT as usize, 0, "Data section start {} is not 256-byte aligned", data_section_start ); // Verify the file is large enough assert!( file_bytes.len() >= data_section_start, "File too small: {} < {}", file_bytes.len(), data_section_start ); } /// Verify that first tensor's absolute file position is 256-byte aligned #[test] fn test_first_tensor_absolute_position_aligned() { let data: Vec = vec![1, 2, 3, 4, 5, 6, 7, 8]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![8], DType::U8), vec!["first".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); let tensor_desc = metadata.tensors.get("first").unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); // Absolute file position of first tensor let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize; assert_eq!( absolute_pos % TENSOR_ALIGNMENT as usize, 0, "First tensor absolute position {} is not 256-byte aligned", absolute_pos ); } /// Verify that all tensors in a multi-tensor file have 256-byte aligned absolute positions #[test] fn test_all_tensors_absolute_positions_aligned() { // Create multiple tensors of different sizes (all U8 to simplify shape calculation) let tensors = vec![ ("tensor_a", vec![1u8, 2, 3]), // 3 bytes ("tensor_b", vec![0u8; 16]), // 16 bytes ("tensor_c", vec![0u8; 64]), // 64 bytes ("tensor_d", vec![42u8]), // 1 byte ("tensor_e", vec![0u8; 400]), // 400 bytes ]; let snapshots: Vec = tensors .into_iter() .map(|(name, data)| { let len = data.len(); TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![len], DType::U8), vec![name.to_string()], vec![], ParamId::new(), ) }) .collect(); let writer = BurnpackWriter::new(snapshots); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); // Check every tensor has aligned absolute position for (name, desc) in &metadata.tensors { let absolute_pos = data_section_start + desc.data_offsets.0 as usize; assert_eq!( absolute_pos % TENSOR_ALIGNMENT as usize, 0, "Tensor '{}' at absolute position {} is not 256-byte aligned (offset in data section: {})", name, absolute_pos, desc.data_offsets.0 ); } } /// Test edge case: metadata size that results in no padding needed #[test] fn test_alignment_with_minimal_padding() { // We can't control metadata size directly, but we can verify the math works // When HEADER_SIZE + metadata_size is already a multiple of 256, no padding needed let aligned_metadata_size = TENSOR_ALIGNMENT as usize - HEADER_SIZE; // 256 - 10 = 246 let result = aligned_data_section_start(aligned_metadata_size); assert_eq!(result, TENSOR_ALIGNMENT as usize); // Should be exactly 256 // One byte more should still round up to 256 let result_plus_one = aligned_data_section_start(aligned_metadata_size + 1); assert_eq!(result_plus_one, 2 * TENSOR_ALIGNMENT as usize); // Should be 512 } /// Verify padding bytes in the file are zeros #[test] fn test_padding_bytes_are_zeros() { let data: Vec = vec![0xAA; 16]; // Distinctive pattern let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![16], DType::U8), vec!["tensor".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let data_section_start = aligned_data_section_start(header.metadata_size as usize); // Check padding between metadata and data section if data_section_start > metadata_end { let padding = &file_bytes[metadata_end..data_section_start]; assert!( padding.iter().all(|&b| b == 0), "Padding bytes between metadata and data section contain non-zero values" ); } } /// Verify alignment is sufficient for all primitive types /// 256-byte alignment is a multiple of all primitive type alignments: /// - f64/i64/u64: 8 bytes /// - f32/i32/u32: 4 bytes /// - f16/bf16/i16/u16: 2 bytes /// - i8/u8/bool: 1 byte #[test] #[allow(clippy::modulo_one)] fn test_alignment_covers_all_primitive_types() { // 256 must be divisible by all common alignments assert_eq!( TENSOR_ALIGNMENT % 8, 0, "256 not divisible by 8 (f64 alignment)" ); assert_eq!( TENSOR_ALIGNMENT % 4, 0, "256 not divisible by 4 (f32 alignment)" ); assert_eq!( TENSOR_ALIGNMENT % 2, 0, "256 not divisible by 2 (f16 alignment)" ); assert_eq!( TENSOR_ALIGNMENT % 1, 0, "256 not divisible by 1 (u8 alignment)" ); } /// Verify that tensor data can be read correctly after alignment #[test] fn test_aligned_tensor_data_readable() { // Create f32 tensor let f32_data = vec![1.0f32, 2.0, 3.0, 4.0]; let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes.clone(), vec![4], DType::F32), vec!["floats".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); let tensor_desc = metadata.tensors.get("floats").unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); let start = data_section_start + tensor_desc.data_offsets.0 as usize; let end = data_section_start + tensor_desc.data_offsets.1 as usize; let tensor_bytes = &file_bytes[start..end]; // Verify the bytes match what we wrote assert_eq!(tensor_bytes, f32_bytes.as_slice()); // Verify we can interpret them as floats let mut floats = Vec::new(); for chunk in tensor_bytes.chunks_exact(4) { floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); } assert_eq!(floats, f32_data); } /// Verify alignment works with f64 data #[test] fn test_aligned_f64_tensor_data_readable() { let f64_data = vec![1.0f64, 2.0, 3.0, 4.0]; let f64_bytes: Vec = f64_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f64_bytes.clone(), vec![4], DType::F64), vec!["doubles".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); let tensor_desc = metadata.tensors.get("doubles").unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); let start = data_section_start + tensor_desc.data_offsets.0 as usize; let end = data_section_start + tensor_desc.data_offsets.1 as usize; let tensor_bytes = &file_bytes[start..end]; // Verify the bytes match assert_eq!(tensor_bytes, f64_bytes.as_slice()); // Verify we can interpret them as doubles let mut doubles = Vec::new(); for chunk in tensor_bytes.chunks_exact(8) { doubles.push(f64::from_le_bytes([ chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7], ])); } assert_eq!(doubles, f64_data); } /// Test round-trip preserves alignment (write then read) #[test] fn test_round_trip_maintains_alignment() { let f32_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes, vec![2, 4], DType::F32), vec!["matrix".to_string()], vec![], ParamId::new(), ); // Write let writer = BurnpackWriter::new(vec![snapshot]); let file_bytes = writer.to_bytes().unwrap(); // Read back let reader = BurnpackReader::from_bytes(file_bytes.clone()).unwrap(); let snapshots = reader.get_snapshots().unwrap(); assert_eq!(snapshots.len(), 1); let loaded = &snapshots[0]; assert_eq!(loaded.full_path(), "matrix"); // Verify the loaded data is correct let tensor_data = loaded.to_data().unwrap(); let mut loaded_floats = Vec::new(); for chunk in tensor_data.bytes.chunks_exact(4) { loaded_floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); } assert_eq!(loaded_floats, f32_data); } /// Test that tensor offsets within data section are also aligned #[test] fn test_tensor_relative_offsets_are_aligned() { // Create several small tensors to force multiple alignment padding let tensors: Vec<_> = (0..5) .map(|i| { let data = vec![i as u8; 7]; // 7 bytes each - not aligned TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![7], DType::U8), vec![format!("tensor_{}", i)], vec![], ParamId::new(), ) }) .collect(); let writer = BurnpackWriter::new(tensors); let file_bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); // All tensor start offsets within data section should be multiples of 256 for (name, desc) in &metadata.tensors { assert_eq!( desc.data_offsets.0 % TENSOR_ALIGNMENT, 0, "Tensor '{}' relative offset {} is not 256-byte aligned", name, desc.data_offsets.0 ); } } #[cfg(feature = "std")] mod file_tests { use super::*; use std::fs; use tempfile::tempdir; /// Test alignment is preserved when writing to and reading from file #[test] fn test_file_io_preserves_alignment() { let dir = tempdir().unwrap(); let file_path = dir.path().join("aligned.bpk"); let f32_data = [1.0f32, 2.0, 3.0, 4.0]; let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes, vec![4], DType::F32), vec!["floats".to_string()], vec![], ParamId::new(), ); // Write to file let writer = BurnpackWriter::new(vec![snapshot]); writer.write_to_file(&file_path).unwrap(); // Read file bytes directly let file_bytes = fs::read(&file_path).unwrap(); let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap(); let tensor_desc = metadata.tensors.get("floats").unwrap(); let data_section_start = aligned_data_section_start(header.metadata_size as usize); let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize; assert_eq!( absolute_pos % TENSOR_ALIGNMENT as usize, 0, "Tensor absolute position in file {} is not 256-byte aligned", absolute_pos ); // Verify data is correct let start = data_section_start + tensor_desc.data_offsets.0 as usize; let end = data_section_start + tensor_desc.data_offsets.1 as usize; let tensor_bytes = &file_bytes[start..end]; let mut floats = Vec::new(); for chunk in tensor_bytes.chunks_exact(4) { floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); } assert_eq!(floats, vec![1.0f32, 2.0, 3.0, 4.0]); } } ================================================ FILE: crates/burn-store/src/burnpack/tests/edge_cases.rs ================================================ use crate::TensorSnapshot; use crate::burnpack::{ base::{BurnpackHeader, HEADER_SIZE}, reader::BurnpackReader, writer::BurnpackWriter, }; use burn_core::module::ParamId; use burn_tensor::{BoolStore, DType, TensorData, shape}; #[test] fn test_maximum_metadata_size() { // Create metadata that approaches u32::MAX (4GB limit) // In practice, we'll test with a reasonably large metadata let large_key = "x".repeat(1000); let large_value = "y".repeat(10000); let mut writer = BurnpackWriter::new(vec![]); for i in 0..100 { writer = writer.with_metadata(&format!("{}_{}", large_key, i), &large_value); } let result = writer.to_bytes(); assert!(result.is_ok()); let bytes = result.unwrap(); let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); // Metadata size should be large but within u32 bounds assert!(header.metadata_size > 1000000); // At least 1MB of metadata assert!(header.metadata_size < u32::MAX); } #[test] fn test_zero_size_tensor_shapes() { // Test various zero-dimensional shapes let test_cases = [ (vec![0], vec![]), // Empty 1D (vec![0, 10], vec![]), // Zero rows (vec![10, 0], vec![]), // Zero columns (vec![0, 0], vec![]), // Zero both dimensions (vec![5, 0, 10], vec![]), // Zero in middle dimension ]; let mut snapshots = vec![]; for (i, (shape, data)) in test_cases.iter().enumerate() { let name = format!("zero_tensor_{}", i); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32), vec![name.clone()], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); // Read back and verify let reader = BurnpackReader::from_bytes(bytes).unwrap(); let names = reader.tensor_names(); assert_eq!(names.len(), 5); } #[test] fn test_extremely_long_tensor_names() { // Create a tensor with an extremely long name let long_name = "a".repeat(10000); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), vec![long_name.clone()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let names = reader.tensor_names(); assert_eq!(names[0].len(), 10000); } #[test] fn test_unicode_in_names_and_metadata() { // Test various Unicode characters in tensor names and metadata let unicode_names = vec![ "测试_tensor", // Chinese "тест_tensor", // Cyrillic "テスト_tensor", // Japanese "🔥_burn_tensor", // Emoji "αβγδ_tensor", // Greek "한글_tensor", // Korean ]; let mut snapshots = vec![]; for name in &unicode_names { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1], vec![1], DType::U8), vec![name.to_string()], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots) .with_metadata("模型名称", "测试模型") .with_metadata("מודל", "בדיקה") .with_metadata("🔥", "fire"); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Verify all Unicode names are preserved let names = reader.tensor_names(); assert_eq!(names.len(), unicode_names.len()); // Verify metadata assert_eq!( reader.metadata().metadata.get("模型名称"), Some(&"测试模型".to_string()) ); assert_eq!( reader.metadata().metadata.get("🔥"), Some(&"fire".to_string()) ); } #[test] fn test_all_supported_dtypes() { // Test all DTypes with their boundary values let dtypes_with_data = [ ( DType::F32, [ f32::MIN.to_le_bytes().to_vec(), f32::MAX.to_le_bytes().to_vec(), ] .concat(), ), ( DType::F64, [ f64::MIN.to_le_bytes().to_vec(), f64::MAX.to_le_bytes().to_vec(), ] .concat(), ), ( DType::I32, [ i32::MIN.to_le_bytes().to_vec(), i32::MAX.to_le_bytes().to_vec(), ] .concat(), ), ( DType::I64, [ i64::MIN.to_le_bytes().to_vec(), i64::MAX.to_le_bytes().to_vec(), ] .concat(), ), ( DType::U32, [ u32::MIN.to_le_bytes().to_vec(), u32::MAX.to_le_bytes().to_vec(), ] .concat(), ), ( DType::U64, [ u64::MIN.to_le_bytes().to_vec(), u64::MAX.to_le_bytes().to_vec(), ] .concat(), ), (DType::U8, vec![u8::MIN, u8::MAX]), (DType::Bool(BoolStore::Native), vec![0, 1]), ]; let mut snapshots = vec![]; for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() { let name = format!("dtype_test_{}", i); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![2], *dtype), vec![name], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); assert_eq!(reader.tensor_names().len(), dtypes_with_data.len()); // Verify dtypes are preserved for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() { let name = format!("dtype_test_{}", i); let snapshot = reader.get_tensor_snapshot(&name).unwrap(); assert_eq!(snapshot.dtype, *expected_dtype); } } #[test] fn test_special_float_values() { // Test special floating-point values (NaN, Inf, -Inf) let special_values = [ f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0_f32, -0.0_f32, ]; let data: Vec = special_values .iter() .flat_map(|f| f.to_le_bytes()) .collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32), vec!["special_floats".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let tensor_data = reader.get_tensor_data("special_floats").unwrap(); // Check data is preserved exactly (bit-for-bit) assert_eq!(tensor_data, data); } #[test] fn test_metadata_with_empty_values() { let writer = BurnpackWriter::new(vec![]) .with_metadata("empty_value", "") .with_metadata("", "empty_key") .with_metadata("normal", "value"); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let metadata = &reader.metadata().metadata; assert_eq!(metadata.get("empty_value"), Some(&"".to_string())); assert_eq!(metadata.get(""), Some(&"empty_key".to_string())); assert_eq!(metadata.get("normal"), Some(&"value".to_string())); } #[test] fn test_single_byte_tensor() { // Test the smallest possible tensor (1 byte) let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![42], vec![1], DType::U8), vec!["single_byte".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let data = reader.get_tensor_data("single_byte").unwrap(); assert_eq!(data, vec![42]); } #[test] fn test_high_dimensional_tensor() { // Test a tensor with many dimensions (10D) let shape = shape![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; // 10 dimensions, 1024 elements total let data = vec![1u8; 1024]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8), vec!["high_dim".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let loaded_snapshot = reader.get_tensor_snapshot("high_dim").unwrap(); assert_eq!(loaded_snapshot.shape, shape); } #[test] fn test_metadata_key_collision() { // Test that later values override earlier ones for the same key let writer = BurnpackWriter::new(vec![]) .with_metadata("key", "value1") .with_metadata("key", "value2") .with_metadata("key", "value3"); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); assert_eq!( reader.metadata().metadata.get("key"), Some(&"value3".to_string()) ); } #[test] fn test_tensor_name_with_path_separators() { // Test tensor names that look like file paths let path_like_names = vec![ "model/encoder/layer1/weights", "model\\decoder\\layer1\\bias", "model::module::param", "model.submodule.weight", ]; let mut snapshots = vec![]; for name in &path_like_names { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), vec![name.to_string()], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let names = reader.tensor_names(); // All names should be preserved exactly for expected_name in &path_like_names { assert!(names.contains(expected_name)); } } // The following tests are commented out as they test error conditions // that might be handled differently in the new API // #[test] // fn test_data_overflow_protection() { // // Test that we handle potential integer overflows in offset calculations // ... // } // #[test] // fn test_reading_corrupted_header() { // // Test reading files with corrupted headers // ... // } ================================================ FILE: crates/burn-store/src/burnpack/tests/header.rs ================================================ use crate::burnpack::base::*; #[test] fn test_header_serialization() { let header = BurnpackHeader::new(12345); // Check fields assert_eq!(header.magic, MAGIC_NUMBER); assert_eq!(header.version, FORMAT_VERSION); assert_eq!(header.metadata_size, 12345); // Serialize to bytes let bytes = header.into_bytes(); assert_eq!(bytes.len(), HEADER_SIZE); // Deserialize back let header2 = BurnpackHeader::from_bytes(&bytes).unwrap(); assert_eq!(header2.magic, header.magic); assert_eq!(header2.version, header.version); assert_eq!(header2.metadata_size, header.metadata_size); } #[test] fn test_header_invalid_magic() { let mut bytes = [0u8; HEADER_SIZE]; // Write wrong magic number bytes[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]); let result = BurnpackHeader::from_bytes(&bytes); match result { Err(BurnpackError::InvalidMagicNumber) => {} _ => panic!("Expected InvalidMagicNumber error"), } } #[test] fn test_header_insufficient_bytes() { let bytes = [0u8; 5]; // Too short let result = BurnpackHeader::from_bytes(&bytes); match result { Err(BurnpackError::InvalidHeader) => {} _ => panic!("Expected InvalidHeader error"), } } #[test] fn test_version_compatibility() { // Create a header with current version let header = BurnpackHeader::new(100); let bytes = header.into_bytes(); // Should succeed with current version let result = BurnpackHeader::from_bytes(&bytes); assert!(result.is_ok()); // Test with future version (should fail in real implementation) // For now, we just verify the version field is correctly set let header = result.unwrap(); assert_eq!(header.version, FORMAT_VERSION); } ================================================ FILE: crates/burn-store/src/burnpack/tests/helpers.rs ================================================ use crate::TensorSnapshot; use burn_core::module::ParamId; use burn_tensor::{DType, TensorData}; /// Helper to create a test TensorSnapshot #[allow(dead_code)] pub fn create_test_snapshot( name: String, data: Vec, shape: Vec, dtype: DType, ) -> TensorSnapshot { TensorSnapshot::from_data( TensorData::from_bytes_vec(data, shape, dtype), vec![name], vec![], ParamId::new(), ) } ================================================ FILE: crates/burn-store/src/burnpack/tests/mod.rs ================================================ use crate::TensorSnapshot; mod alignment; mod edge_cases; mod header; mod helpers; mod reader; mod round_trip; mod store; mod writer; mod zero_copy; ================================================ FILE: crates/burn-store/src/burnpack/tests/reader.rs ================================================ use crate::burnpack::{ base::{ BurnpackError, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, metadata_size_range, version_range, }, reader::BurnpackReader, writer::BurnpackWriter, }; use super::*; use burn_tensor::{BoolStore, Bytes, DType, TensorData, shape}; #[test] fn test_reader_from_bytes_empty() { // Create empty burnpack data let writer = BurnpackWriter::new(Vec::new()); let bytes = writer.to_bytes().unwrap(); // Read it back let reader = BurnpackReader::from_bytes(bytes).unwrap(); assert_eq!(reader.metadata().tensors.len(), 0); assert!(reader.metadata().metadata.is_empty()); } #[test] fn test_reader_from_bytes_with_data() { // Create test data let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); let bytes = writer.to_bytes().unwrap(); // Read it back let reader = BurnpackReader::from_bytes(bytes).unwrap(); assert_eq!(reader.metadata().tensors.len(), 1); assert_eq!( reader.metadata().metadata.get("test"), Some(&"value".to_string()) ); // Get tensor data let tensor_data = reader.get_tensor_data("test_tensor").unwrap(); assert_eq!(tensor_data, &[1, 2, 3, 4]); } #[test] fn test_reader_invalid_magic_number() { let mut bytes = vec![0u8; 100]; // Write invalid magic number bytes[magic_range()].copy_from_slice(b"NOPE"); let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); assert!(matches!(result, Err(BurnpackError::InvalidMagicNumber))); } #[test] fn test_reader_invalid_version() { let mut bytes = vec![0u8; 100]; // Write correct magic but invalid version bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); bytes[version_range()].copy_from_slice(&999u16.to_le_bytes()); // Invalid version bytes[metadata_size_range()].copy_from_slice(&10u32.to_le_bytes()); // Metadata size let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); assert!(matches!(result, Err(BurnpackError::InvalidVersion))); } #[test] fn test_reader_header_too_short() { let bytes = vec![0u8; 5]; // Less than HEADER_SIZE let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); assert!(matches!(result, Err(BurnpackError::InvalidHeader))); } #[test] fn test_reader_metadata_truncated() { let mut bytes = vec![0u8; HEADER_SIZE + 10]; // Write valid header bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes()); bytes[metadata_size_range()].copy_from_slice(&100u32.to_le_bytes()); // Claims 100 bytes of metadata // But only provide 10 bytes after header let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); assert!(matches!(result, Err(BurnpackError::InvalidHeader))); } #[test] fn test_reader_get_tensor_not_found() { let writer = BurnpackWriter::new(Vec::new()); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let result = reader.get_tensor_data("non_existent"); assert!(matches!(result, Err(BurnpackError::TensorNotFound(_)))); } #[test] fn test_reader_get_tensor_snapshot() { let data = [1.0f32, 2.0, 3.0, 4.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), vec!["weights".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let writer_bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(writer_bytes).unwrap(); // Get tensor as snapshot let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap(); // Verify snapshot metadata assert_eq!(loaded_snapshot.full_path(), "weights"); assert_eq!(loaded_snapshot.dtype, DType::F32); assert_eq!(loaded_snapshot.shape, shape![2, 2]); // Verify data through closure let tensor_data = loaded_snapshot.to_data().unwrap(); assert_eq!(tensor_data.shape, shape![2, 2]); } #[test] fn test_reader_multiple_tensors() { // Add multiple tensors let mut snapshots = Vec::new(); for i in 0..10 { let name = format!("tensor_{}", i); let data = vec![i as u8; 100]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, shape![100], DType::U8), vec![name.clone()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Verify all tensors can be read for i in 0..10 { let name = format!("tensor_{}", i); let data = reader.get_tensor_data(&name).unwrap(); assert_eq!(data.len(), 100); assert!(data.iter().all(|&b| b == i as u8)); } } #[test] fn test_reader_lazy_loading() { // Create large tensor let size = 1024 * 1024; // 1MB let data = vec![42u8; size]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![size], DType::U8), vec!["large".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Get snapshot (should be lazy) let snapshot = reader.get_tensor_snapshot("large").unwrap(); // Data should only be accessed when to_data is called let tensor_data = snapshot.to_data().unwrap(); assert_eq!(tensor_data.bytes.len(), size); assert!(tensor_data.bytes.iter().all(|&b| b == 42)); } #[test] fn test_reader_all_dtypes() { // Test all data types let test_data = [ (DType::F32, [1.0f32.to_le_bytes().to_vec()].concat()), (DType::F64, [2.0f64.to_le_bytes().to_vec()].concat()), (DType::I32, [3i32.to_le_bytes().to_vec()].concat()), (DType::I64, [4i64.to_le_bytes().to_vec()].concat()), (DType::U32, [5u32.to_le_bytes().to_vec()].concat()), (DType::U64, [6u64.to_le_bytes().to_vec()].concat()), (DType::U8, vec![7u8]), (DType::Bool(BoolStore::Native), vec![1u8]), ]; let mut snapshots = Vec::new(); for (i, (dtype, data)) in test_data.iter().enumerate() { let name = format!("tensor_{}", i); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![1], *dtype), vec![name.clone()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Verify all dtypes are preserved for (i, (expected_dtype, expected_data)) in test_data.iter().enumerate() { let name = format!("tensor_{}", i); let snapshot = reader.get_tensor_snapshot(&name).unwrap(); assert_eq!(snapshot.dtype, *expected_dtype); let data = reader.get_tensor_data(&name).unwrap(); assert_eq!(data, expected_data.as_slice()); } } #[test] fn test_reader_empty_tensor() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![], vec![0], DType::F32), vec!["empty".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let data = reader.get_tensor_data("empty").unwrap(); assert_eq!(data.len(), 0); let snapshot = reader.get_tensor_snapshot("empty").unwrap(); assert_eq!(snapshot.shape, shape![0]); } #[cfg(feature = "std")] #[test] fn test_reader_from_file() { use tempfile::tempdir; let dir = tempdir().unwrap(); let file_path = dir.path().join("test.bpk"); // Create test file let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![10, 20, 30], vec![3], DType::U8), vec!["file_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("from_file_test", "true"); writer.write_to_file(&file_path).unwrap(); // Read from file let reader = BurnpackReader::from_file(&file_path).unwrap(); assert_eq!( reader.metadata().metadata.get("from_file_test"), Some(&"true".to_string()) ); let data = reader.get_tensor_data("file_tensor").unwrap(); assert_eq!(data, &[10, 20, 30]); } #[cfg(all(feature = "std", feature = "memmap"))] #[test] fn test_reader_from_file_mmap() { use tempfile::tempdir; let dir = tempdir().unwrap(); let file_path = dir.path().join("test_mmap.bpk"); // Create large test file let size = 1024 * 1024; // 1MB let data = vec![99u8; size]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![size], DType::U8), vec!["large_mmap".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); writer.write_to_file(&file_path).unwrap(); // Read using mmap let reader = BurnpackReader::from_file_mmap(&file_path).unwrap(); let data = reader.get_tensor_data("large_mmap").unwrap(); assert_eq!(data.len(), size); assert!(data.iter().all(|&b| b == 99)); } #[cfg(feature = "std")] #[test] fn test_reader_from_file_buffered() { use tempfile::tempdir; let dir = tempdir().unwrap(); let file_path = dir.path().join("test_buffered.bpk"); // Create test file let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![5, 10, 15], vec![3], DType::U8), vec!["buffered_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); writer.write_to_file(&file_path).unwrap(); // Read using buffered reader let reader = BurnpackReader::from_file_buffered(&file_path).unwrap(); let data = reader.get_tensor_data("buffered_tensor").unwrap(); assert_eq!(data, &[5, 10, 15]); } #[test] fn test_reader_metadata_access() { // Add various metadata using builder pattern let writer = BurnpackWriter::new(Vec::new()) .with_metadata("model_name", "test_model") .with_metadata("version", "1.2.3") .with_metadata("author", "test_author") .with_metadata("description", "A test model"); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); let metadata = reader.metadata(); assert_eq!(metadata.metadata.len(), 4); assert_eq!( metadata.metadata.get("model_name"), Some(&"test_model".to_string()) ); assert_eq!(metadata.metadata.get("version"), Some(&"1.2.3".to_string())); assert_eq!( metadata.metadata.get("author"), Some(&"test_author".to_string()) ); assert_eq!( metadata.metadata.get("description"), Some(&"A test model".to_string()) ); } #[test] fn test_reader_tensor_iteration() { // Add tensors let tensor_names = vec!["weights", "bias", "running_mean", "running_var"]; let mut snapshots = Vec::new(); for name in &tensor_names { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), vec![name.to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Iterate through all tensors let metadata = reader.metadata(); assert_eq!(metadata.tensors.len(), 4); // Check that all expected tensor names are present for name in &tensor_names { let tensor_desc = metadata.tensors.get(*name).unwrap(); assert_eq!(tensor_desc.shape, vec![4u64]); assert_eq!(tensor_desc.dtype, DType::U8); } // Verify the keys match the expected names let mut actual_names: Vec<_> = metadata.tensors.keys().cloned().collect(); actual_names.sort(); let mut expected_names = tensor_names .iter() .map(|s| s.to_string()) .collect::>(); expected_names.sort(); assert_eq!(actual_names, expected_names); } #[test] fn test_reader_corrupt_metadata() { let mut bytes = vec![0u8; 100]; // Write valid header bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes()); bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes()); bytes[metadata_size_range()].copy_from_slice(&50u32.to_le_bytes()); // 50 bytes of metadata // Write garbage as metadata #[allow(clippy::needless_range_loop)] for i in HEADER_SIZE..HEADER_SIZE + 50 { bytes[i] = 0xFF; } let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes)); assert!(result.is_err()); } #[test] fn test_reader_data_offsets_validation() { // Add two tensors let snapshot1 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), vec!["tensor1".to_string()], vec![], burn_core::module::ParamId::new(), ); let snapshot2 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8), vec!["tensor2".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]); let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Verify offsets don't overlap and are properly aligned let metadata = reader.metadata(); let tensor1_desc = metadata.tensors.get("tensor1").unwrap(); let tensor2_desc = metadata.tensors.get("tensor2").unwrap(); // First tensor starts at offset 0 (already aligned to 256 bytes) assert_eq!(tensor1_desc.data_offsets, (0, 4)); // Second tensor starts at next 256-byte aligned offset assert_eq!(tensor2_desc.data_offsets, (256, 260)); } #[test] fn test_reader_out_of_bounds_error() { use crate::burnpack::reader::StorageBackend; use alloc::rc::Rc; // Create a small data buffer let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]); let backend = StorageBackend::Memory(Rc::new(data)); // Try to read beyond the available data let mut buffer = vec![0u8; 10]; let result = backend.read_into(&mut buffer, 0); // Should return an error assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.to_string().contains("out of bounds")); } #[test] fn test_reader_offset_overflow_error() { use crate::burnpack::reader::StorageBackend; use alloc::rc::Rc; let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]); let backend = StorageBackend::Memory(Rc::new(data)); // Try to read with an offset that would overflow let mut buffer = vec![0u8; 10]; let result = backend.read_into(&mut buffer, usize::MAX - 5); // Should return an error about overflow assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.to_string().contains("overflow")); } #[test] fn test_reader_corrupted_shape_returns_error() { // Only test this on platforms where usize is smaller than u64 // On 64-bit platforms, u64 values can fit in usize #[cfg(target_pointer_width = "32")] { use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; use alloc::collections::BTreeMap; use alloc::rc::Rc; use burn_tensor::DType; // Create metadata with a shape dimension that exceeds usize::MAX on 32-bit platforms let mut tensors = BTreeMap::new(); tensors.insert( "corrupted_tensor".to_string(), TensorDescriptor { dtype: DType::F32, shape: vec![u64::MAX, 2, 3], // First dimension exceeds usize::MAX on 32-bit data_offsets: (0, 100), param_id: None, }, ); let metadata = BurnpackMetadata { tensors, metadata: BTreeMap::new(), }; // Create a small data buffer let data = Bytes::from_bytes_vec(vec![0u8; 1000]); let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); let reader = BurnpackReader { metadata, storage: backend, data_offset: 0, }; // This should return an error, not panic let result = reader.get_snapshots(); assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, BurnpackError::ValidationError(_))); assert!( err.to_string().contains("corrupted shape data") || err.to_string().contains("exceeds platform maximum") ); } #[cfg(not(target_pointer_width = "32"))] { // On 64-bit platforms, just pass the test // The conversion logic is still correct, but u64 fits in usize } } #[test] fn test_reader_corrupted_offsets_returns_error() { // Only test this on platforms where usize is smaller than u64 #[cfg(target_pointer_width = "32")] { use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; use alloc::collections::BTreeMap; use alloc::rc::Rc; use burn_tensor::DType; // Create metadata with offsets that would overflow let mut tensors = BTreeMap::new(); tensors.insert( "tensor_bad_offset".to_string(), TensorDescriptor { dtype: DType::F32, shape: vec![2, 2], data_offsets: (u64::MAX - 10, u64::MAX), // Offsets that exceed usize::MAX on 32-bit param_id: None, }, ); let metadata = BurnpackMetadata { tensors, metadata: BTreeMap::new(), }; let data = Bytes::from_bytes_vec(vec![0u8; 1000]); let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); let reader = BurnpackReader { metadata, storage: backend, data_offset: 0, }; // This should return an error, not panic let result = reader.get_snapshots(); assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, BurnpackError::ValidationError(_))); assert!( err.to_string().contains("corrupted offset data") || err.to_string().contains("exceeds platform maximum") ); } #[cfg(not(target_pointer_width = "32"))] { use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; use alloc::collections::BTreeMap; use alloc::rc::Rc; use burn_tensor::DType; // On 64-bit platforms, test offset overflow during addition let mut tensors = BTreeMap::new(); tensors.insert( "tensor_overflow".to_string(), TensorDescriptor { dtype: DType::F32, shape: vec![2, 2], data_offsets: (0, 100), param_id: None, }, ); let metadata = BurnpackMetadata { tensors, metadata: BTreeMap::new(), }; let data = Bytes::from_bytes_vec(vec![0u8; 1000]); let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); // Use a data_offset that will overflow when added to the tensor offset let reader = BurnpackReader { metadata, storage: backend, data_offset: usize::MAX - 50, // Will overflow when added to 100 }; // This should return an error, not panic let result = reader.get_snapshots(); assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, BurnpackError::ValidationError(_))); assert!(err.to_string().contains("overflow")); } } #[test] fn test_reader_inverted_offsets_returns_error() { use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; use alloc::collections::BTreeMap; use alloc::rc::Rc; use burn_tensor::DType; // Create metadata with end offset < start offset (corrupted) let mut tensors = BTreeMap::new(); tensors.insert( "inverted_tensor".to_string(), TensorDescriptor { dtype: DType::F32, shape: vec![2, 2], data_offsets: (100, 50), // End offset < start offset param_id: None, }, ); let metadata = BurnpackMetadata { tensors, metadata: BTreeMap::new(), }; let data = Bytes::from_bytes_vec(vec![0u8; 1000]); let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data)); let reader = BurnpackReader { metadata, storage: backend, data_offset: 0, }; // This should return an error, not panic let result = reader.get_snapshots(); assert!(result.is_err()); let err = result.unwrap_err(); assert!(matches!(err, BurnpackError::ValidationError(_))); assert!(err.to_string().contains("end offset") && err.to_string().contains("start offset")); } #[test] fn test_reader_truncated_file_from_bytes() { // Create a valid burnpack with tensor data let tensor_size = 1024; // 1KB of data let data = vec![42u8; tensor_size]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), vec!["large_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let full_bytes = writer.to_bytes().unwrap(); // Truncate the bytes - remove the last 512 bytes of tensor data let truncated_len = full_bytes.len() - 512; let truncated_bytes = Bytes::from_bytes_vec(full_bytes.to_vec()[..truncated_len].to_vec()); // This should fail with a validation error indicating file truncation let result = BurnpackReader::from_bytes(truncated_bytes); assert!(result.is_err()); if let Err(err) = result { assert!(matches!(err, BurnpackError::ValidationError(_))); assert!(err.to_string().contains("File truncated")); assert!(err.to_string().contains("expected at least")); } } #[cfg(feature = "std")] #[test] fn test_reader_truncated_file_from_file() { use std::fs::OpenOptions; use tempfile::tempdir; let dir = tempdir().unwrap(); let file_path = dir.path().join("truncated.bpk"); // Create a valid burnpack file with tensor data let tensor_size = 2048; // 2KB of data let data = vec![99u8; tensor_size]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), vec!["data_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); writer.write_to_file(&file_path).unwrap(); // Read the full file to get its size let full_size = std::fs::metadata(&file_path).unwrap().len(); // Truncate the file - remove the last 1KB let truncated_size = full_size - 1024; let truncated_file = OpenOptions::new().write(true).open(&file_path).unwrap(); truncated_file.set_len(truncated_size).unwrap(); drop(truncated_file); // Try to read the truncated file - should fail with validation error let result = BurnpackReader::from_file(&file_path); assert!(result.is_err()); if let Err(err) = result { assert!(matches!(err, BurnpackError::ValidationError(_))); assert!(err.to_string().contains("File truncated")); assert!(err.to_string().contains("expected at least")); } } #[test] fn test_reader_file_size_exactly_correct() { // Test that a file with exactly the right size passes validation let tensor_size = 100; let data = vec![77u8; tensor_size]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8), vec!["exact_size".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); // This should succeed - file is exactly the right size let reader = BurnpackReader::from_bytes(bytes); assert!(reader.is_ok()); // Verify we can read the data let reader = reader.unwrap(); let tensor_data = reader.get_tensor_data("exact_size").unwrap(); assert_eq!(tensor_data.len(), tensor_size); assert!(tensor_data.iter().all(|&b| b == 77)); } ================================================ FILE: crates/burn-store/src/burnpack/tests/round_trip.rs ================================================ use crate::burnpack::{reader::BurnpackReader, writer::BurnpackWriter}; use super::*; use alloc::collections::BTreeMap; use alloc::string::String; use burn_tensor::{BoolStore, DType, TensorData, shape}; /// Helper function to perform round-trip test fn round_trip_test(setup: F) where F: FnOnce(&mut Vec, &mut BTreeMap), { // Collect snapshots and metadata let mut snapshots = Vec::new(); let mut metadata = BTreeMap::new(); setup(&mut snapshots, &mut metadata); // Sort snapshots by name to ensure consistent ordering // This is necessary because BTreeMap will store them sorted snapshots.sort_by_key(|a| a.full_path()); // Create writer with snapshots and metadata let mut writer = BurnpackWriter::new(snapshots); for (key, value) in &metadata { writer = writer.with_metadata(key, value); } let bytes = writer.to_bytes().unwrap(); let reader = BurnpackReader::from_bytes(bytes.clone()).unwrap(); // Write to bytes again from reader data let mut snapshots2 = Vec::new(); // Copy tensors (metadata.tensors is now BTreeMap) // They will come out in sorted order from tensor_names() for tensor_name in reader.tensor_names() { let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap(); snapshots2.push(snapshot); } // Create writer2 with collected snapshots and metadata let mut writer2 = BurnpackWriter::new(snapshots2); for (key, value) in &reader.metadata().metadata { writer2 = writer2.with_metadata(key, value); } let bytes2 = writer2.to_bytes().unwrap(); // Both byte representations should be identical assert_eq!(bytes, bytes2, "Round-trip produced different bytes"); } #[test] fn test_round_trip_empty() { round_trip_test(|_snapshots, _metadata| { // Empty writer }); } #[test] fn test_round_trip_metadata_only() { round_trip_test(|_snapshots, metadata| { metadata.insert("key1".to_string(), "value1".to_string()); metadata.insert("key2".to_string(), "value2".to_string()); metadata.insert("key3".to_string(), "value3".to_string()); }); } #[test] fn test_round_trip_f32() { round_trip_test(|snapshots, _metadata| { let data = [1.0f32, 2.0, 3.0, 4.0, 5.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![5], DType::F32), vec!["f32_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_f64() { round_trip_test(|snapshots, _metadata| { let data = [1.0f64, 2.0, 3.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![3], DType::F64), vec!["f64_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_i32() { round_trip_test(|snapshots, _metadata| { let data = [-10i32, 0, 10, 20]; let bytes: Vec = data.iter().flat_map(|i| i.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![4], DType::I32), vec!["i32_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_i64() { round_trip_test(|snapshots, _metadata| { let data = [i64::MIN, 0, i64::MAX]; let bytes: Vec = data.iter().flat_map(|i| i.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![3], DType::I64), vec!["i64_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_u32() { round_trip_test(|snapshots, _metadata| { let data = [0u32, 100, 1000, u32::MAX]; let bytes: Vec = data.iter().flat_map(|u| u.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![4], DType::U32), vec!["u32_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_u64() { round_trip_test(|snapshots, _metadata| { let data = [0u64, u64::MAX / 2, u64::MAX]; let bytes: Vec = data.iter().flat_map(|u| u.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![3], DType::U64), vec!["u64_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_u8() { round_trip_test(|snapshots, _metadata| { let data = vec![0u8, 127, 255]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![3], DType::U8), vec!["u8_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_bool() { round_trip_test(|snapshots, _metadata| { let data = vec![0u8, 1, 0, 1, 1]; let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data, vec![5], DType::Bool(BoolStore::Native)), vec!["bool_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[test] fn test_round_trip_mixed_dtypes() { round_trip_test(|snapshots, _metadata| { // F32 let f32_data = [1.0f32, 2.0]; let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let f32_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes, vec![2], DType::F32), vec!["f32".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(f32_snapshot); // I64 let i64_data = [100i64, 200]; let i64_bytes: Vec = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect(); let i64_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(i64_bytes, vec![2], DType::I64), vec!["i64".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(i64_snapshot); // Bool let bool_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 0, 1], vec![3], DType::Bool(BoolStore::Native)), vec!["bool".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(bool_snapshot); }); } #[test] fn test_round_trip_multidimensional() { round_trip_test(|snapshots, _metadata| { // 2D tensor let data_2d = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let bytes_2d: Vec = data_2d.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot_2d = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes_2d, vec![2, 3], DType::F32), vec!["tensor_2d".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot_2d); // 3D tensor let data_3d = [1.0f32; 24]; let bytes_3d: Vec = data_3d.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot_3d = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes_3d, vec![2, 3, 4], DType::F32), vec!["tensor_3d".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot_3d); // 4D tensor (common for CNNs) let data_4d = vec![1.0f32; 120]; let bytes_4d: Vec = data_4d.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot_4d = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes_4d, vec![2, 3, 4, 5], DType::F32), vec!["tensor_4d".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot_4d); }); } #[test] fn test_round_trip_with_metadata_and_tensors() { round_trip_test(|snapshots, metadata| { // Add metadata metadata.insert("model_name".to_string(), "test_model".to_string()); metadata.insert("version".to_string(), "1.0.0".to_string()); metadata.insert( "description".to_string(), "A test model for round-trip testing".to_string(), ); // Add tensors let weights = [0.1f32, 0.2, 0.3, 0.4]; let weights_bytes: Vec = weights.iter().flat_map(|f| f.to_le_bytes()).collect(); let weights_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(weights_bytes, vec![2, 2], DType::F32), vec!["layer1".to_string(), "weights".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(weights_snapshot); let bias = [0.5f32, 0.6]; let bias_bytes: Vec = bias.iter().flat_map(|f| f.to_le_bytes()).collect(); let bias_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bias_bytes, vec![2], DType::F32), vec!["layer1".to_string(), "bias".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(bias_snapshot); }); } #[test] fn test_round_trip_special_values() { round_trip_test(|snapshots, _metadata| { // Test special float values let special_f32 = [ 0.0f32, -0.0, f32::INFINITY, f32::NEG_INFINITY, f32::NAN, f32::MIN, f32::MAX, f32::EPSILON, ]; let f32_bytes: Vec = special_f32.iter().flat_map(|f| f.to_le_bytes()).collect(); let f32_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes, vec![8], DType::F32), vec!["special_f32".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(f32_snapshot); // Test special f64 values let special_f64 = [ 0.0f64, -0.0, f64::INFINITY, f64::NEG_INFINITY, f64::NAN, f64::MIN, f64::MAX, f64::EPSILON, ]; let f64_bytes: Vec = special_f64.iter().flat_map(|f| f.to_le_bytes()).collect(); let f64_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(f64_bytes, vec![8], DType::F64), vec!["special_f64".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(f64_snapshot); }); } #[test] fn test_round_trip_large_tensors() { round_trip_test(|snapshots, _metadata| { // Large tensor (100KB) let size = 25600; // 100KB / 4 bytes per f32 let data: Vec = (0..size).map(|i| i as f32).collect(); let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![size], DType::F32), vec!["large_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); }); } #[cfg(feature = "std")] #[test] fn test_round_trip_file_io() { use std::fs; use tempfile::tempdir; use crate::burnpack::writer::BurnpackWriter; let dir = tempdir().unwrap(); let file_path = dir.path().join("round_trip.bpk"); // Create original data let data = [1.0f32, 2.0, 3.0, 4.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), vec!["weights".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "round_trip"); // Write to file writer.write_to_file(&file_path).unwrap(); // Read from file let reader = BurnpackReader::from_file(&file_path).unwrap(); // Write to another file let file_path2 = dir.path().join("round_trip2.bpk"); // Collect snapshots from reader let mut snapshots2 = Vec::new(); for tensor_name in reader.tensor_names() { let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap(); snapshots2.push(snapshot); } // Create writer2 with snapshots and metadata let mut writer2 = BurnpackWriter::new(snapshots2); for (key, value) in &reader.metadata().metadata { writer2 = writer2.with_metadata(key, value); } writer2.write_to_file(&file_path2).unwrap(); // Compare files let bytes1 = fs::read(&file_path).unwrap(); let bytes2 = fs::read(&file_path2).unwrap(); assert_eq!( bytes1, bytes2, "Round-trip through files produced different content" ); } #[test] fn test_round_trip_empty_shapes() { round_trip_test(|snapshots, _metadata| { // Scalar (0-dimensional) let scalar = [42.0f32]; let scalar_bytes: Vec = scalar.iter().flat_map(|f| f.to_le_bytes()).collect(); let scalar_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(scalar_bytes, shape![], DType::F32), vec!["scalar".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(scalar_snapshot); // Empty tensor let empty_snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![], shape![0], DType::F32), vec!["empty".to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(empty_snapshot); }); } #[test] fn test_param_id_persistence() { use burn_core::module::ParamId; // Create a specific ParamId with a known value let original_param_id = ParamId::from(123456789u64); let data = [1.0f32, 2.0, 3.0, 4.0]; let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32), vec!["weights".to_string()], vec![], original_param_id, ); // Write to burnpack let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); // Read back from burnpack let reader = BurnpackReader::from_bytes(bytes).unwrap(); let loaded_snapshot = reader.get_tensor_snapshot("weights").unwrap(); // Verify ParamId was preserved assert!( loaded_snapshot.tensor_id.is_some(), "ParamId should be present" ); let loaded_param_id = loaded_snapshot.tensor_id.unwrap(); assert_eq!( loaded_param_id.val(), original_param_id.val(), "ParamId value should be preserved: expected {}, got {}", original_param_id.val(), loaded_param_id.val() ); } #[test] fn test_param_id_backward_compatibility() { use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor}; use alloc::collections::BTreeMap; // Create metadata without param_id (simulating old burnpack format) let mut tensors = BTreeMap::new(); tensors.insert( "old_tensor".to_string(), TensorDescriptor { dtype: DType::F32, shape: vec![2, 2], data_offsets: (0, 16), param_id: None, // No param_id stored (old format) }, ); let metadata = BurnpackMetadata { tensors, metadata: BTreeMap::new(), }; // Serialize metadata let mut metadata_bytes = Vec::new(); ciborium::ser::into_writer(&metadata, &mut metadata_bytes).unwrap(); // Create a complete burnpack with header and data use crate::burnpack::base::{BurnpackHeader, FORMAT_VERSION, MAGIC_NUMBER}; let metadata_size = metadata_bytes.len() as u32; let header = BurnpackHeader { magic: MAGIC_NUMBER, version: FORMAT_VERSION, metadata_size, }; let mut full_bytes = Vec::new(); full_bytes.extend_from_slice(&header.into_bytes()); full_bytes.extend_from_slice(&metadata_bytes); // Add tensor data (4 f32 values = 16 bytes) let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0]; for value in tensor_data { full_bytes.extend_from_slice(&value.to_le_bytes()); } // Read the old format burnpack let reader = BurnpackReader::from_bytes(burn_tensor::Bytes::from_bytes_vec(full_bytes)).unwrap(); let loaded_snapshot = reader.get_tensor_snapshot("old_tensor").unwrap(); // Verify that a new ParamId was generated (backward compatibility) assert!( loaded_snapshot.tensor_id.is_some(), "ParamId should be generated for old format" ); // The generated ParamId should be different each time (it's new), but we can't test the exact value // We just verify it exists and has a valid u64 value let generated_param_id = loaded_snapshot.tensor_id.unwrap(); assert!( generated_param_id.val() > 0, "Generated ParamId should have a valid value" ); } #[test] fn test_multiple_tensors_preserve_distinct_param_ids() { use burn_core::module::ParamId; // Create multiple tensors with distinct ParamIds let param_id_1 = ParamId::from(111111u64); let param_id_2 = ParamId::from(222222u64); let param_id_3 = ParamId::from(333333u64); let mut snapshots = Vec::new(); let data1 = [1.0f32, 2.0]; let bytes1: Vec = data1.iter().flat_map(|f| f.to_le_bytes()).collect(); snapshots.push(TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes1, vec![2], DType::F32), vec!["tensor1".to_string()], vec![], param_id_1, )); let data2 = [3.0f32, 4.0]; let bytes2: Vec = data2.iter().flat_map(|f| f.to_le_bytes()).collect(); snapshots.push(TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes2, vec![2], DType::F32), vec!["tensor2".to_string()], vec![], param_id_2, )); let data3 = [5.0f32, 6.0]; let bytes3: Vec = data3.iter().flat_map(|f| f.to_le_bytes()).collect(); snapshots.push(TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes3, vec![2], DType::F32), vec!["tensor3".to_string()], vec![], param_id_3, )); // Write to burnpack let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); // Read back let reader = BurnpackReader::from_bytes(bytes).unwrap(); let snapshot1 = reader.get_tensor_snapshot("tensor1").unwrap(); let snapshot2 = reader.get_tensor_snapshot("tensor2").unwrap(); let snapshot3 = reader.get_tensor_snapshot("tensor3").unwrap(); // Verify each ParamId was preserved correctly assert_eq!(snapshot1.tensor_id.unwrap().val(), param_id_1.val()); assert_eq!(snapshot2.tensor_id.unwrap().val(), param_id_2.val()); assert_eq!(snapshot3.tensor_id.unwrap().val(), param_id_3.val()); // Verify they are distinct let id1 = snapshot1.tensor_id.unwrap().val(); let id2 = snapshot2.tensor_id.unwrap().val(); let id3 = snapshot3.tensor_id.unwrap().val(); assert_ne!(id1, id2, "ParamIds should be distinct"); assert_ne!(id2, id3, "ParamIds should be distinct"); assert_ne!(id1, id3, "ParamIds should be distinct"); } ================================================ FILE: crates/burn-store/src/burnpack/tests/store.rs ================================================ #[cfg(feature = "std")] use crate::KeyRemapper; use crate::burnpack::store::BurnpackStore; use crate::{ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter}; use burn_core as burn; use burn_core::module::{Module, Param}; use burn_tensor::shape; use burn_tensor::{Tensor, backend::Backend}; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] struct TestModule { weight: Param>, bias: Param>, nested: NestedModule, } #[derive(Module, Debug)] struct NestedModule { gamma: Param>, beta: Param>, } impl TestModule { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), bias: Param::from_data([0.1, 0.2], device), nested: NestedModule { gamma: Param::from_data([1.0, 1.0], device), beta: Param::from_data([0.0, 0.0], device), }, } } fn new_zeros(device: &B::Device) -> Self { Self { weight: Param::from_tensor(Tensor::zeros([2, 2], device)), bias: Param::from_tensor(Tensor::zeros([2], device)), nested: NestedModule { gamma: Param::from_tensor(Tensor::zeros([2], device)), beta: Param::from_tensor(Tensor::zeros([2], device)), }, } } fn new_uninitialized(device: &B::Device) -> Self { use burn_core::module::ParamId; let device_clone = device.clone(); let device_clone2 = device.clone(); let device_clone3 = device.clone(); let device_clone4 = device.clone(); Self { weight: Param::uninitialized( ParamId::new(), move |d, _| Tensor::zeros([2, 2], d), device_clone, true, [2, 2].into(), ), bias: Param::uninitialized( ParamId::new(), move |d, _| Tensor::zeros([2], d), device_clone2, true, [2].into(), ), nested: NestedModule { gamma: Param::uninitialized( ParamId::new(), move |d, _| Tensor::zeros([2], d), device_clone3, true, [2].into(), ), beta: Param::uninitialized( ParamId::new(), move |d, _| Tensor::zeros([2], d), device_clone4, true, [2].into(), ), }, } } } #[test] fn test_store_from_bytes_round_trip() { let device = Default::default(); let module = TestModule::::new(&device); // Save to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load from bytes let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); // Verify success assert!(result.is_success()); assert_eq!(result.applied.len(), 4); // weight, bias, nested.gamma, nested.beta assert!(result.errors.is_empty()); // Verify data was loaded correctly let weight1 = module.weight.val().to_data().to_vec::().unwrap(); let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); assert_eq!(weight1, weight2); } #[test] fn test_store_with_metadata() { let device = Default::default(); let module = TestModule::::new(&device); // Save with metadata let mut save_store = BurnpackStore::from_bytes(None) .metadata("version", "1.0.0") .metadata("model_name", "test_model") .metadata("author", "burn_team"); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load and verify metadata is preserved let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); } #[test] #[cfg(feature = "std")] fn test_store_with_path_filter() { let device = Default::default(); let module = TestModule::::new(&device); // Save all tensors let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with filter - only load weight and bias (not nested) let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_regex("^(weight|bias)$"); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 2); // Only weight and bias assert_eq!(result.skipped.len(), 2); // nested.gamma and nested.beta skipped // Verify weight and bias were loaded let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]); // Verify nested module was NOT loaded (should still be zeros) let gamma2 = module2 .nested .gamma .val() .to_data() .to_vec::() .unwrap(); assert_eq!(gamma2, vec![0.0, 0.0]); } #[test] #[cfg(feature = "std")] fn test_store_with_key_remapping() { let device = Default::default(); let module = TestModule::::new(&device); // Save with original names let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with remapping: nested.gamma -> nested.new_gamma, nested.beta -> nested.new_beta let remapper = KeyRemapper::new() .add_pattern(r"nested\.gamma", "nested.new_gamma") .unwrap() .add_pattern(r"nested\.beta", "nested.new_beta") .unwrap(); let mut load_store = BurnpackStore::from_bytes(Some(bytes)) .remap(remapper) .allow_partial(true); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); // The remapping should cause missing tensors since names don't match assert_eq!(result.applied.len(), 2); // Only weight and bias match assert_eq!(result.unused.len(), 2); // nested.new_gamma and nested.new_beta are unused assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing } #[test] fn test_store_allow_partial() { let device = Default::default(); let module = TestModule::::new(&device); // Save only weight and bias let filter = PathFilter::new() .with_full_path("weight") .with_full_path("bias"); let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with allow_partial let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 2); assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing but that's OK // Verify loaded tensors let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]); } #[test] fn test_store_match_all() { let device = Default::default(); let module = TestModule::::new(&device); // Save with match_all filter (should save everything) let mut save_store = BurnpackStore::from_bytes(None).match_all(); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load everything let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); assert!(result.errors.is_empty()); assert!(result.missing.is_empty()); assert!(result.unused.is_empty()); } #[test] fn test_store_with_full_path() { let device = Default::default(); let module = TestModule::::new(&device); // Save everything let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load only specific tensors by full path let mut load_store = BurnpackStore::from_bytes(Some(bytes)) .with_full_path("weight") .with_full_path("nested.gamma"); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 2); // Only weight and nested.gamma assert_eq!(result.skipped.len(), 2); // bias and nested.beta skipped } #[test] #[cfg(feature = "std")] fn test_store_chain_multiple_patterns() { let device = Default::default(); let module = TestModule::::new(&device); // Save with chained metadata and filters let mut save_store = BurnpackStore::from_bytes(None) .metadata("version", "1.0") .metadata("format", "burnpack") .with_regex(r"^(weight|nested\.)") .match_all(); // This overrides the previous filter save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load everything since match_all was called last let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); // All tensors loaded } #[test] #[cfg(feature = "std")] fn test_store_with_remap_pattern() { let device = Default::default(); let module = TestModule::::new(&device); // Save normally let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with single remap pattern using the convenience method let mut load_store = BurnpackStore::from_bytes(Some(bytes)) .with_remap_pattern(r"^nested\.", "sub_module.") .allow_partial(true); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); // After remapping, nested.* becomes sub_module.*, which won't match assert_eq!(result.applied.len(), 2); // Only weight and bias assert_eq!(result.unused.len(), 2); // sub_module.gamma and sub_module.beta unused } #[test] fn test_store_default_metadata() { let device = Default::default(); let module = TestModule::::new(&device); // Save without adding custom metadata let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Verify default metadata is included // We can't directly inspect metadata from bytes, but we can verify // that the model loads successfully which means metadata was written correctly let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] fn test_store_default_metadata_with_custom() { let device = Default::default(); let module = TestModule::::new(&device); // Save with custom metadata (should preserve defaults) let mut save_store = BurnpackStore::from_bytes(None) .metadata("custom_field", "custom_value") .metadata("author", "test_author"); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load and verify it works (metadata including defaults was saved) let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] fn test_store_clear_metadata() { let device = Default::default(); let module = TestModule::::new(&device); // Save with cleared metadata (no defaults) let mut save_store = BurnpackStore::from_bytes(None).clear_metadata(); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Verify it still loads correctly let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] fn test_store_validate_enabled() { let device = Default::default(); let module = TestModule::::new(&device); // Save normally let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with validation enabled (default) let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert!(result.errors.is_empty()); } #[test] fn test_store_validate_disabled() { let device = Default::default(); let module = TestModule::::new(&device); // Save normally let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load with validation disabled let mut load_store = BurnpackStore::from_bytes(Some(bytes)).validate(false); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); // Should still succeed assert!(result.is_success()); } #[test] fn test_store_allow_partial_missing_tensors() { let device = Default::default(); let module = TestModule::::new(&device); // Save only weight (not bias or nested) let filter = PathFilter::new().with_full_path("weight"); let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Try to load without allow_partial - should fail due to missing tensors let mut load_store = BurnpackStore::from_bytes(Some(bytes.clone())); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2); // Should fail because of missing tensors assert!(result.is_err()); // Now try with allow_partial - should succeed let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true); let mut module3 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module3).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 1); // Only weight assert!(!result.missing.is_empty()); // Has missing tensors } #[test] #[cfg(feature = "std")] fn test_store_file_round_trip() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory and file path let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_file_round_trip.bpk"); // Save to file let mut save_store = BurnpackStore::from_file(&path).metadata("test", "value"); save_store.collect_from(&module).unwrap(); // Verify file exists assert!(path.exists()); // Load from file let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); // Verify data let weight1 = module.weight.val().to_data().to_vec::().unwrap(); let weight2 = module2.weight.val().to_data().to_vec::().unwrap(); assert_eq!(weight1, weight2); } #[test] #[cfg(feature = "std")] fn test_store_overwrite_protection() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory and file path (file doesn't exist yet) let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_model.bpk"); // First save - should succeed let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); assert!(path.exists()); // Second save without overwrite flag - should fail let mut save_store2 = BurnpackStore::from_file(&path); let result = save_store2.collect_from(&module); assert!(result.is_err()); assert!( result .unwrap_err() .to_string() .contains("File already exists") ); // Third save with overwrite flag - should succeed let mut save_store3 = BurnpackStore::from_file(&path).overwrite(true); save_store3.collect_from(&module).unwrap(); // Verify file still exists and is valid let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_store_overwrite_with_metadata() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory and file path let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_model_metadata.bpk"); // First save with v1 metadata let mut save_store = BurnpackStore::from_file(&path) .metadata("version", "1.0") .overwrite(true); save_store.collect_from(&module).unwrap(); // Second save with v2 metadata and overwrite enabled let mut save_store2 = BurnpackStore::from_file(&path) .metadata("version", "2.0") .overwrite(true); save_store2.collect_from(&module).unwrap(); // Verify file loads correctly let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_store_auto_extension_default() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("model"); // Save without extension - should auto-append .bpk let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Verify that model.bpk was created let expected_path = temp_dir.path().join("model.bpk"); assert!(expected_path.exists()); assert!(!path.exists()); // Original path without extension should not exist // Load using the path without extension - should work let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_store_auto_extension_with_existing_extension() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("model.bpk"); // Save with .bpk extension - should not double append let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Verify that only model.bpk was created assert!(path.exists()); let double_ext_path = temp_dir.path().join("model.bpk.bpk"); assert!(!double_ext_path.exists()); // Load and verify let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_store_auto_extension_with_custom_extension() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("model.mpk"); // Save with .mpk extension - should preserve it let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Verify that model.mpk was created (not model.mpk.bpk) assert!(path.exists()); let burnpack_path = temp_dir.path().join("model.mpk.bpk"); assert!(!burnpack_path.exists()); // Load and verify let mut load_store = BurnpackStore::from_file(&path); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_store_auto_extension_disabled() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Create temp directory let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("model"); // Save with auto_extension disabled - should use exact path let mut save_store = BurnpackStore::from_file(&path).auto_extension(false); save_store.collect_from(&module).unwrap(); // Verify that "model" (without extension) was created assert!(path.exists()); let burnpack_path = temp_dir.path().join("model.bpk"); assert!(!burnpack_path.exists()); // Load with auto_extension disabled let mut load_store = BurnpackStore::from_file(&path).auto_extension(false); let mut module2 = TestModule::::new_zeros(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] #[cfg(feature = "std")] fn test_partial_loading_preserves_lazy_initialization() { use tempfile::tempdir; let device = Default::default(); // Create and save a full module let module = TestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("model.bpk"); let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Create an uninitialized module (all params lazy) let mut load_module = TestModule::::new_uninitialized(&device); // Before loading: verify ALL params are uninitialized (lazy) assert!( !load_module.weight.is_initialized(), "weight should be uninitialized before loading" ); assert!( !load_module.bias.is_initialized(), "bias should be uninitialized before loading" ); assert!( !load_module.nested.gamma.is_initialized(), "nested.gamma should be uninitialized before loading" ); assert!( !load_module.nested.beta.is_initialized(), "nested.beta should be uninitialized before loading" ); // Partial load: only load weight and bias (skip nested.*) let filter = PathFilter::new().with_regex("^(weight|bias)$"); let mut load_store = BurnpackStore::from_file(&path).filter(filter); let result = load_module.load_from(&mut load_store).unwrap(); // Verify only weight and bias were loaded assert_eq!(result.applied.len(), 2); assert!(result.applied.contains(&"weight".to_string())); assert!(result.applied.contains(&"bias".to_string())); assert_eq!(result.skipped.len(), 2); assert!(result.skipped.contains(&"nested.gamma".to_string())); assert!(result.skipped.contains(&"nested.beta".to_string())); // After loading: verify loaded params are initialized, skipped remain lazy assert!( load_module.weight.is_initialized(), "weight should be initialized after loading" ); assert!( load_module.bias.is_initialized(), "bias should be initialized after loading" ); assert!( !load_module.nested.gamma.is_initialized(), "nested.gamma should remain uninitialized (was skipped)" ); assert!( !load_module.nested.beta.is_initialized(), "nested.beta should remain uninitialized (was skipped)" ); // Verify the loaded values are correct let weight_data = load_module.weight.val().to_data().to_vec::().unwrap(); assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); let bias_data = load_module.bias.val().to_data().to_vec::().unwrap(); assert_eq!(bias_data, vec![0.1, 0.2]); // Now check that nested params can still be initialized on first access let gamma_data = load_module .nested .gamma .val() .to_data() .to_vec::() .unwrap(); assert_eq!(gamma_data, vec![0.0, 0.0]); // Initialized to zeros via the init function // After accessing, they should be initialized assert!( load_module.nested.gamma.is_initialized(), "nested.gamma should be initialized after first access" ); } // Model with forward pass for testing weight preservation #[derive(Module, Debug)] struct ForwardTestModel { linear1: burn_nn::Linear, linear2: burn_nn::Linear, } impl ForwardTestModel { /// Forward pass: input -> linear1 -> gelu -> linear2 fn forward(&self, input: Tensor) -> Tensor { let x = self.linear1.forward(input); let x = burn::tensor::activation::gelu(x); self.linear2.forward(x) } } #[derive(burn::config::Config, Debug)] struct ForwardTestModelConfig { input_size: usize, hidden_size: usize, output_size: usize, } impl ForwardTestModelConfig { fn init(&self, device: &B::Device) -> ForwardTestModel { ForwardTestModel { linear1: burn_nn::LinearConfig::new(self.input_size, self.hidden_size) .with_bias(true) .init(device), linear2: burn_nn::LinearConfig::new(self.hidden_size, self.output_size) .with_bias(true) .init(device), } } } #[test] #[cfg(feature = "std")] fn test_forward_pass_preservation_after_save_load() { use tempfile::tempdir; let device = Default::default(); // Create model config let config = ForwardTestModelConfig { input_size: 4, hidden_size: 8, output_size: 2, }; // Initialize model1 with random weights let model1 = config.init::(&device); // Create random input let input = Tensor::::random( [1, 4], burn_tensor::Distribution::Uniform(-1.0, 1.0), &device, ); // Forward pass with model1 -> output1 let output1 = model1.forward(input.clone()); // Save model1 weights let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("forward_test_model.bpk"); let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&model1).unwrap(); // Initialize model2 with different random weights let mut model2 = config.init::(&device); // Forward pass with model2 -> output2 (should differ from output1) let output2 = model2.forward(input.clone()); // Verify output2 differs from output1 (different random weights) assert!( !output1 .clone() .all_close(output2.clone(), Some(1e-6), Some(1e-6)), "output2 should differ from output1 (different random initializations)" ); // Load model1 weights into model2 let mut load_store = BurnpackStore::from_file(&path); let result = load_store.apply_to(&mut model2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases // Forward pass with model2 (now has model1 weights) -> output3 let output3 = model2.forward(input.clone()); // Verify output3 equals output1 (same weights) assert!( output1.all_close(output3, Some(1e-6), Some(1e-6)), "output3 should equal output1 after loading weights" ); } #[test] fn test_store_get_all_snapshots() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Get all snapshots (returns &BTreeMap) let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshots = load_store.get_all_snapshots().unwrap(); // Should have 4 tensors assert_eq!(snapshots.len(), 4); // Verify tensor names exist (BTreeMap keys) assert!(snapshots.contains_key("weight")); assert!(snapshots.contains_key("bias")); assert!(snapshots.contains_key("nested.gamma")); assert!(snapshots.contains_key("nested.beta")); } #[test] fn test_store_get_snapshot_existing() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Get a specific snapshot (returns Option<&TensorSnapshot>) let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshot = load_store.get_snapshot("weight").unwrap(); // Should find the tensor assert!(snapshot.is_some()); let snapshot = snapshot.unwrap(); assert_eq!(snapshot.full_path(), "weight"); assert_eq!(snapshot.shape, shape![2, 2]); // Verify data can be loaded let data = snapshot.to_data().unwrap(); assert_eq!(data.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0]); } #[test] fn test_store_get_snapshot_nested() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Get a nested snapshot let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshot = load_store.get_snapshot("nested.gamma").unwrap(); assert!(snapshot.is_some()); let snapshot = snapshot.unwrap(); assert_eq!(snapshot.full_path(), "nested.gamma"); assert_eq!(snapshot.shape, shape![2]); } #[test] fn test_store_get_snapshot_not_found() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Try to get a non-existent snapshot let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshot = load_store.get_snapshot("nonexistent").unwrap(); // Should return None assert!(snapshot.is_none()); } #[test] fn test_store_keys() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Get all keys let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let keys = load_store.keys().unwrap(); // Should have 4 keys assert_eq!(keys.len(), 4); assert!(keys.contains(&"weight".to_string())); assert!(keys.contains(&"bias".to_string())); assert!(keys.contains(&"nested.gamma".to_string())); assert!(keys.contains(&"nested.beta".to_string())); } #[test] #[cfg(feature = "std")] fn test_store_get_all_snapshots_from_file() { use tempfile::tempdir; let device = Default::default(); let module = TestModule::::new(&device); // Save to file let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_get_all_snapshots.bpk"); let mut save_store = BurnpackStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Get snapshots from file (returns &BTreeMap) let mut load_store = BurnpackStore::from_file(&path); let snapshots = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots.len(), 4); // Verify we can load data from a snapshot (use get() on BTreeMap) let weight_snapshot = snapshots.get("weight").unwrap(); let data = weight_snapshot.to_data().unwrap(); assert_eq!(data.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0]); } #[test] fn test_store_caching_behavior() { let device = Default::default(); let module = TestModule::::new(&device); // Save module to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Create store and call get_snapshots multiple times let mut load_store = BurnpackStore::from_bytes(Some(bytes)); // First call should populate cache let snapshots1 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots1.len(), 4); // Second call should return cached data (same reference) let snapshots2 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots2.len(), 4); // get_snapshot should also use the cache let weight = load_store.get_snapshot("weight").unwrap(); assert!(weight.is_some()); } #[test] fn test_store_cache_invalidation_on_save() { let device = Default::default(); // Create first module with specific weights let module1 = TestModule::::new(&device); // Save module1 to bytes store let mut store = BurnpackStore::from_bytes(None); store.collect_from(&module1).unwrap(); // Populate cache by calling get_snapshots let snapshots1 = store.get_all_snapshots().unwrap(); assert_eq!(snapshots1.len(), 4); let weight1_data = snapshots1.get("weight").unwrap().to_data().unwrap(); let weight1_values: Vec = weight1_data.to_vec().unwrap(); // Create a different module with different weights let module2 = TestModule:: { weight: Param::from_tensor(Tensor::from_data([[10.0, 20.0], [30.0, 40.0]], &device)), bias: Param::from_tensor(Tensor::from_data([100.0, 200.0], &device)), nested: NestedModule { gamma: Param::from_tensor(Tensor::from_data([1000.0, 2000.0], &device)), beta: Param::from_tensor(Tensor::from_data([3000.0, 4000.0], &device)), }, }; // Save module2 - this should invalidate the cache store.collect_from(&module2).unwrap(); // Get snapshots again - should return NEW data, not cached old data let snapshots2 = store.get_all_snapshots().unwrap(); assert_eq!(snapshots2.len(), 4); let weight2_data = snapshots2.get("weight").unwrap().to_data().unwrap(); let weight2_values: Vec = weight2_data.to_vec().unwrap(); // Verify the data changed (cache was invalidated) assert_ne!(weight1_values, weight2_values); assert_eq!(weight2_values, vec![10.0, 20.0, 30.0, 40.0]); } /// Test storing and loading quantized weights with BurnpackStore. /// Regression test for https://github.com/tracel-ai/burn/issues/4179 #[test] fn test_store_quantized_module_round_trip() { use burn_core::module::Quantizer; use burn_nn::LinearConfig; use burn_tensor::quantization::{ Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue, }; let device = Default::default(); // Create a simple linear module (512x512 as in the bug report) let linear = LinearConfig::new(512, 512) .with_bias(false) .init::(&device); // Define quantization scheme (Q8S with tensor-level quantization) let scheme = <::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::Tensor) .with_param(QuantParam::F32); // Quantize the module let calibration = Calibration::MinMax; let mut quantizer = Quantizer { calibration, scheme, }; let quantized_linear = linear.quantize_weights(&mut quantizer); // Save the quantized module let mut save_store = BurnpackStore::from_bytes(None); let result = save_store.collect_from(&quantized_linear); assert!( result.is_ok(), "Failed to save quantized module: {:?}", result.err() ); // Get the bytes let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load the bytes and verify we can read the tensor metadata let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshots = load_store .get_all_snapshots() .expect("Failed to get snapshots"); // Verify we have the weight tensor assert_eq!(snapshots.len(), 1, "Expected 1 tensor (weight)"); assert!(snapshots.contains_key("weight"), "Expected 'weight' tensor"); // Verify the tensor metadata let weight_snapshot = snapshots.get("weight").unwrap(); assert_eq!(weight_snapshot.shape, shape![512, 512]); // Verify we can load the tensor data let weight_data = weight_snapshot .to_data() .expect("Failed to load tensor data"); assert_eq!(weight_data.shape, shape![512, 512]); } /// Test HalfPrecisionAdapter bidirectional round-trip: same adapter for save and load. #[test] fn test_store_half_precision_round_trip() { use crate::HalfPrecisionAdapter; use burn_nn::{Linear, LinearConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct HalfModel { linear: Linear, } let device = Default::default(); let model = HalfModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), }; // Save with HalfPrecisionAdapter (F32 -> F16) let adapter = HalfPrecisionAdapter::new(); let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter.clone()); save_store.collect_from(&model).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Verify stored tensors are F16 let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone())); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (_, snapshot) in snapshots.iter() { assert_eq!(snapshot.dtype, DType::F16, "Expected F16 in stored data"); } // Load back with same adapter instance (F16 -> F32) let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter); let mut model2 = HalfModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), }; let result = load_store.apply_to(&mut model2).unwrap(); assert!(result.is_success()); // Verify values are close (F32 -> F16 -> F32 has rounding) let w1 = model.linear.weight.val().to_data().to_vec::().unwrap(); let w2 = model2 .linear .weight .val() .to_data() .to_vec::() .unwrap(); for (a, b) in w1.iter().zip(w2.iter()) { assert!( (a - b).abs() < 0.01, "Weight values differ too much after F16 round-trip: {} vs {}", a, b ); } } /// Test HalfPrecisionAdapter: BatchNorm excluded by default. #[test] fn test_store_half_precision_batch_norm_excluded() { use crate::HalfPrecisionAdapter; use burn_nn::{BatchNorm, BatchNormConfig, Linear, LinearConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct BnModel { linear: Linear, bn: BatchNorm, } let device = Default::default(); let model = BnModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), bn: BatchNormConfig::new(2).init(&device), }; let adapter = HalfPrecisionAdapter::new(); let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter); save_store.collect_from(&model).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Verify: Linear tensors are F16, BatchNorm tensors remain F32 let mut inspect_store = BurnpackStore::from_bytes(Some(bytes)); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (name, snapshot) in snapshots.iter() { if name.starts_with("linear") { assert_eq!( snapshot.dtype, DType::F16, "Linear tensor '{}' should be F16", name ); } else if name.starts_with("bn") { assert_eq!( snapshot.dtype, DType::F32, "BatchNorm tensor '{}' should stay F32", name ); } } } /// Test HalfPrecisionAdapter with without_module customization. #[test] fn test_store_half_precision_without_module() { use crate::HalfPrecisionAdapter; use burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct MixedModel { linear: Linear, norm: LayerNorm, } let device = Default::default(); let model = MixedModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), norm: LayerNormConfig::new(2).init(&device), }; // Remove LayerNorm from half-precision conversion let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm"); let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter); save_store.collect_from(&model).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut inspect_store = BurnpackStore::from_bytes(Some(bytes)); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (name, snapshot) in snapshots.iter() { if name.starts_with("linear") { assert_eq!( snapshot.dtype, DType::F16, "Linear tensor '{}' should be F16", name ); } else if name.starts_with("norm") { assert_eq!( snapshot.dtype, DType::F32, "LayerNorm tensor '{}' should stay F32", name ); } } } /// Test HalfPrecisionAdapter chained with PyTorch adapter. #[test] fn test_store_half_precision_chained_with_pytorch() { use crate::{HalfPrecisionAdapter, PyTorchToBurnAdapter}; use burn_nn::{Linear, LinearConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct ChainModel { linear: Linear, } let device = Default::default(); let model = ChainModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), }; // Save with chained adapter: BurnToPyTorch then half-precision let adapter = crate::BurnToPyTorchAdapter.chain(HalfPrecisionAdapter::new()); let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter); save_store.collect_from(&model).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Verify stored tensors are F16 and transposed let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone())); let snapshots = inspect_store.get_all_snapshots().unwrap(); let weight = snapshots.get("linear.weight").unwrap(); assert_eq!(weight.dtype, DType::F16); // Weight should be transposed: [4, 2] original -> [2, 4] after BurnToPyTorch assert_eq!(weight.shape, shape![2, 4]); // Load back with reverse chain: half-precision (F16 -> F32) then PyTorchToBurn let adapter = HalfPrecisionAdapter::new().chain(PyTorchToBurnAdapter); let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter); let mut model2 = ChainModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), }; let result = load_store.apply_to(&mut model2).unwrap(); assert!(result.is_success()); } /// Test storing quantized weights with block-level quantization. #[test] fn test_store_quantized_module_block_level() { use burn_core::module::Quantizer; use burn_nn::LinearConfig; use burn_tensor::quantization::{ Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue, }; let device = Default::default(); // Create a linear module let linear = LinearConfig::new(128, 128) .with_bias(false) .init::(&device); // Define quantization scheme with block-level quantization let scheme = <::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme() .with_value(QuantValue::Q8S) .with_level(QuantLevel::block([32])) // Block size of 32 .with_param(QuantParam::F32); // Quantize the module let calibration = Calibration::MinMax; let mut quantizer = Quantizer { calibration, scheme, }; let quantized_linear = linear.quantize_weights(&mut quantizer); // Save the quantized module let mut save_store = BurnpackStore::from_bytes(None); let result = save_store.collect_from(&quantized_linear); assert!( result.is_ok(), "Failed to save quantized module with block-level quantization: {:?}", result.err() ); // Get the bytes and verify round-trip let bytes = save_store.get_bytes().expect("Failed to get bytes"); let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let snapshots = load_store .get_all_snapshots() .expect("Failed to get snapshots"); assert_eq!(snapshots.len(), 1); let weight_snapshot = snapshots.get("weight").unwrap(); assert_eq!(weight_snapshot.shape, shape![128, 128]); } ================================================ FILE: crates/burn-store/src/burnpack/tests/writer.rs ================================================ use crate::burnpack::{ base::{ BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, aligned_data_section_start, magic_range, }, writer::BurnpackWriter, }; use super::*; use burn_core::module::ParamId; use burn_tensor::{BoolStore, DType, TensorData, shape}; use std::rc::Rc; #[test] fn test_writer_new() { let writer = BurnpackWriter::new(vec![]); assert_eq!(writer.snapshots.len(), 0); assert!(writer.metadata.is_empty()); } #[test] fn test_writer_add_metadata() { let writer = BurnpackWriter::new(vec![]) .with_metadata("model_name", "test_model") .with_metadata("version", "1.0.0") .with_metadata("author", "test_author"); assert_eq!(writer.metadata.len(), 3); assert_eq!( writer.metadata.get("model_name"), Some(&"test_model".to_string()) ); assert_eq!(writer.metadata.get("version"), Some(&"1.0.0".to_string())); assert_eq!( writer.metadata.get("author"), Some(&"test_author".to_string()) ); } #[test] fn test_writer_add_tensor_snapshot() { // Create test tensor snapshots let snapshot1 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["layer1".to_string(), "weights".to_string()], vec![], burn_core::module::ParamId::new(), ); let snapshot2 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8), vec!["layer1".to_string(), "bias".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]); assert_eq!(writer.snapshots.len(), 2); assert_eq!(writer.snapshots[0].full_path(), "layer1.weights"); assert_eq!(writer.snapshots[1].full_path(), "layer1.bias"); } #[test] fn test_writer_to_bytes_empty() { let writer = BurnpackWriter::new(vec![]); let bytes = writer.to_bytes().unwrap(); // Verify header assert!(bytes.len() >= HEADER_SIZE); assert_eq!(&bytes[magic_range()], &MAGIC_NUMBER.to_le_bytes()); // Parse header let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); assert_eq!(header.magic, MAGIC_NUMBER); assert_eq!(header.version, FORMAT_VERSION); // Verify metadata let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata_bytes = &bytes[HEADER_SIZE..metadata_end]; let metadata: BurnpackMetadata = ciborium::de::from_reader(metadata_bytes).unwrap(); assert_eq!(metadata.tensors.len(), 0); assert!(metadata.metadata.is_empty()); } #[test] fn test_writer_to_bytes_with_tensors() { // Add tensors with different data types let f32_data = [1.0f32, 2.0, 3.0, 4.0]; let f32_bytes: Vec = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot_f32 = TensorSnapshot::from_data( TensorData::from_bytes_vec(f32_bytes.clone(), vec![2, 2], DType::F32), vec!["weights".to_string()], vec![], burn_core::module::ParamId::new(), ); let i64_data = [10i64, 20, 30]; let i64_bytes: Vec = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect(); let snapshot_i64 = TensorSnapshot::from_data( TensorData::from_bytes_vec(i64_bytes.clone(), vec![3], DType::I64), vec!["bias".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot_f32, snapshot_i64]) .with_metadata("test_key", "test_value"); let bytes = writer.to_bytes().unwrap(); // Parse and verify let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap(); // Verify metadata assert_eq!( metadata.metadata.get("test_key"), Some(&"test_value".to_string()) ); // Verify tensors assert_eq!(metadata.tensors.len(), 2); let weights = metadata.tensors.get("weights").unwrap(); assert_eq!(weights.dtype, DType::F32); assert_eq!(weights.shape, vec![2, 2]); assert_eq!(weights.data_offsets.1 - weights.data_offsets.0, 16); // 4 * 4 bytes let bias = metadata.tensors.get("bias").unwrap(); assert_eq!(bias.dtype, DType::I64); assert_eq!(bias.shape, vec![3]); assert_eq!(bias.data_offsets.1 - bias.data_offsets.0, 24); // 3 * 8 bytes // Verify actual tensor data // Data section starts at aligned position after metadata let data_section_start = aligned_data_section_start(header.metadata_size as usize); let weights = metadata.tensors.get("weights").unwrap(); let bias = metadata.tensors.get("bias").unwrap(); let weights_data = &bytes[data_section_start + weights.data_offsets.0 as usize ..data_section_start + weights.data_offsets.1 as usize]; assert_eq!(weights_data, f32_bytes); let bias_data = &bytes[data_section_start + bias.data_offsets.0 as usize ..data_section_start + bias.data_offsets.1 as usize]; assert_eq!(bias_data, i64_bytes); } #[test] fn test_writer_all_dtypes() { use half::{bf16, f16}; // Test all supported data types (excluding QFloat which is tested separately) // Format: (DType, expected_size_per_element, sample_data_bytes) let test_cases = vec![ // Floating point types (DType::F64, 8, 1.0f64.to_le_bytes().to_vec()), (DType::F32, 4, 1.0f32.to_le_bytes().to_vec()), (DType::F16, 2, f16::from_f32(1.0).to_le_bytes().to_vec()), (DType::BF16, 2, bf16::from_f32(1.0).to_le_bytes().to_vec()), // Signed integers (DType::I64, 8, 1i64.to_le_bytes().to_vec()), (DType::I32, 4, 1i32.to_le_bytes().to_vec()), (DType::I16, 2, 1i16.to_le_bytes().to_vec()), (DType::I8, 1, 1i8.to_le_bytes().to_vec()), // Unsigned integers (DType::U64, 8, 255u64.to_le_bytes().to_vec()), (DType::U32, 4, 255u32.to_le_bytes().to_vec()), (DType::U16, 2, 255u16.to_le_bytes().to_vec()), (DType::U8, 1, vec![255u8]), // Boolean (DType::Bool(BoolStore::Native), 1, vec![1u8]), ]; let mut snapshots = vec![]; let mut expected_data = vec![]; for (i, (dtype, expected_size, data)) in test_cases.into_iter().enumerate() { let name = format!("tensor_{}", i); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), vec![1], dtype), vec![name.clone()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); expected_data.push((name, dtype, expected_size, data)); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); // Parse and verify metadata let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) .unwrap(); assert_eq!( metadata.tensors.len(), 13, "Expected 13 dtypes to be tested" ); // Verify each tensor's metadata and data let data_section_start = aligned_data_section_start(header.metadata_size as usize); for (name, expected_dtype, expected_size, expected_bytes) in expected_data { let tensor = metadata .tensors .get(&name) .unwrap_or_else(|| panic!("Missing tensor: {}", name)); assert_eq!(tensor.dtype, expected_dtype, "DType mismatch for {}", name); assert_eq!(tensor.shape, vec![1], "Shape mismatch for {}", name); // Verify data size matches expected let data_size = (tensor.data_offsets.1 - tensor.data_offsets.0) as usize; assert_eq!( data_size, expected_size, "Data size mismatch for {} ({:?})", name, expected_dtype ); // Verify actual data bytes match let actual_bytes = &bytes[data_section_start + tensor.data_offsets.0 as usize ..data_section_start + tensor.data_offsets.1 as usize]; assert_eq!( actual_bytes, expected_bytes.as_slice(), "Data mismatch for {} ({:?})", name, expected_dtype ); } } #[test] fn test_writer_all_dtypes_round_trip() { use crate::burnpack::reader::BurnpackReader; use half::{bf16, f16}; // Test all dtypes can be written and read back correctly let test_cases = vec![ // Floating point types - use multiple elements to better test ( "f64_tensor", DType::F64, [1.0f64, 2.0, 3.0, 4.0] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![4], ), ( "f32_tensor", DType::F32, [1.0f32, 2.0, 3.0, 4.0] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2, 2], ), ( "f16_tensor", DType::F16, [f16::from_f32(1.0), f16::from_f32(2.0)] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2], ), ( "bf16_tensor", DType::BF16, [bf16::from_f32(1.0), bf16::from_f32(2.0)] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2], ), // Signed integers ( "i64_tensor", DType::I64, [1i64, -2, 3, -4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![4], ), ( "i32_tensor", DType::I32, [1i32, -2, 3, -4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2, 2], ), ( "i16_tensor", DType::I16, [1i16, -2, 3, -4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![4], ), ( "i8_tensor", DType::I8, [1i8, -2, 3, -4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2, 2], ), // Unsigned integers ( "u64_tensor", DType::U64, [1u64, 2, 3, 4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![4], ), ( "u32_tensor", DType::U32, [1u32, 2, 3, 4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![2, 2], ), ( "u16_tensor", DType::U16, [1u16, 2, 3, 4] .iter() .flat_map(|v| v.to_le_bytes()) .collect::>(), shape![4], ), ("u8_tensor", DType::U8, vec![1u8, 2, 3, 4], shape![2, 2]), // Boolean ( "bool_tensor", DType::Bool(BoolStore::Native), vec![1u8, 0, 1, 0], shape![4], ), ]; let mut snapshots = vec![]; let mut expected_results: Vec<(&str, DType, Vec, _)> = vec![]; for (name, dtype, data, shape) in test_cases.into_iter() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(data.clone(), shape.clone(), dtype), vec![name.to_string()], vec![], burn_core::module::ParamId::new(), ); snapshots.push(snapshot); expected_results.push((name, dtype, data, shape)); } // Write to bytes let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); // Read back using BurnpackReader let reader = BurnpackReader::from_bytes(bytes).unwrap(); // Verify each tensor can be read back with correct data for (name, expected_dtype, expected_data, expected_shape) in expected_results { let snapshot = reader .get_tensor_snapshot(name) .unwrap_or_else(|e| panic!("Failed to get tensor snapshot {}: {}", name, e)); let tensor_data = snapshot .to_data() .unwrap_or_else(|e| panic!("Failed to read tensor data {}: {}", name, e)); assert_eq!( tensor_data.dtype, expected_dtype, "DType mismatch for {}", name ); assert_eq!( tensor_data.shape, expected_shape, "Shape mismatch for {}", name ); assert_eq!( &tensor_data.bytes[..], expected_data.as_slice(), "Data mismatch for {}", name ); } } #[test] fn test_writer_large_tensor() { // Create a large tensor (1MB) let size = 256 * 1024; // 256K floats = 1MB let data: Vec = (0..size).map(|i| i as f32).collect(); let bytes: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(bytes.clone(), vec![size], DType::F32), vec!["large_tensor".to_string()], vec![], burn_core::module::ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let result = writer.to_bytes().unwrap(); // Verify the large tensor is correctly stored let header = BurnpackHeader::from_bytes(&result[..HEADER_SIZE]).unwrap(); let metadata: BurnpackMetadata = ciborium::de::from_reader( &result[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize], ) .unwrap(); assert_eq!(metadata.tensors.len(), 1); let tensor = metadata.tensors.get("large_tensor").unwrap(); assert_eq!(tensor.shape, vec![size as u64]); assert_eq!( tensor.data_offsets.1 - tensor.data_offsets.0, (size * 4) as u64 ); } #[test] fn test_writer_empty_tensors() { // Add tensor with empty data let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![], vec![0], DType::F32), vec!["empty".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); let bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) .unwrap(); assert_eq!(metadata.tensors.len(), 1); let tensor = metadata.tensors.get("empty").unwrap(); assert_eq!(tensor.shape, vec![0]); assert_eq!(tensor.data_offsets.1 - tensor.data_offsets.0, 0); } #[test] fn test_writer_special_characters_in_names() { // Test various special characters in tensor names let special_names = vec![ "layer.0.weight", "model/encoder/layer1", "model::layer::weight", "layer[0].bias", "layer_1_weight", "layer-1-bias", "layer@1#weight", "emoji_😀_tensor", "unicode_测试_tensor", "spaces in name", ]; let mut snapshots = vec![]; for name in &special_names { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8), vec![name.to_string()], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) .unwrap(); assert_eq!(metadata.tensors.len(), 10); for (tensor_name, _tensor) in metadata.tensors.iter() { assert!(!tensor_name.is_empty()); // Names should be preserved exactly assert!( tensor_name.contains("layer") || tensor_name.contains("model") || tensor_name.contains("emoji") || tensor_name.contains("unicode") || tensor_name.contains("spaces") ); } } #[test] fn test_writer_metadata_overwrite() { let writer = BurnpackWriter::new(vec![]) .with_metadata("key", "value1") .with_metadata("key", "value2"); assert_eq!(writer.metadata.get("key"), Some(&"value2".to_string())); assert_eq!(writer.metadata.len(), 1); } #[test] fn test_writer_tensor_order_preserved() { // Add tensors in specific order let names = vec!["z_tensor", "a_tensor", "m_tensor", "b_tensor"]; let mut snapshots = vec![]; for name in &names { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1], vec![1], DType::U8), vec![name.to_string()], vec![], ParamId::new(), ); snapshots.push(snapshot); } let writer = BurnpackWriter::new(snapshots); let bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize]) .unwrap(); // Verify all tensors are present (BTreeMap stores in sorted order by key) let expected_sorted = vec!["a_tensor", "b_tensor", "m_tensor", "z_tensor"]; let actual_names: Vec<_> = metadata.tensors.keys().collect(); assert_eq!(actual_names, expected_sorted); } #[test] fn test_writer_lazy_snapshot_evaluation() { // Create a lazy snapshot using closure let data = Rc::new(vec![1.0f32, 2.0, 3.0, 4.0]); let data_clone = data.clone(); let snapshot = TensorSnapshot::from_closure( Rc::new(move || { let bytes: Vec = data_clone.iter().flat_map(|f| f.to_le_bytes()).collect(); Ok(TensorData::from_bytes_vec(bytes, shape![2, 2], DType::F32)) }), DType::F32, shape![2, 2], vec!["lazy".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); // The closure should only be evaluated when to_bytes is called let bytes = writer.to_bytes().unwrap(); let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap(); let metadata_end = HEADER_SIZE + header.metadata_size as usize; let metadata: BurnpackMetadata = ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap(); assert_eq!(metadata.tensors.len(), 1); let tensor = metadata.tensors.get("lazy").unwrap(); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, vec![2, 2]); // Verify the data was correctly written // Data section starts at aligned position after metadata let data_section_start = aligned_data_section_start(header.metadata_size as usize); let tensor_data = &bytes[data_section_start..data_section_start + 16]; let expected: Vec = [1.0f32, 2.0, 3.0, 4.0] .iter() .flat_map(|f| f.to_le_bytes()) .collect(); assert_eq!(tensor_data, expected.as_slice()); } #[cfg(feature = "std")] #[test] fn test_writer_write_to_file() { use std::fs; use tempfile::tempdir; let dir = tempdir().unwrap(); let file_path = dir.path().join("test.bpk"); let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("file_test", "true"); writer.write_to_file(&file_path).unwrap(); // Verify file exists and has correct content assert!(file_path.exists()); let file_bytes = fs::read(&file_path).unwrap(); let memory_bytes = writer.to_bytes().unwrap(); assert_eq!(file_bytes.as_slice(), &*memory_bytes); } #[test] fn test_writer_size() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); let size = writer.size().unwrap(); let bytes = writer.to_bytes().unwrap(); // Size should match actual bytes length assert_eq!(size, bytes.len()); } #[test] fn test_writer_write_into() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]).with_metadata("test", "value"); // Get size and allocate buffer let size = writer.size().unwrap(); let mut buffer = vec![0u8; size]; // Write into buffer writer.write_into(&mut buffer).unwrap(); // Compare with to_bytes() let bytes = writer.to_bytes().unwrap(); assert_eq!(buffer.as_slice(), &*bytes); } #[test] fn test_writer_write_into_buffer_too_small() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); // Allocate a buffer that's too small let mut buffer = vec![0u8; 10]; // Should fail with buffer too small error let result = writer.write_into(&mut buffer); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("Buffer too small")); } #[test] fn test_writer_write_into_buffer_larger_than_needed() { let snapshot = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["test".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot]); // Allocate a larger buffer let size = writer.size().unwrap(); let mut buffer = vec![0u8; size + 100]; // Extra 100 bytes // Should succeed and only write the necessary bytes writer.write_into(&mut buffer).unwrap(); // Compare the written portion with to_bytes() let bytes = writer.to_bytes().unwrap(); assert_eq!(&buffer[..size], &*bytes); } #[test] fn test_writer_write_into_multiple_tensors() { let snapshot1 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8), vec!["tensor1".to_string()], vec![], ParamId::new(), ); let snapshot2 = TensorSnapshot::from_data( TensorData::from_bytes_vec(vec![5, 6, 7, 8, 9, 10], vec![2, 3], DType::U8), vec!["tensor2".to_string()], vec![], ParamId::new(), ); let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]).with_metadata("test", "multiple"); let size = writer.size().unwrap(); let mut buffer = vec![0u8; size]; writer.write_into(&mut buffer).unwrap(); let bytes = writer.to_bytes().unwrap(); assert_eq!(buffer.as_slice(), &*bytes); } #[test] fn test_writer_write_into_empty() { let writer = BurnpackWriter::new(vec![]); let size = writer.size().unwrap(); let mut buffer = vec![0u8; size]; writer.write_into(&mut buffer).unwrap(); let bytes = writer.to_bytes().unwrap(); assert_eq!(buffer.as_slice(), &*bytes); } ================================================ FILE: crates/burn-store/src/burnpack/tests/zero_copy.rs ================================================ //! Tests for zero-copy tensor loading functionality. use crate::ModuleStore; use crate::burnpack::store::BurnpackStore; use burn_core as burn; use burn_core::module::{Module, Param}; use burn_tensor::{AllocationProperty, Bytes, Tensor, backend::Backend}; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] struct SimpleModule { weight: Param>, bias: Param>, } impl SimpleModule { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[1.0f32, 2.0], [3.0, 4.0]], device), bias: Param::from_data([0.5f32, 1.5], device), } } fn new_zeros(device: &B::Device) -> Self { Self { weight: Param::from_tensor(Tensor::zeros([2, 2], device)), bias: Param::from_tensor(Tensor::zeros([2], device)), } } } /// Test that from_static creates a store with zero_copy enabled by default. #[test] fn test_from_static_enables_zero_copy() { let device = Default::default(); let module = SimpleModule::::new(&device); // Save to bytes first let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Convert to Vec and then leak to get &'static [u8] let bytes_vec: Vec = bytes.to_vec(); let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice()); // Create store from static - zero_copy should be enabled let mut load_store = BurnpackStore::from_static(static_bytes); // Load into a new module let mut loaded_module = SimpleModule::::new_zeros(&device); load_store.apply_to(&mut loaded_module).unwrap(); // Verify data is correct let loaded_weight = loaded_module.weight.val().to_data(); let loaded_bias = loaded_module.bias.val().to_data(); assert_eq!( loaded_weight.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0] ); assert_eq!(loaded_bias.to_vec::().unwrap(), vec![0.5, 1.5]); } /// Test that zero_copy builder method works. #[test] fn test_zero_copy_builder_method() { let device = Default::default(); let module = SimpleModule::::new(&device); // Save to bytes first let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Create shared bytes for zero-copy let shared = bytes::Bytes::from(bytes.to_vec()); let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other); // Create store with zero_copy enabled let mut load_store = BurnpackStore::from_bytes(Some(cubecl_bytes)).zero_copy(true); // Load into a new module let mut loaded_module = SimpleModule::::new_zeros(&device); load_store.apply_to(&mut loaded_module).unwrap(); // Verify data is correct let loaded_weight = loaded_module.weight.val().to_data(); assert_eq!( loaded_weight.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0] ); } /// Test that zero_copy(false) uses copying even with shared bytes. #[test] fn test_zero_copy_disabled_uses_copy() { let device = Default::default(); let module = SimpleModule::::new(&device); // Save to bytes first let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Convert to Vec and then leak to get &'static [u8] let bytes_vec: Vec = bytes.to_vec(); let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice()); // Create store from static but disable zero_copy let mut load_store = BurnpackStore::from_static(static_bytes).zero_copy(false); // Load into a new module let mut loaded_module = SimpleModule::::new_zeros(&device); load_store.apply_to(&mut loaded_module).unwrap(); // Verify data is correct (copied, not zero-copy) let loaded_weight = loaded_module.weight.val().to_data(); assert_eq!( loaded_weight.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0] ); } /// Test that from_bytes with regular Bytes uses copying by default. #[test] fn test_from_bytes_uses_copy_by_default() { let device = Default::default(); let module = SimpleModule::::new(&device); // Save to bytes let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Load from bytes (default: zero_copy = false) let mut load_store = BurnpackStore::from_bytes(Some(bytes)); let mut loaded_module = SimpleModule::::new_zeros(&device); load_store.apply_to(&mut loaded_module).unwrap(); // Verify data is correct let loaded_weight = loaded_module.weight.val().to_data(); assert_eq!( loaded_weight.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0] ); } /// Test that slice_bytes works correctly on StorageBackend. #[test] fn test_storage_backend_slice_bytes() { use crate::burnpack::reader::BurnpackReader; let device = Default::default(); let module = SimpleModule::::new(&device); // Save to bytes first let mut save_store = BurnpackStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); // Create shared bytes let shared = bytes::Bytes::from(bytes.to_vec()); let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other); // Create reader and get snapshots with zero-copy let reader = BurnpackReader::from_bytes(cubecl_bytes).unwrap(); let snapshots = reader.get_snapshots_zero_copy(true).unwrap(); // Verify we got the expected number of tensors assert_eq!(snapshots.len(), 2); // Load the tensor data for snapshot in &snapshots { let data = snapshot.to_data().unwrap(); // Just verify we can access the data - the actual content depends on tensor order assert!(!data.bytes.is_empty()); } } /// Test that zero_copy=true with file-based loading works (via mmap + bytes::Bytes). #[test] fn test_zero_copy_file_based_works() { use tempfile::NamedTempFile; let device = Default::default(); let module = SimpleModule::::new(&device); // Save to a temporary file let temp_file = NamedTempFile::new().unwrap(); let path = temp_file.path(); let mut save_store = BurnpackStore::from_file(path).overwrite(true); save_store.collect_from(&module).unwrap(); // Load with zero_copy=true - should work because mmap is converted to bytes::Bytes let mut load_store = BurnpackStore::from_file(path).zero_copy(true); let mut loaded_module = SimpleModule::::new_zeros(&device); // The apply should succeed - mmap now supports zero-copy via bytes::Bytes::from_owner() load_store.apply_to(&mut loaded_module).unwrap(); // Verify data is correct let loaded_weight = loaded_module.weight.val().to_data(); assert_eq!( loaded_weight.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0] ); } ================================================ FILE: crates/burn-store/src/burnpack/writer.rs ================================================ use super::base::{ BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start, }; use crate::TensorSnapshot; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; use alloc::vec::Vec; use burn_tensor::Bytes; #[cfg(feature = "std")] use std::fs::File; #[cfg(feature = "std")] use std::io::Write; #[cfg(feature = "std")] use std::path::Path; /// Align an offset to the specified alignment boundary. /// /// Returns the smallest value >= `offset` that is a multiple of `alignment`. #[inline] const fn align_offset(offset: u64, alignment: u64) -> u64 { offset.div_ceil(alignment) * alignment } /// Writer for creating Burnpack files pub struct BurnpackWriter { /// Tensors to write pub(crate) snapshots: Vec, /// Metadata key-value pairs pub(crate) metadata: BTreeMap, } impl BurnpackWriter { /// Create a new writer pub fn new(snapshots: Vec) -> Self { Self { snapshots, metadata: BTreeMap::new(), } } /// Builder pattern: add metadata and return self pub fn with_metadata(mut self, key: &str, value: &str) -> Self { self.metadata.insert(key.to_string(), value.to_string()); self } /// Build tensor descriptors and metadata fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec), BurnpackError> { // Build tensor descriptors and calculate offsets with alignment let mut tensors = BTreeMap::new(); let mut current_offset = 0u64; for snapshot in &self.snapshots { let data_len = snapshot.data_len() as u64; // Align the start offset for mmap zero-copy support let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT); let end = aligned_start.checked_add(data_len).ok_or_else(|| { BurnpackError::IoError(format!( "Tensor offset overflow: {} + {} exceeds maximum", aligned_start, data_len )) })?; tensors.insert( snapshot.full_path(), TensorDescriptor { dtype: snapshot.dtype, shape: snapshot.shape.iter().map(|&s| s as u64).collect(), data_offsets: (aligned_start, end), param_id: snapshot.tensor_id.map(|id| id.val()), }, ); current_offset = end; } // Create metadata structure let metadata = BurnpackMetadata { tensors, metadata: self.metadata.clone(), }; // Serialize metadata with CBOR let mut metadata_bytes = Vec::new(); ciborium::ser::into_writer(&metadata, &mut metadata_bytes) .map_err(|e| BurnpackError::IoError(e.to_string()))?; Ok((metadata, metadata_bytes)) } /// Calculate the total size needed for the burnpack data /// /// This is useful when you want to pre-allocate a buffer for `write_into()`. /// The size includes padding bytes for both metadata alignment and tensor alignment. pub fn size(&self) -> Result { let (metadata, metadata_bytes) = self.build_metadata()?; // Data section starts at aligned position after header + metadata let data_section_start = aligned_data_section_start(metadata_bytes.len()); // Calculate total data section size from aligned offsets // The last tensor's end offset gives us the total data section size let data_size = metadata .tensors .values() .map(|t| t.data_offsets.1) .max() .unwrap_or(0) as usize; Ok(data_section_start + data_size) } /// Write burnpack data into a caller-provided buffer /// /// The buffer must be large enough to hold all data. Use `size()` to determine /// the required buffer size. If the buffer is too small, this will return an error. /// /// This allows the caller to control buffer allocation, enabling optimizations like: /// - Buffer reuse across multiple writes /// - Custom allocators /// - Pinned memory for GPU transfers /// /// # Arguments /// /// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes. pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> { let (metadata, metadata_bytes) = self.build_metadata()?; // Check metadata size fits in u32 let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| { BurnpackError::IoError(format!( "Metadata size {} exceeds maximum of {} bytes", metadata_bytes.len(), u32::MAX )) })?; // Create header let header = BurnpackHeader { magic: MAGIC_NUMBER, version: FORMAT_VERSION, metadata_size, }; // Data section starts at aligned position after header + metadata let data_section_start = aligned_data_section_start(metadata_bytes.len()); // Calculate required size from aligned offsets let data_size = metadata .tensors .values() .map(|t| t.data_offsets.1) .max() .unwrap_or(0) as usize; let total_size = data_section_start + data_size; // Check buffer size if buffer.len() < total_size { return Err(BurnpackError::IoError(format!( "Buffer too small: need {} bytes, got {} bytes", total_size, buffer.len() ))); } let mut offset = 0; // Write header let header_bytes = header.into_bytes(); buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes); offset += HEADER_SIZE; // Write metadata buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes); offset += metadata_bytes.len(); // Write padding to align data section start if data_section_start > offset { buffer[offset..data_section_start].fill(0); offset = data_section_start; } // Write tensor data with alignment padding for snapshot in &self.snapshots { // Get the aligned offset from metadata let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| { BurnpackError::IoError(format!( "Internal error: tensor '{}' not found in metadata", snapshot.full_path() )) })?; let aligned_offset = descriptor.data_offsets.0 as usize; let target_offset = data_section_start + aligned_offset; // Write padding zeros if needed if target_offset > offset { buffer[offset..target_offset].fill(0); offset = target_offset; } let expected_len = snapshot.data_len(); let data = snapshot.to_data().map_err(|e| { BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e)) })?; let actual_len = data.bytes.len(); // Validate data length consistency if actual_len != expected_len { return Err(BurnpackError::IoError(format!( "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})", snapshot.full_path(), expected_len, actual_len ))); } buffer[offset..offset + actual_len].copy_from_slice(&data.bytes); offset += actual_len; } Ok(()) } /// Write to a byte buffer (convenience method) /// /// This allocates a buffer internally and writes the burnpack data. /// For more control over buffer allocation, use `size()` + `write_into()`. pub fn to_bytes(&self) -> Result { let size = self.size()?; let mut buffer = vec![0u8; size]; self.write_into(&mut buffer)?; Ok(Bytes::from_bytes_vec(buffer)) } /// Write directly to a file (more memory efficient for large models) #[cfg(feature = "std")] pub fn write_to_file>(&self, path: P) -> Result<(), BurnpackError> { let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?; let (metadata, metadata_bytes) = self.build_metadata()?; // Check metadata size fits in u32 let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| { BurnpackError::IoError(format!( "Metadata size {} exceeds maximum of {} bytes", metadata_bytes.len(), u32::MAX )) })?; // Create and write header let header = BurnpackHeader { magic: MAGIC_NUMBER, version: FORMAT_VERSION, metadata_size, }; file.write_all(&header.into_bytes()) .map_err(|e| BurnpackError::IoError(e.to_string()))?; // Write metadata file.write_all(&metadata_bytes) .map_err(|e| BurnpackError::IoError(e.to_string()))?; // Data section starts at aligned position after header + metadata let data_section_start = aligned_data_section_start(metadata_bytes.len()); let current_file_pos = HEADER_SIZE + metadata_bytes.len(); // Write padding to align data section start if data_section_start > current_file_pos { let padding_size = data_section_start - current_file_pos; let padding = vec![0u8; padding_size]; file.write_all(&padding) .map_err(|e| BurnpackError::IoError(e.to_string()))?; } // Track current position within data section (relative to data_section_start) let mut data_offset = 0usize; // Stream tensor data directly to file with alignment padding for snapshot in &self.snapshots { // Get the aligned offset from metadata let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| { BurnpackError::IoError(format!( "Internal error: tensor '{}' not found in metadata", snapshot.full_path() )) })?; let aligned_offset = descriptor.data_offsets.0 as usize; // Write padding zeros if needed if aligned_offset > data_offset { let padding_size = aligned_offset - data_offset; let padding = vec![0u8; padding_size]; file.write_all(&padding) .map_err(|e| BurnpackError::IoError(e.to_string()))?; data_offset = aligned_offset; } let expected_len = snapshot.data_len(); let data = snapshot.to_data().map_err(|e| { BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e)) })?; let actual_len = data.bytes.len(); // Validate data length consistency if actual_len != expected_len { return Err(BurnpackError::IoError(format!( "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})", snapshot.full_path(), expected_len, actual_len ))); } file.write_all(&data.bytes) .map_err(|e| BurnpackError::IoError(e.to_string()))?; data_offset += actual_len; } file.flush() .map_err(|e| BurnpackError::IoError(e.to_string()))?; Ok(()) } } ================================================ FILE: crates/burn-store/src/collector.rs ================================================ use alloc::boxed::Box; use alloc::string::{String, ToString}; use alloc::vec::Vec; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; use crate::{ModuleAdapter, PathFilter, TensorSnapshot}; use burn_core::module::{ModuleVisitor, Param, ParamId}; /// Collects tensor views from modules without copying data. /// /// This collector traverses a module hierarchy and creates lightweight views /// of tensors that can be materialized to `TensorData` on demand. /// /// # Examples /// /// ## Collect all tensors /// ```rust,no_run /// # use burn_store::Collector; /// let collector = Collector::new(None, None, false); /// // Use with module.visit(&mut collector); /// let all_tensors = collector.tensors; /// ``` /// /// ## Filter with single pattern /// ```rust,no_run /// # use burn_store::{Collector, PathFilter}; /// let filter = PathFilter::new().with_regex(r"^encoder\..*"); /// let collector = Collector::new(Some(filter), None, false); /// // Use with module.visit(&mut collector); /// // Only collects tensors starting with "encoder." /// ``` /// /// ## Filter with multiple patterns (OR union) /// ```rust,no_run /// # use burn_store::{Collector, PathFilter}; /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.bias$"); // OR match any bias tensors /// let collector = Collector::new(Some(filter), None, false); /// // Use with module.visit(&mut collector); /// // Collects tensors matching ANY of the patterns /// ``` pub struct Collector { /// Collection of tensor views pub tensors: Vec, path_stack: Vec, container_stack: Vec, filter: Option, adapter: Option>, /// Skip enum variant names when building paths /// When true, enum variant names are not included in tensor paths skip_enum_variants: bool, } impl Default for Collector { fn default() -> Self { Self::new(None, None, false) } } impl Collector { /// Create a new tensor view collector with an optional filter and adapter. /// /// # Arguments /// /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect. /// When `None`, all tensors are collected. /// * `adapter` - Optional adapter to transform tensors based on container types. /// Applied to all collected tensors before returning. /// * `skip_enum_variants` - Skip enum variant names when building paths. /// When true, paths will not include enum variant names (e.g., "feature.weight" /// instead of "feature.BaseConv.weight"). Useful when exporting to formats /// like PyTorch that don't use enum variants. /// /// # Examples /// /// ```rust,no_run /// # use burn_store::{Collector, PathFilter}; /// // Collect all tensors without adapter /// let collector = Collector::new(None, None, false); /// /// // Use PathFilter builder /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") /// .with_full_path("decoder.weight"); /// let collector = Collector::new(Some(filter), None, false); /// /// // Skip enum variants for PyTorch export /// let collector = Collector::new(None, None, true); /// ``` pub fn new( filter: Option, adapter: Option>, skip_enum_variants: bool, ) -> Self { Self { tensors: Vec::new(), path_stack: Vec::new(), container_stack: Vec::new(), filter, adapter, skip_enum_variants, } } /// Apply the adapter to collected tensors and return the result. pub fn into_tensors(self) -> Vec { if let Some(adapter) = self.adapter { self.tensors .into_iter() .map(|snapshot| adapter.adapt(&snapshot)) .collect() } else { self.tensors } } fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool { // If filter is present, use it; otherwise collect all match &self.filter { None => true, Some(f) => f.matches_with_container_path(path, container_stack), } } } impl ModuleVisitor for Collector { fn enter_module(&mut self, name: &str, container_type: &str) { // Always track the container type for proper filtering and module type detection self.container_stack.push(container_type.to_string()); // Only add to path if it's not an enum variant (when skip_enum_variants is enabled) // This ensures paths are built without enum variant names from the start if !self.skip_enum_variants || !container_type.starts_with("Enum:") { self.path_stack.push(name.to_string()); } } fn exit_module(&mut self, _name: &str, container_type: &str) { self.container_stack.pop(); // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled) if !self.skip_enum_variants || !container_type.starts_with("Enum:") { self.path_stack.pop(); } } fn visit_float(&mut self, param: &Param>) { if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_float( ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), param.id, )); } } fn visit_int(&mut self, param: &Param>) { if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_int( ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), param.id, )); } } fn visit_bool(&mut self, param: &Param>) { if self.should_collect(&self.path_stack, &self.container_stack) { self.tensors.push(TensorSnapshot::from_bool( ¶m.transform_for_save().val(), self.path_stack.clone(), self.container_stack.clone(), param.id, )); } } fn visit_float_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { // For path-based visits, we use the current container stack for filtering if self.should_collect(path, &self.container_stack) { self.tensors.push(TensorSnapshot::from_float( tensor, path.to_vec(), self.container_stack.clone(), id, )); } } fn visit_int_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { if self.should_collect(path, &self.container_stack) { self.tensors.push(TensorSnapshot::from_int( tensor, path.to_vec(), self.container_stack.clone(), id, )); } } fn visit_bool_with_path( &mut self, path: &[String], id: ParamId, tensor: &Tensor, ) { if self.should_collect(path, &self.container_stack) { self.tensors.push(TensorSnapshot::from_bool( tensor, path.to_vec(), self.container_stack.clone(), id, )); } } } #[cfg(all(test, feature = "std"))] mod tests { use super::*; use burn_core as burn; type TestBackend = burn_ndarray::NdArray; use alloc::collections::BTreeMap; use alloc::string::String; use burn_core::module::{Module, Param}; use burn_nn::LinearConfig; use burn_tensor::shape; #[test] fn tensor_snapshot_collector() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let mut collector = Collector::new(None, None, false); let id = ParamId::new(); // Collect a tensor collector.visit_float_with_path(&["model".to_string(), "weight".to_string()], id, &tensor); assert_eq!(collector.tensors.len(), 1); assert_eq!(collector.tensors[0].full_path(), "model.weight"); // Verify the tensor can be converted to data let view = &collector.tensors[0]; let data = view.to_data().unwrap(); assert_eq!(data.shape, shape![2, 2]); } #[test] fn root_level_parameters() { use burn_core::module::ModuleVisitor; let device = Default::default(); // Create root-level parameters (single-element path, not nested in modules) let weight = Param::>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let bias = Param::>::from_data([5.0, 6.0], &device); let mut collector = Collector::new(None, None, false); // Simulate module traversal for root-level parameters // Enter "weight" path (as if we're visiting a field named "weight") ModuleVisitor::::enter_module(&mut collector, "weight", ""); ModuleVisitor::::visit_float(&mut collector, &weight); ModuleVisitor::::exit_module(&mut collector, "weight", ""); // Enter "bias" path (as if we're visiting a field named "bias") ModuleVisitor::::enter_module(&mut collector, "bias", ""); ModuleVisitor::::visit_float(&mut collector, &bias); ModuleVisitor::::exit_module(&mut collector, "bias", ""); // Verify both parameters were collected assert_eq!(collector.tensors.len(), 2); // Verify paths are correct (single-element paths) assert_eq!(collector.tensors[0].full_path(), "weight"); assert_eq!(collector.tensors[1].full_path(), "bias"); // Verify data is correct let weight_data = collector.tensors[0] .to_data() .unwrap() .to_vec::() .unwrap(); let bias_data = collector.tensors[1] .to_data() .unwrap() .to_vec::() .unwrap(); assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]); assert_eq!(bias_data, vec![5.0, 6.0]); } #[test] #[cfg(target_has_atomic = "ptr")] fn tensor_snapshot_collector_with_filter() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let filter = PathFilter::new().with_regex(r"^encoder\..*"); let mut collector = Collector::new(Some(filter), None, false); let id = ParamId::new(); // This should be collected collector.visit_float_with_path( &["encoder".to_string(), "weight".to_string()], id, &tensor, ); // This should NOT be collected collector.visit_float_with_path( &["decoder".to_string(), "weight".to_string()], id, &tensor, ); assert_eq!(collector.tensors.len(), 1); assert_eq!(collector.tensors[0].full_path(), "encoder.weight"); } #[test] #[cfg(target_has_atomic = "ptr")] fn tensor_snapshot_collector_with_multiple_filters() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); // Multiple patterns - collect if matches ANY (OR union) let filter = PathFilter::new() .with_regex(r"^encoder\..*") // Match encoder.* .with_regex(r".*\.bias$"); // Match *.bias let mut collector = Collector::new(Some(filter), None, false); let id = ParamId::new(); // These should be collected collector.visit_float_with_path( &["encoder".to_string(), "weight".to_string()], id, &tensor, ); // matches first pattern collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); // matches second pattern collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); // matches both patterns // This should NOT be collected collector.visit_float_with_path( &["decoder".to_string(), "weight".to_string()], id, &tensor, ); // matches neither assert_eq!(collector.tensors.len(), 3); let paths: Vec = collector.tensors.iter().map(|v| v.full_path()).collect(); assert!(paths.contains(&"encoder.weight".to_string())); assert!(paths.contains(&"decoder.bias".to_string())); assert!(paths.contains(&"encoder.bias".to_string())); assert!(!paths.contains(&"decoder.weight".to_string())); } #[test] fn tensor_snapshot_collector_with_predicate() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); // Use predicate function for filtering fn filter_fn(path: &str, _container_path: &str) -> bool { path.starts_with("encoder.") || path == "decoder.bias" } let filter = PathFilter::new().with_predicate(filter_fn); let mut collector = Collector::new(Some(filter), None, false); let id = ParamId::new(); // These should be collected collector.visit_float_with_path( &["encoder".to_string(), "weight".to_string()], id, &tensor, ); collector.visit_float_with_path(&["encoder".to_string(), "bias".to_string()], id, &tensor); collector.visit_float_with_path(&["decoder".to_string(), "bias".to_string()], id, &tensor); // This should NOT be collected collector.visit_float_with_path( &["decoder".to_string(), "weight".to_string()], id, &tensor, ); collector.visit_float_with_path(&["other".to_string(), "tensor".to_string()], id, &tensor); assert_eq!(collector.tensors.len(), 3); let paths: Vec = collector.tensors.iter().map(|v| v.full_path()).collect(); assert!(paths.contains(&"encoder.weight".to_string())); assert!(paths.contains(&"encoder.bias".to_string())); assert!(paths.contains(&"decoder.bias".to_string())); assert!(!paths.contains(&"decoder.weight".to_string())); assert!(!paths.contains(&"other.tensor".to_string())); } #[test] fn tensor_snapshot_collector_predicate_with_complex_logic() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); // Complex predicate with multiple conditions fn complex_filter(path: &str, _container_path: &str) -> bool { let parts: Vec<&str> = path.split('.').collect(); if parts.len() != 3 { return false; } // Only collect if it's layer1 or layer2, and it's a weight tensor (parts[1] == "layer1" || parts[1] == "layer2") && parts[2] == "weight" } let filter = PathFilter::new().with_predicate(complex_filter); let mut collector = Collector::new(Some(filter), None, false); let id = ParamId::new(); // These should be collected collector.visit_float_with_path( &[ "model".to_string(), "layer1".to_string(), "weight".to_string(), ], id, &tensor, ); collector.visit_float_with_path( &[ "model".to_string(), "layer2".to_string(), "weight".to_string(), ], id, &tensor, ); // These should NOT be collected collector.visit_float_with_path( &[ "model".to_string(), "layer1".to_string(), "bias".to_string(), ], id, &tensor, ); collector.visit_float_with_path( &[ "model".to_string(), "layer3".to_string(), "weight".to_string(), ], id, &tensor, ); collector.visit_float_with_path( &["encoder".to_string(), "weight".to_string()], id, &tensor, ); // wrong structure assert_eq!(collector.tensors.len(), 2); let paths: Vec = collector.tensors.iter().map(|v| v.full_path()).collect(); assert!(paths.contains(&"model.layer1.weight".to_string())); assert!(paths.contains(&"model.layer2.weight".to_string())); assert!(!paths.contains(&"model.layer1.bias".to_string())); assert!(!paths.contains(&"model.layer3.weight".to_string())); assert!(!paths.contains(&"encoder.weight".to_string())); } // Test visitor that collects tensor paths struct TensorPathCollector { pub paths: BTreeMap)>, path_stack: Vec, } impl TensorPathCollector { fn new() -> Self { Self { paths: BTreeMap::new(), path_stack: Vec::new(), } } fn current_path(&self) -> String { self.path_stack.join(".") } } impl ModuleVisitor for TensorPathCollector { fn enter_module(&mut self, name: &str, _container_type: &str) { self.path_stack.push(name.to_string()); } fn exit_module(&mut self, _name: &str, _container_type: &str) { self.path_stack.pop(); } fn visit_float(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { self.paths.insert( path, (param.id, param.transform_for_save().val().shape().to_vec()), ); } } fn visit_int(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { self.paths.insert( path, (param.id, param.transform_for_save().val().shape().to_vec()), ); } } fn visit_bool(&mut self, param: &Param>) { let path = self.current_path(); if !path.is_empty() { self.paths.insert( path, (param.id, param.transform_for_save().val().shape().to_vec()), ); } } } // Simple nested module for testing #[derive(Module, Debug)] struct InnerModule { weight: Param>, bias: Param>, } #[derive(Module, Debug)] struct OuterModule { layer1: InnerModule, layer2: InnerModule, } impl InnerModule { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), bias: Param::from_data([5.0, 6.0], device), } } } impl OuterModule { fn new(device: &B::Device) -> Self { Self { layer1: InnerModule::new(device), layer2: InnerModule::new(device), } } } #[test] fn nested_module_path_tracking() { let device = Default::default(); let module = OuterModule::::new(&device); let mut collector = TensorPathCollector::new(); module.visit(&mut collector); let paths = collector.paths; // Verify we have the expected paths // Note: Param fields are themselves modules, so we get an extra level assert!(paths.contains_key("layer1.weight"), "Missing layer1.weight"); assert!(paths.contains_key("layer1.bias"), "Missing layer1.bias"); assert!(paths.contains_key("layer2.weight"), "Missing layer2.weight"); assert!(paths.contains_key("layer2.bias"), "Missing layer2.bias"); // Verify the shapes are correct assert_eq!(paths.get("layer1.weight").unwrap().1, vec![2, 2]); assert_eq!(paths.get("layer1.bias").unwrap().1, vec![2]); assert_eq!(paths.get("layer2.weight").unwrap().1, vec![2, 2]); assert_eq!(paths.get("layer2.bias").unwrap().1, vec![2]); } #[test] fn linear_module_paths() { let device = Default::default(); let config = LinearConfig::new(10, 20).with_bias(true); let linear = config.init::(&device); let mut collector = TensorPathCollector::new(); linear.visit(&mut collector); let paths = collector.paths; // Linear module has weight and optional bias assert!(paths.contains_key("weight")); assert!(paths.contains_key("bias")); // Check dimensions assert_eq!(paths.get("weight").unwrap().1, vec![10, 20]); assert_eq!(paths.get("bias").unwrap().1, vec![20]); } // Deep nesting test structures (4+ levels) #[derive(Module, Debug)] struct Level4Module { weight: Param>, bias: Param>, } #[derive(Module, Debug)] struct Level3Module { layer: Level4Module, extra: Level4Module, } #[derive(Module, Debug)] struct Level2Module { block1: Level3Module, block2: Level3Module, } #[derive(Module, Debug)] struct Level1Module { encoder: Level2Module, decoder: Level2Module, } #[derive(Module, Debug)] struct DeepModel { backbone: Level1Module, head: Level4Module, } impl Level4Module { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), bias: Param::from_data([5.0, 6.0], device), } } } impl Level3Module { fn new(device: &B::Device) -> Self { Self { layer: Level4Module::new(device), extra: Level4Module::new(device), } } } impl Level2Module { fn new(device: &B::Device) -> Self { Self { block1: Level3Module::new(device), block2: Level3Module::new(device), } } } impl Level1Module { fn new(device: &B::Device) -> Self { Self { encoder: Level2Module::new(device), decoder: Level2Module::new(device), } } } impl DeepModel { fn new(device: &B::Device) -> Self { Self { backbone: Level1Module::new(device), head: Level4Module::new(device), } } } #[test] fn deep_module_path_tracking() { let device = Default::default(); let model = DeepModel::::new(&device); let mut collector = Collector::new(None, None, false); model.visit(&mut collector); let views = collector.tensors; let paths: Vec = views.iter().map(|v| v.full_path()).collect(); // Test 5-level deep paths assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string())); assert!(paths.contains(&"backbone.encoder.block1.layer.bias".to_string())); assert!(paths.contains(&"backbone.encoder.block1.extra.weight".to_string())); assert!(paths.contains(&"backbone.encoder.block1.extra.bias".to_string())); assert!(paths.contains(&"backbone.encoder.block2.layer.weight".to_string())); assert!(paths.contains(&"backbone.encoder.block2.layer.bias".to_string())); assert!(paths.contains(&"backbone.encoder.block2.extra.weight".to_string())); assert!(paths.contains(&"backbone.encoder.block2.extra.bias".to_string())); assert!(paths.contains(&"backbone.decoder.block1.layer.weight".to_string())); assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string())); assert!(paths.contains(&"backbone.decoder.block1.extra.weight".to_string())); assert!(paths.contains(&"backbone.decoder.block1.extra.bias".to_string())); assert!(paths.contains(&"backbone.decoder.block2.layer.weight".to_string())); assert!(paths.contains(&"backbone.decoder.block2.layer.bias".to_string())); assert!(paths.contains(&"backbone.decoder.block2.extra.weight".to_string())); assert!(paths.contains(&"backbone.decoder.block2.extra.bias".to_string())); // Test 2-level paths assert!(paths.contains(&"head.weight".to_string())); assert!(paths.contains(&"head.bias".to_string())); // Total should be 18 tensors (16 from backbone + 2 from head) assert_eq!(views.len(), 18); // Verify data can be materialized let view = views .iter() .find(|v| v.full_path() == "backbone.encoder.block1.layer.weight") .unwrap(); let data = view.to_data().unwrap(); assert_eq!(data.shape, shape![2, 2]); } #[test] fn deep_module_filtered_export() { let device = Default::default(); let model = DeepModel::::new(&device); // Test filtering at different depths #[cfg(target_has_atomic = "ptr")] { let filter = PathFilter::new().with_regex(r"^backbone\.encoder\..*"); let mut collector = Collector::new(Some(filter), None, false); model.visit(&mut collector); assert_eq!(collector.tensors.len(), 8); // Only encoder tensors } // Test filtering specific blocks #[cfg(target_has_atomic = "ptr")] { let filter = PathFilter::new().with_regex(r".*\.block1\..*"); let mut collector = Collector::new(Some(filter), None, false); model.visit(&mut collector); assert_eq!(collector.tensors.len(), 8); // block1 in both encoder and decoder } // Test filtering by tensor type at any depth #[cfg(target_has_atomic = "ptr")] { let filter = PathFilter::new().with_regex(r".*\.weight$"); let mut collector = Collector::new(Some(filter), None, false); model.visit(&mut collector); assert_eq!(collector.tensors.len(), 9); // All weight tensors } // Test complex multi-pattern filtering #[cfg(target_has_atomic = "ptr")] { let filter = PathFilter::new() .with_regex(r"^backbone\.encoder\.block1\..*") // All encoder.block1 tensors .with_regex(r"^backbone\.decoder\..*\.bias$") // All decoder biases .with_regex(r"^head\.weight$"); // Head weight only let mut collector = Collector::new(Some(filter), None, false); model.visit(&mut collector); // Should have: // - 4 from encoder.block1 (2 weights + 2 biases) // - 4 decoder biases // - 1 head weight assert_eq!(collector.tensors.len(), 9); let paths: Vec = collector.tensors.iter().map(|v| v.full_path()).collect(); assert!(paths.contains(&"backbone.encoder.block1.layer.weight".to_string())); assert!(paths.contains(&"backbone.decoder.block1.layer.bias".to_string())); assert!(paths.contains(&"head.weight".to_string())); assert!(!paths.contains(&"head.bias".to_string())); // Not included } } use crate::traits::ModuleSnapshot; use burn_nn::Linear; use hashbrown::HashMap; // Test module with Option fields #[derive(Module, Debug)] struct OptionalFieldModule { required: Param>, optional: Option>>, } impl OptionalFieldModule { fn new_with_optional(device: &B::Device) -> Self { Self { required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), optional: Some(Param::from_data([5.0, 6.0], device)), } } fn new_without_optional(device: &B::Device) -> Self { Self { required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), optional: None, } } } #[test] fn optional_field_module_with_value() { let device = Default::default(); let module = OptionalFieldModule::::new_with_optional(&device); let views: HashMap = module .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); assert_eq!(views.len(), 2); assert!(views.contains_key("required")); assert!(views.contains_key("optional")); } #[test] fn optional_field_module_without_value() { let device = Default::default(); let module = OptionalFieldModule::::new_without_optional(&device); let views: HashMap = module .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); assert_eq!(views.len(), 1); assert!(views.contains_key("required")); assert!(!views.contains_key("optional")); } // Test Vec of modules #[derive(Module, Debug)] struct VecModule { layers: Vec>, } impl VecModule { fn new(device: &B::Device, num_layers: usize) -> Self { Self { layers: (0..num_layers) .map(|_| LinearConfig::new(10, 10).init(device)) .collect(), } } } // Test tuple of modules #[derive(Module, Debug)] struct TupleModule { layers: (Linear, Linear, Linear), } impl TupleModule { fn new(device: &B::Device) -> Self { Self { layers: ( LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), ), } } } #[test] fn vec_module_collect() { let device = Default::default(); let module = VecModule::::new(&device, 3); let views: HashMap = module .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); // With the fix, all Vec items should now be properly indexed and visited assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors // Check that all indexed paths exist assert!(views.contains_key("layers.0.weight")); assert!(views.contains_key("layers.0.bias")); assert!(views.contains_key("layers.1.weight")); assert!(views.contains_key("layers.1.bias")); assert!(views.contains_key("layers.2.weight")); assert!(views.contains_key("layers.2.bias")); } #[test] fn tuple_module_collect() { let device = Default::default(); let module = TupleModule::::new(&device); let snapshots = module.collect(None, None, false); assert_eq!(snapshots.len(), 6); let views: HashMap = snapshots.into_iter().map(|v| (v.full_path(), v)).collect(); assert_eq!(views.len(), 6); assert!(views.contains_key("layers.0.weight")); assert!(views.contains_key("layers.0.bias")); assert!(views.contains_key("layers.1.weight")); assert!(views.contains_key("layers.1.bias")); assert!(views.contains_key("layers.2.weight")); assert!(views.contains_key("layers.2.bias")); } // Test array of modules #[derive(Module, Debug)] struct ArrayModule { layers: [Linear; 3], } impl ArrayModule { fn new(device: &B::Device) -> Self { Self { layers: [ LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), ], } } } #[test] fn array_module_collect() { let device = Default::default(); let module = ArrayModule::::new(&device); let views: HashMap = module .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); // All array items should be properly indexed assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors // Check indexed paths for i in 0..3 { assert!(views.contains_key(&format!("layers.{}.weight", i))); assert!(views.contains_key(&format!("layers.{}.bias", i))); } } // Test enum modules #[derive(Module, Debug)] enum EnumModule { LayerA(Linear), LayerB(Linear), LayerC(Linear), } #[test] fn enum_module_collect() { let device = Default::default(); // Test variant A let module_a = EnumModule::::LayerA(LinearConfig::new(10, 20).init(&device)); let views_a: HashMap = module_a .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); // Should have the variant name in the path assert_eq!(views_a.len(), 2); assert!(views_a.contains_key("LayerA.weight")); assert!(views_a.contains_key("LayerA.bias")); // Test variant B let module_b = EnumModule::::LayerB(LinearConfig::new(10, 20).init(&device)); let views_b: HashMap = module_b .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); assert_eq!(views_b.len(), 2); assert!(views_b.contains_key("LayerB.weight")); assert!(views_b.contains_key("LayerB.bias")); } // Container type tracking tests #[test] fn linear_container_type() { let device = Default::default(); #[derive(Module, Debug)] struct ModelWithLinear { linear: Linear, } impl ModelWithLinear { fn new(device: &B::Device) -> Self { Self { linear: LinearConfig::new(10, 20).init(device), } } } let model = ModelWithLinear::::new(&device); let views: HashMap = model .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); // Check that tensors inside Linear layers have "Struct:Linear" as their module type for (path, view) in views.iter() { if path == "linear.weight" || path == "linear.bias" { assert_eq!( view.module_type(), Some("Struct:Linear".to_string()), "Tensor '{}' should have module type 'Struct:Linear'", path ); } } } #[test] fn complex_model_container_types() { let device = Default::default(); #[derive(Module, Debug)] struct ComplexModel { linear_layers: [Linear; 2], vec_layers: Vec>, single_linear: Linear, } impl ComplexModel { fn new(device: &B::Device) -> Self { Self { linear_layers: [ LinearConfig::new(100, 50).init(device), LinearConfig::new(50, 10).init(device), ], vec_layers: vec![ LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), ], single_linear: LinearConfig::new(10, 1).init(device), } } } let model = ComplexModel::::new(&device); let views: HashMap = model .collect(None, None, false) .into_iter() .map(|v| (v.full_path(), v)) .collect(); // Should have 10 tensors total assert_eq!(views.len(), 10); // Verify different module types for (_path, view) in views.iter() { assert_eq!(view.module_type(), Some("Struct:Linear".to_string())); } } #[test] fn collect_with_container_filter() { let device = Default::default(); #[derive(Module, Debug)] struct FilterTestModel { layers: Vec>, } impl FilterTestModel { fn new(device: &B::Device) -> Self { Self { layers: vec![ LinearConfig::new(10, 10).init(device), LinearConfig::new(10, 10).init(device), ], } } } let model = FilterTestModel::::new(&device); // Filter to only collect tensors from Linear modules let filter = PathFilter::new().with_predicate(|_path, container_path| { container_path.split('.').next_back() == Some("Struct:Linear") }); let linear_views: Vec = model.collect(Some(filter), None, false); // All collected tensors should be from Linear modules for view in linear_views.iter() { assert_eq!( view.module_type(), Some("Struct:Linear".to_string()), "All tensors should be from Linear modules" ); } // Should have collected all Linear tensors assert_eq!(linear_views.len(), 4); } } ================================================ FILE: crates/burn-store/src/filter.rs ================================================ use alloc::format; use alloc::string::String; use alloc::vec::Vec; use core::fmt; #[cfg(feature = "std")] use regex::Regex; /// A sophisticated path filter that supports multiple matching strategies. /// /// The filter uses an OR logic - a path is included if it matches ANY of the configured criteria. /// This allows for flexible and powerful filtering configurations. /// /// # Examples /// /// ```rust,no_run /// # use burn_store::PathFilter; /// // Create a filter that matches encoder paths or any weight path /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") /// .with_regex(r".*\.weight$") /// .with_full_path("special_tensor"); /// /// // Check if a path should be included /// if filter.matches("encoder.layer1.weight") { /// // This will match due to both regex patterns /// } /// ``` #[derive(Debug, Clone, Default)] pub struct PathFilter { /// Compiled regex patterns for matching paths #[cfg(feature = "std")] regex_patterns: Vec, /// Exact full paths to match exact_paths: Vec, /// Predicate functions for custom matching logic based on path and container path /// Note: These cannot be cloned, so we store them separately predicates: Vec bool>, /// If true, matches all paths (overrides other filters) match_all: bool, } impl PathFilter { /// Create a new empty filter (matches nothing by default) pub fn new() -> Self { Self::default() } /// Create a filter that matches all paths pub fn all() -> Self { Self { match_all: true, ..Default::default() } } /// Create a filter that matches nothing pub fn none() -> Self { Self::default() } /// Add a regex pattern for matching paths #[cfg(feature = "std")] pub fn with_regex>(mut self, pattern: S) -> Self { if let Ok(regex) = Regex::new(pattern.as_ref()) { self.regex_patterns.push(regex); } // TODO: Consider returning Result to handle regex compilation errors self } /// Add multiple regex patterns #[cfg(feature = "std")] pub fn with_regexes(mut self, patterns: I) -> Self where I: IntoIterator, S: AsRef, { for pattern in patterns { if let Ok(regex) = Regex::new(pattern.as_ref()) { self.regex_patterns.push(regex); } } self } /// Add an exact full path to match pub fn with_full_path>(mut self, path: S) -> Self { self.exact_paths.push(path.into()); self } /// Add multiple exact full paths pub fn with_full_paths(mut self, paths: I) -> Self where I: IntoIterator, S: Into, { self.exact_paths.extend(paths.into_iter().map(|p| p.into())); self } /// Add a predicate function for custom matching based on path and container path pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self { self.predicates.push(predicate); self } /// Add multiple predicates pub fn with_predicates(mut self, predicates: I) -> Self where I: IntoIterator bool>, { self.predicates.extend(predicates); self } /// Set to match all paths pub fn match_all(mut self) -> Self { self.match_all = true; self } /// Check if a path matches this filter (assumes empty container path for backward compatibility) pub fn matches(&self, path: &str) -> bool { self.matches_with_container_path_str(path, "") } /// Check if a path and container type match this filter (for backward compatibility) pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool { // For backward compatibility, treat single container type as the full path self.matches_with_container_path_str(path, container_type) } /// Check if a path and container path match this filter pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool { let path_str = path.join("."); let container_path = container_stack.join("."); self.matches_with_container_path_str(&path_str, &container_path) } /// Check if a path and container path (dot-notated strings) match this filter pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool { // If match_all is set, always return true if self.match_all { return true; } // If no filters are configured, match nothing if self.is_empty() { return false; } // Check exact path matches if self.exact_paths.iter().any(|p| p == path) { return true; } // Check regex patterns (on the path) #[cfg(feature = "std")] { for regex in &self.regex_patterns { if regex.is_match(path) { return true; } } } // Check predicates with container path if self .predicates .iter() .any(|pred| pred(path, container_path)) { return true; } false } /// Check if the filter is empty (matches nothing) pub fn is_empty(&self) -> bool { if self.match_all { return false; } #[cfg(feature = "std")] let regex_empty = self.regex_patterns.is_empty(); #[cfg(not(feature = "std"))] let regex_empty = true; self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty } /// Get the number of filter criteria configured pub fn criteria_count(&self) -> usize { if self.match_all { return 1; } #[allow(unused_mut)] let mut count = self.exact_paths.len() + self.predicates.len(); #[cfg(feature = "std")] { count += self.regex_patterns.len(); } count } /// Clear all regex patterns #[cfg(feature = "std")] pub fn clear_regex(&mut self) -> &mut Self { self.regex_patterns.clear(); self } /// Clear all exact paths pub fn clear_paths(&mut self) -> &mut Self { self.exact_paths.clear(); self } /// Clear all predicates pub fn clear_predicates(&mut self) -> &mut Self { self.predicates.clear(); self } /// Clear all filters pub fn clear(&mut self) -> &mut Self { #[cfg(feature = "std")] self.clear_regex(); self.clear_paths().clear_predicates(); self.match_all = false; self } /// Create a filter from regex patterns only #[cfg(feature = "std")] pub fn from_regex_patterns(patterns: I) -> Self where I: IntoIterator, S: AsRef, { Self::new().with_regexes(patterns) } /// Create a filter from exact paths only pub fn from_paths(paths: I) -> Self where I: IntoIterator, S: Into, { Self::new().with_full_paths(paths) } /// Create a filter from a single predicate pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self { Self::new().with_predicate(predicate) } /// Combine with another filter using OR logic pub fn or(mut self, other: Self) -> Self { if self.match_all || other.match_all { return Self::all(); } #[cfg(feature = "std")] { self.regex_patterns.extend(other.regex_patterns); } self.exact_paths.extend(other.exact_paths); self.predicates.extend(other.predicates); self } } impl fmt::Display for PathFilter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.match_all { return write!(f, "PathFilter::all()"); } if self.is_empty() { return write!(f, "PathFilter::none()"); } write!(f, "PathFilter[")?; let mut parts = Vec::new(); #[cfg(feature = "std")] if !self.regex_patterns.is_empty() { parts.push(format!("regex: {:?}", self.regex_patterns)); } if !self.exact_paths.is_empty() { parts.push(format!("paths: {:?}", self.exact_paths)); } if !self.predicates.is_empty() { parts.push(format!("predicates: {}", self.predicates.len())); } write!(f, "{}]", parts.join(", ")) } } #[cfg(test)] mod tests { use super::*; #[test] fn empty_filter() { let filter = PathFilter::new(); assert!(filter.is_empty()); assert!(!filter.matches("encoder.weight")); assert!(!filter.matches("decoder.bias")); } #[test] fn match_all() { let filter = PathFilter::all(); assert!(!filter.is_empty()); assert!(filter.matches("encoder.weight")); assert!(filter.matches("decoder.bias")); assert!(filter.matches("anything")); } #[test] fn exact_paths() { let filter = PathFilter::new() .with_full_path("encoder.weight") .with_full_path("decoder.bias"); assert!(filter.matches("encoder.weight")); assert!(filter.matches("decoder.bias")); assert!(!filter.matches("encoder.bias")); assert!(!filter.matches("decoder.weight")); } #[test] #[cfg(feature = "std")] fn regex_patterns() { let filter = PathFilter::new() .with_regex(r"^encoder\..*") .with_regex(r".*\.weight$"); assert!(filter.matches("encoder.layer1.bias")); assert!(filter.matches("decoder.weight")); assert!(filter.matches("encoder.weight")); assert!(!filter.matches("decoder.bias")); } #[test] fn predicates() { fn contains_norm(path: &str, _container_path: &str) -> bool { path.contains("norm") } fn is_short(path: &str, _container_path: &str) -> bool { path.len() < 10 } let filter = PathFilter::new() .with_predicate(contains_norm) .with_predicate(is_short); assert!(filter.matches("norm.weight")); assert!(filter.matches("layer.norm.bias")); assert!(filter.matches("bias")); assert!(!filter.matches("encoder.decoder.weight.long.name")); } #[test] fn combined_filters() { let filter = PathFilter::new() .with_full_path("special.tensor") .with_predicate(|path, _container_path| path.contains("attention")); #[cfg(feature = "std")] let filter = filter.with_regex(r"^encoder\..*"); assert!(filter.matches("special.tensor")); assert!(filter.matches("self_attention.query")); #[cfg(feature = "std")] assert!(filter.matches("encoder.anything")); assert!(!filter.matches("decoder.weight")); } #[test] fn or_combination() { let encoder_filter = PathFilter::new().with_full_path("encoder.weight"); let decoder_filter = PathFilter::new().with_full_path("decoder.bias"); let combined = encoder_filter.or(decoder_filter); assert!(combined.matches("encoder.weight")); assert!(combined.matches("decoder.bias")); assert!(!combined.matches("model.head.weight")); } #[test] #[cfg(feature = "std")] fn common_patterns() { // Test encoder pattern let encoder = PathFilter::new().with_regex(r"^encoder\..*"); assert!(encoder.matches("encoder.weight")); assert!(!encoder.matches("decoder.weight")); // Test weights-only pattern let weights = PathFilter::new().with_regex(r".*\.weight$"); assert!(weights.matches("encoder.weight")); assert!(weights.matches("decoder.weight")); assert!(!weights.matches("encoder.bias")); // Test layer-specific patterns let layers = PathFilter::new() .with_regex(r"(^|.*\.)layers\.0\.") .with_regex(r"(^|.*\.)layers\.2\.") .with_regex(r"(^|.*\.)layers\.4\."); assert!(layers.matches("model.layers.0.weight")); assert!(layers.matches("layers.2.bias")); assert!(!layers.matches("layers.1.weight")); } #[test] fn criteria_count() { let filter = PathFilter::new() .with_full_path("path1") .with_full_path("path2") .with_predicate(|_, _| true); #[cfg(feature = "std")] let filter = filter.with_regex(".*"); #[cfg(feature = "std")] assert_eq!(filter.criteria_count(), 4); #[cfg(not(feature = "std"))] assert_eq!(filter.criteria_count(), 3); } #[test] fn clear_operations() { let mut filter = PathFilter::new().with_full_path("test"); filter.clear_paths(); assert!(!filter.matches("test")); filter.clear(); assert!(filter.is_empty()); } #[test] fn container_predicates() { // Filter that matches only Linear module weights let linear_weights = PathFilter::new().with_predicate(|path, container_path| { container_path.split('.').next_back() == Some("Linear") && path.ends_with(".weight") }); assert!(linear_weights.matches_with_container("layer1.weight", "Linear")); assert!(!linear_weights.matches_with_container("layer1.weight", "Conv2d")); assert!(!linear_weights.matches_with_container("layer1.bias", "Linear")); // Filter for specific container types let conv_only = PathFilter::new().with_predicate(|_path, container_path| { let last = container_path.split('.').next_back(); last == Some("Conv2d") || last == Some("ConvTranspose2d") }); assert!(conv_only.matches_with_container("encoder.weight", "Conv2d")); assert!(conv_only.matches_with_container("decoder.weight", "ConvTranspose2d")); assert!(!conv_only.matches_with_container("fc.weight", "Linear")); // Combine path and container predicates let combined = PathFilter::new() .with_predicate(|path, _container_path| path.starts_with("encoder.")) .with_predicate(|_path, container_path| { container_path.split('.').next_back() == Some("BatchNorm2d") }); // Should match either condition (OR logic) assert!(combined.matches_with_container("encoder.layer1", "Linear")); assert!(combined.matches_with_container("decoder.bn", "BatchNorm2d")); assert!(!combined.matches_with_container("decoder.layer", "Linear")); } #[test] fn container_predicate_with_regex() { // Combine regex patterns with container predicates #[cfg(feature = "std")] { let filter = PathFilter::new() .with_regex(r"^encoder\..*") .with_predicate(|path, container_path| { container_path.split('.').next_back() == Some("Linear") && path.contains(".bias") }); // Matches due to regex assert!(filter.matches_with_container("encoder.layer1.weight", "Conv2d")); // Matches due to container predicate assert!(filter.matches_with_container("decoder.fc.bias", "Linear")); // Doesn't match either assert!(!filter.matches_with_container("decoder.conv.weight", "Conv2d")); } } #[test] fn container_stack_predicates() { // Filter using full container path - only tensors nested in a specific hierarchy let nested_filter = PathFilter::new().with_predicate(|_path, container_path| { // Check if tensor is nested within: Model -> TransformerBlock -> Linear let parts: Vec<&str> = container_path.split('.').collect(); parts.len() >= 3 && parts[0] == "Model" && parts[1] == "TransformerBlock" && parts[2] == "Linear" }); assert!(nested_filter.matches_with_container_path_str( "encoder.weight", "Model.TransformerBlock.Linear.Param" )); assert!( !nested_filter .matches_with_container_path_str("decoder.weight", "Model.Decoder.Linear.Param") ); assert!(!nested_filter.matches_with_container_path_str( "encoder.weight", "Model.TransformerBlock.Conv2d.Param" )); // Filter that checks for specific depth in hierarchy let depth_filter = PathFilter::new().with_predicate(|_path, container_path| { let parts: Vec<&str> = container_path.split('.').collect(); parts.len() == 4 && parts.get(2) == Some(&"Linear") }); assert!(depth_filter.matches_with_container_path_str( "model.layer.weight", "Model.TransformerBlock.Linear.Param" )); assert!( !depth_filter .matches_with_container_path_str("model.weight", "Model.TransformerBlock.Conv2d") ); // Too shallow // Filter that checks any Linear in the path (not just the last) let any_linear = PathFilter::new() .with_predicate(|_path, container_path| container_path.contains("Linear")); assert!( any_linear.matches_with_container_path_str( "some.path", "Model.TransformerBlock.Linear.Param" ) ); assert!( any_linear.matches_with_container_path_str("other.path", "Model.Decoder.Linear.Param") ); assert!( !any_linear.matches_with_container_path_str( "conv.path", "Model.TransformerBlock.Conv2d.Param" ) ); } #[test] fn container_path_dot_notation() { // Filter using dot-notated container path let dot_filter = PathFilter::new().with_predicate(|_path, container_path| { container_path.starts_with("Model.TransformerBlock") }); // Test with matches_with_container_path assert!( dot_filter.matches_with_container_path_str("weight", "Model.TransformerBlock.Linear") ); assert!(!dot_filter.matches_with_container_path_str("weight", "Model.Decoder.Linear")); // Filter that checks for specific patterns in container path let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| { // Match any path that has Linear after a block container_path.contains("Block.Linear") || container_path.contains("Block.Conv") }); assert!( pattern_filter .matches_with_container_path_str("weight", "Model.TransformerBlock.Linear") ); assert!(pattern_filter.matches_with_container_path_str("weight", "Model.ResBlock.Conv2d")); assert!(!pattern_filter.matches_with_container_path_str("weight", "Model.Linear.Param")); // Filter combining path and container path patterns let combined = PathFilter::new().with_predicate(|path, container_path| { // Only weights in Linear layers that are inside blocks path.ends_with(".weight") && container_path.contains("Block") && container_path.split('.').next_back() == Some("Linear") }); assert!( combined .matches_with_container_path_str("layer.weight", "Model.TransformerBlock.Linear") ); assert!( !combined .matches_with_container_path_str("layer.bias", "Model.TransformerBlock.Linear") ); assert!(!combined.matches_with_container_path_str("layer.weight", "Model.Decoder.Linear")); } } ================================================ FILE: crates/burn-store/src/keyremapper.rs ================================================ use alloc::collections::BTreeMap; use alloc::string::{String, ToString}; use alloc::vec::Vec; use regex::{self, Regex}; use crate::TensorSnapshot; /// Key remapper for transforming tensor names. /// /// This allows mapping tensor names from one naming convention to another, /// which is useful for loading models from different frameworks or versions. /// /// # Examples /// /// ```rust /// # use burn_store::KeyRemapper; /// // Create a key remapper /// let remapper = KeyRemapper::new() /// .add_pattern(r"^pytorch\.(.*)", "burn.$1").expect("valid regex") // pytorch.layer -> burn.layer /// .add_pattern(r"\.gamma$", ".weight").expect("valid regex"); // layer.gamma -> layer.weight /// /// // Use remapper with stores /// // store.remap(remapper) /// ``` #[derive(Debug, Clone, Default)] pub struct KeyRemapper { /// Pattern-based remapping rules (regex pattern, replacement string) pub patterns: Vec<(Regex, String)>, } impl KeyRemapper { /// Create a new empty key remapper pub fn new() -> Self { Self::default() } /// Add a remapping pattern (compiles regex) /// /// # Arguments /// /// * `from` - Source pattern (regex string) /// * `to` - Replacement string (can include capture groups like `$1`) /// /// # Returns /// /// * `Ok(Self)` - Updated remapping configuration /// * `Err(regex::Error)` - If regex compilation fails pub fn add_pattern(mut self, from: S1, to: S2) -> Result where S1: AsRef, S2: Into, { let regex = Regex::new(from.as_ref())?; self.patterns.push((regex, to.into())); Ok(self) } /// Create from a list of compiled regex patterns pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self { Self { patterns } } /// Create from string patterns (will compile to regex) /// /// # Arguments /// /// * `patterns` - Vector of (pattern, replacement) tuples /// /// # Returns /// /// * `Ok(Self)` - New remapping configuration /// * `Err(regex::Error)` - If any regex compilation fails pub fn from_patterns(patterns: Vec<(S1, S2)>) -> Result where S1: AsRef, S2: Into, { let mut compiled_patterns = Vec::new(); for (pattern, replacement) in patterns { let regex = Regex::new(pattern.as_ref())?; compiled_patterns.push((regex, replacement.into())); } Ok(Self { patterns: compiled_patterns, }) } /// Create from an iterator of patterns /// /// # Arguments /// /// * `iter` - Iterator yielding (pattern, replacement) tuples /// /// # Returns /// /// * `Ok(Self)` - New remapping configuration /// * `Err(regex::Error)` - If any regex compilation fails pub fn from_pattern_iter(iter: I) -> Result where I: IntoIterator, S1: AsRef, S2: Into, { let patterns: Result, _> = iter .into_iter() .map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into()))) .collect(); Ok(Self { patterns: patterns?, }) } /// Check if the remapping is empty pub fn is_empty(&self) -> bool { self.patterns.is_empty() } /// Convert to the format expected by remap_tensor_paths_with_patterns pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> { self.patterns.clone() } /// Remap tensor paths using the configured patterns. /// /// # Arguments /// /// * `tensors` - Vec of TensorSnapshots to remap /// /// # Returns /// /// A tuple containing: /// * The remapped Vec of TensorSnapshots with updated paths /// * A vector of (new_path, original_path) showing the transformations pub fn remap( &self, mut tensors: Vec, ) -> (Vec, Vec<(String, String)>) { if self.patterns.is_empty() { let remapped_names = tensors .iter() .map(|v| { let path = v.full_path(); (path.clone(), path) }) .collect(); return (tensors, remapped_names); } let mut remapped_snapshots = Vec::new(); let mut remapped_names = Vec::new(); for mut snapshot in tensors.drain(..) { let original_path = snapshot.full_path(); let mut new_path = original_path.clone(); // Apply all patterns to get the new path for (pattern, replacement) in &self.patterns { if pattern.is_match(&new_path) { new_path = pattern .replace_all(&new_path, replacement.as_str()) .to_string(); } } // Update the snapshot's internal path_stack if the path changed if new_path != original_path && let Some(ref mut path_stack) = snapshot.path_stack { *path_stack = new_path.split('.').map(|s| s.to_string()).collect(); } remapped_names.push((new_path.clone(), original_path)); remapped_snapshots.push(snapshot); } (remapped_snapshots, remapped_names) } } /// Map tensor paths to have contiguous numeric indices. /// /// This function detects numeric indices in tensor paths and renumbers them /// to be contiguous (0, 1, 2, ...) while preserving their relative order. /// It handles nested sequential structures by processing ALL numeric indices /// in each path independently based on their position context. /// /// This is useful when loading PyTorch models that have gaps in layer numbering, /// such as when using `nn.Sequential` with mixed layer types (e.g., Conv2d + ReLU /// where only Conv2d has parameters). /// /// # Example /// /// Simple case - input paths: /// - `fc.0.weight`, `fc.0.bias` /// - `fc.2.weight`, `fc.2.bias` /// - `fc.4.weight`, `fc.4.bias` /// /// Output paths: /// - `fc.0.weight`, `fc.0.bias` /// - `fc.1.weight`, `fc.1.bias` /// - `fc.2.weight`, `fc.2.bias` /// /// Nested case - input paths: /// - `feature.layers.0.conv_block.0.weight` /// - `feature.layers.0.conv_block.2.weight` /// - `feature.layers.2.conv_block.0.weight` /// - `feature.layers.2.conv_block.2.weight` /// /// Output paths: /// - `feature.layers.0.conv_block.0.weight` /// - `feature.layers.0.conv_block.1.weight` /// - `feature.layers.1.conv_block.0.weight` /// - `feature.layers.1.conv_block.1.weight` /// /// # Arguments /// /// * `tensors` - Vec of TensorSnapshots to map /// /// # Returns /// /// A tuple containing: /// * The mapped Vec of TensorSnapshots with updated paths /// * A vector of (new_path, original_path) showing the transformations pub fn map_indices_contiguous( mut tensors: Vec, ) -> (Vec, Vec<(String, String)>) { if tensors.is_empty() { return (tensors, Vec::new()); } // Step 1: Collect all paths and find all index positions // For each index position (identified by prefix using ORIGINAL indices), // collect all indices seen at that position. // // Key: prefix using original path (e.g., "feature.layers." or "feature.layers.0.conv_block.") // Value: BTreeMap of original_index -> new_index let mut index_maps: BTreeMap> = BTreeMap::new(); // First pass: collect all indices at each position using original prefixes for snapshot in &tensors { let path = snapshot.full_path(); let parts: Vec<&str> = path.split('.').collect(); // Check each part for numeric indices for (i, part) in parts.iter().enumerate() { if let Ok(index) = part.parse::() { // The prefix is everything before this index (using original path) let prefix = if i > 0 { format!("{}.", parts[..i].join(".")) } else { String::new() }; index_maps .entry(prefix) .or_default() .entry(index) .or_insert(usize::MAX); // Placeholder } } } // Second pass: assign contiguous indices for each position for indices in index_maps.values_mut() { let mut sorted_indices: Vec = indices.keys().cloned().collect(); sorted_indices.sort(); for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() { indices.insert(old_idx, new_idx); } } // Third pass: apply the remapping to all tensors // We use original prefixes for lookup since that's how we collected indices let mut mapped_snapshots = Vec::new(); let mut transformations = Vec::new(); for mut snapshot in tensors.drain(..) { let original_path = snapshot.full_path(); let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps); // Update the snapshot's internal path_stack if the path changed if new_path != original_path && let Some(ref mut path_stack) = snapshot.path_stack { *path_stack = new_path.split('.').map(|s| s.to_string()).collect(); } transformations.push((new_path, original_path)); mapped_snapshots.push(snapshot); } (mapped_snapshots, transformations) } /// Remap all numeric indices in a path using the provided index maps. /// Uses original path prefixes for lookup. fn remap_all_indices_with_original_prefix( path: &str, index_maps: &BTreeMap>, ) -> String { let parts: Vec<&str> = path.split('.').collect(); let mut result_parts: Vec = Vec::with_capacity(parts.len()); for (i, part) in parts.iter().enumerate() { if let Ok(index) = part.parse::() { // Build the prefix from ORIGINAL parts (not remapped) let prefix = if i > 0 { format!("{}.", parts[..i].join(".")) } else { String::new() }; // Look up the new index using original prefix if let Some(index_map) = index_maps.get(&prefix) && let Some(&new_index) = index_map.get(&index) { result_parts.push(new_index.to_string()); continue; } } // Not a numeric index or no mapping found, keep as-is result_parts.push((*part).to_string()); } result_parts.join(".") } #[cfg(all(test, feature = "std"))] mod tests { use super::*; use burn_core::module::ParamId; use burn_tensor::{TensorData, shape}; fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot { let data = TensorData { bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]), shape: shape![2, 2], dtype: burn_tensor::DType::F32, }; let path_parts: Vec = name.split('.').map(|s| s.to_string()).collect(); TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new()) } #[test] fn test_key_remapper_basic() { let remapper = KeyRemapper::new() .add_pattern(r"^encoder\.", "transformer.encoder.") .expect("valid regex"); let tensors = vec![ create_test_tensor_snapshot("encoder.layer1.weight"), create_test_tensor_snapshot("decoder.layer1.weight"), ]; let (remapped, transformations) = remapper.remap(tensors); // Check that remapped views exist with correct paths assert!( remapped .iter() .any(|v| v.full_path() == "transformer.encoder.layer1.weight") ); assert!( remapped .iter() .any(|v| v.full_path() == "decoder.layer1.weight") ); assert_eq!(remapped.len(), 2); // Check transformations let encoder_transform = transformations .iter() .find(|(_new, old)| old == "encoder.layer1.weight") .expect("should find encoder transformation"); assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight"); } #[test] fn test_key_remapper_multiple_patterns() { let remapper = KeyRemapper::new() .add_pattern(r"^encoder\.", "transformer.encoder.") .expect("valid regex") .add_pattern(r"\.gamma$", ".weight") .expect("valid regex"); let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")]; let (remapped, _) = remapper.remap(tensors); assert!( remapped .iter() .any(|v| v.full_path() == "transformer.encoder.layer1.weight") ); assert_eq!(remapped.len(), 1); } #[test] fn test_key_remapper_from_patterns() { let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")]; let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns"); let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")]; let (remapped, _) = remapper.remap(tensors); assert!( remapped .iter() .any(|v| v.full_path() == "burn.linear.bias_param") ); } #[test] fn test_key_remapper_empty() { let remapper = KeyRemapper::new(); assert!(remapper.is_empty()); let tensors = vec![create_test_tensor_snapshot("test.weight")]; let (remapped, transformations) = remapper.remap(tensors); assert!(remapped.iter().any(|v| v.full_path() == "test.weight")); assert_eq!(remapped.len(), 1); assert_eq!(transformations.len(), 1); assert_eq!( transformations[0], ("test.weight".to_string(), "test.weight".to_string()) ); } #[test] fn test_map_indices_contiguous_basic() { // Simulate PyTorch nn.Sequential with Conv2d (0, 2, 4) and ReLU (1, 3, 5) // Only Conv2d layers have parameters let tensors = vec![ create_test_tensor_snapshot("fc.0.weight"), create_test_tensor_snapshot("fc.0.bias"), create_test_tensor_snapshot("fc.2.weight"), create_test_tensor_snapshot("fc.2.bias"), create_test_tensor_snapshot("fc.4.weight"), create_test_tensor_snapshot("fc.4.bias"), ]; let (reindexed, transformations) = map_indices_contiguous(tensors); // Check that indices are now contiguous assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.bias")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.bias")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.bias")); assert_eq!(reindexed.len(), 6); // Check transformations let transform_2_to_1 = transformations .iter() .find(|(_, old)| old == "fc.2.weight") .expect("should find fc.2.weight transformation"); assert_eq!(transform_2_to_1.0, "fc.1.weight"); let transform_4_to_2 = transformations .iter() .find(|(_, old)| old == "fc.4.weight") .expect("should find fc.4.weight transformation"); assert_eq!(transform_4_to_2.0, "fc.2.weight"); } #[test] fn test_map_indices_contiguous_already_contiguous() { // Already contiguous indices should remain unchanged let tensors = vec![ create_test_tensor_snapshot("fc.0.weight"), create_test_tensor_snapshot("fc.1.weight"), create_test_tensor_snapshot("fc.2.weight"), ]; let (reindexed, transformations) = map_indices_contiguous(tensors); assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight")); assert_eq!(reindexed.len(), 3); // All transformations should have same old and new paths for (new, old) in &transformations { assert_eq!(new, old); } } #[test] fn test_map_indices_contiguous_multiple_prefixes() { // Different prefixes should be mapped independently let tensors = vec![ create_test_tensor_snapshot("encoder.0.weight"), create_test_tensor_snapshot("encoder.2.weight"), create_test_tensor_snapshot("decoder.1.weight"), create_test_tensor_snapshot("decoder.5.weight"), ]; let (reindexed, _) = map_indices_contiguous(tensors); // encoder: 0, 2 -> 0, 1 assert!( reindexed .iter() .any(|v| v.full_path() == "encoder.0.weight") ); assert!( reindexed .iter() .any(|v| v.full_path() == "encoder.1.weight") ); // decoder: 1, 5 -> 0, 1 assert!( reindexed .iter() .any(|v| v.full_path() == "decoder.0.weight") ); assert!( reindexed .iter() .any(|v| v.full_path() == "decoder.1.weight") ); } #[test] fn test_map_indices_contiguous_no_indices() { // Paths without indices should remain unchanged let tensors = vec![ create_test_tensor_snapshot("encoder.weight"), create_test_tensor_snapshot("decoder.bias"), ]; let (reindexed, transformations) = map_indices_contiguous(tensors); assert!(reindexed.iter().any(|v| v.full_path() == "encoder.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "decoder.bias")); for (new, old) in &transformations { assert_eq!(new, old); } } #[test] fn test_map_indices_contiguous_empty() { let tensors: Vec = vec![]; let (reindexed, transformations) = map_indices_contiguous(tensors); assert!(reindexed.is_empty()); assert!(transformations.is_empty()); } #[test] fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() { // Mix of indexed and non-indexed paths let tensors = vec![ create_test_tensor_snapshot("fc.0.weight"), create_test_tensor_snapshot("fc.2.weight"), create_test_tensor_snapshot("output.weight"), // no index ]; let (reindexed, _) = map_indices_contiguous(tensors); assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); // 2 -> 1 assert!(reindexed.iter().any(|v| v.full_path() == "output.weight")); // unchanged } #[test] fn test_map_indices_contiguous_nested_sequential() { // Test nested sequential structures like: // feature = nn.Sequential(ConvBlock, ReLU, ConvBlock, ReLU, ConvBlock) // where ConvBlock = nn.Sequential(Conv2d, ReLU, Conv2d) // // This produces paths like: // feature.layers.0.conv_block.0.weight (layer 0, conv 0) // feature.layers.0.conv_block.2.weight (layer 0, conv 2 - skipping ReLU at 1) // feature.layers.2.conv_block.0.weight (layer 2 - skipping ReLU at 1, conv 0) // feature.layers.2.conv_block.2.weight (layer 2, conv 2) let tensors = vec![ create_test_tensor_snapshot("feature.layers.0.conv_block.0.weight"), create_test_tensor_snapshot("feature.layers.0.conv_block.2.weight"), create_test_tensor_snapshot("feature.layers.2.conv_block.0.weight"), create_test_tensor_snapshot("feature.layers.2.conv_block.2.weight"), ]; let (mapped, transformations) = map_indices_contiguous(tensors); // Expected mapping: // feature.layers: 0, 2 -> 0, 1 // feature.layers.0.conv_block: 0, 2 -> 0, 1 // feature.layers.2.conv_block: 0, 2 -> 0, 1 // // Result: // feature.layers.0.conv_block.0.weight -> feature.layers.0.conv_block.0.weight // feature.layers.0.conv_block.2.weight -> feature.layers.0.conv_block.1.weight // feature.layers.2.conv_block.0.weight -> feature.layers.1.conv_block.0.weight // feature.layers.2.conv_block.2.weight -> feature.layers.1.conv_block.1.weight assert!( mapped .iter() .any(|v| v.full_path() == "feature.layers.0.conv_block.0.weight"), "0.0 should stay as 0.0" ); assert!( mapped .iter() .any(|v| v.full_path() == "feature.layers.0.conv_block.1.weight"), "0.2 should become 0.1" ); assert!( mapped .iter() .any(|v| v.full_path() == "feature.layers.1.conv_block.0.weight"), "2.0 should become 1.0" ); assert!( mapped .iter() .any(|v| v.full_path() == "feature.layers.1.conv_block.1.weight"), "2.2 should become 1.1" ); // Verify specific transformations let t1 = transformations .iter() .find(|(_, old)| old == "feature.layers.2.conv_block.2.weight"); assert_eq!( t1.map(|(new, _)| new.as_str()), Some("feature.layers.1.conv_block.1.weight"), "2.2 should map to 1.1" ); } #[test] fn test_map_indices_contiguous_deeply_nested() { // Test with three levels of nesting let tensors = vec![ create_test_tensor_snapshot("a.0.b.0.c.0.weight"), create_test_tensor_snapshot("a.0.b.0.c.2.weight"), create_test_tensor_snapshot("a.0.b.2.c.0.weight"), create_test_tensor_snapshot("a.2.b.0.c.0.weight"), ]; let (mapped, _) = map_indices_contiguous(tensors); // a: 0, 2 -> 0, 1 // a.0.b: 0, 2 -> 0, 1 // a.2.b: 0 -> 0 // a.0.b.0.c: 0, 2 -> 0, 1 // a.0.b.2.c: 0 -> 0 // a.2.b.0.c: 0 -> 0 assert!(mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.0.weight")); assert!( mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.1.weight"), "a.0.b.0.c.2 should become a.0.b.0.c.1" ); assert!( mapped.iter().any(|v| v.full_path() == "a.0.b.1.c.0.weight"), "a.0.b.2.c.0 should become a.0.b.1.c.0" ); assert!( mapped.iter().any(|v| v.full_path() == "a.1.b.0.c.0.weight"), "a.2.b.0.c.0 should become a.1.b.0.c.0" ); } } ================================================ FILE: crates/burn-store/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] //! # Burn Store //! //! Advanced model storage and serialization infrastructure for the Burn deep learning framework. //! //! This crate provides comprehensive functionality for storing and loading Burn modules //! and their tensor data, with support for cross-framework interoperability, flexible filtering, //! and efficient memory management through lazy materialization. //! //! ## Key Features //! //! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support //! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization //! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation //! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance //! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates //! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility //! - **No-std Support**: Core functionality available in embedded and WASM environments //! //! ## Quick Start //! //! ### Basic Save and Load //! //! ```rust,ignore //! use burn_store::{ModuleSnapshot, SafetensorsStore}; //! //! // Save a model //! let mut store = SafetensorsStore::from_file("model.safetensors"); //! model.save_into(&mut store)?; //! //! // Load a model //! let mut store = SafetensorsStore::from_file("model.safetensors"); //! model.load_from(&mut store)?; //! ``` //! //! ### Loading PyTorch Models //! //! ```rust,ignore //! use burn_store::PytorchStore; //! //! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter) //! let mut store = PytorchStore::from_file("pytorch_model.pth") //! .with_top_level_key("state_dict") // Access nested state dict if needed //! .allow_partial(true); // Skip unknown tensors //! //! model.load_from(&mut store)?; //! ``` //! //! ### Filtering and Remapping //! //! ```rust,no_run //! # use burn_store::SafetensorsStore; //! // Save only specific layers with renaming //! let mut store = SafetensorsStore::from_file("encoder.safetensors") //! .with_regex(r"^encoder\..*") // Filter: only encoder layers //! .with_key_remapping(r"^encoder\.", "transformer.") // Rename: encoder.X -> transformer.X //! .metadata("subset", "encoder_only"); //! //! // Use store with model.save_into(&mut store)?; //! ``` //! //! ## Core Components //! //! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods //! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows //! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format //! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files //! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving //! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns //! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility //! //! ## Feature Flags //! //! - `std`: Enables file I/O and other std-only features (default) //! - `safetensors`: Enables SafeTensors format support (default) extern crate alloc; mod adapter; mod applier; mod apply_result; mod collector; mod filter; mod tensor_snapshot; mod traits; pub use adapter::{ BurnToPyTorchAdapter, ChainAdapter, HalfPrecisionAdapter, IdentityAdapter, ModuleAdapter, PyTorchToBurnAdapter, }; pub use applier::Applier; pub use apply_result::{ApplyError, ApplyResult}; pub use collector::Collector; pub use filter::PathFilter; pub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError}; pub use traits::{ModuleSnapshot, ModuleStore}; #[cfg(feature = "std")] mod keyremapper; #[cfg(feature = "std")] pub use keyremapper::{KeyRemapper, map_indices_contiguous}; #[cfg(feature = "pytorch")] pub mod pytorch; #[cfg(feature = "pytorch")] pub use pytorch::{PytorchStore, PytorchStoreError}; #[cfg(feature = "safetensors")] mod safetensors; #[cfg(feature = "safetensors")] pub use safetensors::{SafetensorsStore, SafetensorsStoreError}; #[cfg(feature = "burnpack")] mod burnpack; #[cfg(feature = "burnpack")] pub use burnpack::writer::BurnpackWriter; #[cfg(feature = "burnpack")] pub use burnpack::{base::BurnpackError, store::BurnpackStore}; ================================================ FILE: crates/burn-store/src/pytorch/lazy_data.rs ================================================ //! Lazy data loading support for PyTorch files. //! //! This module provides abstractions for lazy loading of tensor data from PyTorch files, //! avoiding the need to load all data into memory upfront. use alloc::string::String; use alloc::vec::Vec; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Seek}; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex, RwLock}; use zip::ZipArchive; /// A data source that can lazily load tensor data. #[derive(Clone)] pub enum LazyDataSource { /// ZIP archive with lazy loading Zip(Arc>), /// TAR archive format (older torchvision models) Tar(Arc>), /// Legacy format with multiple storages in single blob LegacyMultiStorage(Arc>), } /// ZIP archive source for lazy loading pub struct ZipSource { path: PathBuf, // Cache the file list to avoid reopening archive repeatedly file_list: Vec<(String, u64, u64)>, // (name, offset, compressed_size) } /// TAR archive source for lazy loading (older torchvision models like AlexNet, SqueezeNet) /// /// Older PyTorch/torchvision models (pre-1.6) use TAR format instead of ZIP. /// The TAR archive contains: /// - `sys_info`: System info pickle (endianness, type sizes) /// - `pickle`: OrderedDict mapping tensor names to storage keys /// - `tensors`: Tensor metadata pickles (unused, metadata is embedded in pickle) /// - `storages`: Storage count + sequential (metadata pickle, element count, raw data) pub struct TarSource { /// Cached storage map: storage_key -> (offset_in_storages, size_bytes) storage_map: HashMap, /// The raw storages data (kept in memory for TAR format) storages_data: Vec, } /// Legacy multi-storage source for old PyTorch format (0.1.10 - 1.5) /// /// Legacy format stores tensor data as concatenated raw binary without explicit /// storage boundaries. This source tracks storage usage during tensor parsing /// to build a storage map for lazy loading. /// /// ## Storage Layout /// - Pickle metadata with tensor definitions /// - List of storage keys (determines concatenation order) /// - Raw binary blob with all storages concatenated pub struct LegacyMultiStorageSource { path: PathBuf, data_offset: u64, #[allow(dead_code)] data_size: u64, // Map of storage_key -> (offset_in_blob, size) storage_map: RwLock>>, // Storage keys in order (for boundary calculation) storage_keys: RwLock>>, // Track storage usage as tensors are accessed storage_usage: RwLock>, // key -> max_bytes_needed } impl ZipSource { /// Create a new ZIP source pub fn new(path: PathBuf) -> std::io::Result { let file = File::open(&path)?; let reader = BufReader::new(file); let mut archive = ZipArchive::new(reader)?; // Cache file metadata let mut file_list = Vec::new(); for i in 0..archive.len() { let file = archive.by_index(i)?; let name = file.name().to_string(); let offset = file.data_start(); let compressed_size = file.compressed_size(); file_list.push(( name, offset.expect("should have an offset"), compressed_size, )); } Ok(Self { path, file_list }) } /// Check if a file exists in the archive pub fn contains(&self, name: &str) -> bool { self.file_list.iter().any(|(n, _, _)| n == name) } /// Get list of data files (excluding pickle files) pub fn data_files(&self) -> Vec { self.file_list .iter() .filter(|(name, _, _)| name.starts_with("data/") || name.contains("/data/")) .filter(|(name, _, _)| !name.ends_with(".pkl") && !name.ends_with("/")) .map(|(name, _, _)| name.clone()) .collect() } /// Read a specific file from the archive pub fn read_file(&self, name: &str) -> std::io::Result> { let file = File::open(&self.path)?; let reader = BufReader::new(file); let mut archive = ZipArchive::new(reader)?; let mut file = archive.by_name(name)?; let mut contents = Vec::with_capacity(file.size() as usize); file.read_to_end(&mut contents)?; Ok(contents) } /// Read a portion of a file pub fn read_file_range( &self, name: &str, offset: usize, length: usize, ) -> std::io::Result> { let file = File::open(&self.path)?; let reader = BufReader::new(file); let mut archive = ZipArchive::new(reader)?; let mut file = archive.by_name(name)?; let mut buffer = vec![0u8; length]; // Skip to offset let mut skip_buffer = vec![0u8; offset.min(8192)]; let mut skipped = 0; while skipped < offset { let to_skip = (offset - skipped).min(skip_buffer.len()); file.read_exact(&mut skip_buffer[..to_skip])?; skipped += to_skip; } // Read the requested data file.read_exact(&mut buffer)?; Ok(buffer) } } impl LegacyMultiStorageSource { /// Create a new legacy multi-storage source pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self { Self { path, data_offset, data_size, storage_map: RwLock::new(None), storage_keys: RwLock::new(None), storage_usage: RwLock::new(HashMap::new()), } } /// Set the ordered storage keys from the pickle pub fn set_storage_keys(&self, keys: Vec) { let mut storage_keys = self .storage_keys .write() .unwrap_or_else(|poisoned| poisoned.into_inner()); *storage_keys = Some(keys); } /// Track storage usage from tensor access /// This is called from within tensor loading closures pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) { let mut usage = self .storage_usage .write() .unwrap_or_else(|poisoned| poisoned.into_inner()); let max_extent = offset + size; usage .entry(storage_key.to_string()) .and_modify(|current| *current = (*current).max(max_extent)) .or_insert(max_extent); // Try to build storage map if we have enough information drop(usage); self.try_build_storage_map(); } /// Try to build the storage map from tracked usage fn try_build_storage_map(&self) { // Only build if we don't already have a map if self .storage_map .read() .unwrap_or_else(|poisoned| poisoned.into_inner()) .is_some() { return; } // Check if we have storage keys let keys_guard = self .storage_keys .read() .unwrap_or_else(|poisoned| poisoned.into_inner()); if let Some(ref keys) = *keys_guard { let usage = self .storage_usage .read() .unwrap_or_else(|poisoned| poisoned.into_inner()); // Only build if we have usage info for all storages if keys.iter().all(|k| usage.contains_key(k)) { let mut map = HashMap::new(); let mut current_offset = 0u64; for key in keys { if let Some(&size) = usage.get(key) { map.insert(key.clone(), (current_offset, size as u64)); current_offset += size as u64; } } // Set the storage map drop(keys_guard); drop(usage); let mut storage_map = self .storage_map .write() .unwrap_or_else(|poisoned| poisoned.into_inner()); *storage_map = Some(map); } } } /// Read data for a specific storage key /// Only loads the specific storage portion, never the entire blob pub fn read(&self, key: &str) -> std::io::Result> { // Extract numeric key from paths like "data/0" or just "0" let storage_key = key.split('/').next_back().unwrap_or(key); // Get storage map - must be available for lazy loading to work let storage_map = self .storage_map .read() .unwrap_or_else(|poisoned| poisoned.into_inner()); if let Some(ref map) = *storage_map && let Some(&(offset, size)) = map.get(storage_key) { // Load only this specific storage let mut file = File::open(&self.path)?; file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?; let mut buffer = vec![0u8; size as usize]; file.read_exact(&mut buffer)?; return Ok(buffer); } // NO FALLBACK! If we don't have storage boundaries, we cannot load data lazily // The storage map MUST be built from tensor metadata for lazy loading to work Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Storage boundaries not available for key '{}'. Cannot perform lazy loading.", storage_key ), )) } } impl TarSource { /// Create a new TAR source by parsing storages data. /// /// # Arguments /// * `storages_data` - Raw storages blob with structure: /// - Count pickle (number of storages) /// - For each storage: metadata pickle + u64 num_elements + raw binary data pub fn new(storages_data: Vec) -> std::io::Result { use super::pickle_reader::{read_pickle, storage_type_to_element_size}; use std::io::Cursor; let mut storage_map = HashMap::new(); let mut pos = 0usize; // First, read the count of storages let mut cursor = Cursor::new(&storages_data[pos..]); let storage_count = if let Ok(super::pickle_reader::Object::Int(count)) = read_pickle(&mut cursor) { pos += cursor.position() as usize; count as usize } else { 0 }; // Parse each storage entry for _i in 0..storage_count { if pos >= storages_data.len() { break; } // Read the storage metadata pickle: (storage_key, device, storage_type) let mut cursor = Cursor::new(&storages_data[pos..]); if let Ok(obj) = read_pickle(&mut cursor) { let pickle_size = cursor.position() as usize; pos += pickle_size; // Extract storage info from pickle tuple let (storage_key, storage_type) = match obj { super::pickle_reader::Object::Tuple(tuple) if tuple.len() >= 3 => { let key = match &tuple[0] { super::pickle_reader::Object::Int(i) => i.to_string(), super::pickle_reader::Object::String(s) => s.clone(), _ => continue, }; // tuple[1] is device (e.g., "cpu") // tuple[2] is storage type class let stype = match &tuple[2] { super::pickle_reader::Object::Class { name, .. } => name.clone(), other => { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!("Expected Class for storage type, got {:?}", other), )); } }; (key, stype) } _ => continue, }; // Read the number of elements (u64 little-endian) if pos + 8 > storages_data.len() { break; } let num_elements = u64::from_le_bytes([ storages_data[pos], storages_data[pos + 1], storages_data[pos + 2], storages_data[pos + 3], storages_data[pos + 4], storages_data[pos + 5], storages_data[pos + 6], storages_data[pos + 7], ]) as usize; pos += 8; // Determine element size from storage type let element_size = storage_type_to_element_size(&storage_type) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; let data_size = num_elements * element_size; // Store the offset to raw data and its size storage_map.insert(storage_key, (pos, data_size)); // Skip the raw binary data pos += data_size; } else { break; } } Ok(Self { storage_map, storages_data, }) } /// Read data for a specific storage key pub fn read_file(&self, key: &str) -> std::io::Result> { // Extract the storage key from paths like "data/0" let storage_key = key.split('/').next_back().unwrap_or(key); if let Some(&(offset, size)) = self.storage_map.get(storage_key) && offset + size <= self.storages_data.len() { return Ok(self.storages_data[offset..offset + size].to_vec()); } Err(std::io::Error::new( std::io::ErrorKind::NotFound, format!("Storage key '{}' not found in TAR archive", storage_key), )) } /// Read a range of data for a specific storage key (avoids double allocation) pub fn read_file_range( &self, key: &str, offset: usize, length: usize, ) -> std::io::Result> { let storage_key = key.split('/').next_back().unwrap_or(key); if let Some(&(storage_offset, storage_size)) = self.storage_map.get(storage_key) && storage_offset + storage_size <= self.storages_data.len() { let start = storage_offset + offset; let end = (storage_offset + offset + length).min(storage_offset + storage_size); return Ok(self.storages_data[start..end].to_vec()); } Err(std::io::Error::new( std::io::ErrorKind::NotFound, format!("Storage key '{}' not found in TAR archive", storage_key), )) } /// Check if a storage key exists pub fn contains(&self, key: &str) -> bool { let storage_key = key.split('/').next_back().unwrap_or(key); self.storage_map.contains_key(storage_key) } /// Get list of storage keys pub fn keys(&self) -> Vec { self.storage_map.keys().cloned().collect() } } impl LazyDataSource { /// Create from a ZIP file pub fn from_zip(path: impl AsRef) -> std::io::Result { Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new( path.as_ref().to_path_buf(), )?)))) } /// Create from a TAR archive's storages data pub fn from_tar(storages_data: &[u8]) -> std::io::Result { Ok(Self::Tar(Arc::new(Mutex::new(TarSource::new( storages_data.to_vec(), )?)))) } /// Create from a legacy multi-storage file pub fn from_legacy_multi_storage( path: impl AsRef, data_offset: u64, data_size: u64, ) -> Self { Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new( path.as_ref().to_path_buf(), data_offset, data_size, )))) } /// Read data for a specific key pub fn read(&self, key: &str) -> std::io::Result> { match self { Self::Zip(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.read_file(key) } Self::Tar(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.read_file(key) } Self::LegacyMultiStorage(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.read(key) } } } /// Read a portion of data for a specific key pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result> { match self { Self::Zip(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.read_file_range(key, offset, length) } Self::Tar(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.read_file_range(key, offset, length) } Self::LegacyMultiStorage(source) => { // For legacy format, read only the requested range let storage_key = key.split('/').next_back().unwrap_or(key); let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); // Get storage boundaries let storage_map = source .storage_map .read() .unwrap_or_else(|poisoned| poisoned.into_inner()); if let Some(ref map) = *storage_map && let Some(&(storage_offset, storage_size)) = map.get(storage_key) { // Calculate actual file position let file_offset = source.data_offset + storage_offset + offset as u64; let read_length = length.min((storage_size as usize).saturating_sub(offset)); // Read only the requested range let mut file = File::open(&source.path)?; file.seek(std::io::SeekFrom::Start(file_offset))?; let mut buffer = vec![0u8; read_length]; file.read_exact(&mut buffer)?; Ok(buffer) } else { Err(std::io::Error::new( std::io::ErrorKind::InvalidData, format!( "Storage boundaries not available for key '{}'. Cannot perform lazy loading.", storage_key ), )) } } } } /// Check if a key exists pub fn contains(&self, key: &str) -> bool { match self { Self::Zip(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.contains(key) } Self::Tar(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.contains(key) } Self::LegacyMultiStorage(_) => true, // Legacy format has all data } } /// Get list of available keys (for ZIP sources) pub fn keys(&self) -> Vec { match self { Self::Zip(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.data_files() } Self::Tar(source) => { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.keys() } Self::LegacyMultiStorage(_) => vec![], // Legacy format doesn't have distinct keys } } } ================================================ FILE: crates/burn-store/src/pytorch/mod.rs ================================================ //! PyTorch format support for burn-store. //! //! This module provides comprehensive support for loading PyTorch model files (.pth, .pt) //! into Burn, with automatic weight transformation and flexible configuration options. //! //! ## Features //! //! - **Direct .pth/.pt file loading**: Load PyTorch checkpoint and state dict files //! - **Automatic weight transformation**: `PyTorchToBurnAdapter` is applied by default: //! - Linear layer weights are automatically transposed //! - Normalization parameters are renamed (gamma → weight, beta → bias) //! - Conv2d weights maintain their format //! - **Flexible filtering**: Load only specific layers or parameters //! - **Key remapping**: Rename tensors during loading to match your model structure //! - **Partial loading**: Continue even when some tensors are missing //! //! ## Example //! //! ```rust,ignore //! use burn_store::PytorchStore; //! //! // Load a PyTorch model (PyTorchToBurnAdapter is applied automatically) //! let mut store = PytorchStore::from_file("model.pth") //! .with_top_level_key("state_dict") // Access nested state dict //! .with_regex(r"^encoder\..*") // Only load encoder layers //! .with_key_remapping(r"^fc\.", "linear.") // Rename fc -> linear //! .allow_partial(true); // Skip missing tensors //! //! let mut model = MyModel::new(&device); //! let result = model.load_from(&mut store)?; //! //! println!("Loaded {} tensors", result.applied.len()); //! if !result.missing.is_empty() { //! println!("Missing tensors: {:?}", result.missing); //! } //! ``` pub mod lazy_data; pub mod pickle_reader; pub mod reader; pub mod store; #[cfg(test)] pub mod tests; // Main public interface pub use reader::{PytorchError, PytorchReader}; pub use store::{PytorchStore, PytorchStoreError}; ================================================ FILE: crates/burn-store/src/pytorch/pickle_reader.rs ================================================ //! Just enough pickle support to be able to read PyTorch checkpoints. //! //! This implementation is based on the candle project's pickle loader with significant //! modifications for improved separation of concerns and extended PyTorch compatibility. //! //! Original source: //! //! Modifications include: //! - Lazy tensor data loading for memory efficiency //! - Extended PyTorch version compatibility (0.1.10 - 2.x) //! - Better separation of pickle parsing and tensor extraction //! - Support for both legacy and modern PyTorch formats use crate::TensorSnapshot; use crate::pytorch::lazy_data::LazyDataSource; use alloc::rc::Rc; use alloc::string::{String, ToString}; use alloc::vec::Vec; use burn_core::module::ParamId; use burn_tensor::{BoolStore, DType, TensorData}; use byteorder::{LittleEndian, ReadBytesExt}; use half::{bf16, f16}; use std::collections::HashMap; use std::io::{self, BufRead}; use std::sync::Arc; /// Error type for pickle operations #[derive(Debug)] pub enum PickleError { Io(io::Error), InvalidOpCode(u8), InvalidProtocol(u8), UnexpectedOpCode(OpCode), UnsupportedType(String), InvalidData(String), StackUnderflow, MemoNotFound(u32), InvalidShapeOrType, } impl From for PickleError { fn from(e: io::Error) -> Self { PickleError::Io(e) } } impl std::fmt::Display for PickleError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { PickleError::Io(e) => write!(f, "IO error: {}", e), PickleError::InvalidOpCode(code) => write!( f, "Invalid pickle opcode: 0x{:02x}. The file may be corrupted or use an unsupported pickle protocol.", code ), PickleError::InvalidProtocol(proto) => write!( f, "Invalid or unsupported pickle protocol version: {}. Supported versions are 2-5.", proto ), PickleError::UnexpectedOpCode(op) => { write!(f, "Unexpected pickle opcode {:?} in current context", op) } PickleError::UnsupportedType(ty) => write!( f, "Unsupported Python type '{}'. This may indicate a full model save rather than a state_dict.", ty ), PickleError::InvalidData(msg) => write!(f, "Invalid data in pickle file: {}", msg), PickleError::StackUnderflow => { write!(f, "Pickle stack underflow - the file may be corrupted") } PickleError::MemoNotFound(idx) => write!( f, "Pickle memo reference {} not found - the file may be corrupted", idx ), PickleError::InvalidShapeOrType => { write!(f, "Invalid tensor shape or data type in PyTorch file") } } } } impl std::error::Error for PickleError {} type Result = std::result::Result; /// Convert PyTorch storage type name to element size in bytes. /// /// This is used to calculate storage sizes for lazy loading. /// The storage type names follow PyTorch's naming convention (e.g., "FloatStorage", "BFloat16Storage"). /// /// Returns an error for unknown storage types to avoid silently loading garbage data. pub fn storage_type_to_element_size(storage_type: &str) -> std::result::Result { match storage_type { "DoubleStorage" | "LongStorage" | "ComplexFloatStorage" => Ok(8), "FloatStorage" | "IntStorage" | "ComplexHalfStorage" => Ok(4), "HalfStorage" | "BFloat16Storage" | "ShortStorage" => Ok(2), "ByteStorage" | "CharStorage" | "BoolStorage" => Ok(1), _ => Err(format!("Unknown storage type: {}", storage_type)), } } // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/ #[repr(u8)] #[derive(Debug, Eq, PartialEq, Clone)] pub enum OpCode { // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123 Proto = 0x80, Global = b'c', BinPut = b'q', LongBinPut = b'r', EmptyTuple = b')', Reduce = b'R', Mark = b'(', BinUnicode = b'X', ShortBinString = b'U', BinInt = b'J', Int = b'I', Tuple = b't', BinPersId = b'Q', BinInt1 = b'K', BinInt2 = b'M', Tuple1 = 0x85, Tuple2 = 0x86, Tuple3 = 0x87, NewTrue = 0x88, NewFalse = 0x89, None = b'N', BinGet = b'h', LongBinGet = b'j', SetItem = b's', SetItems = b'u', EmptyDict = b'}', Dict = b'd', Build = b'b', Stop = b'.', NewObj = 0x81, EmptyList = b']', List = b'l', BinFloat = b'G', Append = b'a', Appends = b'e', Long1 = 0x8a, Memoize = 0x94, } // Avoid using FromPrimitive so as not to drag another dependency. impl TryFrom for OpCode { type Error = u8; fn try_from(value: u8) -> std::result::Result { match value { 0x80 => Ok(Self::Proto), b'c' => Ok(Self::Global), b'q' => Ok(Self::BinPut), b'r' => Ok(Self::LongBinPut), b')' => Ok(Self::EmptyTuple), b'R' => Ok(Self::Reduce), b'(' => Ok(Self::Mark), b'X' => Ok(Self::BinUnicode), b'U' => Ok(Self::ShortBinString), b'J' => Ok(Self::BinInt), b'I' => Ok(Self::Int), b't' => Ok(Self::Tuple), b'Q' => Ok(Self::BinPersId), b'K' => Ok(Self::BinInt1), b'M' => Ok(Self::BinInt2), b'N' => Ok(Self::None), 0x85 => Ok(Self::Tuple1), 0x86 => Ok(Self::Tuple2), 0x87 => Ok(Self::Tuple3), 0x88 => Ok(Self::NewTrue), 0x89 => Ok(Self::NewFalse), b'h' => Ok(Self::BinGet), b'j' => Ok(Self::LongBinGet), b's' => Ok(Self::SetItem), b'u' => Ok(Self::SetItems), b'}' => Ok(Self::EmptyDict), b'd' => Ok(Self::Dict), b'b' => Ok(Self::Build), b'.' => Ok(Self::Stop), 0x81 => Ok(Self::NewObj), b']' => Ok(Self::EmptyList), b'l' => Ok(Self::List), b'G' => Ok(Self::BinFloat), b'a' => Ok(Self::Append), b'e' => Ok(Self::Appends), 0x8a => Ok(Self::Long1), 0x94 => Ok(Self::Memoize), value => Err(value), } } } fn read_to_newline(r: &mut R) -> Result> { let mut data: Vec = Vec::with_capacity(32); r.read_until(b'\n', &mut data)?; data.pop(); if data.last() == Some(&b'\r') { data.pop(); } Ok(data) } fn buf_to_str(buf: &[u8]) -> Result { String::from_utf8(buf.to_vec()) .map_err(|e| PickleError::InvalidData(format!("Invalid UTF-8: {}", e))) } #[derive(Debug, Clone)] pub enum Object { Class { module_name: String, name: String, }, String(String), Int(i64), Float(f64), Bool(bool), None, Tuple(Vec), List(Vec), Dict(HashMap), Persistent(Vec), PersistentTuple(Vec), Reduce { callable: Box, args: Box, }, Build { callable: Box, args: Box, }, TorchParam(TensorSnapshot), } fn rebuild_from_type_v2( o: Object, memo: &mut HashMap, data_source: &Option>, ) -> Result { let args = if let Object::Tuple(args) = o { if args.is_empty() { return Err(PickleError::InvalidData( "rebuild_from_type_v2: empty args".to_string(), )); } args } else { return Err(PickleError::InvalidData(format!( "rebuild_from_type_v2: expected tuple got {:?}", o ))); }; let func = &args[0]; match func { Object::Class { module_name, name } => { let module_name = module_name.as_str(); let name = name.as_str(); // For rebuild_tensor_v2, the args might already be in a tuple let actual_args = if args.len() == 2 && matches!(&args[1], Object::Tuple(_)) { // If there's only one arg and it's a tuple, use it directly args[1].clone() } else { // Otherwise, wrap the remaining args in a tuple Object::Tuple(args[1..].to_vec()) }; if module_name == "torch._utils" && name == "_rebuild_tensor_v2" { rebuild_tensor_v2(actual_args, memo, data_source) } else if module_name == "torch._utils" && name == "_rebuild_tensor" { // Legacy _rebuild_tensor (PyTorch < 1.6) // Same as v2 but with fewer arguments: (storage, storage_offset, size, stride) rebuild_tensor(actual_args, memo, data_source) } else if module_name == "torch._tensor" && name == "_rebuild_from_type_v2" { rebuild_from_type_v2(actual_args, memo, data_source) } else if module_name == "torch._utils" && name == "_rebuild_parameter" { rebuild_parameter(actual_args, memo, data_source) } else if module_name == "collections" && name == "OrderedDict" { // OrderedDict is treated as a regular Dict in our implementation Ok(Object::Dict(HashMap::new())) } else { Err(PickleError::UnsupportedType(format!( "{}.{}", module_name, name ))) } } _ => Err(PickleError::InvalidData(format!( "rebuild_from_type_v2: expected class got {:?}", func ))), } } fn rebuild_parameter( args: Object, memo: &mut HashMap, data_source: &Option>, ) -> Result { let args = if let Object::Tuple(args) = args { if args.is_empty() { return Err(PickleError::InvalidData( "rebuild_parameter: empty args".to_string(), )); } args } else { return Err(PickleError::InvalidData(format!( "rebuild_parameter: expected tuple got {:?}", args ))); }; let data = &args[0]; let tensor = match data { Object::Reduce { callable: _, args: _, } => rebuild_from_type_v2(data.clone(), memo, data_source)?, _ => data.clone(), }; Ok(tensor) } /// Parse storage argument and extract storage info and tuple. fn parse_storage_arg(arg: &Object, fn_name: &str) -> Result<(Vec, Option>)> { match arg { Object::Persistent(data) => Ok((data.clone(), None)), Object::PersistentTuple(tuple) => Ok((vec![], Some(tuple.clone()))), // Also accept regular Tuple for TAR format compatibility Object::Tuple(tuple) => Ok((vec![], Some(tuple.clone()))), _ => Err(PickleError::InvalidData(format!( "{}: expected persistent id got {:?}", fn_name, arg ))), } } /// Parse shape argument. fn parse_shape_arg(arg: &Object, fn_name: &str) -> Result> { match arg { Object::Tuple(shape) => shape .iter() .map(|x| match x { Object::Int(i) => Ok(*i as usize), _ => Err(PickleError::InvalidData( "shape must contain ints".to_string(), )), }) .collect::>>(), _ => Err(PickleError::InvalidData(format!( "{}: expected shape tuple got {:?}", fn_name, arg ))), } } /// Legacy _rebuild_tensor function for PyTorch < 1.6. /// Thin wrapper that parses 4 arguments and calls rebuild_tensor_impl. fn rebuild_tensor( args: Object, _memo: &mut HashMap, data_source: &Option>, ) -> Result { let args = if let Object::Tuple(args) = args { args } else { return Err(PickleError::InvalidData(format!( "rebuild_tensor: expected tuple got {:?}", args ))); }; if args.len() < 4 { return Err(PickleError::InvalidData(format!( "rebuild_tensor: expected at least 4 args, got {}", args.len() ))); } let (storage_info, storage_tuple) = parse_storage_arg(&args[0], "rebuild_tensor")?; let storage_offset = match &args[1] { Object::Int(offset) => *offset as usize, _ => 0, }; let shape = parse_shape_arg(&args[2], "rebuild_tensor")?; rebuild_tensor_impl( storage_info, storage_tuple, storage_offset, shape, data_source, ) } /// Modern _rebuild_tensor_v2 function for PyTorch >= 1.6. /// Thin wrapper that parses 5+ arguments and calls rebuild_tensor_impl. fn rebuild_tensor_v2( args: Object, _memo: &mut HashMap, data_source: &Option>, ) -> Result { let args = if let Object::Tuple(args) = args { args } else { return Err(PickleError::InvalidData(format!( "rebuild_tensor_v2: expected tuple got {:?}", args ))); }; if args.len() < 5 { return Err(PickleError::InvalidData(format!( "rebuild_tensor_v2: expected at least 5 args, got {}", args.len() ))); } let (storage_info, storage_tuple) = parse_storage_arg(&args[0], "rebuild_tensor_v2")?; let storage_offset = match &args[1] { Object::Int(offset) => *offset as usize, _ => 0, }; let shape = parse_shape_arg(&args[2], "rebuild_tensor_v2")?; // args[3] is stride (unused) // args[4] is requires_grad (unused) // args[5] is backward_hooks (unused) rebuild_tensor_impl( storage_info, storage_tuple, storage_offset, shape, data_source, ) } /// Helper to convert storage type name to DType. fn storage_type_to_dtype(storage_type: &str) -> Result { match storage_type { "FloatStorage" => Ok(DType::F32), "DoubleStorage" => Ok(DType::F64), "HalfStorage" => Ok(DType::F16), "BFloat16Storage" => Ok(DType::BF16), "LongStorage" => Ok(DType::I64), "IntStorage" => Ok(DType::I32), "ShortStorage" => Ok(DType::I16), "CharStorage" => Ok(DType::I8), "ByteStorage" => Ok(DType::U8), "BoolStorage" => Ok(DType::Bool(BoolStore::Native)), _ => Err(PickleError::InvalidData(format!( "Unknown storage type: {}", storage_type ))), } } /// Core implementation for rebuilding tensors. /// Shared by both rebuild_tensor (legacy) and rebuild_tensor_v2 (modern). fn rebuild_tensor_impl( storage_info: Vec, storage_tuple: Option>, storage_offset: usize, shape: Vec, data_source: &Option>, ) -> Result { // Parse the storage info to extract dtype and storage key // The persistent ID is typically a tuple like: ('storage', 'FloatStorage', '0', 'cpu', 4) let (dtype, storage_key) = if let Some(tuple) = storage_tuple { // Direct tuple access if tuple.len() >= 3 { let storage_type = match &tuple[1] { Object::String(s) => s.as_str(), Object::Class { module_name: _, name, } => name.as_str(), other => { return Err(PickleError::InvalidData(format!( "Expected storage type as String or Class, got {:?}", other ))); } }; let dtype = storage_type_to_dtype(storage_type)?; let key = match &tuple[2] { Object::String(s) => s.clone(), other => { return Err(PickleError::InvalidData(format!( "Expected storage key as String, got {:?}", other ))); } }; (dtype, key) } else { return Err(PickleError::InvalidData(format!( "Storage tuple too short, expected at least 3 elements, got {}", tuple.len() ))); } } else if !storage_info.is_empty() { // Legacy string-based parsing let storage_str = String::from_utf8_lossy(&storage_info); if storage_str.starts_with("Tuple(") { // Parse from the debug representation we stored let parts: Vec<&str> = storage_str .trim_start_matches("Tuple(") .trim_end_matches(")") .split(", ") .map(|s| { let trimmed = s.trim_matches('"'); if let Some(inner) = trimmed .strip_prefix("Object::String(\"") .and_then(|s| s.strip_suffix("\")")) { inner } else { trimmed } }) .collect(); if parts.len() >= 3 { let dtype = storage_type_to_dtype(parts[1])?; (dtype, parts[2].to_string()) } else { return Err(PickleError::InvalidData(format!( "Storage info tuple too short, expected at least 3 parts, got {}", parts.len() ))); } } else { return Err(PickleError::InvalidData(format!( "Invalid storage info format: {}", storage_str ))); } } else { return Err(PickleError::InvalidData( "No storage information available".to_string(), )); }; // If no data source, we can't load tensor data let data_source = match data_source { Some(ds) => ds.clone(), None => { return Err(PickleError::InvalidData( "Cannot load tensor data without a data source".to_string(), )); } }; // Create clones for the closure let data_source_clone = data_source.clone(); let shape_clone = shape.clone(); // Find the correct data file key let data_file_key = { let exact_key = format!("data/{}", storage_key); if data_source.contains(&exact_key) { exact_key } else { // Try other patterns data_source .keys() .into_iter() .find(|key| { key.ends_with(&format!("/data/{}", storage_key)) || (key.contains("/data/") && key.rsplit('/').next() == Some(&storage_key)) }) .unwrap_or_else(|| format!("data/{}", storage_key)) } }; // Track storage usage IMMEDIATELY for lazy boundary detection // This must happen BEFORE creating the closure, not inside it! if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); let num_elements: usize = shape.iter().product(); let bytes_needed = storage_offset * dtype.size() + num_elements * dtype.size(); source.track_storage_usage(&storage_key, 0, bytes_needed); } // Create a TensorSnapshot with a closure that loads the actual data on-demand Ok(Object::TorchParam(TensorSnapshot::from_closure( Rc::new(move || { // Load data only when needed if let Ok(data) = data_source_clone.read(&data_file_key) { // Parse the binary data based on dtype let num_elements = shape_clone.iter().product::().max(1); // Use dtype.size() to get element size in bytes let element_size = dtype.size(); // Apply storage offset let offset_bytes = storage_offset * element_size; if offset_bytes >= data.len() { return Ok(TensorData::new( vec![0.0f32; num_elements], shape_clone.clone(), )); } let data_slice = &data[offset_bytes..]; let available_elements = data_slice.len() / element_size; let elements_to_read = num_elements.min(available_elements); // Convert bytes to the appropriate type match dtype { DType::F32 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let bytes = [ data_slice[i * element_size], data_slice[i * element_size + 1], data_slice[i * element_size + 2], data_slice[i * element_size + 3], ]; values.push(f32::from_le_bytes(bytes)); } // Pad with zeros if needed values.resize(num_elements, 0.0); Ok(TensorData::new(values, shape_clone.clone())) } DType::F64 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 8]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(f64::from_le_bytes(bytes)); } values.resize(num_elements, 0.0); Ok(TensorData::new(values, shape_clone.clone())) } DType::I64 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 8]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(i64::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::I32 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 4]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(i32::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::I16 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 2]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(i16::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::I8 => { let mut values = Vec::with_capacity(num_elements); for &byte in data_slice.iter().take(elements_to_read) { values.push(byte as i8); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::Bool(BoolStore::Native) => { let mut values = Vec::with_capacity(num_elements); for &byte in data_slice.iter().take(elements_to_read) { values.push(byte != 0); } values.resize(num_elements, false); Ok(TensorData::new(values, shape_clone.clone())) } DType::F16 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 2]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(f16::from_le_bytes(bytes)); } values.resize(num_elements, f16::ZERO); Ok(TensorData::new(values, shape_clone.clone())) } DType::BF16 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 2]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(bf16::from_le_bytes(bytes)); } values.resize(num_elements, bf16::ZERO); Ok(TensorData::new(values, shape_clone.clone())) } DType::U8 => { let mut values = Vec::with_capacity(num_elements); for &byte in data_slice.iter().take(elements_to_read) { values.push(byte); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::U16 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 2]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(u16::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::U32 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 4]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(u32::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } DType::U64 => { let mut values = Vec::with_capacity(num_elements); for i in 0..elements_to_read { let mut bytes = [0u8; 8]; bytes.copy_from_slice( &data_slice[i * element_size..(i + 1) * element_size], ); values.push(u64::from_le_bytes(bytes)); } values.resize(num_elements, 0); Ok(TensorData::new(values, shape_clone.clone())) } _ => { // For any remaining unsupported types, return an error Err(crate::TensorSnapshotError::DataError(format!( "Unsupported dtype for tensor data reading: {:?}", dtype ))) } } } else { // If no data file found, return zeros of the appropriate type let num_elements = shape_clone.iter().product::().max(1); match dtype { DType::F32 => Ok(TensorData::new( vec![0.0f32; num_elements], shape_clone.clone(), )), DType::F64 => Ok(TensorData::new( vec![0.0f64; num_elements], shape_clone.clone(), )), DType::F16 => Ok(TensorData::new( vec![f16::ZERO; num_elements], shape_clone.clone(), )), DType::BF16 => Ok(TensorData::new( vec![bf16::ZERO; num_elements], shape_clone.clone(), )), DType::I64 => Ok(TensorData::new( vec![0i64; num_elements], shape_clone.clone(), )), DType::I32 => Ok(TensorData::new( vec![0i32; num_elements], shape_clone.clone(), )), DType::I16 => Ok(TensorData::new( vec![0i16; num_elements], shape_clone.clone(), )), DType::I8 => Ok(TensorData::new( vec![0i8; num_elements], shape_clone.clone(), )), DType::U8 => Ok(TensorData::new( vec![0u8; num_elements], shape_clone.clone(), )), DType::U16 => Ok(TensorData::new( vec![0u16; num_elements], shape_clone.clone(), )), DType::U32 => Ok(TensorData::new( vec![0u32; num_elements], shape_clone.clone(), )), DType::U64 => Ok(TensorData::new( vec![0u64; num_elements], shape_clone.clone(), )), DType::Bool(BoolStore::Native) => Ok(TensorData::new( vec![false; num_elements], shape_clone.clone(), )), _ => { // For any remaining unsupported types, return an error Err(crate::TensorSnapshotError::DataError(format!( "Unsupported dtype for tensor data reading: {:?}", dtype ))) } } } }), dtype, shape.into(), vec![], // path_stack vec![], // container_stack ParamId::new(), // tensor_id ))) } pub struct Stack { stack: Vec, memo: HashMap, data_source: Option>, } impl Default for Stack { fn default() -> Self { Self::new() } } impl Stack { pub fn new() -> Self { // For cases where no data source is needed (pure pickle without tensor data) Self { stack: Vec::new(), memo: HashMap::new(), data_source: None, } } pub fn with_data_source(data_source: Arc) -> Self { Self { stack: Vec::new(), memo: HashMap::new(), data_source: Some(data_source), } } fn push(&mut self, o: Object) { self.stack.push(o) } fn pop(&mut self) -> Result { match self.stack.pop() { None => Err(PickleError::StackUnderflow), Some(o) => Ok(o), } } fn top(&self) -> Result { match self.stack.last() { None => Err(PickleError::StackUnderflow), Some(o) => Ok(o.clone()), } } fn pop_to_marker(&mut self) -> Result> { let marker_pos = self .stack .iter() .rposition(|o| { matches!(o, Object::Class { module_name, name } if module_name == "mark" && name == "mark") }) .ok_or(PickleError::InvalidData("marker not found".to_string()))?; let result = self.stack.split_off(marker_pos + 1); self.stack.pop(); // Remove the marker Ok(result) } fn last_mut(&mut self) -> Result<&mut Object> { match self.stack.last_mut() { None => Err(PickleError::StackUnderflow), Some(o) => Ok(o), } } fn push_mark(&mut self) { self.stack.push(Object::Class { module_name: "mark".to_string(), name: "mark".to_string(), }); } fn memo_get(&self, idx: u32) -> Result { self.memo .get(&idx) .cloned() .ok_or(PickleError::MemoNotFound(idx)) } fn memo_put(&mut self, idx: u32, obj: Object) { self.memo.insert(idx, obj); } fn memo_len(&self) -> usize { self.memo.len() } } fn read_global(r: &mut R, stack: &mut Stack) -> Result<()> { let module_name = buf_to_str(&read_to_newline(r)?)?; let name = buf_to_str(&read_to_newline(r)?)?; stack.push(Object::Class { module_name, name }); Ok(()) } fn read_long1(r: &mut R, stack: &mut Stack) -> Result<()> { let len = r.read_u8()? as usize; let mut data = vec![0u8; len]; r.read_exact(&mut data)?; // Handle little-endian signed integer let mut value = 0i64; for (i, &byte) in data.iter().enumerate().take(8) { // Only process up to 8 bytes for i64, and use wrapping to avoid overflow value |= (byte as i64).wrapping_shl((i as u32) * 8); } // Handle sign extension for negative numbers if len < 8 && data.last().is_some_and(|&b| b & 0x80 != 0) { // Sign extend for i in len..8 { value |= 0xffi64.wrapping_shl((i as u32) * 8); } } stack.push(Object::Int(value)); Ok(()) } fn read_string(r: &mut R, stack: &mut Stack, len: usize) -> Result<()> { let mut data = vec![0u8; len]; r.read_exact(&mut data)?; let s = buf_to_str(&data)?; stack.push(Object::String(s)); Ok(()) } fn read_bin_int(r: &mut R, stack: &mut Stack) -> Result<()> { let v = r.read_i32::()?; stack.push(Object::Int(v as i64)); Ok(()) } fn read_int(r: &mut R, stack: &mut Stack) -> Result<()> { // INT opcode reads an integer as ASCII string followed by newline let line = read_to_newline(r)?; let s = buf_to_str(&line)?; let v = s .parse::() .map_err(|e| PickleError::InvalidData(format!("Invalid INT value '{}': {}", s, e)))?; stack.push(Object::Int(v)); Ok(()) } fn read_bin_int1(r: &mut R, stack: &mut Stack) -> Result<()> { let v = r.read_u8()?; stack.push(Object::Int(v as i64)); Ok(()) } fn read_bin_int2(r: &mut R, stack: &mut Stack) -> Result<()> { let v = r.read_u16::()?; stack.push(Object::Int(v as i64)); Ok(()) } fn read_bin_float(r: &mut R, stack: &mut Stack) -> Result<()> { // Python's BINFLOAT uses big-endian encoding let v = r.read_f64::()?; stack.push(Object::Float(v)); Ok(()) } pub fn read_pickle(r: &mut R) -> Result { // For pure pickle without tensor data, no data source is needed read_pickle_with_optional_data(r, None) } /// Skip over a pickle without parsing it fully /// This is useful for legacy format where we need to skip the main object /// that contains tensors but we don't have a data source yet pub fn skip_pickle(r: &mut R) -> Result<()> { // Read the protocol marker if present let mut first_byte = [0u8; 1]; r.read_exact(&mut first_byte)?; if first_byte[0] == 0x80 { // PROTO marker - read protocol version let mut proto_version = [0u8; 1]; r.read_exact(&mut proto_version)?; } // If not PROTO, the first byte is an opcode - continue to main loop // Helper to skip until newline fn skip_line(r: &mut R) -> Result<()> { let mut buf = Vec::new(); r.read_until(b'\n', &mut buf)?; Ok(()) } // Helper to skip length-prefixed data fn skip_length_prefixed(r: &mut R, length: usize) -> Result<()> { let mut skip_buf = vec![0u8; length.min(8192)]; let mut skipped = 0; while skipped < length { let to_skip = (length - skipped).min(skip_buf.len()); r.read_exact(&mut skip_buf[..to_skip])?; skipped += to_skip; } Ok(()) } // Process first byte if it wasn't PROTO let mut pending_byte = if first_byte[0] != 0x80 { Some(first_byte[0]) } else { None }; // Scan until we find STOP (0x2e) opcode loop { let byte = if let Some(b) = pending_byte.take() { b } else { let mut byte = [0u8; 1]; r.read_exact(&mut byte)?; byte[0] }; match byte { 0x2e => { // STOP - end of pickle break; } // === Newline-terminated string opcodes === 0x63 => { // GLOBAL - two newline-terminated strings (module\nname\n) skip_line(r)?; skip_line(r)?; } 0x69 => { // INST - two newline-terminated strings skip_line(r)?; skip_line(r)?; } 0x53 => { // STRING - quoted string ending with newline skip_line(r)?; } 0x46 | 0x49 | 0x4c => { // FLOAT, INT, LONG - newline-terminated ASCII skip_line(r)?; } 0x50 => { // PERSID - newline-terminated persistent ID skip_line(r)?; } // === Length-prefixed binary opcodes === 0x58 | 0x42 | 0x43 | 0x54 | 0x55 | 0x56 | 0x8c | 0x8d | 0x8e => { // String/bytes opcodes with length prefixes let length = match byte { 0x43 | 0x55 | 0x8c => { // SHORT versions - 1 byte length let mut len_byte = [0u8; 1]; r.read_exact(&mut len_byte)?; len_byte[0] as usize } 0x42 | 0x54 | 0x58 | 0x56 => { // Regular versions - 4 byte length let mut len_bytes = [0u8; 4]; r.read_exact(&mut len_bytes)?; u32::from_le_bytes(len_bytes) as usize } 0x8d | 0x8e => { // 8-byte length versions let mut len_bytes = [0u8; 8]; r.read_exact(&mut len_bytes)?; u64::from_le_bytes(len_bytes) as usize } _ => 0, }; skip_length_prefixed(r, length)?; } // === Fixed-size integer opcodes === 0x4b => { // BININT1 - 1 byte let mut buf = [0u8; 1]; r.read_exact(&mut buf)?; } 0x4d => { // BININT2 - 2 bytes let mut buf = [0u8; 2]; r.read_exact(&mut buf)?; } 0x4a => { // BININT - 4 bytes (signed int) let mut buf = [0u8; 4]; r.read_exact(&mut buf)?; } 0x47 => { // BINFLOAT - 8 bytes let mut buf = [0u8; 8]; r.read_exact(&mut buf)?; } // === Variable-length integer opcodes === 0x8a => { // LONG1 - 1 byte length, then that many bytes let mut len_byte = [0u8; 1]; r.read_exact(&mut len_byte)?; let length = len_byte[0] as usize; skip_length_prefixed(r, length)?; } 0x8b => { // LONG4 - 4 byte length, then that many bytes let mut len_bytes = [0u8; 4]; r.read_exact(&mut len_bytes)?; let length = u32::from_le_bytes(len_bytes) as usize; skip_length_prefixed(r, length)?; } // === Memo opcodes === 0x71 | 0x68 => { // BINPUT, BINGET - 1 byte index let mut buf = [0u8; 1]; r.read_exact(&mut buf)?; } 0x72 | 0x6a => { // LONG_BINPUT, LONG_BINGET - 4 byte index let mut buf = [0u8; 4]; r.read_exact(&mut buf)?; } 0x67 | 0x70 => { // GET, PUT - newline-terminated decimal index skip_line(r)?; } // === Extension opcodes === 0x82 => { // EXT1 - 1 byte code let mut buf = [0u8; 1]; r.read_exact(&mut buf)?; } 0x83 => { // EXT2 - 2 byte code let mut buf = [0u8; 2]; r.read_exact(&mut buf)?; } 0x84 => { // EXT4 - 4 byte code let mut buf = [0u8; 4]; r.read_exact(&mut buf)?; } // === Frame opcode (protocol 4+) === 0x95 => { // FRAME - 8 byte frame size (we don't actually use framing, just skip the size) let mut buf = [0u8; 8]; r.read_exact(&mut buf)?; } // === Opcodes with no additional data === // These just manipulate the stack or are markers 0x28 | 0x29 | 0x30 | 0x31 | 0x32 | // MARK, TUPLE, POP, POP_MARK, DUP 0x4e | 0x52 | 0x5d | 0x5b | 0x7d | // NONE, REDUCE, LIST, EMPTY_LIST, EMPTY_DICT 0x61 | 0x62 | 0x64 | 0x65 | 0x73 | // APPEND, BUILD, DICT, APPENDS, SETITEM 0x74 | 0x75 | 0x85 | 0x86 | 0x87 | // TUPLE, SETITEMS, TUPLE1, TUPLE2, TUPLE3 0x88 | 0x89 | 0x8f | 0x90 | 0x91 | // NEWTRUE, NEWFALSE, STACK_GLOBAL, MEMOIZE, EMPTY_SET 0x92 | 0x93 | 0x94 | 0x51 | 0x81 => { // ADDITEMS, FROZENSET, NEWOBJ, BINPERSID, NEWOBJ_EX // No additional data to skip } _ => { // Unknown opcode - assume no additional data // This is a best-effort approach } } } Ok(()) } pub fn read_pickle_with_data( r: &mut R, data_source: Arc, ) -> Result { read_pickle_with_optional_data(r, Some(data_source)) } fn get_dict_key(obj: Object) -> Result { match obj { Object::String(s) => Ok(s), Object::Int(i) => Ok(i.to_string()), _ => Err(PickleError::InvalidData(format!( "dict key must be a valid type, got {obj:?}" ))), } } pub fn read_pickle_with_optional_data( r: &mut R, data_source: Option>, ) -> Result { let mut stack = match data_source { Some(ds) => Stack::with_data_source(ds), None => Stack::new(), }; loop { let op_code = r.read_u8()?; let op_code = OpCode::try_from(op_code).map_err(PickleError::InvalidOpCode)?; match op_code { OpCode::Proto => { let version = r.read_u8()?; if version > 5 { return Err(PickleError::InvalidProtocol(version)); } } OpCode::Global => read_global(r, &mut stack)?, OpCode::BinInt => read_bin_int(r, &mut stack)?, OpCode::Int => read_int(r, &mut stack)?, OpCode::BinInt1 => read_bin_int1(r, &mut stack)?, OpCode::BinInt2 => read_bin_int2(r, &mut stack)?, OpCode::BinFloat => read_bin_float(r, &mut stack)?, OpCode::BinUnicode => { let len = r.read_u32::()? as usize; read_string(r, &mut stack, len)? } OpCode::ShortBinString => { let len = r.read_u8()? as usize; read_string(r, &mut stack, len)? } OpCode::Long1 => read_long1(r, &mut stack)?, OpCode::None => stack.push(Object::None), OpCode::NewTrue => stack.push(Object::Bool(true)), OpCode::NewFalse => stack.push(Object::Bool(false)), OpCode::EmptyTuple => stack.push(Object::Tuple(Vec::new())), OpCode::EmptyList => stack.push(Object::List(Vec::new())), OpCode::EmptyDict => stack.push(Object::Dict(HashMap::new())), OpCode::Tuple => { let objs = stack.pop_to_marker()?; stack.push(Object::Tuple(objs)) } OpCode::Tuple1 => { let obj = stack.pop()?; stack.push(Object::Tuple(vec![obj])) } OpCode::Tuple2 => { let obj2 = stack.pop()?; let obj1 = stack.pop()?; stack.push(Object::Tuple(vec![obj1, obj2])) } OpCode::Tuple3 => { let obj3 = stack.pop()?; let obj2 = stack.pop()?; let obj1 = stack.pop()?; stack.push(Object::Tuple(vec![obj1, obj2, obj3])) } OpCode::Append => { let value = stack.pop()?; match stack.last_mut()? { Object::List(list) => list.push(value), _ => return Err(PickleError::UnexpectedOpCode(op_code)), } } OpCode::Appends => { let objs = stack.pop_to_marker()?; match stack.last_mut()? { Object::List(list) => list.extend(objs), _ => return Err(PickleError::UnexpectedOpCode(op_code)), } } OpCode::SetItem => { let value = stack.pop()?; let key = stack.pop()?; match stack.last_mut()? { Object::Dict(dict) => { if let Object::String(key) = key { dict.insert(key, value); } else { return Err(PickleError::InvalidData( "dict key must be a string".to_string(), )); } } _ => return Err(PickleError::UnexpectedOpCode(op_code)), } } OpCode::SetItems => { let mut objs = stack.pop_to_marker()?; if objs.len() % 2 != 0 { return Err(PickleError::InvalidData( "setitems requires even number of objects".to_string(), )); } match stack.last_mut()? { Object::Dict(dict) => { while !objs.is_empty() { let key = objs.remove(0); let value = objs.remove(0); let key = get_dict_key(key)?; dict.insert(key, value); } } _ => return Err(PickleError::UnexpectedOpCode(op_code)), } } OpCode::BinPut => { let idx = r.read_u8()? as u32; let obj = stack.top()?; stack.memo_put(idx, obj); } OpCode::LongBinPut => { let idx = r.read_u32::()?; let obj = stack.top()?; stack.memo_put(idx, obj); } OpCode::BinGet => { let idx = r.read_u8()? as u32; let obj = stack.memo_get(idx)?; stack.push(obj); } OpCode::LongBinGet => { let idx = r.read_u32::()?; let obj = stack.memo_get(idx)?; stack.push(obj); } OpCode::Mark => stack.push_mark(), OpCode::BinPersId => { let pid = stack.pop()?; match pid { Object::String(s) => { stack.push(Object::Persistent(s.into_bytes())); } Object::Tuple(tuple) => { // The persistent ID is a tuple (e.g., ('storage', 'FloatStorage', '0', 'cpu', 4)) // Store it as a PersistentTuple for proper handling stack.push(Object::PersistentTuple(tuple)); } _ => { return Err(PickleError::InvalidData(format!( "persistent id must be a string or tuple, got {:?}", pid ))); } } } OpCode::Reduce => { let args = stack.pop()?; let callable = stack.pop()?; // Check if this is an OrderedDict if let Object::Class { module_name, name } = &callable { if module_name == "collections" && name == "OrderedDict" { // OrderedDict can be created with items: OrderedDict([(key1, val1), ...]) // The args is typically a tuple containing a list of [key, value] pairs let mut dict = HashMap::new(); // Extract items from args let items = match &args { Object::Tuple(tuple) if !tuple.is_empty() => { // Args is a tuple, get the first element (the list of items) match &tuple[0] { Object::List(list) => Some(list.clone()), _ => None, } } Object::List(list) => Some(list.clone()), _ => None, }; if let Some(items) = items { for item in items { // Each item is a list/tuple of [key, value] match item { Object::List(pair) | Object::Tuple(pair) if pair.len() >= 2 => { if let Object::String(key) = &pair[0] { dict.insert(key.clone(), pair[1].clone()); } } _ => {} } } } stack.push(Object::Dict(dict)); } else { let _obj = Object::Reduce { callable: Box::new(callable.clone()), args: Box::new(args.clone()), }; let obj = rebuild_from_type_v2( Object::Tuple(vec![callable, args]), &mut stack.memo, &stack.data_source, )?; stack.push(obj); } } else { let _obj = Object::Reduce { callable: Box::new(callable.clone()), args: Box::new(args.clone()), }; let obj = rebuild_from_type_v2( Object::Tuple(vec![callable, args]), &mut stack.memo, &stack.data_source, )?; stack.push(obj); } } OpCode::Build => { let args = stack.pop()?; let obj = stack.pop()?; match obj { Object::Dict(mut dict) => { // For dicts, BUILD updates with the args if let Object::Dict(update) = args { dict.extend(update); } stack.push(Object::Dict(dict)); } _ => { stack.push(Object::Build { callable: Box::new(obj), args: Box::new(args), }); } } } OpCode::NewObj => { let args = stack.pop()?; let cls = stack.pop()?; stack.push(Object::Reduce { callable: Box::new(cls), args: Box::new(args), }); } OpCode::Dict => { let objs = stack.pop_to_marker()?; let mut dict = HashMap::new(); if objs.len() % 2 != 0 { return Err(PickleError::InvalidData( "dict requires even number of objects".to_string(), )); } for chunk in objs.chunks(2) { let key = get_dict_key(chunk[0].clone())?; dict.insert(key, chunk[1].clone()); } stack.push(Object::Dict(dict)); } OpCode::List => { let objs = stack.pop_to_marker()?; stack.push(Object::List(objs)); } OpCode::Memoize => { // Store top of stack in memo without popping // The memo index is the current number of items in the memo let obj = stack.top()?; let idx = stack.memo_len() as u32; stack.memo_put(idx, obj); } OpCode::Stop => break, } } stack.pop() } /// Load tensors from a pickle file (PyTorch checkpoint format) pub fn read_pickle_tensors(reader: &mut R) -> Result> { let obj = read_pickle(reader)?; // Extract tensors from the loaded object let mut tensors = HashMap::new(); let mut path = Vec::new(); extract_tensors(&obj, &mut path, &mut tensors); Ok(tensors) } fn extract_tensors<'a>( obj: &'a Object, path: &mut Vec<&'a str>, tensors: &mut HashMap, ) { match obj { Object::Dict(dict) => { for (key, value) in dict { path.push(key); extract_tensors(value, path, tensors); path.pop(); } } Object::TorchParam(snapshot) => { // Only allocate the string here when we actually insert tensors.insert(path.join("."), snapshot.clone()); } _ => {} } } ================================================ FILE: crates/burn-store/src/pytorch/reader.rs ================================================ //! PyTorch file reader implementation. //! //! This module provides support for reading PyTorch checkpoint files (.pt/.pth). //! //! # Supported Formats //! //! ## 1. Modern ZIP Format (PyTorch 1.6+) //! Files are ZIP archives containing: //! - `data.pkl` or `archive/data.pkl`: Pickled tensor metadata //! - `data/` directory: Binary tensor data files //! //! ## 2. TAR Format (older torchvision models like AlexNet, SqueezeNet) //! TAR archives containing: //! - `sys_info`: System info pickle (endianness, type sizes) //! - `pickle`: OrderedDict mapping tensor names to storage keys //! - `tensors`: Tensor metadata (unused, metadata is in pickle) //! - `storages`: Count pickle + sequential (metadata, num_elements, raw data) //! //! ## 3. Legacy Pickle Format (PyTorch 0.1.10 - 1.5) //! Sequential pickle streams with the structure: //! - Magic number pickle (0x1950a86a20f9469cfc6c) //! - Protocol version pickle (e.g., 1001) //! - System info pickle (endianness, type sizes) //! - Model data pickle (state_dict or full model) //! //! ## 4. Simple Pickle Format //! Direct pickle file with a dictionary at the root, commonly used for //! manually saved state_dicts. //! //! # Compatibility //! //! The reader handles backward compatibility by detecting the file format //! automatically. Files from PyTorch 0.1.10 through current versions are //! supported, though full model saves (vs state_dict) may have limitations //! as they contain Python code references. use crate::TensorSnapshot; use alloc::string::{String, ToString}; use alloc::vec::Vec; use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer}; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom}; use std::path::Path; use super::lazy_data::LazyDataSource; use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data}; use std::sync::Arc; /// Error type for PyTorch file operations #[derive(Debug)] pub enum PytorchError { /// IO error Io(std::io::Error), /// Pickle parsing error Pickle(PickleError), /// Zip archive error Zip(zip::result::ZipError), /// TAR archive error Tar(std::io::Error), /// Invalid file format InvalidFormat(String), /// Key not found KeyNotFound(String), /// Serde deserialization error Serde(burn_core::record::serde::error::Error), } impl From for PytorchError { fn from(e: std::io::Error) -> Self { PytorchError::Io(e) } } impl From for PytorchError { fn from(e: PickleError) -> Self { PytorchError::Pickle(e) } } impl From for PytorchError { fn from(e: zip::result::ZipError) -> Self { PytorchError::Zip(e) } } impl From for PytorchError { fn from(e: burn_core::record::serde::error::Error) -> Self { PytorchError::Serde(e) } } impl std::fmt::Display for PytorchError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { PytorchError::Io(e) => write!(f, "IO error: {}", e), PytorchError::Pickle(e) => write!( f, "Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.", e ), PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e), PytorchError::Tar(e) => write!(f, "TAR archive error: {}", e), PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg), PytorchError::KeyNotFound(key) => write!( f, "Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.", key ), PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e), } } } impl std::error::Error for PytorchError {} type Result = std::result::Result; /// Metadata about a PyTorch file /// /// Contains information about the file format, version, and other properties /// that can be useful for debugging or compatibility checking. #[derive(Debug, Clone)] pub struct PytorchMetadata { /// Format version (e.g., "1.0" for modern ZIP format) pub format_version: Option, /// File format type (ZIP, Legacy, or Pickle) pub format_type: FileFormat, /// Byte order (endianness) - currently only LittleEndian is supported pub byte_order: ByteOrder, /// Whether the file has storage alignment information pub has_storage_alignment: bool, /// PyTorch version that saved the file (if available) pub pytorch_version: Option, /// Number of tensors in the file pub tensor_count: usize, /// Total size of tensor data in bytes (if available) pub total_data_size: Option, } impl PytorchMetadata { /// Check if this is a modern format file (ZIP-based, PyTorch 1.6+) pub fn is_modern_format(&self) -> bool { matches!(self.format_type, FileFormat::Zip) } /// Check if this is a legacy format file (PyTorch 0.1.10 - 1.5) pub fn is_legacy_format(&self) -> bool { matches!(self.format_type, FileFormat::Legacy) } } /// File format type #[derive(Debug, Clone, PartialEq)] pub enum FileFormat { /// ZIP-based format (PyTorch 1.6+) Zip, /// TAR-based format (older torchvision models) Tar, /// Legacy format (PyTorch 0.1.10 - 1.5) Legacy, /// Simple pickle file Pickle, } /// Byte order (endianness) #[derive(Debug, Clone, PartialEq)] pub enum ByteOrder { LittleEndian, BigEndian, } /// PyTorch checkpoint reader /// /// This is the main interface for reading PyTorch checkpoint files (.pt/.pth). /// It supports multiple PyTorch formats including modern ZIP-based format (1.6+), /// legacy format (0.1.10-1.5), and simple pickle files. /// /// # Example /// ```rust,no_run /// # use burn_store::pytorch::PytorchReader; /// # fn example() -> Result<(), Box> { /// // Load a checkpoint file /// let reader = PytorchReader::new("model.pt")?; /// /// // Get tensor names /// let keys = reader.keys(); /// /// // Access a specific tensor /// if let Some(tensor) = reader.get("conv1.weight") { /// let data = tensor.to_data(); // Materializes the tensor /// } /// /// // Check file metadata /// println!("Format: {:?}", reader.metadata().format_type); /// println!("Tensor count: {}", reader.metadata().tensor_count); /// # Ok(()) /// # } /// ``` pub struct PytorchReader { tensors: HashMap, metadata: PytorchMetadata, } impl PytorchReader { /// Load a PyTorch checkpoint file /// /// # Arguments /// * `path` - Path to the PyTorch file (.pt or .pth) /// /// # Returns /// A `PytorchReader` with lazy-loaded tensors and metadata pub fn new>(path: P) -> Result { let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?; Ok(Self { tensors, metadata }) } /// Load a PyTorch checkpoint with a specific top-level key /// /// Many PyTorch checkpoints store the model weights under a specific key /// like "state_dict", "model", or "model_state_dict". /// /// # Arguments /// * `path` - Path to the PyTorch file /// * `key` - Top-level key to extract (e.g., "state_dict") /// /// # Example /// ```rust,no_run /// # use burn_store::pytorch::PytorchReader; /// # fn example() -> Result<(), Box> { /// let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?; /// # Ok(()) /// # } /// ``` pub fn with_top_level_key>(path: P, key: &str) -> Result { let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?; Ok(Self { tensors, metadata }) } /// Load from a reader /// /// This method is useful when loading from non-file sources like memory buffers. /// Note: Metadata detection is limited when loading from a reader. /// /// # Arguments /// * `reader` - Any type implementing `Read` /// * `top_level_key` - Optional key to extract pub fn from_reader(reader: R, top_level_key: Option<&str>) -> Result { // For reader-based loading, we don't have full metadata access let tensors = load_from_reader(reader, top_level_key)?; let metadata = PytorchMetadata { format_version: None, format_type: FileFormat::Pickle, // Default assumption byte_order: ByteOrder::LittleEndian, has_storage_alignment: false, pytorch_version: None, tensor_count: tensors.len(), total_data_size: None, }; Ok(Self { tensors, metadata }) } /// Get all tensor names pub fn keys(&self) -> Vec { self.tensors.keys().cloned().collect() } /// Get a tensor by name pub fn get(&self, name: &str) -> Option<&TensorSnapshot> { self.tensors.get(name) } /// Get all tensors pub fn tensors(&self) -> &HashMap { &self.tensors } /// Take ownership of all tensors pub fn into_tensors(self) -> HashMap { self.tensors } /// Get metadata about the loaded file /// /// Provides information about the file format, version, endianness, etc. pub fn metadata(&self) -> &PytorchMetadata { &self.metadata } /// Get the number of tensors in the file pub fn len(&self) -> usize { self.tensors.len() } /// Check if the file contains no tensors pub fn is_empty(&self) -> bool { self.tensors.is_empty() } /// Read raw pickle data from a PyTorch file /// /// This is useful for extracting configuration or metadata that isn't tensor data. /// Returns a simplified JSON-like structure that can be easily converted to other formats. /// /// # Arguments /// * `path` - Path to the PyTorch file /// * `top_level_key` - Optional key to extract from the top-level dictionary /// /// # Returns /// A `PickleValue` representing the pickle data structure pub fn read_pickle_data>( path: P, top_level_key: Option<&str>, ) -> Result { read_pickle_as_value(path.as_ref(), top_level_key) } /// Load and deserialize configuration data from a PyTorch file /// /// This method reads configuration or metadata stored in PyTorch checkpoint files /// and deserializes it into the specified type. It's particularly useful for /// extracting model configurations that might be saved alongside model weights. /// /// # Arguments /// * `path` - Path to the PyTorch file (.pt or .pth) /// * `top_level_key` - Optional key to extract specific data within the pickle file. /// If `None`, the entire content is deserialized. /// /// # Type Parameters /// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`. /// /// # Returns /// A `Result` containing the deserialized configuration data, or an `Error` if /// reading or deserialization fails. /// /// # Example /// ```rust,no_run /// # use burn_store::pytorch::PytorchReader; /// # use serde::Deserialize; /// # fn example() -> Result<(), Box> { /// #[derive(Debug, Deserialize)] /// struct ModelConfig { /// hidden_size: usize, /// num_layers: usize, /// } /// /// let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?; /// # Ok(()) /// # } /// ``` pub fn load_config(path: P, top_level_key: Option<&str>) -> Result where D: DeserializeOwned, P: AsRef, { // Read the PyTorch file and extract the pickle data let pickle_value = Self::read_pickle_data(path, top_level_key)?; // Convert PickleValue to NestedValue let nested_value = convert_pickle_to_nested_value(pickle_value)?; // Create a deserializer with the default adapter let deserializer = Deserializer::::new(nested_value, false); // Deserialize the nested value into the target type let value = D::deserialize(deserializer)?; Ok(value) } } /// Simplified representation of pickle data /// /// This enum provides a JSON-like structure that's easier to work with /// than the internal pickle Object type. #[derive(Debug, Clone, PartialEq)] pub enum PickleValue { /// None/null value None, /// Boolean value Bool(bool), /// Integer value Int(i64), /// Floating point value Float(f64), /// String value String(String), /// List/array of values List(Vec), /// Dictionary/map of string keys to values Dict(HashMap), /// Binary data Bytes(Vec), } /// Internal function to load a PyTorch file with metadata fn load_pytorch_file_with_metadata( path: &Path, top_level_key: Option<&str>, ) -> Result<(HashMap, PytorchMetadata)> { // First, try to read as a zip file if let Ok(file) = File::open(path) && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file)) { // PyTorch saves the main data in various locations within the zip let mut pickle_data = Vec::new(); let mut pickle_found = false; // Try different common pickle file locations let possible_pickle_paths = [ "data.pkl", "archive/data.pkl", // Look for any .pkl file in the root or first-level directories ]; for pickle_path in &possible_pickle_paths { if archive.by_name(pickle_path).is_ok() { let mut pickle_file = archive.by_name(pickle_path)?; pickle_file.read_to_end(&mut pickle_data)?; pickle_found = true; break; } } // If not found in standard locations, search for any .pkl file if !pickle_found { for i in 0..archive.len() { let file = archive.by_index(i)?; let name = file.name().to_string(); drop(file); // Release the borrow if name.ends_with("data.pkl") { let mut file = archive.by_index(i)?; file.read_to_end(&mut pickle_data)?; pickle_found = true; break; } } } if !pickle_found { return Err(PytorchError::InvalidFormat( "No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(), )); } // Check for format version (optional) let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") { let mut version_data = Vec::new(); version_file.read_to_end(&mut version_data)?; let version_str = String::from_utf8_lossy(&version_data); let version = version_str.trim().to_string(); Some(version) } else { None }; // Check for byteorder file to detect endianness let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") { let mut byteorder_data = Vec::new(); byteorder_file.read_to_end(&mut byteorder_data)?; let byteorder_str = String::from_utf8_lossy(&byteorder_data); byteorder_str.trim() == "big" } else { false // Default to little-endian if no byteorder file }; if is_big_endian { // Big-endian files are not yet supported as they require different byte order conversion // TODO: To support big-endian files, we need to: // 1. Pass endianness info through to pickle_reader // 2. Use from_be_bytes instead of from_le_bytes for tensor data // 3. Handle byte swapping for all numeric types (f32, f64, i32, etc.) return Err(PytorchError::InvalidFormat( "Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string() )); } // Check for storage alignment file let has_storage_alignment = archive.by_name(".storage_alignment").is_ok(); // Check for PyTorch version (if saved) let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") { let mut version_data = Vec::new(); version_file.read_to_end(&mut version_data)?; Some(String::from_utf8_lossy(&version_data).trim().to_string()) } else { None }; // Create a lazy data source instead of loading all data upfront let data_source = Arc::new(LazyDataSource::from_zip(path)?); // Calculate total data size without loading let mut total_data_size = 0usize; for i in 0..archive.len() { let file = archive.by_index(i)?; let name = file.name(); // Look for data files - they can be in various locations let is_data_file = (name.contains("/data/") || name.starts_with("data/") || name.starts_with("archive/data/")) && !name.ends_with(".pkl") && !name.ends_with("/"); if is_data_file { total_data_size += file.size() as usize; } } // Parse the pickle data with lazy data source let mut pickle_reader = BufReader::new(pickle_data.as_slice()); let obj = read_pickle_with_data(&mut pickle_reader, data_source)?; // Extract tensors with their data let tensors = extract_tensors_with_data(obj, top_level_key)?; // Create metadata let metadata = PytorchMetadata { format_version, format_type: FileFormat::Zip, byte_order: if is_big_endian { ByteOrder::BigEndian } else { ByteOrder::LittleEndian }, has_storage_alignment, pytorch_version, tensor_count: tensors.len(), total_data_size: Some(total_data_size), }; return Ok((tensors, metadata)); } // If not a zip or zip reading failed, try TAR format if is_tar_file(path) { return load_tar_pytorch_file_with_metadata(path, top_level_key); } // Try reading as a plain pickle file let mut file = File::open(path)?; // Check for PyTorch legacy format (starts with magic number as pickled integer) let mut header = [0u8; 15]; // Use read() instead of read_exact() to handle files smaller than 15 bytes let bytes_read = file.read(&mut header)?; file.seek(std::io::SeekFrom::Start(0))?; // Only check for legacy format if we have enough bytes // PyTorch legacy format detection (PyTorch 0.1.10 - 1.3) // Reference: https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L65 // // These files use sequential pickle streams with metadata before the actual data. // Format structure: // 1. Magic number (0x1950a86a20f9469cfc6c) stored as LONG1 pickle // 2. Protocol version (e.g., 1001) // 3. System info dict (protocol_version, little_endian, type_sizes) // 4. Actual model data (state_dict or full model) // 5. Storage keys list (pickle) // 6. Raw binary data for each storage // // The pattern is: 0x80 0x02 0x8a 0x0a (PROTO 2, LONG1 with 10 bytes) // followed by 10 bytes of magic number (little-endian), then 0x2e (STOP) let is_legacy_format = bytes_read >= 15 && header[0] == 0x80 // PROTO opcode && header[1] == 0x02 // Protocol version 2 && header[2] == 0x8a // LONG1 opcode && header[3] == 0x0a // 10 bytes follow // Magic number 0x1950a86a20f9469cfc6c in little-endian && header[4] == 0x6c && header[5] == 0xfc && header[6] == 0x9c && header[7] == 0x46 && header[8] == 0xf9 && header[9] == 0x20 && header[10] == 0x6a && header[11] == 0xa8 && header[12] == 0x50 && header[13] == 0x19 && header[14] == 0x2e; // STOP opcode if is_legacy_format { return load_legacy_pytorch_file_with_metadata(path, top_level_key); } // Standard pickle file // This might be a pickle with tensor references, so we need to handle that case // For plain pickle files without a separate data section, we can't use lazy loading // so we'll just create empty placeholder tensors for the structure let file = File::open(path)?; let mut reader = BufReader::new(file); // Try reading without data source first match read_pickle(&mut reader) { Ok(obj) => { let tensors = extract_tensors_with_data(obj, top_level_key)?; let tensor_count = tensors.len(); Ok(( tensors, PytorchMetadata { format_version: None, format_type: FileFormat::Pickle, byte_order: ByteOrder::LittleEndian, has_storage_alignment: false, pytorch_version: None, tensor_count, total_data_size: None, }, )) } Err(e) if e.to_string() .contains("Cannot load tensor data without a data source") => { // This pickle file contains tensor data but we're trying to read it without // providing a data source. This shouldn't happen in normal usage as PyTorch // files with actual tensor data should be in ZIP or legacy format. Err(PytorchError::InvalidFormat( "Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string() )) } Err(e) => Err(PytorchError::Pickle(e)), } } /// Load from a reader fn load_from_reader( reader: R, top_level_key: Option<&str>, ) -> Result> { let mut buf_reader = BufReader::new(reader); // Try reading without data source match read_pickle(&mut buf_reader) { Ok(obj) => extract_tensors_with_data(obj, top_level_key), Err(e) if e.to_string() .contains("Cannot load tensor data without a data source") => { // This reader contains tensor data but we can't load it without a file path Err(PytorchError::InvalidFormat( "Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string() )) } Err(e) => Err(PytorchError::Pickle(e)), } } /// Extract tensors from a parsed pickle object fn extract_tensors_with_data( obj: Object, top_level_key: Option<&str>, ) -> Result> { let dict = match obj { Object::Dict(dict) => { if let Some(key) = top_level_key { // Extract the nested dictionary if a top-level key is specified match dict.get(key) { Some(Object::Dict(nested)) => nested.clone(), _ => { return Err(PytorchError::KeyNotFound(format!( "Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}", key, dict.keys().collect::>() ))); } } } else { dict } } _ => { return Err(PytorchError::InvalidFormat( "Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(), )); } }; let mut tensors = HashMap::new(); let mut path = Vec::new(); extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors); Ok(tensors) } /// Recursively extract tensors from an object fn extract_tensors_recursive<'a>( obj: &'a Object, path: &mut Vec<&'a str>, tensors: &mut HashMap, ) { match obj { Object::Dict(dict) => { for (key, value) in dict { path.push(key); extract_tensors_recursive(value, path, tensors); path.pop(); } } Object::TorchParam(snapshot) => { // The TensorSnapshot already contains the data loading closure // Only allocate the string here when we actually insert tensors.insert(path.join("."), snapshot.clone()); } _ => {} } } /// Load a legacy PyTorch file with metadata fn load_legacy_pytorch_file_with_metadata( path: &Path, top_level_key: Option<&str>, ) -> Result<(HashMap, PytorchMetadata)> { let file = File::open(path)?; let mut reader = BufReader::new(file); // Skip metadata pickles // 1. Magic number let _ = read_pickle(&mut reader).map_err(|e| { PytorchError::InvalidFormat(format!( "Failed to read magic number from legacy format: {}", e )) })?; // 2. Protocol version let _ = read_pickle(&mut reader).map_err(|e| { PytorchError::InvalidFormat(format!( "Failed to read protocol version from legacy format: {}", e )) })?; // 3. System info let _ = read_pickle(&mut reader).map_err(|e| { PytorchError::InvalidFormat(format!( "Failed to read system info from legacy format: {}", e )) })?; // Save position before main pickle let main_pickle_pos = reader.stream_position()?; // 4. Skip main object - it might contain tensors so we can't parse it yet // We'll re-read it with a data source later use crate::pytorch::pickle_reader::skip_pickle; skip_pickle(&mut reader).map_err(|e| { PytorchError::InvalidFormat(format!( "Failed to skip main object in legacy format: {}", e )) })?; // 5. Storage keys list (sorted keys as written by PyTorch) let storage_keys = match read_pickle(&mut reader) { Ok(Object::List(keys)) => keys .into_iter() .filter_map(|obj| match obj { Object::String(s) => Some(s), _ => None, }) .collect::>(), _ => vec![], }; // 6. Skip 8-byte header before raw binary data // PyTorch legacy format has an 8-byte header (possibly protocol version or alignment) // between the storage keys list and the actual tensor data let mut header = [0u8; 8]; if reader.read(&mut header).is_ok() { // Header read successfully, data starts after this } // 7. Raw binary data starts here let data_start_pos = reader.stream_position()?; let file_size = reader.seek(SeekFrom::End(0))?; let data_size = file_size - data_start_pos; // Create a lazy data source for legacy multi-storage format let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage( path, data_start_pos, data_size, )); // Set storage keys BEFORE parsing the main pickle // This is critical because track_storage_usage() is called during parsing // and it needs storage_keys to build the storage map if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source && !storage_keys.is_empty() { let source = source .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); source.set_storage_keys(storage_keys.clone()); } // Now re-read the main pickle with lazy data source reader.seek(SeekFrom::Start(main_pickle_pos))?; let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?; // Extract tensors normally let tensors = extract_tensors_with_data(main_obj, top_level_key)?; // Create metadata for legacy format let metadata = PytorchMetadata { format_version: None, // Legacy format doesn't have version files format_type: FileFormat::Legacy, byte_order: ByteOrder::LittleEndian, // Legacy format is little-endian has_storage_alignment: false, pytorch_version: None, // Could parse from protocol version, but not reliable tensor_count: tensors.len(), total_data_size: Some(data_size as usize), }; Ok((tensors, metadata)) } /// Check if a file is a TAR archive fn is_tar_file(path: &Path) -> bool { if let Ok(mut file) = File::open(path) { // TAR files have "ustar" magic at offset 257 let mut header = [0u8; 263]; if file.read_exact(&mut header).is_ok() { // Check for "ustar" magic at offset 257 return &header[257..262] == b"ustar"; } } false } /// Load a TAR format PyTorch file with metadata fn load_tar_pytorch_file_with_metadata( path: &Path, top_level_key: Option<&str>, ) -> Result<(HashMap, PytorchMetadata)> { use tar::Archive; let file = File::open(path)?; let mut archive = Archive::new(BufReader::new(file)); // Extract the main entries from the TAR archive let mut sys_info_data: Option> = None; let mut pickle_data: Option> = None; let mut storages_data: Option> = None; for entry in archive.entries().map_err(PytorchError::Tar)? { let mut entry = entry.map_err(PytorchError::Tar)?; let entry_path = entry .path() .map_err(PytorchError::Tar)? .to_string_lossy() .to_string(); // Skip PAX headers if entry_path.contains("@PaxHeader") { continue; } // Normalize path (remove ./ prefix if present) let normalized = entry_path.trim_start_matches("./"); match normalized { "sys_info" => { let mut data = Vec::new(); entry.read_to_end(&mut data).map_err(PytorchError::Tar)?; sys_info_data = Some(data); } "pickle" => { let mut data = Vec::new(); entry.read_to_end(&mut data).map_err(PytorchError::Tar)?; pickle_data = Some(data); } "storages" => { let mut data = Vec::new(); entry.read_to_end(&mut data).map_err(PytorchError::Tar)?; storages_data = Some(data); } _ => {} } } // Validate required entries let pickle_data = pickle_data.ok_or_else(|| { PytorchError::InvalidFormat("TAR file missing 'pickle' entry".to_string()) })?; let storages_data = storages_data.ok_or_else(|| { PytorchError::InvalidFormat("TAR file missing 'storages' entry".to_string()) })?; // Parse sys_info to check endianness let is_little_endian = if let Some(ref data) = sys_info_data { parse_tar_sys_info(data)? } else { true // Default to little-endian }; if !is_little_endian { return Err(PytorchError::InvalidFormat( "Big-endian TAR PyTorch files are not supported".to_string(), )); } // Create TarSource for lazy loading let data_source = Arc::new(LazyDataSource::from_tar(&storages_data)?); // Parse the pickle (OrderedDict of name -> storage_key) let mut pickle_reader = BufReader::new(pickle_data.as_slice()); let obj = read_pickle_with_data(&mut pickle_reader, data_source)?; // Extract tensors let tensors = extract_tensors_with_data(obj, top_level_key)?; let metadata = PytorchMetadata { format_version: None, format_type: FileFormat::Tar, byte_order: ByteOrder::LittleEndian, has_storage_alignment: false, pytorch_version: None, tensor_count: tensors.len(), total_data_size: Some(storages_data.len()), }; Ok((tensors, metadata)) } /// Parse sys_info pickle from TAR format to extract endianness fn parse_tar_sys_info(data: &[u8]) -> Result { let mut reader = BufReader::new(data); let obj = read_pickle(&mut reader)?; if let Object::Dict(dict) = obj && let Some(Object::Bool(little_endian)) = dict.get("little_endian") { return Ok(*little_endian); } Ok(true) // Default assumption } /// Read pickle data from a PyTorch file as a simplified value fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result { use crate::pytorch::lazy_data::LazyDataSource; use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data}; use std::sync::Arc; // Try to open as ZIP first if let Ok(file) = File::open(path) && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file)) { // Read pickle data from ZIP let mut pickle_data = Vec::new(); // Try standard locations for pickle_path in &["data.pkl", "archive/data.pkl"] { if let Ok(mut pickle_file) = archive.by_name(pickle_path) { pickle_file.read_to_end(&mut pickle_data)?; break; } } // If not found, search for any .pkl file if pickle_data.is_empty() { for i in 0..archive.len() { let file = archive.by_index(i)?; let name = file.name().to_string(); drop(file); if name.ends_with("data.pkl") { let mut file = archive.by_index(i)?; file.read_to_end(&mut pickle_data)?; break; } } } if !pickle_data.is_empty() { // Create a data source for the ZIP file let data_source = LazyDataSource::from_zip(path)?; let data_source_arc = Arc::new(data_source); let mut reader = BufReader::new(pickle_data.as_slice()); let obj = read_pickle_with_data(&mut reader, data_source_arc)?; return convert_object_to_value(obj, top_level_key); } } // Try as plain pickle file // First attempt without data source (for pure metadata files) let file = File::open(path)?; let mut reader = BufReader::new(file); match read_pickle(&mut reader) { Ok(obj) => convert_object_to_value(obj, top_level_key), Err(e) if e.to_string() .contains("Cannot load tensor data without a data source") => { // File contains tensors, need to use full PytorchReader // Use the regular reader to get proper tensor handling let reader = PytorchReader::new(path)?; // Convert tensors to PickleValue structure let mut result = std::collections::HashMap::new(); for key in reader.keys() { // For pickle value extraction, we just need the structure, not the actual data result.insert( key.clone(), PickleValue::String(format!("", key)), ); } if let Some(key) = top_level_key { Ok(PickleValue::Dict( [(key.to_string(), PickleValue::Dict(result))] .into_iter() .collect(), )) } else { Ok(PickleValue::Dict(result)) } } Err(e) => Err(PytorchError::Pickle(e)), } } /// Convert internal Object to public PickleValue fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result { use crate::pytorch::pickle_reader::Object; // If a top-level key is specified, extract it first if let Some(key) = top_level_key && let Object::Dict(dict) = obj { if let Some(value) = dict.get(key) { return object_to_pickle_value(value.clone()); } else { return Err(PytorchError::KeyNotFound(format!( "Key '{}' not found in pickle data", key ))); } } object_to_pickle_value(obj) } /// Convert Object to PickleValue fn object_to_pickle_value(obj: Object) -> Result { use crate::pytorch::pickle_reader::Object; Ok(match obj { Object::None => PickleValue::None, Object::Bool(b) => PickleValue::Bool(b), Object::Int(i) => PickleValue::Int(i), Object::Float(f) => PickleValue::Float(f), Object::String(s) => PickleValue::String(s), Object::Persistent(data) => { // Persistent data is raw bytes PickleValue::Bytes(data) } Object::PersistentTuple(tuple) => { // Convert persistent tuples to lists let mut values = Vec::new(); for item in tuple { values.push(object_to_pickle_value(item)?); } PickleValue::List(values) } Object::List(list) => { let mut values = Vec::new(); for item in list { values.push(object_to_pickle_value(item)?); } PickleValue::List(values) } Object::Dict(dict) => { let mut map = HashMap::new(); for (k, v) in dict { map.insert(k, object_to_pickle_value(v)?); } PickleValue::Dict(map) } Object::Tuple(tuple) => { // Convert tuples to lists in the public API let mut values = Vec::new(); for item in tuple { values.push(object_to_pickle_value(item)?); } PickleValue::List(values) } Object::TorchParam(_) => { // Skip tensor parameters in config reading PickleValue::None } Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => { // Complex objects are represented as None for simplicity PickleValue::None } }) } /// Convert PickleValue to NestedValue for deserialization fn convert_pickle_to_nested_value(value: PickleValue) -> Result { Ok(match value { PickleValue::None => NestedValue::Default(None), PickleValue::Bool(b) => NestedValue::Bool(b), PickleValue::Int(i) => NestedValue::I64(i), PickleValue::Float(f) => NestedValue::F64(f), PickleValue::String(s) => NestedValue::String(s), PickleValue::List(list) => { let mut vec = Vec::new(); for item in list { vec.push(convert_pickle_to_nested_value(item)?); } NestedValue::Vec(vec) } PickleValue::Dict(dict) => { let mut map = HashMap::new(); for (k, v) in dict { map.insert(k, convert_pickle_to_nested_value(v)?); } NestedValue::Map(map) } PickleValue::Bytes(data) => { // Convert bytes to a list of u8 values let vec: Vec = data.into_iter().map(NestedValue::U8).collect(); NestedValue::Vec(vec) } }) } ================================================ FILE: crates/burn-store/src/pytorch/store.rs ================================================ //! PyTorch store implementation for saving and loading models in PyTorch format. use crate::{ ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter, TensorSnapshot, map_indices_contiguous, }; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec::Vec; use burn_tensor::backend::Backend; use core::fmt; use std::path::PathBuf; use super::reader::{PytorchError as ReaderError, PytorchReader}; /// Errors that can occur during PyTorch operations. #[derive(Debug)] pub enum PytorchStoreError { /// Reader error. Reader(ReaderError), /// I/O error. Io(std::io::Error), /// Tensor not found. TensorNotFound(String), /// Validation failed. ValidationFailed(String), /// Other error. Other(String), } impl fmt::Display for PytorchStoreError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Reader(e) => write!(f, "PyTorch reader error: {}", e), Self::Io(e) => write!(f, "I/O error: {}", e), Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name), Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg), Self::Other(msg) => write!(f, "{}", msg), } } } impl std::error::Error for PytorchStoreError {} impl From for PytorchStoreError { fn from(e: ReaderError) -> Self { PytorchStoreError::Reader(e) } } impl From for PytorchStoreError { fn from(e: std::io::Error) -> Self { PytorchStoreError::Io(e) } } /// PyTorch store for file-based storage only. /// /// This store allows loading models from PyTorch checkpoint files (.pt/.pth) /// with automatic weight transformation using `PyTorchToBurnAdapter`. /// Linear weights are automatically transposed and normalization parameters /// are renamed (gamma -> weight, beta -> bias). /// /// Note that saving to PyTorch format is not yet supported. pub struct PytorchStore { pub(crate) path: PathBuf, pub(crate) filter: PathFilter, pub(crate) remapper: KeyRemapper, pub(crate) validate: bool, pub(crate) allow_partial: bool, pub(crate) top_level_key: Option, pub(crate) skip_enum_variants: bool, /// Enable contiguous mapping of layer indices (default: true) pub(crate) map_indices_contiguous: bool, /// Cached tensor snapshots (parsed once, reused) snapshots_cache: Option>, } impl PytorchStore { /// Create a store for loading from a PyTorch file. /// /// # Arguments /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth) /// /// # Example /// ```rust,no_run /// use burn_store::PytorchStore; /// /// let store = PytorchStore::from_file("model.pth"); /// ``` pub fn from_file(path: impl Into) -> Self { Self { path: path.into(), filter: PathFilter::new(), remapper: KeyRemapper::new(), validate: true, allow_partial: false, top_level_key: None, // PyTorch models never include enum variant names in paths skip_enum_variants: true, // Enable contiguous index mapping by default for PyTorch files // This handles nn.Sequential models with gaps in layer indices map_indices_contiguous: true, snapshots_cache: None, } } /// Set a top-level key to extract tensors from. /// /// PyTorch files often contain nested dictionaries. Use this to extract /// tensors from a specific top-level key like "state_dict" or "model_state_dict". /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("checkpoint.pth") /// .with_top_level_key("model_state_dict"); /// ``` pub fn with_top_level_key(mut self, key: impl Into) -> Self { self.top_level_key = Some(key.into()); self } /// Filter which tensors to load. pub fn filter(mut self, filter: PathFilter) -> Self { self.filter = filter; self } /// Add a regex pattern to filter tensors. /// /// Multiple patterns can be added and they work with OR logic. /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.weight$"); // OR match any weight tensors /// ``` pub fn with_regex>(mut self, pattern: S) -> Self { self.filter = self.filter.with_regex(pattern); self } /// Add multiple regex patterns to filter tensors. pub fn with_regexes(mut self, patterns: I) -> Self where I: IntoIterator, S: AsRef, { self.filter = self.filter.with_regexes(patterns); self } /// Add an exact full path to match. /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_full_path("encoder.layer1.weight") /// .with_full_path("decoder.output.bias"); /// ``` pub fn with_full_path>(mut self, path: S) -> Self { self.filter = self.filter.with_full_path(path); self } /// Add multiple exact full paths to match. pub fn with_full_paths(mut self, paths: I) -> Self where I: IntoIterator, S: Into, { self.filter = self.filter.with_full_paths(paths); self } /// Add a predicate function for custom filtering logic. /// /// The predicate receives the tensor path and container path. /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias")); /// ``` pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self { self.filter = self.filter.with_predicate(predicate); self } /// Add multiple predicate functions. pub fn with_predicates(mut self, predicates: I) -> Self where I: IntoIterator bool>, { self.filter = self.filter.with_predicates(predicates); self } /// Set the filter to match all paths (disables filtering). pub fn match_all(mut self) -> Self { self.filter = self.filter.match_all(); self } /// Remap tensor names during load. pub fn remap(mut self, remapper: KeyRemapper) -> Self { self.remapper = remapper; self } /// Add a regex pattern to remap tensor names during load. /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// let store = PytorchStore::from_file("model.pth") /// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X /// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight /// ``` pub fn with_key_remapping( mut self, from_pattern: impl AsRef, to_pattern: impl Into, ) -> Self { self.remapper = self .remapper .add_pattern(from_pattern, to_pattern) .expect("Invalid regex pattern"); self } /// Set whether to validate tensors during loading (default: true). pub fn validate(mut self, validate: bool) -> Self { self.validate = validate; self } /// Allow partial loading of tensors (continue even if some tensors are missing). pub fn allow_partial(mut self, allow: bool) -> Self { self.allow_partial = allow; self } /// Skip enum variant names when matching tensor paths (default: true). /// /// When enabled, tensor paths from PyTorch that don't include enum variants /// can be matched against Burn module paths that do include them. /// For example, PyTorch path "feature.weight" can match Burn path "feature.BaseConv.weight". /// /// This defaults to `true` for PytorchStore since PyTorch models never include /// enum variant names in their parameter paths. /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// // Disable enum variant skipping (not typical) /// let store = PytorchStore::from_file("model.pth") /// .skip_enum_variants(false); /// ``` pub fn skip_enum_variants(mut self, skip: bool) -> Self { self.skip_enum_variants = skip; self } /// Enable or disable automatic contiguous mapping of layer indices (default: true). /// /// When enabled, non-contiguous numeric indices in tensor paths are renumbered /// to be contiguous. This is useful when loading PyTorch models that have gaps /// in layer numbering, such as when using `nn.Sequential` with mixed layer types /// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5). /// /// # Example /// /// With index mapping enabled (default): /// - `fc.0.weight` → `fc.0.weight` /// - `fc.2.weight` → `fc.1.weight` (gap filled) /// - `fc.4.weight` → `fc.2.weight` (gap filled) /// /// # Arguments /// /// * `map` - `true` to enable contiguous index mapping, `false` to disable /// /// # Example /// ```rust,no_run /// # use burn_store::PytorchStore; /// // Disable contiguous index mapping if your model already has contiguous indices /// let store = PytorchStore::from_file("model.pth") /// .map_indices_contiguous(false); /// ``` pub fn map_indices_contiguous(mut self, map: bool) -> Self { self.map_indices_contiguous = map; self } /// Apply remapping to tensor snapshots. fn apply_remapping(&self, snapshots: Vec) -> Vec { if self.remapper.is_empty() { return snapshots; } let (remapped, _) = self.remapper.remap(snapshots); remapped } /// Create a PytorchReader for the configured path and options. fn create_reader(&self) -> Result { let reader = if let Some(ref key) = self.top_level_key { PytorchReader::with_top_level_key(&self.path, key)? } else { PytorchReader::new(&self.path)? }; Ok(reader) } } impl ModuleStore for PytorchStore { type Error = PytorchStoreError; fn collect_from>( &mut self, _module: &M, ) -> Result<(), Self::Error> { // Saving to PyTorch format is not yet supported Err(PytorchStoreError::Other( "Saving to PyTorch format is not yet supported. Use other formats for saving." .to_string(), )) } fn apply_to>( &mut self, module: &mut M, ) -> Result { // Get snapshots from cache let snapshots: Vec = self.get_all_snapshots()?.values().cloned().collect(); // Get filter (convert to Option for apply) let filter_opt = if self.filter.is_empty() { None } else { Some(self.filter.clone()) }; // Apply to module with PyTorchToBurnAdapter (always used for PyTorch files) // This adapter handles: // - Transposing linear weights from PyTorch format to Burn format // - Renaming normalization parameters (gamma -> weight, beta -> bias) // Filter is applied here during apply, not during cache population let result = module.apply( snapshots, filter_opt, Some(Box::new(PyTorchToBurnAdapter)), self.skip_enum_variants, ); // Validate if needed if self.validate && !result.errors.is_empty() { return Err(PytorchStoreError::ValidationFailed(format!( "Import errors:\n{}", result ))); } if !self.allow_partial && !result.missing.is_empty() { return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result))); } Ok(result) } fn get_snapshot(&mut self, name: &str) -> Result, Self::Error> { self.ensure_snapshots_cache()?; Ok(self.snapshots_cache.as_ref().unwrap().get(name)) } fn get_all_snapshots(&mut self) -> Result<&BTreeMap, Self::Error> { self.ensure_snapshots_cache()?; Ok(self.snapshots_cache.as_ref().unwrap()) } fn keys(&mut self) -> Result, Self::Error> { // Always use the cache to ensure remapping is applied consistently Ok(self.get_all_snapshots()?.keys().cloned().collect()) } } impl PytorchStore { /// Ensure the snapshots cache is populated fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> { if self.snapshots_cache.is_some() { return Ok(()); } let reader = self.create_reader()?; // Convert to tensor snapshots let mut snapshots: Vec = reader .into_tensors() .into_iter() .map(|(key, mut snapshot)| { // Parse the key into path parts (split by '.') let path_parts: Vec = key.split('.').map(|s| s.to_string()).collect(); // Set the path stack from the key snapshot.path_stack = Some(path_parts); snapshot.container_stack = None; snapshot.tensor_id = None; snapshot }) .collect(); // Apply remapping (but NOT filtering - that's done at apply time) snapshots = self.apply_remapping(snapshots); // Apply contiguous index mapping if enabled // This must be done after remapping so that remapped paths are mapped if self.map_indices_contiguous { let (mapped, _) = map_indices_contiguous(snapshots); snapshots = mapped; } // Build cache as BTreeMap let cache: BTreeMap = snapshots.into_iter().map(|s| (s.full_path(), s)).collect(); self.snapshots_cache = Some(cache); Ok(()) } } ================================================ FILE: crates/burn-store/src/pytorch/tests/mod.rs ================================================ pub mod reader; pub mod store; ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/create_legacy_with_offsets.py ================================================ #!/usr/bin/env python3 # /// script # dependencies = ["torch"] # /// """Create a legacy format PyTorch file with specific storage offsets to test offset handling.""" import torch # Create tensors with known values at specific storage offsets # This will help us verify we're reading from the correct location # Create a state dict with tensors that share storage # This is common in PyTorch models (e.g., weight and transposed weight views) state_dict = {} # Create a base tensor with known pattern base_data = torch.arange(100, dtype=torch.float32) # tensor1: uses elements 10-19 (offset 10*4 = 40 bytes) tensor1 = base_data[10:20].clone() tensor1[:] = torch.arange(1.0, 1.1, 0.01)[:10] # 1.00, 1.01, 1.02, ... # tensor2: uses elements 30-35 (offset 30*4 = 120 bytes) tensor2 = base_data[30:35].clone() tensor2[:] = torch.arange(2.0, 2.5, 0.1)[:5] # 2.0, 2.1, 2.2, 2.3, 2.4 # tensor3: starts at beginning (offset 0) tensor3 = base_data[:5].clone() tensor3[:] = torch.arange(3.0, 3.5, 0.1)[:5] # 3.0, 3.1, 3.2, 3.3, 3.4 state_dict['tensor1'] = tensor1 state_dict['tensor2'] = tensor2 state_dict['tensor3'] = tensor3 # Save in legacy format output_file = 'test_data/legacy_with_offsets.pt' torch.save(state_dict, output_file, _use_new_zipfile_serialization=False) print(f"Created {output_file}") # Verify by loading loaded = torch.load(output_file, weights_only=False) print("\nVerification - expected values:") for key, tensor in loaded.items(): print(f" {key}: {tensor.tolist()}") print(f" Storage offset: {tensor.storage_offset()}") print(f" Storage size: {len(tensor.storage())}") # Also create a test with multiple tensors sharing the same storage # This is important for proper offset handling shared_storage = torch.randn(1000) # Create views into the same storage at different offsets view1 = shared_storage[100:110] # offset 100 view2 = shared_storage[500:520] # offset 500 view3 = shared_storage[0:10] # offset 0 # Need to save these properly - PyTorch will handle the storage sharing shared_dict = { 'view1': view1.clone(), # Clone to avoid view issues 'view2': view2.clone(), 'view3': view3.clone(), } output_file2 = 'test_data/legacy_shared_storage.pt' torch.save(shared_dict, output_file2, _use_new_zipfile_serialization=False) print(f"\nCreated {output_file2}") # Print exact values for test verification print("\nExact test values for legacy_with_offsets.pt:") print("tensor1 (10 elements starting at 1.0):") print(" First 3 values: [1.00, 1.01, 1.02]") print("tensor2 (5 elements starting at 2.0):") print(" All values: [2.0, 2.1, 2.2, 2.3, 2.4]") print("tensor3 (5 elements starting at 3.0):") print(" All values: [3.0, 3.1, 3.2, 3.3, 3.4]") ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/create_tar_format.py ================================================ #!/usr/bin/env python3 """ Create TAR format test fixtures for burn-store integration tests. The TAR format was used by very early versions of PyTorch (pre 0.1.10). Modern torch.save cannot create this format, so we construct it manually. TAR format structure: - sys_info: pickle with {protocol_version, little_endian, type_sizes} - pickle: pickle with OrderedDict containing _rebuild_tensor_v2 REDUCE calls - storages: count_pickle + for each storage: (key, device, class) pickle + u64 num_elements + raw data """ import io import pickle import struct import tarfile import os from collections import OrderedDict def create_sys_info(): """Create sys_info pickle data.""" sys_info = { "protocol_version": 1000, "little_endian": True, "type_sizes": { "short": 2, "int": 4, "long": 8, }, } return pickle.dumps(sys_info, protocol=2) def encode_tensor_data(values: list, storage_type: str) -> tuple: """Encode tensor values to bytes and return (bytes, element_size).""" fmt_map = { "FloatStorage": (" bytes: """ Create the storages binary blob manually. Args: tensors: List of (key, storage_type, element_size, data_bytes) tuples """ buffer = io.BytesIO() # Write storage count as pickle (simple integer) pickle.dump(len(tensors), buffer, protocol=2) for key, storage_type, element_size, data_bytes in tensors: # Manually construct the tuple pickle with GLOBAL class reference # Format: (key, "cpu", ) tuple_buffer = io.BytesIO() # Protocol 2 header tuple_buffer.write(b'\x80\x02') # Build tuple with MARK + items + TUPLE tuple_buffer.write(b'(') # MARK # First item: storage key (string) write_string(tuple_buffer, key) # Second item: device "cpu" tuple_buffer.write(b'U\x03cpu') # Third item: class reference using GLOBAL tuple_buffer.write(b'c') # GLOBAL opcode tuple_buffer.write(b'torch\n') # module tuple_buffer.write(storage_type.encode('ascii') + b'\n') # name # End tuple tuple_buffer.write(b't') # TUPLE tuple_buffer.write(b'.') # STOP buffer.write(tuple_buffer.getvalue()) # Write num_elements as u64 little-endian num_elements = len(data_bytes) // element_size buffer.write(struct.pack(" bytes: """ Create the main pickle containing _rebuild_tensor_v2 REDUCE calls. For each tensor, we need: - GLOBAL torch._utils _rebuild_tensor_v2 - MARK - args tuple: (persistent_id, offset, shape, stride, requires_grad, hooks) - TUPLE - REDUCE The persistent_id is a PersistentTuple: ('storage', , key, device, num_elements) """ buffer = io.BytesIO() # Protocol 2 header buffer.write(b'\x80\x02') # Build OrderedDict: GLOBAL + EMPTY_LIST + items + TUPLE + REDUCE # OrderedDict([('name1', tensor1), ('name2', tensor2)]) # GLOBAL collections OrderedDict buffer.write(b'ccollections\nOrderedDict\n') # Start list for items buffer.write(b'(') # MARK buffer.write(b']') # EMPTY_LIST # For each tensor, add (name, rebuilt_tensor) to the list for name, storage_key, storage_type, shape, num_elements in tensors_info: # Calculate stride for row-major (C) order stride = [] s = 1 for dim in reversed(shape): stride.insert(0, s) s *= dim # Build inner tuple: (name, tensor_value) buffer.write(b'(') # MARK for (name, value) tuple # Write name write_string(buffer, name) # Now build the tensor using _rebuild_tensor_v2 REDUCE # GLOBAL torch._utils _rebuild_tensor_v2 buffer.write(b'ctorch._utils\n_rebuild_tensor_v2\n') # Build args tuple for _rebuild_tensor_v2 # (persistent_id, offset, shape, stride, requires_grad, backward_hooks) buffer.write(b'(') # MARK for args tuple # arg 0: persistent_id tuple: ('storage', class, key, device, num_elements) # This will be converted to PersistentTuple by the reader buffer.write(b'(') # MARK for persistent_id write_string(buffer, 'storage') # Class reference - GLOBAL torch FloatStorage buffer.write(b'c') buffer.write(b'torch\n') buffer.write(storage_type.encode('ascii') + b'\n') # Storage key write_string(buffer, storage_key) # Device buffer.write(b'U\x03cpu') # num_elements write_int(buffer, num_elements) buffer.write(b't') # TUPLE - end persistent_id # arg 1: storage offset (0) buffer.write(b'K\x00') # arg 2: shape tuple buffer.write(b'(') for dim in shape: write_int(buffer, dim) buffer.write(b't') # arg 3: stride tuple buffer.write(b'(') for s_val in stride: write_int(buffer, s_val) buffer.write(b't') # arg 4: requires_grad (False) buffer.write(b'\x89') # NEWFALSE # arg 5: backward_hooks (empty OrderedDict) buffer.write(b'ccollections\nOrderedDict\n') buffer.write(b'(') buffer.write(b']') buffer.write(b't') buffer.write(b'R') # REDUCE to create empty OrderedDict buffer.write(b't') # TUPLE - end args tuple buffer.write(b'R') # REDUCE - call _rebuild_tensor_v2 with args buffer.write(b't') # TUPLE - end (name, tensor) tuple buffer.write(b'a') # APPEND to list buffer.write(b't') # TUPLE - wrap list in tuple for REDUCE buffer.write(b'R') # REDUCE - call OrderedDict with the list buffer.write(b'.') # STOP return buffer.getvalue() def create_tar_pytorch_file(filename: str, tensors: dict, dtypes: dict): """ Create a TAR format PyTorch file. Args: filename: Output file path tensors: Dict of tensor_name -> (values_list, shape) dtypes: Dict of tensor_name -> storage_type """ # Prepare storage data storage_list = [] # (key, storage_type, element_size, data_bytes) tensors_info = [] # (name, storage_key, storage_type, shape, num_elements) for idx, (name, (values, shape)) in enumerate(tensors.items()): storage_key = str(idx) storage_type = dtypes[name] data_bytes, element_size = encode_tensor_data(values, storage_type) num_elements = len(values) storage_list.append((storage_key, storage_type, element_size, data_bytes)) tensors_info.append((name, storage_key, storage_type, shape, num_elements)) # Create the three main entries sys_info_data = create_sys_info() pickle_data = create_main_pickle_manual(tensors_info) storages_data = create_storages_blob_manual(storage_list) # Write TAR archive os.makedirs(os.path.dirname(filename) or ".", exist_ok=True) with tarfile.open(filename, "w") as tar: # Add sys_info tarinfo = tarfile.TarInfo(name="sys_info") tarinfo.size = len(sys_info_data) tar.addfile(tarinfo, io.BytesIO(sys_info_data)) # Add pickle tarinfo = tarfile.TarInfo(name="pickle") tarinfo.size = len(pickle_data) tar.addfile(tarinfo, io.BytesIO(pickle_data)) # Add storages tarinfo = tarfile.TarInfo(name="storages") tarinfo.size = len(storages_data) tar.addfile(tarinfo, io.BytesIO(storages_data)) size = os.path.getsize(filename) print(f"Created {filename} ({size} bytes)") print(f" Tensors: {list(tensors.keys())}") def main(): # Create test_data directory os.makedirs("test_data", exist_ok=True) # Test 1: Single float32 tensor create_tar_pytorch_file( "test_data/tar_float32.tar", {"tensor": ([1.0, 2.5, -3.7, 0.0], [4])}, {"tensor": "FloatStorage"}, ) # Test 2: Single float64 tensor create_tar_pytorch_file( "test_data/tar_float64.tar", {"tensor": ([1.1, 2.2, 3.3], [3])}, {"tensor": "DoubleStorage"}, ) # Test 3: Single int64 tensor create_tar_pytorch_file( "test_data/tar_int64.tar", {"tensor": ([100, -200, 300, 0], [4])}, {"tensor": "LongStorage"}, ) # Test 4: Multiple tensors (weight + bias) create_tar_pytorch_file( "test_data/tar_weight_bias.tar", { "weight": ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [2, 3]), "bias": ([0.01, 0.02], [2]), }, { "weight": "FloatStorage", "bias": "FloatStorage", }, ) # Test 5: Different dtypes in one file create_tar_pytorch_file( "test_data/tar_multi_dtype.tar", { "float_tensor": ([1.5, 2.5, 3.5], [3]), "double_tensor": ([1.111, 2.222], [2]), "int_tensor": ([10, 20, 30, 40], [4]), }, { "float_tensor": "FloatStorage", "double_tensor": "DoubleStorage", "int_tensor": "LongStorage", }, ) # Test 6: 2D tensor for shape verification create_tar_pytorch_file( "test_data/tar_2d_tensor.tar", { "matrix": ([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [3, 4]), }, {"matrix": "FloatStorage"}, ) print("\nAll TAR format test files created!") print("\nTo run tests: cargo test -p burn-store --features pytorch test_tar") if __name__ == "__main__": main() ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/mod.rs ================================================ //! Tests for PyTorch file reader functionality //! //! Floating-point comparison tolerances: //! - F16/BF16: 1e-2 (~3 decimal digits precision) //! - F32: 1e-6 (~7 decimal digits precision) //! - F64: 1e-10 (~16 decimal digits precision) #![allow(clippy::needless_range_loop)] use crate::pytorch::PytorchReader; // Import internal types for testing only use crate::pytorch::reader::{ByteOrder, FileFormat}; use burn_tensor::{BoolStore, DType, shape}; use std::path::PathBuf; fn test_data_path(filename: &str) -> PathBuf { // Get the path relative to the crate root PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("src") .join("pytorch") .join("tests") .join("reader") .join("test_data") .join(filename) } #[test] fn test_float32_tensor() { let path = test_data_path("float32.pt"); let reader = PytorchReader::new(&path).expect("Failed to load float32.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 4); assert!((values[0] - 1.0).abs() < 1e-6); assert!((values[1] - 2.5).abs() < 1e-6); assert!((values[2] - (-3.7)).abs() < 1e-6); assert!((values[3] - 0.0).abs() < 1e-6); } #[test] fn test_float64_tensor() { let path = test_data_path("float64.pt"); let reader = PytorchReader::new(&path).expect("Failed to load float64.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F64); assert_eq!(tensor.shape, shape![3]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 3); assert!((values[0] - 1.1).abs() < 1e-10); assert!((values[1] - 2.2).abs() < 1e-10); assert!((values[2] - 3.3).abs() < 1e-10); } #[test] fn test_int64_tensor() { let path = test_data_path("int64.pt"); let reader = PytorchReader::new(&path).expect("Failed to load int64.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::I64); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[100, -200, 300, 0]); } #[test] fn test_int32_tensor() { let path = test_data_path("int32.pt"); let reader = PytorchReader::new(&path).expect("Failed to load int32.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::I32); assert_eq!(tensor.shape, shape![3]); let data = tensor.to_data().unwrap(); // Convert to the appropriate element type let data_converted = data.convert::(); let values = data_converted.as_slice::().unwrap(); assert_eq!(values, &[10, 20, -30]); } #[test] fn test_int16_tensor() { let path = test_data_path("int16.pt"); let reader = PytorchReader::new(&path).expect("Failed to load int16.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::I16); assert_eq!(tensor.shape, shape![3]); let data = tensor.to_data().unwrap(); let data_converted = data.convert::(); let values = data_converted.as_slice::().unwrap(); assert_eq!(values, &[1000, -2000, 3000]); } #[test] fn test_int8_tensor() { let path = test_data_path("int8.pt"); let reader = PytorchReader::new(&path).expect("Failed to load int8.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::I8); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let data_converted = data.convert::(); let values = data_converted.as_slice::().unwrap(); assert_eq!(values, &[127, -128, 0, 50]); } #[test] fn test_bool_tensor() { let path = test_data_path("bool.pt"); let reader = PytorchReader::new(&path).expect("Failed to load bool.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::Bool(BoolStore::Native)); assert_eq!(tensor.shape, shape![5]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[true, false, true, true, false]); } #[test] fn test_uint8_tensor() { let path = test_data_path("uint8.pt"); let reader = PytorchReader::new(&path).expect("Failed to load uint8.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::U8); assert_eq!(tensor.shape, shape![4]); // Verify actual U8 values [0, 128, 255, 42] from test_data.py let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[0, 128, 255, 42]); } #[test] fn test_float16_tensor() { use half::f16; let path = test_data_path("float16.pt"); let reader = PytorchReader::new(&path).expect("Failed to load float16.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F16); assert_eq!(tensor.shape, shape![3]); // Verify actual F16 values [1.5, -2.25, 3.125] from test_data.py let data = tensor.to_data().unwrap(); assert_eq!(data.shape, shape![3]); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 3); assert!((values[0].to_f32() - 1.5).abs() < 1e-2); assert!((values[1].to_f32() - (-2.25)).abs() < 1e-2); assert!((values[2].to_f32() - 3.125).abs() < 1e-2); } #[test] fn test_bfloat16_tensor() { use half::bf16; let path = test_data_path("bfloat16.pt"); let reader = PytorchReader::new(&path).expect("Failed to load bfloat16.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::BF16); assert_eq!(tensor.shape, shape![3]); // Verify actual BF16 values [1.5, -2.5, 3.5] from test_data.py let data = tensor.to_data().unwrap(); assert_eq!(data.shape, shape![3]); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 3); assert!((values[0].to_f32() - 1.5).abs() < 1e-2); assert!((values[1].to_f32() - (-2.5)).abs() < 1e-2); assert!((values[2].to_f32() - 3.5).abs() < 1e-2); } #[test] fn test_2d_tensor() { let path = test_data_path("tensor_2d.pt"); let reader = PytorchReader::new(&path).expect("Failed to load tensor_2d.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![3, 2]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 6); // Check flattened values [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] for (i, expected) in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0].iter().enumerate() { assert!((values[i] - expected).abs() < 1e-6); } } #[test] fn test_3d_tensor() { let path = test_data_path("tensor_3d.pt"); let reader = PytorchReader::new(&path).expect("Failed to load tensor_3d.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![2, 3, 4]); let data = tensor.to_data().unwrap(); assert_eq!(data.shape, shape![2, 3, 4]); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 24); } #[test] fn test_4d_tensor() { let path = test_data_path("tensor_4d.pt"); let reader = PytorchReader::new(&path).expect("Failed to load tensor_4d.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![2, 3, 2, 2]); let data = tensor.to_data().unwrap(); assert_eq!(data.shape, shape![2, 3, 2, 2]); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 24); } #[test] fn test_state_dict() { let path = test_data_path("state_dict.pt"); let reader = PytorchReader::new(&path).expect("Failed to load state_dict.pt"); let keys = reader.keys(); assert_eq!(keys.len(), 4); assert!(keys.contains(&"weight".to_string())); assert!(keys.contains(&"bias".to_string())); assert!(keys.contains(&"running_mean".to_string())); assert!(keys.contains(&"running_var".to_string())); // Check weight tensor let weight = reader.get("weight").unwrap(); assert_eq!(weight.shape, shape![3, 4]); assert_eq!(weight.dtype, DType::F32); // Check bias tensor let bias = reader.get("bias").unwrap(); assert_eq!(bias.shape, shape![3]); assert_eq!(bias.dtype, DType::F32); // Check running_mean (should be zeros) let running_mean = reader.get("running_mean").unwrap(); assert_eq!(running_mean.shape, shape![3]); let mean_data = running_mean.to_data().unwrap(); let mean_values = mean_data.as_slice::().unwrap(); assert!(mean_values.iter().all(|&v| v.abs() < 1e-6)); // Check running_var (should be ones) let running_var = reader.get("running_var").unwrap(); assert_eq!(running_var.shape, shape![3]); let var_data = running_var.to_data().unwrap(); let var_values = var_data.as_slice::().unwrap(); assert!(var_values.iter().all(|&v| (v - 1.0).abs() < 1e-6)); } #[test] fn test_nested_dict() { let path = test_data_path("nested_dict.pt"); let reader = PytorchReader::new(&path).expect("Failed to load nested_dict.pt"); let keys = reader.keys(); assert_eq!(keys.len(), 4); assert!(keys.contains(&"layer1.weight".to_string())); assert!(keys.contains(&"layer1.bias".to_string())); assert!(keys.contains(&"layer2.weight".to_string())); assert!(keys.contains(&"layer2.bias".to_string())); // Check layer1.weight and load data let layer1_weight = reader.get("layer1.weight").unwrap(); assert_eq!(layer1_weight.shape, shape![2, 3]); assert_eq!(layer1_weight.dtype, DType::F32); let data = layer1_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 6); // 2x3 = 6 elements // Check layer2.weight and load data let layer2_weight = reader.get("layer2.weight").unwrap(); assert_eq!(layer2_weight.shape, shape![4, 2]); assert_eq!(layer2_weight.dtype, DType::F32); let data = layer2_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 8); // 4x2 = 8 elements } #[test] fn test_checkpoint() { let path = test_data_path("checkpoint.pt"); let reader = PytorchReader::new(&path).expect("Failed to load checkpoint.pt"); let keys = reader.keys(); // Should have model_state_dict entries and optimizer entries assert!(keys.contains(&"model_state_dict.fc1.weight".to_string())); assert!(keys.contains(&"model_state_dict.fc1.bias".to_string())); assert!(keys.contains(&"model_state_dict.fc2.weight".to_string())); assert!(keys.contains(&"model_state_dict.fc2.bias".to_string())); // Check fc1.weight dimensions and load data let fc1_weight = reader.get("model_state_dict.fc1.weight").unwrap(); assert_eq!(fc1_weight.shape, shape![10, 5]); let data = fc1_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 50); // 10x5 = 50 elements // Check fc2.weight dimensions and load data let fc2_weight = reader.get("model_state_dict.fc2.weight").unwrap(); assert_eq!(fc2_weight.shape, shape![3, 10]); let data = fc2_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 30); // 3x10 = 30 elements } #[test] fn test_empty_tensor() { let path = test_data_path("empty.pt"); let reader = PytorchReader::new(&path).expect("Failed to load empty.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.shape, shape![0]); // Empty tensor has shape [0] assert_eq!(tensor.dtype, DType::F32); // Note: Empty tensors cannot be loaded with to_data() due to TensorData validation // We verify the metadata is correct, which confirms the .pt file is being read } #[test] fn test_scalar_tensor() { let path = test_data_path("scalar.pt"); let reader = PytorchReader::new(&path).expect("Failed to load scalar.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.shape, shape![]); // Scalar has empty shape assert_eq!(tensor.dtype, DType::F32); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 1); assert!((values[0] - 42.0).abs() < 1e-6); } #[test] fn test_large_shape() { let path = test_data_path("large_shape.pt"); let reader = PytorchReader::new(&path).expect("Failed to load large_shape.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.shape, shape![100, 100]); assert_eq!(tensor.dtype, DType::F32); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 10000); // Check specific non-zero values assert!((values[0] - 1.0).abs() < 1e-6); // [0, 0] = 1.0 assert!((values[5050] - 2.0).abs() < 1e-6); // [50, 50] = 2.0 assert!((values[9999] - 3.0).abs() < 1e-6); // [99, 99] = 3.0 } #[test] fn test_mixed_types() { let path = test_data_path("mixed_types.pt"); let reader = PytorchReader::new(&path).expect("Failed to load mixed_types.pt"); let tensors = reader.tensors(); assert_eq!(tensors.len(), 4); // Check float32 tensor [1.0, 2.0] from test_data.py let float32 = reader.get("float32").unwrap(); assert_eq!(float32.dtype, DType::F32); assert_eq!(float32.shape, shape![2]); let data = float32.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert!((values[0] - 1.0).abs() < 1e-6); assert!((values[1] - 2.0).abs() < 1e-6); // Check int64 tensor [100, 200] from test_data.py let int64 = reader.get("int64").unwrap(); assert_eq!(int64.dtype, DType::I64); assert_eq!(int64.shape, shape![2]); let data = int64.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[100, 200]); // Check bool tensor [True, False] from test_data.py let bool_tensor = reader.get("bool").unwrap(); assert_eq!(bool_tensor.dtype, DType::Bool(BoolStore::Native)); assert_eq!(bool_tensor.shape, shape![2]); let data = bool_tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[true, false]); // Check float64 tensor [1.1, 2.2] from test_data.py let float64 = reader.get("float64").unwrap(); assert_eq!(float64.dtype, DType::F64); assert_eq!(float64.shape, shape![2]); let data = float64.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert!((values[0] - 1.1).abs() < 1e-10); assert!((values[1] - 2.2).abs() < 1e-10); } #[test] fn test_special_values() { let path = test_data_path("special_values.pt"); let reader = PytorchReader::new(&path).expect("Failed to load special_values.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![5]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 5); // Check for special values assert!(values[0].is_nan()); assert!(values[1].is_infinite() && values[1] > 0.0); assert!(values[2].is_infinite() && values[2] < 0.0); assert!((values[3] - 0.0).abs() < 1e-6); assert!((values[4] - 1.0).abs() < 1e-6); } #[test] fn test_extreme_values() { let path = test_data_path("extreme_values.pt"); let reader = PytorchReader::new(&path).expect("Failed to load extreme_values.pt"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 4); // Very small positive assert!(values[0] > 0.0 && values[0] < 1e-20); // Very large positive assert!(values[1] > 1e20); // Very small negative assert!(values[2] < 0.0 && values[2] > -1e-20); // Very large negative assert!(values[3] < -1e20); } #[test] fn test_parameter() { let path = test_data_path("parameter.pt"); let reader = PytorchReader::new(&path).expect("Failed to load parameter.pt"); let tensors = reader.tensors(); // nn.Parameter is typically saved as a regular tensor assert_eq!(tensors.len(), 1); let param = reader.get("param").unwrap(); assert_eq!(param.shape, shape![3, 3]); assert_eq!(param.dtype, DType::F32); let data = param.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 9); } #[test] fn test_buffers() { let path = test_data_path("buffers.pt"); let reader = PytorchReader::new(&path).expect("Failed to load buffers.pt"); let tensors = reader.tensors(); assert_eq!(tensors.len(), 2); // Check buffer1 (int32) let buffer1 = reader.get("buffer1").unwrap(); assert_eq!(buffer1.dtype, DType::I32); assert_eq!(buffer1.shape, shape![3]); let data1 = buffer1.to_data().unwrap(); let data1_converted = data1.convert::(); let values1 = data1_converted.as_slice::().unwrap(); assert_eq!(values1, &[1, 2, 3]); // Check buffer2 (bool) let buffer2 = reader.get("buffer2").unwrap(); assert_eq!(buffer2.dtype, DType::Bool(BoolStore::Native)); assert_eq!(buffer2.shape, shape![2]); let data2 = buffer2.to_data().unwrap(); let values2 = data2.as_slice::().unwrap(); assert_eq!(values2, &[true, false]); } #[test] fn test_complex_structure() { let path = test_data_path("complex_structure.pt"); let reader = PytorchReader::new(&path).expect("Failed to load complex_structure.pt"); let keys = reader.keys(); // Should have nested structure tensors assert!(keys.contains(&"state.encoder.layer_0.weight".to_string())); assert!(keys.contains(&"state.encoder.layer_0.bias".to_string())); assert!(keys.contains(&"state.encoder.layer_1.weight".to_string())); assert!(keys.contains(&"state.encoder.layer_1.bias".to_string())); assert!(keys.contains(&"state.decoder.weight".to_string())); assert!(keys.contains(&"state.decoder.bias".to_string())); // Check encoder layer_0 weight and load data let layer0_weight = reader.get("state.encoder.layer_0.weight").unwrap(); assert_eq!(layer0_weight.shape, shape![4, 3]); let data = layer0_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 12); // 4x3 = 12 elements // Check decoder weight and load data let decoder_weight = reader.get("state.decoder.weight").unwrap(); assert_eq!(decoder_weight.shape, shape![3, 2]); let data = decoder_weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 6); // 3x2 = 6 elements } #[test] fn test_read_pytorch_tensors_convenience() { // Test reading and materializing tensors into memory let path = test_data_path("state_dict.pt"); let reader = PytorchReader::new(&path).expect("Failed to read file"); let keys = reader.keys(); assert_eq!(keys.len(), 4); assert!(keys.contains(&"weight".to_string())); assert!(keys.contains(&"bias".to_string())); // Check that data can be materialized let weight = reader.get("weight").unwrap(); let weight_data = weight.to_data().unwrap(); assert_eq!(weight_data.shape, shape![3, 4]); assert_eq!(weight_data.dtype, DType::F32); } #[test] fn test_with_top_level_key() { // Test loading with a specific top-level key let path = test_data_path("checkpoint.pt"); // Load only model_state_dict let reader = PytorchReader::with_top_level_key(&path, "model_state_dict") .expect("Failed to load with top-level key"); let keys = reader.keys(); // Should only have model weights, not optimizer state assert!(keys.contains(&"fc1.weight".to_string())); assert!(keys.contains(&"fc1.bias".to_string())); assert!(keys.contains(&"fc2.weight".to_string())); assert!(keys.contains(&"fc2.bias".to_string())); // Should NOT have nested paths with model_state_dict prefix assert!(!keys.contains(&"model_state_dict.fc1.weight".to_string())); } #[test] fn test_legacy_format() { // Test loading PyTorch legacy format (pre-1.6) let path = test_data_path("simple_legacy.pt"); // This file has the sequential pickle structure of legacy PyTorch format let reader = PytorchReader::new(&path).expect("Failed to load legacy format"); let keys = reader.keys(); // Should have the tensors from the state dict assert!(keys.contains(&"weight".to_string()), "Missing 'weight' key"); assert!(keys.contains(&"bias".to_string()), "Missing 'bias' key"); assert!( keys.contains(&"running_mean".to_string()), "Missing 'running_mean' key" ); // Check weight tensor let weight = reader.get("weight").expect("weight not found"); assert_eq!(weight.shape, shape![2, 3]); assert_eq!(weight.dtype, DType::F32); // Check bias tensor let bias = reader.get("bias").expect("bias not found"); assert_eq!(bias.shape, shape![2]); assert_eq!(bias.dtype, DType::F32); // Verify bias values are all ones let bias_data = bias.to_data().unwrap(); let bias_values = bias_data.as_slice::().unwrap(); // Note: values in simple_legacy.pt are randomly generated, not necessarily 1.0 assert_eq!(bias_values.len(), 2); // Check running_mean tensor let running_mean = reader.get("running_mean").expect("running_mean not found"); assert_eq!(running_mean.shape, shape![2]); assert_eq!(running_mean.dtype, DType::F32); // Verify running_mean values are accessible let mean_data = running_mean.to_data().unwrap(); let mean_values = mean_data.as_slice::().unwrap(); assert_eq!(mean_values.len(), 2); } #[test] fn test_legacy_with_offsets() { // Test with legacy format file that has storage offsets let path = test_data_path("legacy_with_offsets.pt"); let reader = PytorchReader::new(&path).expect("Should read legacy file with offsets"); let keys = reader.keys(); assert_eq!(keys.len(), 3, "Should have 3 tensors"); // Check that tensors exist for key in &keys { assert!(reader.get(key).is_some(), "Should have tensor: {}", key); let tensor = reader.get(key).unwrap(); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert!(!values.is_empty(), "Tensor {} should have data", key); } } #[test] fn test_legacy_shared_storage() { // Test with legacy format file that has shared storage let path = test_data_path("legacy_shared_storage.pt"); let reader = PytorchReader::new(&path).expect("Should read legacy file with shared storage"); let keys = reader.keys(); assert!(keys.len() >= 2, "Should have at least 2 tensors"); // Check that tensors exist and can be loaded for key in &keys { assert!(reader.get(key).is_some(), "Should have tensor: {}", key); let tensor = reader.get(key).unwrap(); let data = tensor.to_data().unwrap(); // Verify tensor data can be accessed match tensor.dtype { DType::F32 => { let values = data.as_slice::().unwrap(); assert!(!values.is_empty(), "Tensor {} should have data", key); } DType::I64 => { let values = data.as_slice::().unwrap(); assert!(!values.is_empty(), "Tensor {} should have data", key); } _ => { // For other types, just verify we can convert to data assert!(!data.shape.is_empty(), "Tensor {} should have shape", key); } } } } #[test] fn test_metadata_zip_format() { // Test that metadata is properly populated for ZIP format files let path = test_data_path("float32.pt"); let reader = PytorchReader::new(&path).expect("Failed to load float32.pt"); // Check metadata let metadata = reader.metadata(); assert_eq!(metadata.format_type, FileFormat::Zip); assert_eq!(metadata.byte_order, ByteOrder::LittleEndian); assert_eq!(metadata.tensor_count, 1); assert!(metadata.total_data_size.is_some()); // Check that metadata is accessible assert!(metadata.is_modern_format()); assert!(!metadata.is_legacy_format()); } #[test] fn test_metadata_legacy_format() { // Test that metadata is properly populated for legacy format files let path = test_data_path("simple_legacy.pt"); let reader = PytorchReader::new(&path).expect("Failed to load legacy file"); // Check metadata let metadata = reader.metadata(); assert_eq!(metadata.format_type, FileFormat::Legacy); assert_eq!(metadata.byte_order, ByteOrder::LittleEndian); assert_eq!(metadata.tensor_count, 3); // weight, bias, running_mean assert!(metadata.total_data_size.is_some()); } #[test] fn test_legacy_metadata_detailed() { // Detailed test to prove we load all metadata for legacy format files let path = test_data_path("simple_legacy.pt"); let reader = PytorchReader::new(&path).expect("Failed to load legacy file"); // Get and examine metadata let metadata = reader.metadata(); // Verify the metadata is correct for legacy format assert_eq!( metadata.format_type, FileFormat::Legacy, "Should be Legacy format" ); assert_eq!( metadata.byte_order, ByteOrder::LittleEndian, "Legacy format is little-endian" ); assert_eq!( metadata.tensor_count, 3, "Should have 3 tensors: weight, bias, running_mean" ); assert!( metadata.total_data_size.is_some(), "Should have total data size" ); assert!( metadata.total_data_size.unwrap() > 0, "Data size should be positive" ); // Legacy format specifics assert_eq!( metadata.format_version, None, "Legacy format doesn't have version file" ); assert_eq!( metadata.pytorch_version, None, "Legacy format doesn't store PyTorch version reliably" ); assert!( !metadata.has_storage_alignment, "Legacy format doesn't have storage alignment" ); // Also verify we can access the tensors let keys = reader.keys(); assert!( keys.contains(&"weight".to_string()), "Should have weight tensor" ); assert!( keys.contains(&"bias".to_string()), "Should have bias tensor" ); assert!( keys.contains(&"running_mean".to_string()), "Should have running_mean tensor" ); } #[test] fn test_small_invalid_file() { // Test that we handle broken/invalid files gracefully let path = test_data_path("broken.pt"); // Should fail gracefully with an appropriate error let result = PytorchReader::new(&path); assert!(result.is_err(), "Expected error for broken file"); // The error should be a pickle error since the file is too small to be valid if let Err(e) = result { let err_str = format!("{}", e); assert!( err_str.contains("Pickle") || err_str.contains("Invalid"), "Error should mention pickle or invalid format: {}", err_str ); } } #[test] fn test_read_pickle_data_basic() { use crate::pytorch::reader::PickleValue; // Test reading pickle data from a checkpoint file let path = test_data_path("checkpoint.pt"); // Read the entire pickle data let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read pickle data"); // Should be a dictionary at the root if let PickleValue::Dict(dict) = data { // Check that expected keys exist assert!(dict.contains_key("model_state_dict")); assert!(dict.contains_key("optimizer_state_dict")); assert!(dict.contains_key("epoch")); assert!(dict.contains_key("loss")); // Check epoch value if let Some(PickleValue::Int(epoch)) = dict.get("epoch") { assert_eq!(*epoch, 42); } else { panic!("Expected epoch to be an integer"); } // Check loss value if let Some(PickleValue::Float(loss)) = dict.get("loss") { assert!(*loss > 0.0 && *loss < 1.0, "Loss should be between 0 and 1"); } else { panic!("Expected loss to be a float"); } } else { panic!("Expected root to be a dictionary"); } } #[test] fn test_read_pickle_data_with_key() { use crate::pytorch::reader::PickleValue; // Test reading specific key from checkpoint let path = test_data_path("checkpoint.pt"); // Read only the model_state_dict let data = PytorchReader::read_pickle_data(&path, Some("model_state_dict")) .expect("Failed to read pickle data with key"); // Should get the model_state_dict directly if let PickleValue::Dict(dict) = data { // Should have model weights assert!(dict.contains_key("fc1.weight")); assert!(dict.contains_key("fc1.bias")); assert!(dict.contains_key("fc2.weight")); assert!(dict.contains_key("fc2.bias")); // Should NOT have optimizer keys assert!(!dict.contains_key("optimizer_state_dict")); assert!(!dict.contains_key("epoch")); } else { panic!("Expected model_state_dict to be a dictionary"); } } #[test] fn test_read_pickle_data_nested_structure() { use crate::pytorch::reader::PickleValue; // Test reading nested dictionary structure let path = test_data_path("nested_dict.pt"); let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read nested structure"); if let PickleValue::Dict(dict) = data { // nested_dict.pt has a nested structure, not flat keys // It should have layer1 and layer2 as nested dicts assert!(!dict.is_empty(), "Dictionary should not be empty"); // The structure depends on how the file was saved // It could be flat keys like "layer1.weight" or nested dicts // Just verify it's a valid dict structure for (_key, value) in dict.iter() { // Values could be None (tensors), nested dicts, or other types assert!( matches!(value, PickleValue::None | PickleValue::Dict(_)), "Values should be None or nested dicts" ); } } else { panic!("Expected nested_dict to be a dictionary"); } } #[test] fn test_read_pickle_data_types() { use crate::pytorch::reader::PickleValue; // Test various data types in mixed_types.pt let path = test_data_path("mixed_types.pt"); let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read mixed types"); if let PickleValue::Dict(dict) = data { // The file contains different tensor types assert!(dict.len() >= 3, "Should have at least 3 tensor types"); // All tensor values should be None in pickle data for (_key, value) in dict.iter() { // All values should be None (tensors are not included in pickle data) assert!( matches!(value, PickleValue::None), "Tensors should be None in pickle data" ); } } else { panic!("Expected mixed_types to be a dictionary"); } } #[test] fn test_read_pickle_data_key_not_found() { // Test error handling when key doesn't exist let path = test_data_path("checkpoint.pt"); let result = PytorchReader::read_pickle_data(&path, Some("nonexistent_key")); assert!(result.is_err()); if let Err(e) = result { let err_str = format!("{}", e); assert!( err_str.contains("not found"), "Error should mention key not found: {}", err_str ); } } #[test] fn test_read_pickle_data_simple_pickle() { use crate::pytorch::reader::PickleValue; // Test reading a simple pickle file (not ZIP) // Note: simple_legacy.pt is a legacy format file, not a simple pickle // It may return None because legacy format reading is different let path = test_data_path("state_dict.pt"); // Use a proper simple pickle file let data = PytorchReader::read_pickle_data(&path, None).expect("Failed to read simple pickle"); // Should contain state dict entries if let PickleValue::Dict(dict) = data { // state_dict.pt has weight, bias, running_mean, running_var assert!(dict.len() >= 3); assert!(dict.contains_key("weight")); assert!(dict.contains_key("bias")); // All tensor values should be None in pickle data for (_key, value) in dict.iter() { assert!(matches!(value, PickleValue::None)); } } else { panic!("Expected state_dict to contain a dictionary"); } } #[test] fn test_load_config_basic() { let path = test_data_path("checkpoint.pt"); // Define a struct that matches part of the checkpoint data #[derive(Debug, serde::Deserialize, PartialEq)] struct CheckpointConfig { epoch: i64, loss: f64, } // Load config let config: CheckpointConfig = PytorchReader::load_config(&path, None).expect("Failed to load config"); // Verify values - based on test_read_pickle_data_basic assert_eq!(config.epoch, 42); assert!((config.loss - 0.123).abs() < 1e-6); } #[test] fn test_load_config_with_top_level_key() { // Test that we can extract a non-existent key and get an appropriate error let path = test_data_path("checkpoint.pt"); #[derive(Debug, serde::Deserialize, PartialEq)] struct DummyConfig { field: String, } // Try loading with a valid top-level key that exists but has wrong structure let result: Result = PytorchReader::load_config(&path, Some("epoch")); // This should fail because epoch is an integer, not a struct with a field assert!(result.is_err()); // Now test that we can load with a real key that has the right structure // Since checkpoint.pt doesn't have nested configs, let's use nested_dict.pt let path2 = test_data_path("nested_dict.pt"); // Try to extract a specific nested key if it exists // Since nested_dict has complex structure, let's just verify we can read it let data = PytorchReader::read_pickle_data(&path2, None).unwrap(); // Verify it's a dict if let crate::pytorch::reader::PickleValue::Dict(dict) = data { assert!(!dict.is_empty()); } else { panic!("Expected a dict"); } } #[test] fn test_load_config_complex_types() { // For this test, let's create a comprehensive test using checkpoint.pt // which has both metadata and state_dict fields let path = test_data_path("checkpoint.pt"); // Define a partial config that only captures metadata fields #[derive(Debug, serde::Deserialize, PartialEq)] struct PartialCheckpoint { epoch: i64, loss: f64, // We skip model_state_dict and optimizer_state_dict // as they contain tensor references that become None } // Load partial config let config: PartialCheckpoint = PytorchReader::load_config(&path, None).expect("Failed to load config"); // Verify we can extract the metadata assert_eq!(config.epoch, 42); assert!((config.loss - 0.123).abs() < 1e-6); } #[test] fn test_load_config_key_not_found() { let path = test_data_path("checkpoint.pt"); #[derive(Debug, serde::Deserialize)] struct DummyConfig { #[allow(dead_code)] field: String, } // Try to load with non-existent key let result: Result = PytorchReader::load_config(&path, Some("nonexistent")); assert!(result.is_err()); let error = result.unwrap_err(); assert!(error.to_string().contains("not found") || error.to_string().contains("Key")); } #[test] fn test_pickle_value_conversion() { use crate::pytorch::reader::PickleValue; // Test that PickleValue provides useful data structures let path = test_data_path("checkpoint.pt"); let data = PytorchReader::read_pickle_data(&path, None).unwrap(); // Test pattern matching and data extraction match data { PickleValue::Dict(dict) => { // Extract epoch as integer if let Some(PickleValue::Int(epoch)) = dict.get("epoch") { assert!(*epoch >= 0); } // Extract loss as float if let Some(PickleValue::Float(loss)) = dict.get("loss") { assert!(loss.is_finite()); } // Test nested access if let Some(PickleValue::Dict(model_dict)) = dict.get("model_state_dict") { assert!(!model_dict.is_empty()); } } _ => panic!("Unexpected root type"), } } // ============================================================================ // TAR Format Tests // ============================================================================ // The TAR format was used by very early versions of PyTorch (pre 0.1.10). // These tests verify that we can correctly load models saved in this format. #[test] fn test_tar_format_detection() { // Test that is_tar_file correctly detects TAR files let tar_path = test_data_path("tar_float32.tar"); let zip_path = test_data_path("float32.pt"); // TAR file should be detected as TAR let reader = PytorchReader::new(&tar_path).expect("Failed to load TAR file"); let metadata = reader.metadata(); assert_eq!(metadata.format_type, FileFormat::Tar); // ZIP file should NOT be detected as TAR let reader = PytorchReader::new(&zip_path).expect("Failed to load ZIP file"); let metadata = reader.metadata(); assert_ne!(metadata.format_type, FileFormat::Tar); } #[test] fn test_tar_float32_tensor() { let path = test_data_path("tar_float32.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_float32.tar"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F32); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 4); assert!((values[0] - 1.0).abs() < 1e-6); assert!((values[1] - 2.5).abs() < 1e-6); assert!((values[2] - (-3.7)).abs() < 1e-6); assert!((values[3] - 0.0).abs() < 1e-6); } #[test] fn test_tar_float64_tensor() { let path = test_data_path("tar_float64.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_float64.tar"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::F64); assert_eq!(tensor.shape, shape![3]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 3); assert!((values[0] - 1.1).abs() < 1e-10); assert!((values[1] - 2.2).abs() < 1e-10); assert!((values[2] - 3.3).abs() < 1e-10); } #[test] fn test_tar_int64_tensor() { let path = test_data_path("tar_int64.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_int64.tar"); let tensor = reader.get("tensor").expect("tensor key not found"); assert_eq!(tensor.dtype, DType::I64); assert_eq!(tensor.shape, shape![4]); let data = tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[100, -200, 300, 0]); } #[test] fn test_tar_multiple_tensors() { // Test loading multiple tensors (weight + bias) with correct shapes let path = test_data_path("tar_weight_bias.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_weight_bias.tar"); // Check weight tensor (2x3 matrix) let weight = reader.get("weight").expect("weight key not found"); assert_eq!(weight.dtype, DType::F32); assert_eq!(weight.shape, shape![2, 3]); let data = weight.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 6); assert!((values[0] - 0.1).abs() < 1e-6); assert!((values[1] - 0.2).abs() < 1e-6); assert!((values[5] - 0.6).abs() < 1e-6); // Check bias tensor (2-element vector) let bias = reader.get("bias").expect("bias key not found"); assert_eq!(bias.dtype, DType::F32); assert_eq!(bias.shape, shape![2]); let data = bias.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 2); assert!((values[0] - 0.01).abs() < 1e-6); assert!((values[1] - 0.02).abs() < 1e-6); } #[test] fn test_tar_multi_dtype() { // Test loading different dtypes from the same TAR file let path = test_data_path("tar_multi_dtype.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_multi_dtype.tar"); // Float32 tensor let float_tensor = reader .get("float_tensor") .expect("float_tensor key not found"); assert_eq!(float_tensor.dtype, DType::F32); let data = float_tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert!((values[0] - 1.5).abs() < 1e-6); // Float64 tensor let double_tensor = reader .get("double_tensor") .expect("double_tensor key not found"); assert_eq!(double_tensor.dtype, DType::F64); let data = double_tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert!((values[0] - 1.111).abs() < 1e-10); // Int64 tensor let int_tensor = reader.get("int_tensor").expect("int_tensor key not found"); assert_eq!(int_tensor.dtype, DType::I64); let data = int_tensor.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values, &[10, 20, 30, 40]); } #[test] fn test_tar_2d_tensor_shape() { // Test that 2D tensor shapes are correctly preserved let path = test_data_path("tar_2d_tensor.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_2d_tensor.tar"); let matrix = reader.get("matrix").expect("matrix key not found"); assert_eq!(matrix.dtype, DType::F32); assert_eq!(matrix.shape, shape![3, 4]); // 3 rows, 4 columns let data = matrix.to_data().unwrap(); let values = data.as_slice::().unwrap(); assert_eq!(values.len(), 12); // Verify values in row-major order for i in 0..12 { assert!((values[i] - (i as f32 + 1.0)).abs() < 1e-6); } } #[test] fn test_tar_metadata() { // Test that TAR metadata is correctly populated let path = test_data_path("tar_float32.tar"); let reader = PytorchReader::new(&path).expect("Failed to load tar_float32.tar"); let metadata = reader.metadata(); assert_eq!(metadata.format_type, FileFormat::Tar); assert_eq!(metadata.byte_order, ByteOrder::LittleEndian); assert_eq!(metadata.tensor_count, 1); assert!(metadata.total_data_size.is_some()); } ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/simple_legacy.py ================================================ #!/usr/bin/env python3 # /// script # dependencies = ["torch"] # /// """Create a simple legacy format PyTorch file.""" import torch # Create a simple state dict state_dict = { 'weight': torch.randn(2, 3), 'bias': torch.ones(2), 'running_mean': torch.zeros(2), } # Save without using zip format (legacy format) torch.save(state_dict, 'test_data/simple_legacy.pt', _use_new_zipfile_serialization=False) print("Created simple_legacy.pt") # Verify loaded = torch.load('test_data/simple_legacy.pt', weights_only=False) print(f"Loaded {len(loaded)} tensors") for key, val in loaded.items(): print(f" {key}: shape {val.shape}, dtype {val.dtype}") ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/test_data/broken.pt ================================================ abc ================================================ FILE: crates/burn-store/src/pytorch/tests/reader/test_data.py ================================================ #!/usr/bin/env python3 # /// script # dependencies = ["torch", "numpy"] # /// """ Generate test PyTorch .pt files for testing the burn-store PyTorch reader. Run with: uv run test_files.py """ import torch import numpy as np import os from pathlib import Path # Create test directory test_dir = Path(__file__).parent / "test_data" test_dir.mkdir(exist_ok=True) def save_test_file(filename, data, description): """Save a test file and print what was saved.""" filepath = test_dir / filename torch.save(data, filepath) print(f"✓ {filename}: {description}") return filepath # Test 1: Simple tensors of different types print("\n=== Generating Basic Tensor Tests ===") # Float32 tensor (wrap in dict for compatibility) float32_tensor = torch.tensor([1.0, 2.5, -3.7, 0.0], dtype=torch.float32) save_test_file("float32.pt", {"tensor": float32_tensor}, "Float32 tensor [1.0, 2.5, -3.7, 0.0]") # Float64 tensor float64_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float64) save_test_file("float64.pt", {"tensor": float64_tensor}, "Float64 tensor [1.1, 2.2, 3.3]") # Int64 tensor int64_tensor = torch.tensor([100, -200, 300, 0], dtype=torch.int64) save_test_file("int64.pt", {"tensor": int64_tensor}, "Int64 tensor [100, -200, 300, 0]") # Int32 tensor int32_tensor = torch.tensor([10, 20, -30], dtype=torch.int32) save_test_file("int32.pt", {"tensor": int32_tensor}, "Int32 tensor [10, 20, -30]") # Int16 tensor int16_tensor = torch.tensor([1000, -2000, 3000], dtype=torch.int16) save_test_file("int16.pt", {"tensor": int16_tensor}, "Int16 tensor [1000, -2000, 3000]") # Int8 tensor int8_tensor = torch.tensor([127, -128, 0, 50], dtype=torch.int8) save_test_file("int8.pt", {"tensor": int8_tensor}, "Int8 tensor [127, -128, 0, 50]") # Boolean tensor bool_tensor = torch.tensor([True, False, True, True, False], dtype=torch.bool) save_test_file("bool.pt", {"tensor": bool_tensor}, "Bool tensor [True, False, True, True, False]") # Float16 tensor (half precision) float16_tensor = torch.tensor([1.5, -2.25, 3.125], dtype=torch.float16) save_test_file("float16.pt", {"tensor": float16_tensor}, "Float16 tensor [1.5, -2.25, 3.125]") # BFloat16 tensor bfloat16_tensor = torch.tensor([1.5, -2.5, 3.5], dtype=torch.bfloat16) save_test_file("bfloat16.pt", {"tensor": bfloat16_tensor}, "BFloat16 tensor [1.5, -2.5, 3.5]") # UInt8 tensor uint8_tensor = torch.tensor([0, 128, 255, 42], dtype=torch.uint8) save_test_file("uint8.pt", {"tensor": uint8_tensor}, "UInt8 tensor [0, 128, 255, 42]") # Test 2: Multi-dimensional tensors print("\n=== Generating Multi-dimensional Tensor Tests ===") # 2D tensor tensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32) save_test_file("tensor_2d.pt", {"tensor": tensor_2d}, "2D tensor shape (3, 2)") # 3D tensor torch.manual_seed(42) tensor_3d = torch.randn(2, 3, 4) * 10 save_test_file("tensor_3d.pt", {"tensor": tensor_3d}, "3D tensor shape (2, 3, 4)") # 4D tensor (common for conv weights) tensor_4d = torch.randn(2, 3, 2, 2) save_test_file("tensor_4d.pt", {"tensor": tensor_4d}, "4D tensor shape (2, 3, 2, 2)") # Test 3: State dict (multiple tensors) print("\n=== Generating State Dict Tests ===") state_dict = { "weight": torch.randn(3, 4), "bias": torch.randn(3), "running_mean": torch.zeros(3), "running_var": torch.ones(3), } save_test_file("state_dict.pt", state_dict, "State dict with 4 tensors") # Nested state dict nested_dict = { "layer1": { "weight": torch.randn(2, 3), "bias": torch.randn(2) }, "layer2": { "weight": torch.randn(4, 2), "bias": torch.randn(4) } } save_test_file("nested_dict.pt", nested_dict, "Nested state dict") # Test 4: Model checkpoint format print("\n=== Generating Model Checkpoint Tests ===") # Typical checkpoint format (use string keys for compatibility) checkpoint = { "model_state_dict": { "fc1.weight": torch.randn(10, 5), "fc1.bias": torch.randn(10), "fc2.weight": torch.randn(3, 10), "fc2.bias": torch.randn(3), }, "optimizer_state_dict": { "state": { "0": { # Use string key instead of integer "momentum_buffer": torch.randn(10, 5) } } }, "epoch": 42, "loss": 0.123 } save_test_file("checkpoint.pt", checkpoint, "Full checkpoint with model and optimizer state") # Test 5: Edge cases print("\n=== Generating Edge Case Tests ===") # Empty tensor (1D with 0 elements) empty_tensor = torch.zeros(0) save_test_file("empty.pt", {"tensor": empty_tensor}, "Empty tensor") # Scalar tensor (0-dimensional) scalar_tensor = torch.tensor(42.0) save_test_file("scalar.pt", {"tensor": scalar_tensor}, "Scalar tensor (0-dim)") # Large shape but small data (testing shape vs actual data) sparse_like = torch.zeros(100, 100) sparse_like[0, 0] = 1.0 sparse_like[50, 50] = 2.0 sparse_like[99, 99] = 3.0 save_test_file("large_shape.pt", {"tensor": sparse_like}, "Large shape (100, 100) mostly zeros") # Test 6: Mixed types in dict print("\n=== Generating Mixed Type Tests ===") mixed_types = { "float32": torch.tensor([1.0, 2.0], dtype=torch.float32), "int64": torch.tensor([100, 200], dtype=torch.int64), "bool": torch.tensor([True, False], dtype=torch.bool), "float64": torch.tensor([1.1, 2.2], dtype=torch.float64), } save_test_file("mixed_types.pt", mixed_types, "Dict with mixed tensor types") # Test 7: Special values print("\n=== Generating Special Value Tests ===") # NaN and Inf values special_values = torch.tensor([float('nan'), float('inf'), float('-inf'), 0.0, 1.0]) save_test_file("special_values.pt", {"tensor": special_values}, "Tensor with NaN and Inf") # Very small and very large values extreme_values = torch.tensor([1e-30, 1e30, -1e-30, -1e30], dtype=torch.float32) save_test_file("extreme_values.pt", {"tensor": extreme_values}, "Tensor with extreme values") # Test 8: Parameter wrapper (common in models) print("\n=== Generating Parameter Tests ===") import torch.nn as nn param = nn.Parameter(torch.randn(3, 3)) param_dict = {"param": param} save_test_file("parameter.pt", param_dict, "nn.Parameter wrapped tensor") # Test 9: Buffer-style tensors print("\n=== Generating Buffer Tests ===") # Simulate model buffers buffers = { "buffer1": torch.tensor([1, 2, 3], dtype=torch.int32), "buffer2": torch.tensor([True, False], dtype=torch.bool), } save_test_file("buffers.pt", buffers, "Model buffers") # Test 10: Complex nested structure print("\n=== Generating Complex Structure Tests ===") complex_structure = { "metadata": { "version": 1, "name": "test_model" }, "state": { "encoder": { "layer_0": { "weight": torch.randn(4, 3), "bias": torch.randn(4) }, "layer_1": { "weight": torch.randn(2, 4), "bias": torch.randn(2) } }, "decoder": { "weight": torch.randn(3, 2), "bias": torch.randn(3) } }, "config": { "hidden_size": 4, "num_layers": 2 } } save_test_file("complex_structure.pt", complex_structure, "Complex nested structure") print(f"\n✅ Generated {len(list(test_dir.glob('*.pt')))} test files in {test_dir}") print("\nTest files can be used to verify PyTorch reader functionality:") print("- Different data types (float32, int64, bool, etc.)") print("- Multi-dimensional tensors") print("- State dicts and nested structures") print("- Edge cases (empty, scalar, special values)") print("- Model checkpoints and parameters") ================================================ FILE: crates/burn-store/src/pytorch/tests/store/mod.rs ================================================ //! Comprehensive tests for PytorchStore with real model application use burn_core as burn; use std::path::PathBuf; use crate::ModuleStore; use crate::pytorch::PytorchStore; use burn_core::module::Module; use burn_nn::conv::{Conv2d, Conv2dConfig}; use burn_nn::{Linear, LinearConfig}; use burn_tensor::Tensor; use burn_tensor::backend::Backend; /// Path to pytorch test files (now under burn-store) fn pytorch_test_path(subdir: &str, filename: &str) -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("pytorch-tests") .join("tests") .join(subdir) .join(filename) } /// Path to burn-store test data files fn test_data_path(filename: &str) -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("src") .join("pytorch") .join("tests") .join("reader") .join("test_data") .join(filename) } /// Path to store test data files fn store_test_data_path(filename: &str) -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("src") .join("pytorch") .join("tests") .join("store") .join("test_data") .join(filename) } #[cfg(test)] mod basic_tests { use super::*; #[test] fn test_store_creation() { let store = PytorchStore::from_file("model.pth"); assert!(store.validate); assert!(!store.allow_partial); assert!(store.top_level_key.is_none()); // Contiguous index mapping is enabled by default for PyTorch files assert!(store.map_indices_contiguous); } #[test] fn test_store_map_indices_contiguous_default() { // Verify that map_indices_contiguous is enabled by default let store = PytorchStore::from_file("model.pth"); assert!( store.map_indices_contiguous, "map_indices_contiguous should be enabled by default" ); } #[test] fn test_store_map_indices_contiguous_disabled() { // Verify that we can disable map_indices_contiguous let store = PytorchStore::from_file("model.pth").map_indices_contiguous(false); assert!( !store.map_indices_contiguous, "map_indices_contiguous should be disabled after explicit call" ); } #[test] fn test_store_with_top_level_key() { let store = PytorchStore::from_file("model.pth").with_top_level_key("state_dict"); assert_eq!(store.top_level_key, Some("state_dict".to_string())); } #[test] fn test_store_configuration() { let store = PytorchStore::from_file("model.pth") .validate(false) .allow_partial(true) .with_regex(r"^encoder\.") .with_full_path("decoder.weight"); assert!(!store.validate); assert!(store.allow_partial); assert!(!store.filter.is_empty()); } #[test] fn test_store_with_remapping() { let store = PytorchStore::from_file("model.pth").with_key_remapping(r"^old\.", "new."); assert!(!store.remapper.is_empty()); } #[test] fn test_store_save_not_supported() { // Currently, saving to PyTorch format is not implemented // The collect_from method always returns an error let store = PytorchStore::from_file("test.pth"); // Just verify that store creation works assert!(store.validate); // Note: Actually testing save would require a proper Module implementation // which is complex. The implementation guarantees it returns an error. } } #[cfg(test)] mod linear_model_tests { use super::*; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] pub struct SimpleLinearModel { fc1: Linear, fc2: Linear, } impl SimpleLinearModel { pub fn new(device: &B::Device) -> Self { Self { fc1: LinearConfig::new(2, 3).init(device), fc2: LinearConfig::new(3, 4).init(device), } } pub fn forward(&self, x: Tensor) -> Tensor { let x = self.fc1.forward(x); self.fc2.forward(x) } } #[test] fn test_load_linear_model() { let device = Default::default(); let path = pytorch_test_path("linear", "linear.pt"); // Create a model and load weights from PyTorch let mut model = SimpleLinearModel::::new(&device); let mut store = PytorchStore::from_file(path).allow_partial(true); // Apply the PyTorch weights to our model let result = store.apply_to::(&mut model); assert!( result.is_ok(), "Failed to load linear model: {:?}", result.err() ); let result = result.unwrap(); assert!(!result.applied.is_empty(), "No tensors were applied"); // Test forward pass with loaded weights let input = Tensor::::ones([1, 2], &device); let output = model.forward(input); // Verify output shape assert_eq!(&*output.shape(), [1, 4]); } #[test] fn test_load_linear_with_bias() { let device = Default::default(); let path = pytorch_test_path("linear", "linear_with_bias.pt"); // Single linear layer with bias #[derive(Module, Debug)] struct LinearWithBias { fc1: Linear, } let mut model = LinearWithBias { fc1: LinearConfig::new(2, 3).init(&device), }; let mut store = PytorchStore::from_file(path).allow_partial(true); let result = store.apply_to::(&mut model); assert!(result.is_ok(), "Failed to load model with bias"); // Verify biases were loaded let result = result.unwrap(); let bias_loaded = result.applied.iter().any(|s| s.contains("bias")); assert!(bias_loaded, "Bias parameters not loaded"); } #[test] fn test_filter_layers() { let device = Default::default(); let path = pytorch_test_path("linear", "linear.pt"); let mut model = SimpleLinearModel::::new(&device); // Only load fc1 layers let mut store = PytorchStore::from_file(path) .with_regex(r"^fc1\.") .allow_partial(true); let result = store.apply_to::(&mut model).unwrap(); // Should only have fc1 tensors for tensor in &result.applied { assert!(tensor.contains("fc1")); assert!(!tensor.contains("fc2")); } } #[test] fn test_remap_layer_names() { let device = Default::default(); let path = pytorch_test_path("linear", "linear.pt"); // Model with different layer names #[derive(Module, Debug)] struct RemappedModel { linear1: Linear, linear2: Linear, } let mut model = RemappedModel { linear1: LinearConfig::new(2, 3).init(&device), linear2: LinearConfig::new(3, 4).init(&device), }; let mut store = PytorchStore::from_file(path) .with_key_remapping(r"^fc1\.", "linear1.") .with_key_remapping(r"^fc2\.", "linear2.") .allow_partial(true); let result = store.apply_to::(&mut model); assert!(result.is_ok(), "Failed to load with remapped names"); let result = result.unwrap(); // Verify remapped names were applied let has_linear1 = result.applied.iter().any(|s| s.contains("linear1")); assert!(has_linear1, "Remapped names not applied"); } } #[cfg(test)] mod conv_model_tests { use super::*; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] struct SimpleConvModel { conv1: Conv2d, conv2: Conv2d, } impl SimpleConvModel { pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 16], [3, 3]).init(device), conv2: Conv2dConfig::new([16, 32], [3, 3]).init(device), } } } #[test] fn test_load_conv2d_model() { let device = Default::default(); let path = pytorch_test_path("conv2d", "conv2d.pt"); // Check if file exists, skip if not if !path.exists() { println!("Skipping conv2d test - file not found: {:?}", path); return; } let mut model = SimpleConvModel::::new(&device); let mut store = PytorchStore::from_file(path).allow_partial(true); let result = store.apply_to::(&mut model); if let Ok(result) = result { assert!(!result.applied.is_empty(), "No conv tensors applied"); // Check for conv weights let has_conv_weights = result.applied.iter().any(|s| s.contains("weight")); assert!(has_conv_weights, "Conv weights not loaded"); } } #[test] fn test_load_conv1d_model() { let path = pytorch_test_path("conv1d", "conv1d.pt"); if !path.exists() { println!("Skipping conv1d test - file not found: {:?}", path); return; } // Just test that we can create a store for conv1d files let store = PytorchStore::from_file(path).allow_partial(true); assert!(store.allow_partial); } } #[cfg(test)] mod complex_model_tests { use super::*; type TestBackend = burn_ndarray::NdArray; #[test] fn test_load_with_top_level_key() { let path = test_data_path("checkpoint.pt"); // Just verify that we can create a store with top-level key let store = PytorchStore::from_file(path) .with_top_level_key("model_state_dict") .allow_partial(true); assert_eq!(store.top_level_key, Some("model_state_dict".to_string())); } #[test] fn test_load_nested_structure() { let path = test_data_path("complex_structure.pt"); // Just verify that we can create a store for nested structure let store = PytorchStore::from_file(path).allow_partial(true); assert!(store.allow_partial); } #[test] fn test_legacy_format() { let path = test_data_path("simple_legacy.pt"); if !path.exists() { println!("Skipping legacy format test - file not found: {:?}", path); return; } // Just verify that we can create a store for legacy format let store = PytorchStore::from_file(path).allow_partial(true); assert!(store.allow_partial); // Could load into an actual model if we had legacy model structure } #[test] fn test_key_remap_chained() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping key remap test - file not found: {:?}", path); return; } let device = Default::default(); // Model with different layer names that need remapping #[derive(Module, Debug)] struct RemappedChainModel { convolution1: Linear, // Will be remapped from fc1 linear2: Linear, // Will be remapped from fc2 } let mut model = RemappedChainModel { convolution1: LinearConfig::new(2, 3).init(&device), linear2: LinearConfig::new(3, 4).init(&device), }; // Chain multiple remappings let mut store = PytorchStore::from_file(path) .with_key_remapping(r"^fc1\.", "convolution1.") .with_key_remapping(r"^fc2\.", "linear2.") .allow_partial(true); let result = store.apply_to::(&mut model); if let Ok(result) = result { // Check that remapped names were applied assert!( !result.applied.is_empty(), "No tensors were applied after remapping" ); } } } #[cfg(test)] mod adapter_tests { use super::*; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] pub struct SimpleLinearModel { fc1: Linear, fc2: Linear, } impl SimpleLinearModel { pub fn new(device: &B::Device) -> Self { Self { fc1: LinearConfig::new(2, 3).init(device), fc2: LinearConfig::new(3, 4).init(device), } } } #[test] fn test_pytorch_adapter_always_applied() { // Test that PyTorchToBurnAdapter is always applied internally let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping adapter test - file not found: {:?}", path); return; } let device = Default::default(); let mut model = SimpleLinearModel::::new(&device); let mut store = PytorchStore::from_file(path).allow_partial(true); let result = store.apply_to::(&mut model); // PyTorchToBurnAdapter is always applied internally assert!( result.is_ok(), "Failed to load with internal PyTorchToBurnAdapter: {:?}", result.err() ); assert!(!result.unwrap().applied.is_empty()); } #[test] fn test_pytorch_adapter_with_filtering() { // Test that PyTorchToBurnAdapter works with filtering let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping filtering test - file not found: {:?}", path); return; } let device = Default::default(); let mut model = SimpleLinearModel::::new(&device); // Filter to exclude bias tensors let mut store = PytorchStore::from_file(path) .with_predicate(|path, _| !path.contains("bias")) .allow_partial(true); let result = store.apply_to::(&mut model).unwrap(); // Should not have any bias tensors due to filtering for applied_path in &result.applied { assert!( !applied_path.contains("bias"), "Bias tensor was not filtered: {}", applied_path ); } } } #[cfg(test)] mod error_handling_tests { use super::*; use burn_ndarray::NdArray; #[derive(Module, Debug)] pub struct SimpleLinearModel { fc1: Linear, fc2: Linear, } impl SimpleLinearModel { pub fn new(device: &B::Device) -> Self { Self { fc1: LinearConfig::new(2, 3).init(device), fc2: LinearConfig::new(3, 4).init(device), } } } #[test] fn test_missing_file() { let device = Default::default(); let mut model = SimpleLinearModel::::new(&device); let mut store = PytorchStore::from_file("nonexistent.pth"); let result = store.apply_to::(&mut model); assert!(result.is_err()); match result { Err(crate::pytorch::PytorchStoreError::Reader(_)) => {} _ => panic!("Expected reader error for missing file"), } } #[test] fn test_invalid_top_level_key() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!( "Skipping invalid top level key test - file not found: {:?}", path ); return; } let device = Default::default(); let mut model = SimpleLinearModel::::new(&device); let mut store = PytorchStore::from_file(path).with_top_level_key("nonexistent_key"); let result = store.apply_to::(&mut model); assert!(result.is_err(), "Should fail with invalid top level key"); } #[test] fn test_strict_validation() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!( "Skipping strict validation test - file not found: {:?}", path ); return; } let device = Default::default(); let mut model = SimpleLinearModel::::new(&device); // Apply very restrictive filter that matches nothing let mut store = PytorchStore::from_file(path) .with_regex(r"^this_will_never_match$") .validate(true) .allow_partial(false); let result = store.apply_to::(&mut model); // Should fail because no tensors match and allow_partial is false assert!( result.is_err(), "Should fail when no tensors match with allow_partial=false" ); } } #[cfg(test)] mod enum_variant_tests { use super::*; use crate::ModuleSnapshot; use burn_ndarray::NdArray; /// Enum representing different convolution block types (similar to YOLOX architecture) #[derive(Module, Debug)] pub enum ConvBlock { /// Base convolution block BaseConv(Linear), /// Depthwise separable convolution block DwsConv(Linear), } /// Model with enum field that will have variant names in Burn paths #[derive(Module, Debug)] pub struct ModelWithEnum { /// Feature extractor with enum variants feature: ConvBlock, /// Output classifier classifier: Linear, } impl ModelWithEnum { pub fn new(device: &B::Device) -> Self { Self { feature: ConvBlock::BaseConv(LinearConfig::new(3, 64).init(device)), classifier: LinearConfig::new(64, 10).init(device), } } } #[test] fn test_enum_variant_path_mismatch() { let device = Default::default(); let mut model = ModelWithEnum::::new(&device); // Load PyTorch model that was generated without enum variant names // PyTorch paths: "feature.weight", "feature.bias", "classifier.weight", "classifier.bias" // Burn paths: "feature.BaseConv.weight", "feature.BaseConv.bias", "classifier.weight", "classifier.bias" // ^^^^^^^^ enum variant name is included in Burn but not PyTorch let pytorch_file = store_test_data_path("model_without_enum_variants.pt"); // Try to load from PyTorch format (without enum variants) // Explicitly disable skip_enum_variants to demonstrate the mismatch problem let mut store = PytorchStore::from_file(pytorch_file) .skip_enum_variants(false) // Disable to show the mismatch .allow_partial(true) // Allow partial to see what's missing .validate(false); // Disable validation to get detailed missing info let result = store.apply_to::(&mut model); // The load should succeed (allow_partial=true) but report missing tensors match result { Ok(apply_result) => { // Verify we have missing tensors assert!( !apply_result.missing.is_empty(), "Should have missing tensors due to enum variant path mismatch" ); // Check that missing paths contain enum variants let enum_missing: Vec<_> = apply_result .missing .iter() .filter(|(_, container_stack)| container_stack.contains("Enum:")) .collect(); assert!( !enum_missing.is_empty(), "Missing tensors should be detected as having enum containers" ); // Verify the paths look like what we expect let has_base_conv_path = apply_result .missing .iter() .any(|(path, _)| path.contains("BaseConv")); assert!( has_base_conv_path, "Should have missing paths with 'BaseConv' enum variant. Missing: {:?}", apply_result .missing .iter() .map(|(p, _)| p) .collect::>() ); // Print the diagnostic output to show enum detection println!("\n{}", apply_result); // Verify the diagnostic message mentions enum variants let display_output = format!("{}", apply_result); assert!( display_output.contains("enum variant"), "Display output should mention enum variants" ); } Err(e) => panic!( "Load should succeed with allow_partial=true, got error: {}", e ), } } #[test] fn test_enum_variant_detection_in_container_stack() { let device = Default::default(); // Create model with enum let model = ModelWithEnum::::new(&device); // Collect snapshots to inspect container stacks let snapshots = model.collect(None, None, false); // Find a snapshot from inside the enum let enum_snapshot = snapshots .iter() .find(|s| s.full_path().contains("feature")) .expect("Should have feature snapshots"); // Verify container stack contains enum marker if let Some(container_stack) = &enum_snapshot.container_stack { let container_str = container_stack.join("."); assert!( container_str.contains("Enum:ConvBlock"), "Container stack should contain Enum:ConvBlock marker. Got: {}", container_str ); } else { panic!("Snapshot should have container_stack"); } } #[test] fn test_skip_enum_variants_feature() { let device = Default::default(); let mut model = ModelWithEnum::::new(&device); // Load PyTorch model that was generated without enum variant names // PyTorch paths: "feature.weight", "feature.bias", "classifier.weight", "classifier.bias" // Burn paths: "feature.BaseConv.weight", "feature.BaseConv.bias", "classifier.weight", "classifier.bias" let pytorch_file = store_test_data_path("model_without_enum_variants.pt"); // Try to load with skip_enum_variants enabled let mut store = PytorchStore::from_file(pytorch_file) .skip_enum_variants(true) // Enable enum variant skipping .allow_partial(true) .validate(false); let result = store.apply_to::(&mut model); // The load should succeed and all tensors should be loaded match result { Ok(apply_result) => { println!("\n{}", apply_result); // With skip_enum_variants enabled, we should successfully load the feature tensors let feature_applied = apply_result .applied .iter() .filter(|path| path.contains("feature")) .count(); assert!( feature_applied > 0, "Should have applied feature tensors with skip_enum_variants=true. Applied: {:?}", apply_result.applied ); // The feature tensors should NOT be in missing anymore let feature_missing = apply_result .missing .iter() .filter(|(path, _)| path.contains("feature")) .count(); assert_eq!( feature_missing, 0, "Feature tensors should not be missing with skip_enum_variants=true. Missing: {:?}", apply_result.missing ); } Err(e) => panic!( "Load with skip_enum_variants should succeed, got error: {}", e ), } } } #[cfg(test)] mod direct_access_tests { use super::*; #[test] fn test_get_all_snapshots() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let mut store = PytorchStore::from_file(path); let snapshots = store.get_all_snapshots().unwrap(); // linear.pt should have fc1.weight, fc1.bias, fc2.weight, fc2.bias assert!(!snapshots.is_empty(), "Should have snapshots"); assert!( snapshots.contains_key("fc1.weight"), "Should contain fc1.weight" ); assert!( snapshots.contains_key("fc1.bias"), "Should contain fc1.bias" ); } #[test] fn test_get_snapshot_existing() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let mut store = PytorchStore::from_file(path); // Get existing snapshot let snapshot = store.get_snapshot("fc1.weight").unwrap(); assert!(snapshot.is_some(), "Should find fc1.weight"); let snapshot = snapshot.unwrap(); // Linear weight should be 2D assert_eq!(snapshot.shape.len(), 2, "Weight should be 2D tensor"); // Verify we can load data let data = snapshot.to_data().unwrap(); assert!(!data.bytes.is_empty(), "Data should not be empty"); } #[test] fn test_get_snapshot_not_found() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let mut store = PytorchStore::from_file(path); // Get non-existent snapshot let snapshot = store.get_snapshot("nonexistent.weight").unwrap(); assert!(snapshot.is_none(), "Should not find nonexistent tensor"); } #[test] fn test_keys() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let mut store = PytorchStore::from_file(path); let keys = store.keys().unwrap(); assert!(!keys.is_empty(), "Should have keys"); assert!( keys.contains(&"fc1.weight".to_string()), "Keys should contain fc1.weight" ); assert!( keys.contains(&"fc1.bias".to_string()), "Keys should contain fc1.bias" ); } #[test] fn test_keys_fast_path() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } // Create fresh store - cache should be empty let mut store = PytorchStore::from_file(&path); // keys() should work without populating the full cache (fast path) let keys = store.keys().unwrap(); assert!(!keys.is_empty(), "Should have keys via fast path"); // Now call get_all_snapshots to populate cache let snapshots = store.get_all_snapshots().unwrap(); assert!(!snapshots.is_empty(), "Should have snapshots"); // keys() should now use the cached data let keys2 = store.keys().unwrap(); assert_eq!(keys.len(), keys2.len(), "Keys count should match"); } #[test] fn test_caching_behavior() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let mut store = PytorchStore::from_file(path); // First call populates cache let snapshots1 = store.get_all_snapshots().unwrap(); let count1 = snapshots1.len(); // Second call uses cache let snapshots2 = store.get_all_snapshots().unwrap(); let count2 = snapshots2.len(); assert_eq!(count1, count2, "Cached results should match"); } #[test] fn test_get_all_snapshots_with_remapping() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } // Create store with key remapping let mut store = PytorchStore::from_file(path).with_key_remapping(r"^fc1\.", "linear1."); let snapshots = store.get_all_snapshots().unwrap(); // Should have remapped keys assert!( snapshots.contains_key("linear1.weight"), "Should contain remapped key linear1.weight. Keys: {:?}", snapshots.keys().collect::>() ); assert!( snapshots.contains_key("linear1.bias"), "Should contain remapped key linear1.bias" ); // Original keys should not exist assert!( !snapshots.contains_key("fc1.weight"), "Should not contain original key fc1.weight" ); } #[test] fn test_get_snapshot_with_remapped_name() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } // Create store with key remapping let mut store = PytorchStore::from_file(path).with_key_remapping(r"^fc1\.", "linear1."); // Should find by remapped name let snapshot = store.get_snapshot("linear1.weight").unwrap(); assert!(snapshot.is_some(), "Should find tensor by remapped name"); // Should NOT find by original name let snapshot_orig = store.get_snapshot("fc1.weight").unwrap(); assert!( snapshot_orig.is_none(), "Should not find tensor by original name after remapping" ); } #[test] fn test_get_all_snapshots_ignores_filter() { let path = pytorch_test_path("linear", "linear.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } // Create store with filter that only matches fc1 let mut store = PytorchStore::from_file(path).with_regex(r"^fc1\."); // get_all_snapshots should return ALL tensors regardless of filter let snapshots = store.get_all_snapshots().unwrap(); // Should have both fc1 and fc2 tensors assert!( snapshots.contains_key("fc1.weight"), "Should contain fc1.weight" ); assert!( snapshots.contains_key("fc2.weight"), "Should contain fc2.weight (filter not applied to get_all_snapshots)" ); } } /// Tests for contiguous index mapping feature #[cfg(test)] mod map_indices_contiguous_tests { use super::*; type TestBackend = burn_ndarray::NdArray; /// Model with a Vec of Conv2d layers that expects contiguous indices #[derive(Module, Debug)] struct SequentialConvModel { fc: Vec>, } impl SequentialConvModel { pub fn new(device: &B::Device, num_layers: usize) -> Self { Self { fc: (0..num_layers) .map(|_| { Conv2dConfig::new([2, 2], [3, 3]) .with_bias(true) .init(device) }) .collect(), } } } #[test] fn test_load_non_contiguous_indexes_with_mapping() { // This test uses the non_contiguous_indexes.pt file which has: // fc.0.weight, fc.0.bias, fc.2.weight, fc.2.bias, fc.4.weight, ... (non-contiguous) // The Burn model expects fc.0, fc.1, fc.2, ... (contiguous) let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let device = Default::default(); // Create model with 5 conv layers (matching the PyTorch model) let mut model = SequentialConvModel::::new(&device, 5); // Load with contiguous index mapping enabled (default) let mut store = PytorchStore::from_file(&path) .map_indices_contiguous(true) .allow_partial(true) .validate(false); let result = store.apply_to::(&mut model); match result { Ok(apply_result) => { println!("Applied tensors: {:?}", apply_result.applied); println!("Missing tensors: {:?}", apply_result.missing); println!("Unused tensors: {:?}", apply_result.unused); // All fc layers should be loaded successfully assert!( !apply_result.applied.is_empty(), "Should have applied tensors" ); // Verify we have tensors from all 5 layers // With mapping: fc.0, fc.1, fc.2, fc.3, fc.4 for i in 0..5 { let has_weight = apply_result .applied .iter() .any(|p| p.contains(&format!("fc.{}.weight", i))); let has_bias = apply_result .applied .iter() .any(|p| p.contains(&format!("fc.{}.bias", i))); assert!( has_weight, "Should have applied fc.{}.weight, applied: {:?}", i, apply_result.applied ); assert!( has_bias, "Should have applied fc.{}.bias, applied: {:?}", i, apply_result.applied ); } // There should be no missing tensors (assuming model matches) let missing_fc: Vec<_> = apply_result .missing .iter() .filter(|(p, _)| p.starts_with("fc.")) .collect(); assert!( missing_fc.is_empty(), "Should have no missing fc tensors with index mapping. Missing: {:?}", missing_fc ); } Err(e) => panic!("Failed to load with index mapping: {}", e), } } #[test] fn test_load_non_contiguous_indexes_without_mapping() { // This test verifies that loading fails or has missing tensors when // map_indices_contiguous is disabled let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } let device = Default::default(); // Create model with 5 conv layers let mut model = SequentialConvModel::::new(&device, 5); // Load with contiguous index mapping DISABLED let mut store = PytorchStore::from_file(&path) .map_indices_contiguous(false) // Disable index mapping .allow_partial(true) .validate(false); let result = store.apply_to::(&mut model); match result { Ok(apply_result) => { println!( "Without index mapping - Applied tensors: {:?}", apply_result.applied ); println!( "Without index mapping - Missing tensors: {:?}", apply_result.missing ); // Without index mapping, we should have missing tensors for fc.1, fc.3 // because the source has fc.0, fc.2, fc.4, fc.6, fc.8 but model expects fc.0-4 let missing_fc: Vec<_> = apply_result .missing .iter() .filter(|(p, _)| p.starts_with("fc.")) .collect(); assert!( !missing_fc.is_empty(), "Should have missing fc tensors without index mapping (indices 1, 3 don't exist in file)" ); // Specifically, fc.1 and fc.3 should be missing let has_fc1_missing = apply_result .missing .iter() .any(|(p, _)| p.starts_with("fc.1.")); let has_fc3_missing = apply_result .missing .iter() .any(|(p, _)| p.starts_with("fc.3.")); assert!( has_fc1_missing || has_fc3_missing, "Should have fc.1 or fc.3 missing. Missing: {:?}", apply_result.missing ); } Err(e) => panic!("Unexpected error: {}", e), } } #[test] fn test_mapping_applied_to_keys() { // Verify that the keys returned by the store are mapped let path = pytorch_test_path("non_contiguous_indexes", "non_contiguous_indexes.pt"); if !path.exists() { println!("Skipping test - file not found: {:?}", path); return; } // With index mapping enabled (default) let mut store_mapped = PytorchStore::from_file(&path).map_indices_contiguous(true); let keys_mapped = store_mapped.keys().unwrap(); println!("Keys with index mapping: {:?}", keys_mapped); // Should have contiguous keys: fc.0, fc.1, fc.2, fc.3, fc.4 assert!( keys_mapped.iter().any(|k| k.starts_with("fc.1.")), "With index mapping, should have fc.1 (from fc.2)" ); assert!( keys_mapped.iter().any(|k| k.starts_with("fc.2.")), "With index mapping, should have fc.2 (from fc.4)" ); // Without index mapping let mut store_no_mapping = PytorchStore::from_file(&path).map_indices_contiguous(false); let keys_no_mapping = store_no_mapping.keys().unwrap(); println!("Keys without index mapping: {:?}", keys_no_mapping); // Should have original non-contiguous keys: fc.0, fc.2, fc.4, fc.6, fc.8 assert!( keys_no_mapping.iter().any(|k| k.starts_with("fc.2.")), "Without index mapping, should have original fc.2" ); assert!( keys_no_mapping.iter().any(|k| k.starts_with("fc.4.")), "Without index mapping, should have original fc.4" ); assert!( !keys_no_mapping.iter().any(|k| k.starts_with("fc.1.")), "Without index mapping, should NOT have fc.1 (not in original file)" ); } } ================================================ FILE: crates/burn-store/src/pytorch/tests/store/test_data/generate_enum_test.py ================================================ #!/usr/bin/env python3 """ Generate PyTorch test data for enum variant path mismatch testing. This script creates a PyTorch checkpoint that simulates how PyTorch models export their state dicts WITHOUT enum variant names in the paths. Example: - PyTorch path: "feature.weight" - Burn path: "feature.BaseConv.weight" (includes enum variant "BaseConv") Run with: uv run generate_enum_test.py """ import torch import torch.nn as nn class SimpleModel(nn.Module): """ Simple PyTorch model that represents what a Burn enum model would look like WITHOUT the enum variant names in the path. In Burn, this would be: struct ModelWithEnum { feature: ConvBlock, // enum with BaseConv, DwsConv variants classifier: Linear, } But PyTorch exports it as flat paths without the enum variant names. """ def __init__(self): super().__init__() # This represents the "feature" field which is an enum in Burn # PyTorch doesn't have enums, so it's just a Linear layer # Path will be: "feature.weight" and "feature.bias" self.feature = nn.Linear(3, 64) # This represents the "classifier" field # Path will be: "classifier.weight" and "classifier.bias" self.classifier = nn.Linear(64, 10) def forward(self, x): x = self.feature(x) x = torch.relu(x) x = self.classifier(x) return x def generate_enum_variant_mismatch_test(): """Generate test file demonstrating enum variant path mismatch.""" model = SimpleModel() # Initialize with some deterministic weights for testing torch.manual_seed(42) for param in model.parameters(): param.data.normal_(0, 0.1) # Save the state dict # PyTorch paths: "feature.weight", "feature.bias", "classifier.weight", "classifier.bias" # Burn paths: "feature.BaseConv.weight", "feature.BaseConv.bias", ... # ^^^^^^^^ enum variant is missing in PyTorch torch.save(model.state_dict(), "model_without_enum_variants.pt") print("Generated: model_without_enum_variants.pt") print("\nPyTorch state dict keys:") for key in model.state_dict().keys(): shape = tuple(model.state_dict()[key].shape) print(f" {key}: {shape}") print("\nExpected Burn paths (with enum variant):") print(" feature.BaseConv.weight: (3, 64)") print(" feature.BaseConv.bias: (64,)") print(" classifier.weight: (64, 10)") print(" classifier.bias: (10,)") print("\n⚠️ Notice: Burn includes 'BaseConv' enum variant, PyTorch doesn't!") if __name__ == "__main__": generate_enum_variant_mismatch_test() ================================================ FILE: crates/burn-store/src/safetensors/mod.rs ================================================ //! SafeTensors format support for Burn deep learning framework. //! //! [SafeTensors](https://github.com/huggingface/safetensors) is a simple, safe, and efficient format //! for storing and loading tensors. It provides fast zero-copy deserialization and strong safety //! guarantees, making it ideal for production environments. //! //! # Features //! //! - **Fast Loading**: Zero-copy tensor access using safetensors' built-in mechanisms //! - **Safety**: Prevents arbitrary code execution during model loading //! - **Efficiency**: Memory-mapped files enable lazy loading without reading entire file //! - **Filtering**: Load only specific tensors using path filters //! - **Remapping**: Transform tensor names during load/save operations //! - **Metadata**: Store and retrieve custom metadata alongside tensors (automatic `format`, `producer` and `version` metadata included) //! - **Cross-Platform**: Works on all platforms including no-std environments //! //! # Usage Examples //! //! ## Basic Save and Load //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save a model to a file //! let mut store = SafetensorsStore::from_file("model.safetensors"); //! model.save_into(&mut store)?; //! //! // Load a model from a file //! let mut store = SafetensorsStore::from_file("model.safetensors"); //! let mut model = Model::new(&device); //! model.load_from(&mut store)?; //! ``` //! //! ## Memory-Based Operations //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save to memory buffer //! let mut store = SafetensorsStore::from_bytes(None); //! model.save_into(&mut store)?; //! let bytes = store.get_bytes()?; //! //! // Load from memory buffer //! let mut store = SafetensorsStore::from_bytes(Some(bytes)); //! let mut model = Model::new(&device); //! model.load_from(&mut store)?; //! ``` //! //! ## Advanced Features //! //! ### Filter Configuration with Builder Pattern //! //! ```rust,no_run //! # use burn_store::SafetensorsStore; //! // Filter with regex patterns (OR logic - matches any pattern) //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_regex(r"^encoder\..*") // Match all encoder tensors //! .with_regex(r".*\.bias$"); // OR match any bias tensors //! //! // Filter with exact paths //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_full_path("encoder.weight") //! .with_full_path("encoder.bias") //! .with_full_paths(vec!["decoder.scale", "decoder.norm"]); //! //! // Custom filter logic with predicate //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_predicate(|path, _dtype| { //! // Only save layer weights (not biases) //! path.contains("layer") && path.ends_with("weight") //! }); //! //! // Combine multiple filter methods //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_regex(r"^encoder\..*") // All encoder tensors //! .with_full_path("decoder.scale") // Plus specific decoder.scale //! .with_predicate(|path, _| { // Plus any projection tensors //! path.contains("projection") //! }); //! //! // Save or load all tensors (no filtering) //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .match_all(); //! ``` //! //! ### Tensor Name Remapping //! //! Remap tensor names during load/save operations for compatibility between different frameworks: //! //! ```rust,no_run //! # use burn_store::{SafetensorsStore, KeyRemapper}; //! // Using builder pattern for common remapping patterns //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X //! .with_key_remapping(r"\.gamma$", ".weight") // X.gamma -> X.weight //! .with_key_remapping(r"\.beta$", ".bias"); // X.beta -> X.bias //! //! // Or using a pre-configured KeyRemapper for complex transformations //! let remapper = KeyRemapper::new() //! .add_pattern(r"^pytorch\.(.*)", "burn.$1").expect("valid regex") // pytorch.layer -> burn.layer //! .add_pattern(r"^(.*)\.running_mean$", "$1.mean").expect("valid regex") // layer.running_mean -> layer.mean //! .add_pattern(r"^(.*)\.running_var$", "$1.variance").expect("valid regex"); // layer.running_var -> layer.variance //! //! let mut store = SafetensorsStore::from_file("model.safetensors") //! .remap(remapper); //! ``` //! //! ### Framework Adapters //! //! Use adapters for automatic framework-specific transformations: //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter, BurnToPyTorchAdapter}; //! //! // Loading PyTorch model into Burn //! let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") //! .with_from_adapter(PyTorchToBurnAdapter) // Transposes linear weights, renames norm params //! .allow_partial(true); // PyTorch models may have extra tensors //! //! let mut burn_model = Model::new(&device); //! burn_model.load_from(&mut store)?; //! //! // Saving Burn model for PyTorch //! let mut store = SafetensorsStore::from_file("for_pytorch.safetensors") //! .with_to_adapter(BurnToPyTorchAdapter); // Transposes weights back, renames for PyTorch //! //! burn_model.save_into(&mut store)?; //! ``` //! //! ### Additional Configuration Options //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! let mut store = SafetensorsStore::from_file("model.safetensors") //! // Add custom metadata //! .metadata("version", "1.0.0") //! .metadata("producer", "burn") //! // Allow partial loading (continue even if some tensors are missing) //! .allow_partial(true) //! // Disable validation for faster loading //! .validate(false); //! //! // Use the configured store //! model.save_into(&mut store)?; // For saving //! // or //! model.load_from(&mut store)?; // For loading //! ``` //! //! # Efficient Loading with SafeTensors //! //! SafeTensors provides efficient tensor loading through its zero-copy design: //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! let mut store = SafetensorsStore::from_file("large_model.safetensors"); //! // Uses memory mapping (when available) for zero-copy access //! // Falls back to buffered reading when mmap is not available //! let mut model = Model::new(&device); //! model.load_from(&mut store)?; //! ``` //! //! The safetensors approach provides: //! - Zero-copy views - tensors are accessed directly from the mapped file //! - Lazy loading - only accessed tensors are materialized //! - Efficient memory usage - no unnecessary data duplication //! //! # Lazy Loading and Inspection //! //! SafeTensors provides efficient inspection and selective loading through its //! zero-copy design and built-in metadata handling: //! //! ```rust,ignore //! use burn_store::SafetensorsStore; //! //! // Open a file - uses safetensors' efficient header reading //! let store = SafetensorsStore::from_file("large_model.safetensors"); //! //! // List all tensor names from the metadata //! let tensor_names = store.list_tensors()?; //! println!("Model contains {} tensors", tensor_names.len()); //! //! // Get tensor metadata without loading tensor data //! if let Some((shape, dtype)) = store.tensor_info("encoder.weight")? { //! println!("Encoder weight shape: {:?}, dtype: {:?}", shape, dtype); //! } //! //! // Selectively load tensors - safetensors handles efficient access //! let encoder_tensors = store.load_tensors(&[ //! "encoder.weight", //! "encoder.bias", //! "encoder.norm" //! ])?; //! //! // Distributed loading: each worker loads only its assigned layers //! // SafeTensors' zero-copy views ensure minimal memory usage //! let worker_layers = match worker_id { //! 0 => vec!["encoder.layer1", "encoder.layer2"], //! 1 => vec!["encoder.layer3", "encoder.layer4"], //! 2 => vec!["decoder.layer1", "decoder.layer2"], //! _ => vec!["head.weight", "head.bias"], //! }; //! let worker_tensors = store.load_tensors(&worker_layers)?; //! ``` //! //! # Builder Pattern API Reference //! //! The SafetensorsStore provides a fluent builder API for configuration: //! //! ## Filtering Methods //! //! - **`with_regex(pattern)`** - Add regex pattern to match tensor names (OR logic with multiple patterns) //! - **`with_full_path(path)`** - Add exact tensor path to include //! - **`with_full_paths(paths)`** - Add multiple exact tensor paths to include //! - **`with_predicate(fn)`** - Add custom filter function `fn(&str, &str) -> bool` //! - **`match_all()`** - Disable filtering, include all tensors //! //! ## Remapping Methods //! //! - **`with_key_remapping(from, to)`** - Add regex pattern to rename tensors //! - **`remap(KeyRemapper)`** - Use a pre-configured KeyRemapper for complex transformations //! //! ## Adapter Methods //! //! - **`with_from_adapter(adapter)`** - Set adapter for loading (e.g., PyTorchToBurnAdapter) //! - **`with_to_adapter(adapter)`** - Set adapter for saving (e.g., BurnToPyTorchAdapter) //! //! ## Configuration Methods //! //! - **`metadata(key, value)`** - Add custom metadata to saved files (in addition to automatic `format`, `producer` and `version`) //! - **`allow_partial(bool)`** - Allow loading even if some tensors are missing //! - **`validate(bool)`** - Enable/disable tensor validation during loading //! //! All methods return `Self` for chaining: //! //! ```rust,no_run //! use burn_store::{SafetensorsStore, PyTorchToBurnAdapter}; //! //! let store = SafetensorsStore::from_file("model.safetensors") //! .with_regex(r"^encoder\..*") //! .with_key_remapping(r"\.gamma$", ".weight") //! .with_from_adapter(PyTorchToBurnAdapter) //! .allow_partial(true) //! .metadata("version", "2.0"); //! ``` //! //! # Working with Bytes //! //! For direct byte operations without files: //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot}; //! //! // Save to bytes with filtering and remapping //! let mut store = SafetensorsStore::from_bytes(None) //! .with_regex(r"^encoder\..*") // Only save encoder tensors //! .with_key_remapping(r"^encoder\.", "transformer.") // Rename encoder.X -> transformer.X //! .metadata("subset", "encoder_only"); //! model.save_into(&mut store)?; //! let bytes = store.get_bytes()?; //! //! // Load from bytes (allow partial since we only saved encoder) //! let mut store = SafetensorsStore::from_bytes(Some(bytes)) //! .with_key_remapping(r"^transformer\.", "encoder.") // Rename back: transformer.X -> encoder.X //! .allow_partial(true); //! let mut model = Model::new(&device); //! let result = model.load_from(&mut store)?; //! println!("Applied {} tensors", result.applied.len()); //! ``` //! //! # Complete Example: PyTorch Model Migration //! //! Migrating a PyTorch model to Burn with filtering, remapping, and adapters: //! //! ```rust,ignore //! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter}; //! //! // Load PyTorch model with all transformations //! let mut store = SafetensorsStore::from_file("pytorch_model.safetensors") //! // Use PyTorch adapter for automatic transformations //! .with_from_adapter(PyTorchToBurnAdapter) //! // Only load transformer layers //! .with_regex(r"^transformer\..*") //! // Rename old layer names to new structure //! .with_key_remapping(r"^transformer\.h\.(\d+)\.", "transformer.layer$1.") //! // Skip unexpected tensors from PyTorch //! .allow_partial(true) //! // Add metadata about the conversion //! .metadata("source", "pytorch") //! .metadata("converted_by", "burn-store"); //! //! let mut model = TransformerModel::new(&device); //! let result = model.load_from(&mut store)?; //! //! println!("Successfully loaded {} tensors", result.applied.len()); //! if !result.missing.is_empty() { //! println!("Missing tensors: {:?}", result.missing); //! } //! ``` //! //! # Format Details //! //! SafeTensors uses a simple binary format: //! - **8 bytes**: Header size (unsigned little-endian 64-bit integer) //! - **N bytes**: JSON header with tensor metadata //! - Contains: `{"tensor_name": {"dtype": "F32", "shape": [1, 2, 3], "data_offsets": [start, end]}, ...}` //! - Special key `__metadata__` for user-defined string metadata //! - **Rest**: Raw tensor data (referenced by offsets in header) //! //! The format enables: //! - **Secure loading**: No code execution, just data //! - **Efficient access**: Use offsets to read only needed tensors //! - **Simple parsing**: Standard JSON header with fixed structure mod store; pub use store::{SafetensorsStore, SafetensorsStoreError}; #[cfg(test)] mod tests; ================================================ FILE: crates/burn-store/src/safetensors/store.rs ================================================ //! SafeTensors store implementation using the official safetensors crate. use crate::{ ApplyResult, IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot, }; #[cfg(feature = "std")] use crate::{KeyRemapper, map_indices_contiguous}; use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; use alloc::vec::Vec; use burn_core::module::ParamId; use burn_tensor::backend::Backend; use burn_tensor::{BoolStore, DType, TensorData}; use core::fmt; use core::ops::Deref; use hashbrown::HashMap; // Arc is only available on targets with atomic pointers #[cfg(target_has_atomic = "ptr")] use alloc::sync::Arc; // For targets without atomic pointers, we use Box instead #[cfg(not(target_has_atomic = "ptr"))] type Arc = Box; /// Errors that can occur during SafeTensors operations. #[derive(Debug)] pub enum SafetensorsStoreError { /// SafeTensors crate error. Safetensors(safetensors::SafeTensorError), /// I/O error. #[cfg(feature = "std")] Io(std::io::Error), /// Tensor not found. TensorNotFound(String), /// Validation failed. ValidationFailed(String), /// Other error. Other(String), } impl fmt::Display for SafetensorsStoreError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Safetensors(e) => write!(f, "SafeTensors error: {}", e), #[cfg(feature = "std")] Self::Io(e) => write!(f, "I/O error: {}", e), Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name), Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg), Self::Other(msg) => write!(f, "{}", msg), } } } impl core::error::Error for SafetensorsStoreError {} impl From for SafetensorsStoreError { fn from(e: safetensors::SafeTensorError) -> Self { SafetensorsStoreError::Safetensors(e) } } #[cfg(feature = "std")] impl From for SafetensorsStoreError { fn from(e: std::io::Error) -> Self { SafetensorsStoreError::Io(e) } } /// SafeTensors store supporting both file and memory storage. pub enum SafetensorsStore { /// File-based storage. #[cfg(feature = "std")] File(FileStore), /// Memory-based storage. Memory(MemoryStore), } impl Default for SafetensorsStore { /// Create a default memory-based store. fn default() -> Self { Self::from_bytes(None) } } impl SafetensorsStore { /// Get the default metadata that includes Burn framework information. /// /// This includes: /// - `format`: "safetensors" /// - `producer`: "burn" /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION) /// /// These metadata fields are automatically added to all saved models. pub fn default_metadata() -> HashMap { let mut metadata = HashMap::new(); metadata.insert("format".to_string(), "safetensors".to_string()); metadata.insert("producer".to_string(), "burn".to_string()); metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string()); metadata } /// Create a store for loading from or saving to a file. #[cfg(feature = "std")] pub fn from_file(path: impl Into) -> Self { Self::File(FileStore { path: path.into(), filter: PathFilter::new(), remapper: KeyRemapper::new(), metadata: Self::default_metadata(), validate: true, allow_partial: false, overwrite: false, skip_enum_variants: false, // Contiguous index mapping is off by default for SafeTensors // (SafeTensors files typically have clean, contiguous indices) map_indices_contiguous: false, from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), snapshots_cache: None, }) } /// Create a store for working with bytes in memory. pub fn from_bytes(bytes: Option>) -> Self { Self::Memory(MemoryStore { data: bytes.map(Arc::new), filter: PathFilter::new(), #[cfg(feature = "std")] remapper: KeyRemapper::new(), metadata: Self::default_metadata(), validate: true, allow_partial: false, skip_enum_variants: false, // Contiguous index mapping is off by default for SafeTensors #[cfg(feature = "std")] map_indices_contiguous: false, from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), snapshots_cache: None, }) } /// Filter which tensors to load/save. pub fn filter(mut self, filter: PathFilter) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = filter, Self::Memory(p) => p.filter = filter, } self } /// Add a regex pattern to filter tensors. /// /// Multiple patterns can be added and they work with OR logic. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_regex(r"^encoder\..*") // Match all encoder tensors /// .with_regex(r".*\.weight$"); // OR match any weight tensors /// ``` #[cfg(feature = "std")] pub fn with_regex>(mut self, pattern: S) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_regex(pattern), Self::Memory(p) => p.filter = p.filter.clone().with_regex(pattern), } self } /// Add multiple regex patterns to filter tensors. #[cfg(feature = "std")] pub fn with_regexes(mut self, patterns: I) -> Self where I: IntoIterator, S: AsRef, { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_regexes(patterns), Self::Memory(p) => p.filter = p.filter.clone().with_regexes(patterns), } self } /// Add an exact full path to match. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_full_path("encoder.layer1.weight") /// .with_full_path("decoder.output.bias"); /// ``` pub fn with_full_path>(mut self, path: S) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_full_path(path), Self::Memory(p) => p.filter = p.filter.clone().with_full_path(path), } self } /// Add multiple exact full paths to match. pub fn with_full_paths(mut self, paths: I) -> Self where I: IntoIterator, S: Into, { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_full_paths(paths), Self::Memory(p) => p.filter = p.filter.clone().with_full_paths(paths), } self } /// Add a predicate function for custom filtering logic. /// /// The predicate receives the tensor path and container path. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias")); /// ``` pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_predicate(predicate), Self::Memory(p) => p.filter = p.filter.clone().with_predicate(predicate), } self } /// Add multiple predicate functions. pub fn with_predicates(mut self, predicates: I) -> Self where I: IntoIterator bool>, { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().with_predicates(predicates), Self::Memory(p) => p.filter = p.filter.clone().with_predicates(predicates), } self } /// Set the filter to match all paths (disables filtering). pub fn match_all(mut self) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.filter = p.filter.clone().match_all(), Self::Memory(p) => p.filter = p.filter.clone().match_all(), } self } /// Remap tensor names during load/save. #[cfg(feature = "std")] pub fn remap(mut self, remapper: KeyRemapper) -> Self { match &mut self { Self::File(p) => p.remapper = remapper, Self::Memory(p) => p.remapper = remapper, } self } /// Add a regex pattern to remap tensor names during load/save. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// let store = SafetensorsStore::from_file("model.safetensors") /// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X /// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight /// ``` #[cfg(feature = "std")] pub fn with_key_remapping( mut self, from_pattern: impl AsRef, to_pattern: impl Into, ) -> Self { match &mut self { Self::File(p) => { p.remapper = p .remapper .clone() .add_pattern(from_pattern, to_pattern) .expect("Invalid regex pattern"); } Self::Memory(p) => { p.remapper = p .remapper .clone() .add_pattern(from_pattern, to_pattern) .expect("Invalid regex pattern"); } } self } /// Add metadata to be saved with the tensors. pub fn metadata(mut self, key: impl Into, value: impl Into) -> Self { let key = key.into(); let value = value.into(); match &mut self { #[cfg(feature = "std")] Self::File(p) => { p.metadata.insert(key, value); } Self::Memory(p) => { p.metadata.insert(key, value); } } self } /// Clear all metadata including the default Burn framework metadata. /// /// This removes the automatic `format`, `producer` and `version` fields. /// Use this when you need complete control over metadata or when /// saving models for use with other frameworks. pub fn clear_metadata(mut self) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => { p.metadata.clear(); } Self::Memory(p) => { p.metadata.clear(); } } self } /// Set whether to validate tensors during loading (default: true). pub fn validate(mut self, validate: bool) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.validate = validate, Self::Memory(p) => p.validate = validate, } self } /// Allow partial loading of tensors (continue even if some tensors are missing). pub fn allow_partial(mut self, allow: bool) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.allow_partial = allow, Self::Memory(p) => p.allow_partial = allow, } self } /// Skip enum variant names when loading or saving tensor paths. /// /// When enabled during **loading**, tensor paths from the source that don't include enum variants /// can be matched against Burn module paths that do include them. /// For example, source path "feature.weight" can match Burn path "feature.BaseConv.weight". /// /// When enabled during **saving**, enum variant names are omitted from the exported tensor paths, /// making them compatible with PyTorch naming conventions. /// For example, "feature.BaseConv.weight" becomes "feature.weight" in the exported file. /// /// This is useful when working with models from/to formats that don't include enum variant /// names in their parameter paths (like PyTorch models). /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// // For PyTorch compatibility /// let store = SafetensorsStore::from_file("model.safetensors") /// .skip_enum_variants(true); /// ``` pub fn skip_enum_variants(mut self, skip: bool) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.skip_enum_variants = skip, Self::Memory(p) => p.skip_enum_variants = skip, } self } /// Enable or disable automatic contiguous mapping of layer indices (default: false). /// /// When enabled, non-contiguous numeric indices in tensor paths are renumbered /// to be contiguous. This is useful when loading models that have gaps /// in layer numbering, such as PyTorch models using `nn.Sequential` with mixed /// layer types (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5). /// /// # Example /// /// With index mapping enabled: /// - `fc.0.weight` → `fc.0.weight` /// - `fc.2.weight` → `fc.1.weight` (gap filled) /// - `fc.4.weight` → `fc.2.weight` (gap filled) /// /// # Arguments /// /// * `map` - `true` to enable contiguous index mapping, `false` to disable /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// // Enable contiguous index mapping for PyTorch-exported safetensors /// let store = SafetensorsStore::from_file("model.safetensors") /// .map_indices_contiguous(true); /// ``` #[cfg(feature = "std")] pub fn map_indices_contiguous(mut self, map: bool) -> Self { match &mut self { Self::File(p) => p.map_indices_contiguous = map, Self::Memory(p) => p.map_indices_contiguous = map, } self } /// Set whether to overwrite existing files when saving (default: false). /// /// When set to `false`, attempting to save to an existing file will result in an error. /// When set to `true`, existing files will be overwritten without warning. /// /// This setting only applies to file-based stores. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// let mut store = SafetensorsStore::from_file("model.safetensors") /// .overwrite(true); /// // Will overwrite if file exists when saving /// ``` #[cfg(feature = "std")] pub fn overwrite(mut self, overwrite: bool) -> Self { match &mut self { Self::File(p) => p.overwrite = overwrite, Self::Memory(_) => { // Memory stores don't have overwrite semantics, ignore } } self } /// Set the adapter for loading tensors (converting from source format to Burn). pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.from_adapter = Box::new(adapter), Self::Memory(p) => p.from_adapter = Box::new(adapter), } self } /// Set the adapter for saving tensors (converting from Burn to target format). pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self { match &mut self { #[cfg(feature = "std")] Self::File(p) => p.to_adapter = Box::new(adapter), Self::Memory(p) => p.to_adapter = Box::new(adapter), } self } /// Get saved bytes from memory-based store. /// /// # Example /// ```rust,no_run /// # use burn_store::SafetensorsStore; /// # fn example() -> Result<(), Box> { /// let mut store = SafetensorsStore::from_bytes(None); /// // After saving model with collect_to()... /// let bytes = store.get_bytes()?; /// # Ok(()) /// # } /// ``` pub fn get_bytes(&self) -> Result, SafetensorsStoreError> { match self { #[cfg(feature = "std")] Self::File(_) => Err(SafetensorsStoreError::Other( "Cannot get bytes from file-based store".to_string(), )), Self::Memory(p) => p .data() .map(|arc| arc.as_ref().clone()) .ok_or_else(|| SafetensorsStoreError::Other("No data available".to_string())), } } } /// File-based store. #[cfg(feature = "std")] pub struct FileStore { path: std::path::PathBuf, filter: PathFilter, remapper: KeyRemapper, metadata: HashMap, validate: bool, allow_partial: bool, overwrite: bool, skip_enum_variants: bool, /// Enable contiguous mapping of layer indices (default: false) map_indices_contiguous: bool, from_adapter: Box, to_adapter: Box, /// Cached tensor snapshots (parsed once, reused) snapshots_cache: Option>, } /// Memory-based store. pub struct MemoryStore { data: Option>>, filter: PathFilter, #[cfg(feature = "std")] remapper: KeyRemapper, metadata: HashMap, validate: bool, allow_partial: bool, skip_enum_variants: bool, /// Enable contiguous mapping of layer indices (default: false) #[cfg(feature = "std")] map_indices_contiguous: bool, from_adapter: Box, to_adapter: Box, /// Cached tensor snapshots (parsed once, reused) snapshots_cache: Option>, } impl Default for MemoryStore { fn default() -> Self { Self { data: None, filter: PathFilter::new(), #[cfg(feature = "std")] remapper: KeyRemapper::new(), metadata: HashMap::new(), validate: true, allow_partial: false, skip_enum_variants: false, #[cfg(feature = "std")] map_indices_contiguous: false, from_adapter: Box::new(IdentityAdapter), to_adapter: Box::new(IdentityAdapter), snapshots_cache: None, } } } impl MemoryStore { #[cfg(test)] pub(crate) fn data(&self) -> Option>> { self.data.clone() } #[cfg(not(test))] fn data(&self) -> Option>> { self.data.clone() } #[cfg(test)] pub(crate) fn set_data(&mut self, data: Vec) { self.data = Some(Arc::new(data)); } } // Adapter to use TensorSnapshot directly with safetensors #[derive(Debug)] struct TensorSnapshotAdapter(TensorSnapshot); impl safetensors::View for TensorSnapshotAdapter { fn dtype(&self) -> safetensors::Dtype { // Convert from burn dtype to safetensors dtype dtype_to_safetensors(self.0.dtype).unwrap_or(safetensors::Dtype::F32) } fn shape(&self) -> &[usize] { &self.0.shape } fn data(&self) -> alloc::borrow::Cow<'_, [u8]> { // Only materialize data when actually needed for serialization let data = self .0 .to_data() .unwrap_or_else(|e| panic!("Failed to get tensor data: {:?}", e)); alloc::borrow::Cow::Owned(data.bytes.deref().to_vec()) } fn data_len(&self) -> usize { // Use the efficient data_len method from TensorSnapshot self.0.data_len() } } impl ModuleStore for SafetensorsStore { type Error = SafetensorsStoreError; fn collect_from>( &mut self, module: &M, ) -> Result<(), Self::Error> { // Invalidate cache since we're writing new data match self { #[cfg(feature = "std")] Self::File(p) => p.snapshots_cache = None, Self::Memory(p) => p.snapshots_cache = None, } // Collect tensor snapshots from module with adapter // The to_adapter converts from Burn format to target format for saving let to_adapter = match self { #[cfg(feature = "std")] Self::File(p) => p.to_adapter.clone(), Self::Memory(p) => p.to_adapter.clone(), }; let mut snapshots = module.collect(None, Some(to_adapter), self.get_skip_enum_variants()); // Apply filtering snapshots = apply_filter(snapshots, self.get_filter()); // Apply remapping #[cfg(feature = "std")] { snapshots = apply_remapping(snapshots, self.get_remapper()); } // Get metadata (already includes format, producer and version from default_metadata) let metadata = self.get_metadata().clone(); #[cfg(feature = "std")] let std_metadata: std::collections::HashMap = metadata .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(); // Write to storage match self { #[cfg(feature = "std")] Self::File(p) => { // Check if file exists and overwrite is disabled if p.path.exists() && !p.overwrite { return Err(SafetensorsStoreError::Other(format!( "File already exists: {}. Use .overwrite(true) to overwrite.", p.path.display() ))); } // Convert to safetensors format let tensors = snapshots_to_safetensors(snapshots)?; // Use serialize_to_file which streams directly to disk // This calls the lazy closures on-demand without buffering everything safetensors::serialize_to_file(tensors, Some(std_metadata), &p.path)?; Ok(()) } Self::Memory(p) => { // For memory, we need to serialize to bytes let tensors = snapshots_to_safetensors(snapshots)?; // For no-std, serialize still needs std HashMap when std feature is enabled #[cfg(feature = "std")] let data = safetensors::serialize(tensors, Some(std_metadata))?; #[cfg(not(feature = "std"))] let data = safetensors::serialize(tensors, Some(metadata))?; p.data = Some(Arc::new(data)); Ok(()) } } } fn apply_to>( &mut self, module: &mut M, ) -> Result { // Get snapshots from cache let snapshots: Vec = self.get_all_snapshots()?.values().cloned().collect(); // Get the adapter let adapter: Box = match self { #[cfg(feature = "std")] Self::File(p) => p.from_adapter.clone(), Self::Memory(p) => p.from_adapter.clone(), }; // Get filter (cloned to Option for apply) let filter = self.get_filter(); let filter_opt = if filter.is_empty() { None } else { Some(filter.clone()) }; // Apply to module with adapter // The adapter will be applied during module traversal with proper container info // Filter is applied here during apply, not during cache population let result = module.apply( snapshots, filter_opt, Some(adapter), self.get_skip_enum_variants(), ); // Validate if needed if self.get_validate() && !result.errors.is_empty() { return Err(SafetensorsStoreError::ValidationFailed(format!( "Import errors: {:?}", result.errors ))); } if !self.get_allow_partial() && !result.missing.is_empty() { return Err(SafetensorsStoreError::TensorNotFound(format!( "\n{}", result ))); } Ok(result) } fn get_snapshot(&mut self, name: &str) -> Result, Self::Error> { // Ensure cache is populated self.ensure_snapshots_cache()?; let cache = match self { #[cfg(feature = "std")] Self::File(p) => p.snapshots_cache.as_ref().unwrap(), Self::Memory(p) => p.snapshots_cache.as_ref().unwrap(), }; Ok(cache.get(name)) } fn get_all_snapshots(&mut self) -> Result<&BTreeMap, Self::Error> { // Ensure cache is populated self.ensure_snapshots_cache()?; let cache = match self { #[cfg(feature = "std")] Self::File(p) => p.snapshots_cache.as_ref().unwrap(), Self::Memory(p) => p.snapshots_cache.as_ref().unwrap(), }; Ok(cache) } fn keys(&mut self) -> Result, Self::Error> { // Always use the cache to ensure remapping is applied consistently Ok(self.get_all_snapshots()?.keys().cloned().collect()) } } impl SafetensorsStore { fn get_filter(&self) -> &PathFilter { match self { #[cfg(feature = "std")] Self::File(p) => &p.filter, Self::Memory(p) => &p.filter, } } #[cfg(feature = "std")] fn get_remapper(&self) -> &KeyRemapper { match self { Self::File(p) => &p.remapper, Self::Memory(p) => &p.remapper, } } fn get_metadata(&self) -> &HashMap { match self { #[cfg(feature = "std")] Self::File(p) => &p.metadata, Self::Memory(p) => &p.metadata, } } fn get_validate(&self) -> bool { match self { #[cfg(feature = "std")] Self::File(p) => p.validate, Self::Memory(p) => p.validate, } } fn get_allow_partial(&self) -> bool { match self { #[cfg(feature = "std")] Self::File(p) => p.allow_partial, Self::Memory(p) => p.allow_partial, } } fn get_skip_enum_variants(&self) -> bool { match self { #[cfg(feature = "std")] Self::File(p) => p.skip_enum_variants, Self::Memory(p) => p.skip_enum_variants, } } #[cfg(feature = "std")] fn get_map_indices_contiguous(&self) -> bool { match self { Self::File(p) => p.map_indices_contiguous, Self::Memory(p) => p.map_indices_contiguous, } } /// Ensure the snapshots cache is populated fn ensure_snapshots_cache(&mut self) -> Result<(), SafetensorsStoreError> { // Check if cache exists let has_cache = match self { #[cfg(feature = "std")] Self::File(p) => p.snapshots_cache.is_some(), Self::Memory(p) => p.snapshots_cache.is_some(), }; if has_cache { return Ok(()); } // Load snapshots #[allow(unused_mut)] let mut snapshots = match self { #[cfg(feature = "std")] Self::File(p) => safetensors_to_snapshots_lazy_file(&p.path)?, Self::Memory(p) => { let data_arc = p .data .clone() .ok_or_else(|| SafetensorsStoreError::Other("No data loaded".to_string()))?; safetensors_to_snapshots_lazy(data_arc)? } }; // Apply remapping (but NOT filtering - that's done at apply time) #[cfg(feature = "std")] { snapshots = match self { Self::File(p) => apply_remapping(snapshots, &p.remapper), Self::Memory(p) => apply_remapping(snapshots, &p.remapper), }; } // Apply contiguous index mapping if enabled // This must be done after remapping so that remapped paths are mapped #[cfg(feature = "std")] if self.get_map_indices_contiguous() { let (mapped, _) = map_indices_contiguous(snapshots); snapshots = mapped; } // Build cache as BTreeMap let cache: BTreeMap = snapshots.into_iter().map(|s| (s.full_path(), s)).collect(); // Store cache match self { #[cfg(feature = "std")] Self::File(p) => p.snapshots_cache = Some(cache), Self::Memory(p) => p.snapshots_cache = Some(cache), } Ok(()) } } /// Apply filter to tensor snapshots. fn apply_filter(mut snapshots: Vec, filter: &PathFilter) -> Vec { if filter.is_empty() { return snapshots; } snapshots.retain(|snapshot| { let path = snapshot.full_path(); filter.matches(&path) }); snapshots } /// Apply remapping to tensor snapshots. #[cfg(feature = "std")] fn apply_remapping(snapshots: Vec, remapper: &KeyRemapper) -> Vec { if remapper.is_empty() { return snapshots; } let (remapped, _) = remapper.remap(snapshots); remapped } /// Convert TensorSnapshots to safetensors format lazily. fn snapshots_to_safetensors( snapshots: Vec, ) -> Result, SafetensorsStoreError> { let mut tensors = Vec::new(); for snapshot in snapshots { let name = snapshot.full_path(); // No need to materialize data - TensorSnapshot now has dtype and shape cached! tensors.push((name, TensorSnapshotAdapter(snapshot))); } Ok(tensors) } /// Convert safetensors to TensorSnapshots with lazy loading. fn safetensors_to_snapshots_lazy( data_arc: Arc>, ) -> Result, SafetensorsStoreError> { // Parse to get metadata let tensors = safetensors::SafeTensors::deserialize(&data_arc)?; let mut snapshots = Vec::new(); for (name, tensor_snapshot) in tensors.tensors() { // Extract metadata without materializing data let dtype = safetensor_dtype_to_burn(tensor_snapshot.dtype())?; let shape = tensor_snapshot.shape(); let path_parts: Vec = name.split('.').map(|s| s.to_string()).collect(); // Create a lazy closure that will deserialize only this tensor when needed #[cfg(target_has_atomic = "ptr")] let data_clone = Arc::clone(&data_arc); #[cfg(not(target_has_atomic = "ptr"))] let data_clone = data_arc.clone(); let name_clone = name.to_string(); let data_fn = alloc::rc::Rc::new(move || { // Re-deserialize when needed (this is cheap, just parsing header) let tensors = safetensors::SafeTensors::deserialize(&data_clone).map_err(|e| { crate::TensorSnapshotError::IoError(format!( "Failed to re-deserialize safetensors: {}", e )) })?; // Find our specific tensor let tensor = tensors.tensor(&name_clone).map_err(|e| { crate::TensorSnapshotError::DataError(format!( "Tensor '{}' not found: {}", name_clone, e )) })?; // Now materialize just this tensor's data let bytes = burn_tensor::Bytes::from_bytes_vec(tensor.data().to_vec()); Ok(TensorData { bytes, shape: tensor.shape().into(), dtype: safetensor_dtype_to_burn(tensor.dtype()) .map_err(|_| crate::TensorSnapshotError::DataError("Invalid dtype".into()))?, }) }); let snapshot = TensorSnapshot::from_closure( data_fn, dtype, shape.into(), path_parts, vec![], // Empty container_stack - will be filled during module traversal ParamId::new(), ); snapshots.push(snapshot); } Ok(snapshots) } /// Convert safetensors to TensorSnapshots with true on-demand loading from file. /// This reads only the header initially, then loads tensor data on demand. #[cfg(feature = "std")] fn safetensors_to_snapshots_lazy_file( path: &std::path::Path, ) -> Result, SafetensorsStoreError> { // Always use memory mapping for the most efficient access use memmap2::MmapOptions; // Memory map the file for efficient access let file = std::fs::File::open(path)?; let mmap = unsafe { MmapOptions::new().map(&file)? }; let mmap_arc = Arc::new(mmap); // Parse just to get metadata (safetensors won't copy data with mmap) let tensors = safetensors::SafeTensors::deserialize(&mmap_arc)?; let mut snapshots = Vec::new(); for (name, tensor_snapshot) in tensors.tensors() { let dtype = safetensor_dtype_to_burn(tensor_snapshot.dtype())?; let shape = tensor_snapshot.shape(); let path_parts: Vec = name.split('.').map(|s| s.to_string()).collect(); // Create a lazy closure that accesses the mmap'd data let mmap_clone = Arc::clone(&mmap_arc); let name_clone = name.to_string(); let data_fn = alloc::rc::Rc::new(move || { // Re-parse to get the tensor snapshot (this is cheap with mmap) let tensors = safetensors::SafeTensors::deserialize(&mmap_clone).map_err(|e| { crate::TensorSnapshotError::IoError(format!("Failed to deserialize: {}", e)) })?; let tensor = tensors.tensor(&name_clone).map_err(|e| { crate::TensorSnapshotError::DataError(format!( "Tensor '{}' not found: {}", name_clone, e )) })?; // Only now do we actually copy the tensor data Ok(TensorData { bytes: burn_tensor::Bytes::from_bytes_vec(tensor.data().to_vec()), shape: tensor.shape().into(), dtype: safetensor_dtype_to_burn(tensor.dtype()) .map_err(|_| crate::TensorSnapshotError::DataError("Invalid dtype".into()))?, }) }); let snapshot = TensorSnapshot::from_closure( data_fn, dtype, shape.into(), path_parts, vec![], // Empty container_stack - will be filled during module traversal ParamId::new(), ); snapshots.push(snapshot); } Ok(snapshots) } /// Helper to convert safetensors Dtype to burn DType. fn safetensor_dtype_to_burn(dtype: safetensors::Dtype) -> Result { use safetensors::Dtype; match dtype { Dtype::F64 => Ok(DType::F64), Dtype::F32 => Ok(DType::F32), Dtype::F16 => Ok(DType::F16), Dtype::BF16 => Ok(DType::BF16), Dtype::I64 => Ok(DType::I64), Dtype::I32 => Ok(DType::I32), Dtype::I16 => Ok(DType::I16), Dtype::I8 => Ok(DType::I8), Dtype::U64 => Ok(DType::U64), Dtype::U32 => Ok(DType::U32), Dtype::U8 => Ok(DType::U8), Dtype::BOOL => Ok(DType::Bool(BoolStore::Native)), _ => Err(SafetensorsStoreError::Other(format!( "Unsupported dtype: {:?}", dtype ))), } } /// Helper to convert DType to safetensors Dtype. fn dtype_to_safetensors(dtype: DType) -> Result { use safetensors::Dtype; match dtype { DType::F64 => Ok(Dtype::F64), DType::F32 | DType::Flex32 => Ok(Dtype::F32), // Flex32 is stored as F32 DType::F16 => Ok(Dtype::F16), DType::BF16 => Ok(Dtype::BF16), DType::I64 => Ok(Dtype::I64), DType::I32 => Ok(Dtype::I32), DType::I16 => Ok(Dtype::I16), DType::I8 => Ok(Dtype::I8), DType::U64 => Ok(Dtype::U64), DType::U32 => Ok(Dtype::U32), DType::U16 => Err(SafetensorsStoreError::Other( "U16 dtype not yet supported in safetensors".to_string(), )), DType::U8 => Ok(Dtype::U8), DType::Bool(BoolStore::Native) => Ok(Dtype::BOOL), DType::Bool(BoolStore::U32) => Ok(Dtype::U32), DType::Bool(BoolStore::U8) => Ok(Dtype::U8), DType::QFloat(_) => Err(SafetensorsStoreError::Other( "Quantized tensors not yet supported in safetensors".to_string(), )), } } ================================================ FILE: crates/burn-store/src/safetensors/tests/adapter.rs ================================================ use burn_core as burn; use crate::{ BurnToPyTorchAdapter, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore, }; use burn_core::module::{Module, Param}; use burn_nn::{Linear, LinearConfig}; use burn_tensor::Tensor; use burn_tensor::backend::Backend; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] struct TestModel { linear: Linear, norm_weight: Param>, norm_bias: Param>, } impl TestModel { fn new(device: &B::Device) -> Self { Self { linear: LinearConfig::new(4, 2).with_bias(true).init(device), norm_weight: Param::from_data([1.0, 1.0], device), norm_bias: Param::from_data([0.0, 0.0], device), } } } #[test] fn pytorch_to_burn_adapter_linear_transpose() { let device = Default::default(); let model = TestModel::::new(&device); // Save with BurnToPyTorch adapter (will transpose linear weights) let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter); model.save_into(&mut save_store).unwrap(); // Load with PyTorchToBurn adapter (will transpose back) let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let mut model2 = TestModel::::new(&device); let result = model2.load_from(&mut load_store).unwrap(); // Should successfully load all tensors assert!(!result.applied.is_empty()); // Verify the linear weights are the same after round-trip let weight1 = model.linear.weight.val().to_data(); let weight2 = model2.linear.weight.val().to_data(); assert_eq!(weight1.shape, weight2.shape); let data1 = weight1.to_vec::().unwrap(); let data2 = weight2.to_vec::().unwrap(); for (a, b) in data1.iter().zip(data2.iter()) { assert!( (a - b).abs() < 1e-6, "Weights differ after adapter round-trip" ); } } #[test] fn pytorch_to_burn_adapter_norm_rename() { let device = Default::default(); // Create a model with norm-like naming #[derive(Module, Debug)] struct NormModel { norm_gamma: Param>, norm_beta: Param>, } impl NormModel { fn new(device: &B::Device) -> Self { Self { norm_gamma: Param::from_data([1.0, 2.0, 3.0], device), norm_beta: Param::from_data([0.1, 0.2, 0.3], device), } } } let model = NormModel::::new(&device); // Save with BurnToPyTorch adapter (will rename gamma->weight, beta->bias) let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter); model.save_into(&mut save_store).unwrap(); // The saved data should have PyTorch naming convention // We can't directly verify the internal names, but we can verify round-trip works // Load with PyTorchToBurn adapter (will rename weight->gamma, bias->beta) let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let mut model2 = NormModel::::new(&device); let result = model2.load_from(&mut load_store).unwrap(); // Should load successfully assert!(!result.applied.is_empty()); // Verify data is preserved let gamma1 = model.norm_gamma.val().to_data().to_vec::().unwrap(); let gamma2 = model2.norm_gamma.val().to_data().to_vec::().unwrap(); let beta1 = model.norm_beta.val().to_data().to_vec::().unwrap(); let beta2 = model2.norm_beta.val().to_data().to_vec::().unwrap(); assert_eq!(gamma1, gamma2); assert_eq!(beta1, beta2); } #[test] fn no_adapter_preserves_original() { let device = Default::default(); let model = TestModel::::new(&device); // Save without adapter let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).unwrap(); // Load without adapter let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let mut model2 = TestModel::::new(&device); let result = model2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert!(!result.applied.is_empty()); // Verify data is exactly the same let weight1 = model.linear.weight.val().to_data(); let weight2 = model2.linear.weight.val().to_data(); assert_eq!(weight1.shape, weight2.shape); assert_eq!( weight1.to_vec::().unwrap(), weight2.to_vec::().unwrap() ); } #[test] #[cfg(all(feature = "std", target_has_atomic = "ptr"))] fn adapter_with_pytorch_import() { use crate::PyTorchToBurnAdapter; let device = Default::default(); // Reference the safetensors file from burn-store let safetensors_path = concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ); // Simple test model that matches some of the PyTorch structure #[derive(Module, Debug)] struct SimpleNet { fc1: Linear, } impl SimpleNet { fn new(device: &B::Device) -> Self { Self { fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), } } } // Load with PyTorchToBurn adapter let mut store = SafetensorsStore::from_file(safetensors_path) .with_from_adapter(PyTorchToBurnAdapter) .validate(false) .allow_partial(true); let mut model = SimpleNet::::new(&device); let result = model.load_from(&mut store).unwrap(); // Should load some tensors (fc1 if it exists in the file) // This mainly tests that the adapter works with real PyTorch files assert!(!result.applied.is_empty() || !result.missing.is_empty()); } #[test] fn half_precision_adapter_round_trip() { use crate::HalfPrecisionAdapter; use burn_tensor::DType; let device = Default::default(); let model = TestModel::::new(&device); // Save with HalfPrecisionAdapter (F32 -> F16) let adapter = HalfPrecisionAdapter::new(); let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter.clone()); model.save_into(&mut save_store).unwrap(); // Verify Linear tensors are F16, raw params stay F32 (no recognized module type) let save_bytes = match &save_store { SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(), _ => panic!("Expected memory store"), }; let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes.clone())); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (name, snapshot) in snapshots.iter() { if name.starts_with("linear") { assert_eq!( snapshot.dtype, DType::F16, "Linear tensor '{}' should be F16", name ); } else { assert_eq!( snapshot.dtype, DType::F32, "Raw param '{}' should stay F32", name ); } } // Load back with same adapter (F16 -> F32) let mut load_store = SafetensorsStore::from_bytes(Some(save_bytes)).with_from_adapter(adapter); let mut model2 = TestModel::::new(&device); let result = model2.load_from(&mut load_store).unwrap(); assert!(!result.applied.is_empty()); // Verify values are close (F32 -> F16 -> F32 has rounding) let w1 = model.linear.weight.val().to_data().to_vec::().unwrap(); let w2 = model2 .linear .weight .val() .to_data() .to_vec::() .unwrap(); for (a, b) in w1.iter().zip(w2.iter()) { assert!( (a - b).abs() < 0.01, "Weight values differ too much after F16 round-trip: {} vs {}", a, b ); } } #[test] fn half_precision_adapter_without_module() { use crate::HalfPrecisionAdapter; use burn_nn::{LayerNorm, LayerNormConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct MixedModel { linear: Linear, norm: LayerNorm, } let device = Default::default(); let model = MixedModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), norm: LayerNormConfig::new(2).init(&device), }; // Save: exclude LayerNorm from half-precision conversion let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm"); let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter); model.save_into(&mut save_store).unwrap(); // Verify: Linear tensors are F16, LayerNorm tensors remain F32 let save_bytes = match &save_store { SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(), _ => panic!("Expected memory store"), }; let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes)); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (name, snapshot) in snapshots { if name.starts_with("linear") { assert_eq!( snapshot.dtype, DType::F16, "Linear tensor '{}' should be F16", name ); } else if name.starts_with("norm") { assert_eq!( snapshot.dtype, DType::F32, "LayerNorm tensor '{}' should stay F32", name ); } } } #[test] fn half_precision_adapter_default_converts_layer_norm() { use crate::HalfPrecisionAdapter; use burn_nn::{LayerNorm, LayerNormConfig}; use burn_tensor::DType; #[derive(Module, Debug)] struct NormModel { linear: Linear, norm: LayerNorm, } let device = Default::default(); let model = NormModel:: { linear: LinearConfig::new(4, 2).with_bias(true).init(&device), norm: LayerNormConfig::new(2).init(&device), }; // Default adapter converts LayerNorm let adapter = HalfPrecisionAdapter::new(); let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter); model.save_into(&mut save_store).unwrap(); let save_bytes = match &save_store { SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(), _ => panic!("Expected memory store"), }; let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes)); let snapshots = inspect_store.get_all_snapshots().unwrap(); for (name, snapshot) in snapshots { assert_eq!( snapshot.dtype, DType::F16, "All tensors should be F16 by default, but '{}' is {:?}", name, snapshot.dtype ); } } ================================================ FILE: crates/burn-store/src/safetensors/tests/direct_access.rs ================================================ use burn_core as burn; use crate::{ModuleStore, SafetensorsStore}; use burn_core::module::{Module, Param}; use burn_tensor::backend::Backend; use burn_tensor::{Tensor, shape}; type TestBackend = burn_ndarray::NdArray; // Test module for direct access tests #[derive(Module, Debug)] struct DirectAccessTestModule { weight: Param>, bias: Param>, nested: DirectAccessNestedModule, } #[derive(Module, Debug)] struct DirectAccessNestedModule { gamma: Param>, beta: Param>, } impl DirectAccessTestModule { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device), bias: Param::from_data([0.1, 0.2], device), nested: DirectAccessNestedModule { gamma: Param::from_data([1.0, 2.0], device), beta: Param::from_data([0.5, 0.5], device), }, } } } #[test] fn test_memory_get_all_snapshots() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); // Save module to memory let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); // Get bytes and create load store let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); // Get all snapshots let snapshots = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots.len(), 4); assert!(snapshots.contains_key("weight")); assert!(snapshots.contains_key("bias")); assert!(snapshots.contains_key("nested.gamma")); assert!(snapshots.contains_key("nested.beta")); } #[test] fn test_memory_get_snapshot_existing() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); // Get existing snapshot let snapshot = load_store.get_snapshot("weight").unwrap(); assert!(snapshot.is_some()); let snapshot = snapshot.unwrap(); assert_eq!(snapshot.shape, shape![2, 2]); // Verify data let data = snapshot.to_data().unwrap(); let values: Vec = data.to_vec().unwrap(); assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]); } #[test] fn test_memory_get_snapshot_nested() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); // Get nested snapshot let snapshot = load_store.get_snapshot("nested.gamma").unwrap(); assert!(snapshot.is_some()); let snapshot = snapshot.unwrap(); let data = snapshot.to_data().unwrap(); let values: Vec = data.to_vec().unwrap(); assert_eq!(values, vec![1.0, 2.0]); } #[test] fn test_memory_get_snapshot_not_found() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); // Get non-existent snapshot let snapshot = load_store.get_snapshot("nonexistent").unwrap(); assert!(snapshot.is_none()); } #[test] fn test_memory_keys() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let keys = load_store.keys().unwrap(); assert_eq!(keys.len(), 4); assert!(keys.contains(&"weight".to_string())); assert!(keys.contains(&"bias".to_string())); assert!(keys.contains(&"nested.gamma".to_string())); assert!(keys.contains(&"nested.beta".to_string())); } #[test] fn test_memory_caching_behavior() { let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let mut save_store = SafetensorsStore::from_bytes(None); save_store.collect_from(&module).unwrap(); let bytes = save_store.get_bytes().unwrap(); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); // Call get_all_snapshots multiple times - should return same cached data let snapshots1 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots1.len(), 4); let snapshots2 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots2.len(), 4); // Verify we can still access individual snapshots after caching let snapshot = load_store.get_snapshot("bias").unwrap(); assert!(snapshot.is_some()); } // ============================================================================ // Tests for FileStore variant // ============================================================================ #[test] #[cfg(feature = "std")] fn test_file_get_all_snapshots() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_get_all_snapshots.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); let snapshots = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots.len(), 4); assert!(snapshots.contains_key("weight")); assert!(snapshots.contains_key("bias")); assert!(snapshots.contains_key("nested.gamma")); assert!(snapshots.contains_key("nested.beta")); } #[test] #[cfg(feature = "std")] fn test_file_get_snapshot_existing() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_get_snapshot.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); let snapshot = load_store.get_snapshot("weight").unwrap(); assert!(snapshot.is_some()); let snapshot = snapshot.unwrap(); assert_eq!(snapshot.shape, shape![2, 2]); let data = snapshot.to_data().unwrap(); let values: Vec = data.to_vec().unwrap(); assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]); } #[test] #[cfg(feature = "std")] fn test_file_get_snapshot_not_found() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_not_found.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); let snapshot = load_store.get_snapshot("nonexistent").unwrap(); assert!(snapshot.is_none()); } #[test] #[cfg(feature = "std")] fn test_file_keys() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_keys.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); let keys = load_store.keys().unwrap(); assert_eq!(keys.len(), 4); assert!(keys.contains(&"weight".to_string())); assert!(keys.contains(&"bias".to_string())); assert!(keys.contains(&"nested.gamma".to_string())); assert!(keys.contains(&"nested.beta".to_string())); } #[test] #[cfg(feature = "std")] fn test_file_keys_fast_path() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_keys_fast.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); // Create fresh store - cache should be empty let mut load_store = SafetensorsStore::from_file(&path); // keys() should work without populating the full cache (fast path) let keys = load_store.keys().unwrap(); assert_eq!(keys.len(), 4); // Now call get_all_snapshots to populate cache let snapshots = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots.len(), 4); // keys() should now use the cached data let keys2 = load_store.keys().unwrap(); assert_eq!(keys2.len(), 4); } #[test] #[cfg(feature = "std")] fn test_file_caching_behavior() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_caching.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); // First call populates cache let snapshots1 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots1.len(), 4); // Second call uses cache let snapshots2 = load_store.get_all_snapshots().unwrap(); assert_eq!(snapshots2.len(), 4); } #[test] #[cfg(feature = "std")] fn test_file_cache_invalidation_on_save() { use tempfile::tempdir; let device = Default::default(); let module = DirectAccessTestModule::::new(&device); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_invalidation.safetensors"); // Create store, save, and populate cache let mut store = SafetensorsStore::from_file(&path).overwrite(true); store.collect_from(&module).unwrap(); let snapshots1 = store.get_all_snapshots().unwrap(); assert_eq!(snapshots1.len(), 4); // Save again (this should invalidate cache) store.collect_from(&module).unwrap(); // Cache should be repopulated with fresh data let snapshots2 = store.get_all_snapshots().unwrap(); assert_eq!(snapshots2.len(), 4); } ================================================ FILE: crates/burn-store/src/safetensors/tests/error_handling.rs ================================================ use crate::{ModuleSnapshot, SafetensorsStore}; use burn_nn::LinearConfig; type TestBackend = burn_ndarray::NdArray; #[test] fn shape_mismatch_errors() { let device = Default::default(); // Create a module let module = LinearConfig::new(2, 2) .with_bias(true) .init::(&device); // Save module let mut save_store = SafetensorsStore::from_bytes(None); module.save_into(&mut save_store).unwrap(); // Try to load into incompatible module (different dimensions) let mut incompatible_module = LinearConfig::new(3, 3) .with_bias(true) .init::(&device); // Load without validation - should return errors in the result let mut load_store = SafetensorsStore::from_bytes(None).validate(false); // Disable validation to get errors in result if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { // Get Arc and extract data let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let result = incompatible_module.load_from(&mut load_store).unwrap(); // Should have errors due to shape mismatch assert!(!result.errors.is_empty()); // Try again with validation enabled - should return Err let mut load_store_with_validation = SafetensorsStore::from_bytes(None).validate(true); if let SafetensorsStore::Memory(ref mut p) = load_store_with_validation && let SafetensorsStore::Memory(ref p_save) = save_store { // Get Arc and extract data let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let validation_result = incompatible_module.load_from(&mut load_store_with_validation); assert!(validation_result.is_err()); } ================================================ FILE: crates/burn-store/src/safetensors/tests/file_io.rs ================================================ use burn_core as burn; use crate::{ModuleSnapshot, ModuleStore, SafetensorsStore}; use burn_core::module::{Module, Param}; use burn_nn::{Initializer, LinearConfig}; use burn_tensor::Tensor; use burn_tensor::backend::Backend; use tempfile::tempdir; type TestBackend = burn_ndarray::NdArray; // Define a test model with forward pass #[derive(Module, Debug)] struct ForwardTestModel { linear1: burn_nn::Linear, linear2: burn_nn::Linear, } impl ForwardTestModel { fn forward(&self, input: Tensor) -> Tensor { let x = self.linear1.forward(input); let x = burn::tensor::activation::gelu(x); self.linear2.forward(x) } } // Define config for the model #[derive(burn::config::Config, Debug)] struct ForwardTestModelConfig { input_size: usize, hidden_size: usize, output_size: usize, } impl ForwardTestModelConfig { fn init(&self, device: &B::Device) -> ForwardTestModel { ForwardTestModel { linear1: LinearConfig::new(self.input_size, self.hidden_size) .with_bias(true) .init(device), linear2: LinearConfig::new(self.hidden_size, self.output_size) .with_bias(true) .init(device), } } } #[derive(Module, Debug)] pub struct ModuleBasic { weight_basic: Param>, } impl ModuleBasic { fn new(device: &B::Device) -> Self { Self { weight_basic: Initializer::Normal { std: 1.0, mean: 0.0, } .init([20, 20], device), } } } #[derive(Module, Debug)] pub struct ModuleComposed { weight: Param>, basic: ModuleBasic, tuple: (ModuleBasic, ModuleBasic), } impl ModuleComposed { fn new(device: &B::Device) -> Self { let weight = Initializer::Normal { std: 1.0, mean: 0.0, } .init([20, 20], device); Self { weight, basic: ModuleBasic::new(device), tuple: (ModuleBasic::new(device), ModuleBasic::new(device)), } } } #[test] fn file_based_loading() { use std::fs; let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Create temp file path let temp_dir = std::env::temp_dir(); let file_path = temp_dir.join("test_safetensors.st"); // Save to file let mut save_store = SafetensorsStore::from_file(&file_path).metadata("test", "file_loading"); module.save_into(&mut save_store).unwrap(); // Verify file exists assert!(file_path.exists()); // Load from file (will use memory-mapped loading if available) let mut load_store = SafetensorsStore::from_file(&file_path); let mut loaded_module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = loaded_module.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 2); // weight and bias // Clean up fs::remove_file(file_path).ok(); } #[test] fn test_store_overwrite_protection() { use tempfile::tempdir; let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Create temp directory and file path (file doesn't exist yet) let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_model.safetensors"); // First save - should succeed let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&module).unwrap(); assert!(path.exists()); // Second save without overwrite flag - should fail let mut save_store2 = SafetensorsStore::from_file(&path); let result = save_store2.collect_from(&module); assert!(result.is_err()); assert!( result .unwrap_err() .to_string() .contains("File already exists") ); // Third save with overwrite flag - should succeed let mut save_store3 = SafetensorsStore::from_file(&path).overwrite(true); save_store3.collect_from(&module).unwrap(); // Verify file still exists and is valid let mut load_store = SafetensorsStore::from_file(&path); let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = load_store.apply_to(&mut module2).unwrap(); assert!(result.is_success()); } #[test] fn test_store_overwrite_with_metadata() { use tempfile::tempdir; let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Create temp directory and file path let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("test_model_metadata.safetensors"); // First save with v1 metadata and overwrite enabled let mut save_store = SafetensorsStore::from_file(&path) .metadata("model_version", "v1") .overwrite(true); save_store.collect_from(&module).unwrap(); // Second save with v2 metadata and overwrite enabled let mut save_store2 = SafetensorsStore::from_file(&path) .metadata("model_version", "v2") .overwrite(true); save_store2.collect_from(&module).unwrap(); // Load and verify the metadata was updated to v2 let mut load_store = SafetensorsStore::from_file(&path); // Since we can't easily access metadata after loading, we just verify the file loads successfully let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } #[test] fn test_forward_pass_preservation_after_save_load() { let device = Default::default(); // Create model config let config = ForwardTestModelConfig { input_size: 4, hidden_size: 8, output_size: 2, }; // Initialize model1 with random weights let model1 = config.init::(&device); // Create random input let input = Tensor::::random( [1, 4], burn_tensor::Distribution::Uniform(-1.0, 1.0), &device, ); // Forward pass with model1 -> output1 let output1 = model1.forward(input.clone()); // Save model1 weights let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("forward_test_model.safetensors"); let mut save_store = SafetensorsStore::from_file(&path); save_store.collect_from(&model1).unwrap(); // Initialize model2 with different random weights let mut model2 = config.init::(&device); // Forward pass with model2 -> output2 (should differ from output1) let output2 = model2.forward(input.clone()); // Verify output2 differs from output1 (different random weights) assert!( !output1 .clone() .all_close(output2.clone(), Some(1e-6), Some(1e-6)), "output2 should differ from output1 (different random initializations)" ); // Load model1 weights into model2 let mut load_store = SafetensorsStore::from_file(&path); let result = load_store.apply_to(&mut model2).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases // Forward pass with model2 (now has model1 weights) -> output3 let output3 = model2.forward(input.clone()); // Verify output3 equals output1 (same weights) assert!( output1.all_close(output3, Some(1e-6), Some(1e-6)), "output3 should equal output1 after loading weights" ); } #[test] fn should_save_load_compose() { let device = ::Device::default(); let module_1 = ModuleComposed::::new(&device); let mut module_2 = ModuleComposed::::new(&device); assert_ne!(module_1.weight.to_data(), module_2.weight.to_data()); assert_ne!( module_1.basic.weight_basic.to_data(), module_2.basic.weight_basic.to_data() ); let temp_dir = tempdir().unwrap(); let path = temp_dir.path().join("save_load_compose.safetensors"); let mut store = SafetensorsStore::from_file(&path); module_1.save_into(&mut store).unwrap(); let mut load_store = SafetensorsStore::from_file(&path); let result = module_2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(module_1.weight.to_data(), module_2.weight.to_data()); assert_eq!( module_1.basic.weight_basic.to_data(), module_2.basic.weight_basic.to_data() ); } ================================================ FILE: crates/burn-store/src/safetensors/tests/filtering.rs ================================================ use crate::{ModuleSnapshot, SafetensorsStore}; use super::round_trip::ComplexModule; type TestBackend = burn_ndarray::NdArray; #[test] #[cfg(target_has_atomic = "ptr")] fn filtered_export_import() { let device = Default::default(); let module1 = ComplexModule::::new(&device); let mut module2 = ComplexModule::::new_zeros(&device); // Export only encoder tensors using the builder pattern let mut save_store = SafetensorsStore::from_bytes(None).with_regex(r"^encoder\..*"); module1.save_into(&mut save_store).unwrap(); // Import filtered tensors - need to allow partial since we only saved encoder tensors let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { // Get Arc and extract data let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 3); // encoder.weight, encoder.bias, encoder.norm assert!(!result.missing.is_empty()); // decoder and layers tensors are missing } #[test] #[cfg(target_has_atomic = "ptr")] fn builder_pattern_filtering() { let device = Default::default(); let module = ComplexModule::::new(&device); // Test with_regex - multiple patterns (OR logic) let mut store = SafetensorsStore::from_bytes(None) .with_regex(r"^encoder\..*") // Match encoder tensors .with_regex(r".*\.bias$"); // OR match any bias tensors let views = module.collect(None, None, false); let filtered_count = views .iter() .filter(|v| { let path = v.full_path(); path.starts_with("encoder.") || path.ends_with(".bias") }) .count(); module.save_into(&mut store).unwrap(); // Verify we saved the expected number of tensors if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); let tensors = safetensors::SafeTensors::deserialize(&data).unwrap(); assert_eq!(tensors.len(), filtered_count); } } #[test] fn builder_pattern_exact_paths() { let device = Default::default(); let module = ComplexModule::::new(&device); // Test with_full_path and with_full_paths let paths = vec!["encoder.weight", "decoder.scale"]; let mut store = SafetensorsStore::from_bytes(None) .with_full_path("encoder.norm") .with_full_paths(paths.clone()); module.save_into(&mut store).unwrap(); // Verify only specified tensors were saved if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); let tensors = safetensors::SafeTensors::deserialize(&data).unwrap(); assert_eq!(tensors.len(), 3); // encoder.norm + encoder.weight + decoder.scale for (name, _) in tensors.tensors() { assert!(name == "encoder.norm" || name == "encoder.weight" || name == "decoder.scale"); } } } #[test] fn builder_pattern_with_predicate() { let device = Default::default(); let module = ComplexModule::::new(&device); // Test with_predicate - custom logic let mut store = SafetensorsStore::from_bytes(None).with_predicate(|path, _| { // Only save tensors with "layer" in the path and ending with "weight" path.contains("layer") && path.ends_with("weight") }); module.save_into(&mut store).unwrap(); // Verify only layer weights were saved if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); let tensors = safetensors::SafeTensors::deserialize(&data).unwrap(); for (name, _) in tensors.tensors() { assert!(name.contains("layer")); assert!(name.ends_with("weight")); } } } #[test] fn builder_pattern_combined() { let device = Default::default(); let module = ComplexModule::::new(&device); // Combine multiple filter methods #[cfg(target_has_atomic = "ptr")] { let mut store = SafetensorsStore::from_bytes(None) .with_regex(r"^encoder\..*") // All encoder tensors .with_full_path("decoder.scale") // Plus specific decoder.scale .with_predicate(|path, _| { // Plus any projection tensors path.contains("projection") }); module.save_into(&mut store).unwrap(); if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); let tensors = safetensors::SafeTensors::deserialize(&data).unwrap(); // Should have encoder.*, decoder.scale, and projection tensors let mut names = Vec::new(); for (name, _) in tensors.tensors() { names.push(name); } assert!(names.iter().any(|n| n == "encoder.weight")); assert!(names.iter().any(|n| n == "encoder.bias")); assert!(names.iter().any(|n| n == "encoder.norm")); assert!(names.iter().any(|n| n == "decoder.scale")); // decoder.projection.* should also be included due to predicate assert!(names.iter().any(|n| n.contains("projection"))); } } } #[test] fn builder_pattern_match_all() { let device = Default::default(); let module = ComplexModule::::new(&device); let all_views = module.collect(None, None, false); let total_count = all_views.len(); // Test match_all - should save everything let mut store = SafetensorsStore::from_bytes(None).match_all(); module.save_into(&mut store).unwrap(); if let SafetensorsStore::Memory(ref p) = store { let data = p.data().unwrap(); let tensors = safetensors::SafeTensors::deserialize(&data).unwrap(); assert_eq!(tensors.len(), total_count); } } ================================================ FILE: crates/burn-store/src/safetensors/tests/integration.rs ================================================ use burn_core as burn; use crate::{ModuleSnapshot, SafetensorsStore}; use burn_core::module::{Module, Param}; use burn_tensor::Tensor; use burn_tensor::backend::Backend; type TestBackend = burn_ndarray::NdArray; // Integration tests demonstrating the SafeTensors store API #[derive(Module, Debug)] struct IntegrationTestModel { encoder: IntegrationEncoderModule, decoder: IntegrationDecoderModule, head: IntegrationHeadModule, } #[derive(Module, Debug)] struct IntegrationEncoderModule { layer1: IntegrationLinearLayer, layer2: IntegrationLinearLayer, norm: IntegrationNormLayer, } #[derive(Module, Debug)] struct IntegrationDecoderModule { layer1: IntegrationLinearLayer, layer2: IntegrationLinearLayer, norm: IntegrationNormLayer, } #[derive(Module, Debug)] struct IntegrationHeadModule { weight: Param>, bias: Param>, } #[derive(Module, Debug)] struct IntegrationLinearLayer { weight: Param>, bias: Param>, } #[derive(Module, Debug)] struct IntegrationNormLayer { scale: Param>, shift: Param>, } impl IntegrationTestModel { fn new(device: &B::Device) -> Self { Self { encoder: IntegrationEncoderModule::new(device), decoder: IntegrationDecoderModule::new(device), head: IntegrationHeadModule::new(device), } } } impl IntegrationEncoderModule { fn new(device: &B::Device) -> Self { Self { layer1: IntegrationLinearLayer::new(device, 1), layer2: IntegrationLinearLayer::new(device, 2), norm: IntegrationNormLayer::new(device), } } } impl IntegrationDecoderModule { fn new(device: &B::Device) -> Self { Self { layer1: IntegrationLinearLayer::new(device, 3), layer2: IntegrationLinearLayer::new(device, 4), norm: IntegrationNormLayer::new(device), } } } impl IntegrationHeadModule { fn new(device: &B::Device) -> Self { Self { weight: Param::from_data([[5.0, 6.0], [7.0, 8.0]], device), bias: Param::from_data([9.0, 10.0], device), } } } impl IntegrationLinearLayer { fn new(device: &B::Device, seed: i32) -> Self { let weight_data = [ [seed as f32, (seed + 1) as f32], [(seed + 2) as f32, (seed + 3) as f32], ]; let bias_data = [(seed + 4) as f32, (seed + 5) as f32]; Self { weight: Param::from_data(weight_data, device), bias: Param::from_data(bias_data, device), } } } impl IntegrationNormLayer { fn new(device: &B::Device) -> Self { Self { scale: Param::from_data([1.0, 2.0], device), shift: Param::from_data([0.1, 0.2], device), } } } #[test] fn basic_usage() { let device = Default::default(); let model = IntegrationTestModel::::new(&device); // Save using new API (format, producer and version are automatically added) let mut save_store = SafetensorsStore::from_bytes(None).metadata("model_name", "test_model"); // Use collect_to method model.save_into(&mut save_store).unwrap(); // Load using new API let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let mut target_model = IntegrationTestModel::::new(&device); let result = target_model.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert_eq!(result.applied.len(), 14); // All tensors should be applied assert_eq!(result.errors.len(), 0); assert_eq!(result.unused.len(), 0); } #[test] #[cfg(target_has_atomic = "ptr")] fn with_filtering() { let device = Default::default(); let model = IntegrationTestModel::::new(&device); // Save only encoder tensors using the builder pattern let mut save_store = SafetensorsStore::from_bytes(None) .with_regex(r"^encoder\..*") .metadata("subset", "encoder_only"); model.save_into(&mut save_store).unwrap(); // Load into new model - need to allow partial loading since we only saved encoder tensors let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let mut target_model = IntegrationTestModel::::new(&device); let result = target_model.load_from(&mut load_store).unwrap(); // Only encoder tensors should be applied assert_eq!(result.applied.len(), 6); // encoder has 6 tensors (2 layers × 2 + norm × 2) // Check that only encoder tensors were applied for tensor_name in &result.applied { assert!(tensor_name.starts_with("encoder.")); } } ================================================ FILE: crates/burn-store/src/safetensors/tests/metadata.rs ================================================ use crate::{ModuleSnapshot, SafetensorsStore}; use burn_nn::LinearConfig; type TestBackend = burn_ndarray::NdArray; #[test] fn default_metadata_included() { // Verify that default metadata is automatically included let default_metadata = SafetensorsStore::default_metadata(); // Check that format, producer and version are present assert_eq!(default_metadata.get("format").unwrap(), "safetensors"); assert_eq!(default_metadata.get("producer").unwrap(), "burn"); assert!(default_metadata.contains_key("version")); // The version should be the crate version let version = default_metadata.get("version").unwrap(); assert!(!version.is_empty()); } #[test] fn metadata_preservation() { let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Write with metadata - note that format, producer and version are automatically added let mut save_store = SafetensorsStore::from_bytes(None) .metadata("model_type", "linear") .metadata("custom_field", "test_value"); module.save_into(&mut save_store).unwrap(); // Verify metadata was saved (would need to add a method to check metadata) // For now, just verify the round trip works let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { // Get Arc and extract data let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } #[test] fn clear_metadata_removes_all() { let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Create store with custom metadata, then clear all let mut save_store = SafetensorsStore::from_bytes(None) .metadata("model_type", "linear") .metadata("custom_field", "test_value") .clear_metadata(); // Should remove all metadata including defaults module.save_into(&mut save_store).unwrap(); // Load and verify the module still works (metadata is optional) let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } #[test] fn clear_then_add_custom_metadata() { let device = Default::default(); let module = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); // Clear all metadata, then add only custom ones let mut save_store = SafetensorsStore::from_bytes(None) .clear_metadata() .metadata("only_custom", "value"); module.save_into(&mut save_store).unwrap(); // Verify round-trip works let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let mut module2 = LinearConfig::new(4, 2) .with_bias(true) .init::(&device); let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); } ================================================ FILE: crates/burn-store/src/safetensors/tests/mixed_datatypes.rs ================================================ use burn_core as burn; use burn_core::module::{Module, Param, ParamId}; use burn_nn as nn; use burn_tensor::{Bool, Int, Tensor, backend::Backend}; use crate::{ModuleSnapshot, SafetensorsStore}; /// Simple model with different data types for testing #[derive(Module, Debug)] pub struct MixedDtypeModel { // Standard neural network layers (float tensors) linear: nn::Linear, // Direct tensor parameters of different types float_tensor: Param>, int_tensor: Param>, bool_tensor: Param>, } impl MixedDtypeModel { pub fn new(device: &B::Device) -> Self { Self { linear: nn::LinearConfig::new(3, 3).init(device), // Simple float values float_tensor: Param::from_tensor(Tensor::from_floats( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device, )), // Simple integer values int_tensor: Param::initialized( ParamId::new(), Tensor::from_ints([[1, 2, 3], [4, 5, 6]], device), ), // Simple boolean values bool_tensor: Param::initialized( ParamId::new(), Tensor::from_bool( burn::tensor::TensorData::new( vec![true, false, true, false, true, false], [2, 3], ), device, ), ), } } } #[cfg(test)] #[allow(clippy::excessive_precision)] mod tests { use burn_tensor::BoolStore; use super::*; #[test] fn test_mixed_dtypes_round_trip() { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); // Create model with mixed data types let model = MixedDtypeModel::::new(&device); // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load into a new model let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = MixedDtypeModel::::new(&device); loaded_model .load_from(&mut load_store) .expect("Failed to load"); // Verify float tensor is preserved let orig_float = model.float_tensor.val().into_data(); let loaded_float = loaded_model.float_tensor.val().into_data(); assert_eq!(orig_float, loaded_float, "Float tensor not preserved"); // Verify integer tensor is preserved let orig_int = model.int_tensor.val().into_data(); let loaded_int = loaded_model.int_tensor.val().into_data(); assert_eq!(orig_int, loaded_int, "Integer tensor not preserved"); // Verify boolean tensor is preserved let orig_bool = model.bool_tensor.val().into_data(); let loaded_bool = loaded_model.bool_tensor.val().into_data(); assert_eq!(orig_bool, loaded_bool, "Boolean tensor not preserved"); } #[test] fn test_dtype_detection() { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); let model = MixedDtypeModel::::new(&device); let snapshots = model.collect(None, None, false); for snapshot in snapshots { let path = snapshot.full_path(); let dtype = snapshot.dtype; if path.contains("float_tensor") || path.contains("linear") { assert_eq!( dtype, burn::tensor::DType::F32, "Float tensor {} should have F32 dtype", path ); } else if path.contains("int_tensor") { assert!( matches!( dtype, burn::tensor::DType::I64 | burn::tensor::DType::I32 | burn::tensor::DType::I16 | burn::tensor::DType::I8 ), "Integer tensor {} should have integer dtype, got {:?}", path, dtype ); } else if path.contains("bool_tensor") { assert_eq!( dtype, burn::tensor::DType::Bool(BoolStore::Native), "Boolean tensor {} should have Bool dtype", path ); } } } #[test] fn test_extreme_values() { type TestBackend = burn_ndarray::NdArray; let device = ::Device::default(); #[derive(Module, Debug)] struct ExtremeValueModel { large_floats: Param>, small_floats: Param>, large_ints: Param>, } impl ExtremeValueModel { fn new(device: &B::Device) -> Self { Self { large_floats: Param::from_tensor(Tensor::from_floats( [1e30, -1e30, f32::MAX, f32::MIN], device, )), small_floats: Param::from_tensor(Tensor::from_floats( [1e-30, -1e-30, f32::MIN_POSITIVE, f32::EPSILON], device, )), large_ints: Param::initialized( ParamId::new(), Tensor::from_ints([i32::MAX, i32::MIN, 0, -1], device), ), } } } let model = ExtremeValueModel::::new(&device); // Save and load let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = ExtremeValueModel::::new(&device); loaded_model .load_from(&mut load_store) .expect("Failed to load"); // Check exact preservation assert_eq!( model.large_floats.val().into_data(), loaded_model.large_floats.val().into_data(), "Large floats not preserved" ); assert_eq!( model.small_floats.val().into_data(), loaded_model.small_floats.val().into_data(), "Small floats not preserved" ); assert_eq!( model.large_ints.val().into_data(), loaded_model.large_ints.val().into_data(), "Large integers not preserved" ); } #[test] fn test_mixed_precision_floats() { // Note: While SafeTensors format supports storing tensors with different precisions // (F16, BF16, F32, F64, etc.) in the same file, Burn's backend architecture currently // requires all tensors in a model instance to share the same floating-point precision. // This is determined at the backend level (e.g., NdArray or NdArray). // // However, for storage purposes, SafeTensors can correctly save and load tensors // with their original precision, preserving the data type information in the file format. // This test demonstrates that different precision backends work correctly with SafeTensors. // Test with f32 backend { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); let model = MixedDtypeModel::::new(&device); // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = MixedDtypeModel::::new(&device); loaded_model .load_from(&mut load_store) .expect("Failed to load"); assert_eq!( model.float_tensor.val().into_data(), loaded_model.float_tensor.val().into_data(), "F32 float tensor not preserved" ); } // Test with f64 backend { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); #[derive(Module, Debug)] struct F64Model { linear: nn::Linear, double_precision: Param>, } let model = F64Model:: { linear: nn::LinearConfig::new(2, 2).init(&device), double_precision: Param::from_tensor(Tensor::from_floats( [ [1.234567890123456789, 2.345678901234567890], [3.456789012345678901, 4.567890123456789012], ], &device, )), }; // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = F64Model:: { linear: nn::LinearConfig::new(2, 2).init(&device), double_precision: Param::from_tensor(Tensor::zeros([2, 2], &device)), }; loaded_model .load_from(&mut load_store) .expect("Failed to load"); let orig = model.double_precision.val().into_data(); let loaded = loaded_model.double_precision.val().into_data(); assert_eq!(orig, loaded, "F64 double precision not preserved"); } } #[test] fn test_mixed_precision_integers() { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); #[derive(Module, Debug)] struct MultiIntModel { // Note: Burn's Tensor uses the backend's default int type // We can't directly specify i8, i16, etc. in the type system // But we can test with different values that would fit in different ranges small_ints: Param>, // Values that fit in i8 medium_ints: Param>, // Values that fit in i16 large_ints: Param>, // Values that need i32/i64 } let model = MultiIntModel:: { small_ints: Param::initialized( ParamId::new(), Tensor::from_ints([127i32, -128, 0, 42], &device), ), medium_ints: Param::initialized( ParamId::new(), Tensor::from_ints([32767i32, -32768, 1000, -1000], &device), ), large_ints: Param::initialized( ParamId::new(), Tensor::from_ints([i32::MAX, i32::MIN, 1_000_000, -1_000_000], &device), ), }; // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load and verify let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = MultiIntModel:: { small_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)), medium_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)), large_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)), }; loaded_model .load_from(&mut load_store) .expect("Failed to load"); assert_eq!( model.small_ints.val().into_data(), loaded_model.small_ints.val().into_data(), "Small ints (i8 range) not preserved" ); assert_eq!( model.medium_ints.val().into_data(), loaded_model.medium_ints.val().into_data(), "Medium ints (i16 range) not preserved" ); assert_eq!( model.large_ints.val().into_data(), loaded_model.large_ints.val().into_data(), "Large ints (i32 range) not preserved" ); } #[test] fn test_comprehensive_mixed_types() { type TestBackend = burn_ndarray::NdArray; let device = Default::default(); #[derive(Module, Debug)] struct ComprehensiveModel { // Neural network layers linear1: nn::Linear, conv2d: nn::conv::Conv2d, // Different tensor types float32_weights: Param>, integer_indices: Param>, boolean_mask: Param>, } let model = ComprehensiveModel:: { linear1: nn::LinearConfig::new(4, 8).init(&device), conv2d: nn::conv::Conv2dConfig::new([3, 16], [3, 3]).init(&device), float32_weights: Param::from_tensor(Tensor::from_floats( [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], &device, )), integer_indices: Param::initialized( ParamId::new(), Tensor::from_ints( [[0, 1, 2, 3], [10, 20, 30, 40], [100, 200, 300, 400]], &device, ), ), boolean_mask: Param::initialized( ParamId::new(), Tensor::from_bool( burn::tensor::TensorData::new( vec![true, false, false, true, false, true, true, false], [2, 4], ), &device, ), ), }; // Collect all tensors let snapshots = model.collect(None, None, false); // Verify we have all expected tensors let paths: Vec = snapshots.iter().map(|s| s.full_path()).collect(); assert!(paths.iter().any(|p| p.contains("linear1"))); assert!(paths.iter().any(|p| p.contains("conv2d"))); assert!(paths.iter().any(|p| p.contains("float32_weights"))); assert!(paths.iter().any(|p| p.contains("integer_indices"))); assert!(paths.iter().any(|p| p.contains("boolean_mask"))); // Save to bytes let mut save_store = SafetensorsStore::from_bytes(None); model.save_into(&mut save_store).expect("Failed to save"); let bytes = save_store.get_bytes().expect("Failed to get bytes"); // Load into fresh model let mut load_store = SafetensorsStore::from_bytes(Some(bytes)); let mut loaded_model = ComprehensiveModel:: { linear1: nn::LinearConfig::new(4, 8).init(&device), conv2d: nn::conv::Conv2dConfig::new([3, 16], [3, 3]).init(&device), float32_weights: Param::from_tensor(Tensor::zeros([2, 2, 2], &device)), integer_indices: Param::initialized(ParamId::new(), Tensor::zeros([3, 4], &device)), boolean_mask: Param::initialized( ParamId::new(), Tensor::from_bool( burn::tensor::TensorData::new(vec![false; 8], [2, 4]), &device, ), ), }; loaded_model .load_from(&mut load_store) .expect("Failed to load"); // Verify all data is preserved assert_eq!( model.float32_weights.val().into_data(), loaded_model.float32_weights.val().into_data(), "Float32 weights not preserved" ); assert_eq!( model.integer_indices.val().into_data(), loaded_model.integer_indices.val().into_data(), "Integer indices not preserved" ); assert_eq!( model.boolean_mask.val().into_data(), loaded_model.boolean_mask.val().into_data(), "Boolean mask not preserved" ); } } ================================================ FILE: crates/burn-store/src/safetensors/tests/mod.rs ================================================ mod adapter; mod direct_access; mod error_handling; #[cfg(feature = "std")] mod file_io; mod filtering; mod integration; mod metadata; mod mixed_datatypes; mod multi_layer_verify; mod pytorch_import; mod round_trip; ================================================ FILE: crates/burn-store/src/safetensors/tests/multi_layer_verify.rs ================================================ //! Tests for multi-layer model loading with SafeTensors format use burn_core as burn; use burn_core::module::Module; use burn_tensor::{Tensor, backend::Backend}; use burn_nn::{ BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu, conv::{Conv2d, Conv2dConfig}, }; /// Multi-layer neural network model for testing #[derive(Module, Debug)] pub struct Net { conv1: Conv2d, norm1: BatchNorm, fc1: Linear, relu: Relu, } impl Net { /// Create a new network instance pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 4], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), norm1: BatchNormConfig::new(4).init(device), fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), relu: Relu::new(), } } /// Forward pass of the model pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); let x = self.norm1.forward(x); let x = self.relu.forward(x); // Flatten all dimensions except the batch dimension let x = x.flatten(1, 3); self.fc1.forward(x) } } use crate::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; use burn_tensor::Tolerance; type TestBackend = burn_ndarray::NdArray; /// Path to the multi_layer.safetensors test file fn get_safetensors_path() -> &'static str { concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ) } #[test] fn multi_layer_model() { let device = Default::default(); let safetensors_path = get_safetensors_path(); // Load model from SafeTensors file with PyTorch adapter let mut store = SafetensorsStore::from_file(safetensors_path) .with_from_adapter(PyTorchToBurnAdapter) .validate(false) .allow_partial(true); let mut model = Net::::new(&device); let result = model.load_from(&mut store).unwrap(); // Verify loading was successful assert!( !result.applied.is_empty(), "Should have loaded some tensors" ); assert!( result.errors.is_empty(), "Should have no errors: {:?}", result.errors ); // Test forward pass let input = Tensor::::ones([1, 3, 8, 8], &device); let output = model.forward(input); // Expected output values from PyTorch model let expected = Tensor::::from_data( [[ 0.04971555, -0.16849735, 0.05182848, -0.18032673, 0.23138367, 0.05041867, 0.13005908, -0.32202929, -0.07915690, -0.03232457, -0.19790289, -0.17476529, -0.19627589, -0.21757686, -0.31376451, 0.08377837, ]], &device, ); // Verify output matches expected values output .to_data() .assert_approx_eq::(&expected.to_data(), Tolerance::default()); } ================================================ FILE: crates/burn-store/src/safetensors/tests/pytorch_import.rs ================================================ use burn_core as burn; use crate::{ModuleSnapshot, SafetensorsStore}; use burn_core::module::Module; use burn_nn::{ BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu, conv::{Conv2d, Conv2dConfig}, }; use burn_tensor::Tensor; use burn_tensor::backend::Backend; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] pub struct Net { conv1: Conv2d, norm1: BatchNorm, fc1: Linear, relu: Relu, } impl Net { pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 4], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), norm1: BatchNormConfig::new(4).init(device), fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), relu: Relu::new(), } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.conv1.forward(x); let x = self.norm1.forward(x); let x = self.relu.forward(x); // Flatten all dimensions except the batch dimension let x = x.flatten(1, 3); self.fc1.forward(x) } } #[test] #[cfg(all(feature = "std", target_has_atomic = "ptr"))] fn multi_layer_model_import() { let device = Default::default(); // Reference the safetensors file from burn-import let safetensors_path = concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ); // Load the model using SafetensorsStore // Note: PyTorch and Burn have different conventions for linear layer weights // PyTorch stores as [out_features, in_features], Burn as [in_features, out_features] // Also, tensor names may differ (e.g., PyTorch uses different names for BatchNorm params) let mut store = SafetensorsStore::from_file(safetensors_path) .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format .allow_partial(true); // Allow partial loading due to naming differences let mut model = Net::::new(&device); let result = model.load_from(&mut store).unwrap(); // With the adapter, weights should load correctly assert!(!result.applied.is_empty()); assert!( result.errors.is_empty(), "Should have no errors with adapter: {:?}", result.errors ); // Test forward pass with the loaded weights // Note: Due to shape mismatches (PyTorch vs Burn conventions for linear layers), // we can't directly compare outputs with PyTorch model. // This test mainly verifies that the loading mechanism works. let input = Tensor::::ones([1, 3, 8, 8], &device); let _output = model.forward(input); // Verify that some tensors were loaded successfully // Conv and BatchNorm layers should load correctly assert!(result.applied.iter().any(|n| n.contains("conv1"))); assert!(result.applied.iter().any(|n| n.contains("norm1"))); } #[test] #[cfg(all(feature = "std", target_has_atomic = "ptr"))] fn safetensors_round_trip_with_pytorch_model() { let device = Default::default(); // Reference the safetensors file from burn-import let safetensors_path = concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ); // Load the model from PyTorch safetensors let mut load_store = SafetensorsStore::from_file(safetensors_path) .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format .allow_partial(true); // Allow partial loading due to naming differences let mut model = Net::::new(&device); let load_result = model.load_from(&mut load_store).unwrap(); // With the adapter, weights should load correctly assert!(!load_result.applied.is_empty()); assert!( load_result.errors.is_empty(), "Should have no errors with adapter: {:?}", load_result.errors ); // Save the model to memory // Note: format, producer and version are automatically added let mut save_store = SafetensorsStore::from_bytes(None).metadata("source", "pytorch"); model.save_into(&mut save_store).unwrap(); // Load into a new model let mut model2 = Net::::new(&device); let mut load_store2 = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store2 && let SafetensorsStore::Memory(ref p_save) = save_store { p.set_data(p_save.data().unwrap().as_ref().clone()); } let result = model2.load_from(&mut load_store2).unwrap(); assert!(!result.applied.is_empty()); // Verify both models produce the same output let input = Tensor::::ones([1, 3, 8, 8], &device); let output1 = model.forward(input.clone()); let output2 = model2.forward(input); // Check outputs are identical let output1_data = output1.to_data().to_vec::().unwrap(); let output2_data = output2.to_data().to_vec::().unwrap(); for (a, b) in output1_data.iter().zip(output2_data.iter()) { assert!((a - b).abs() < 1e-7, "Outputs differ after round trip"); } } #[test] #[cfg(all(feature = "std", target_has_atomic = "ptr"))] fn partial_load_from_pytorch_model() { let device = Default::default(); // Reference the safetensors file from burn-import let safetensors_path = concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ); // Load only conv1 and norm1 parameters (not fc1) let mut store = SafetensorsStore::from_file(safetensors_path) .validate(false) // Disable validation due to shape differences .allow_partial(true); let mut model = Net::::new(&device); // Save initial fc1 weights for comparison let _initial_fc1_weight = model.fc1.weight.val().to_data(); let result = model.load_from(&mut store).unwrap(); // Should load available tensors (with some errors due to shape mismatch) assert!(!result.applied.is_empty()); // fc1 weight should remain unchanged if not in the file // or should be updated if it is in the file // This test verifies that partial loading works correctly } #[test] #[cfg(all(feature = "std", target_has_atomic = "ptr"))] fn verify_tensor_names_from_pytorch() { let device = Default::default(); // Reference the safetensors file from burn-import let safetensors_path = concat!( env!("CARGO_MANIFEST_DIR"), "/safetensors-tests/tests/multi_layer/multi_layer.safetensors" ); // Create a model and load from PyTorch let mut model = Net::::new(&device); let mut store = SafetensorsStore::from_file(safetensors_path) .validate(false) // Disable validation due to shape differences .allow_partial(true); // Allow partial loading due to naming differences let result = model.load_from(&mut store).unwrap(); // Check that we loaded some tensors (with errors due to shape mismatch) assert!(!result.applied.is_empty()); // Collect tensor names from the model let views = model.collect(None, None, false); let tensor_names: Vec = views.iter().map(|v| v.full_path()).collect(); // Verify expected tensor names are present assert!(tensor_names.iter().any(|n| n.contains("conv1"))); assert!(tensor_names.iter().any(|n| n.contains("norm1"))); assert!(tensor_names.iter().any(|n| n.contains("fc1"))); } ================================================ FILE: crates/burn-store/src/safetensors/tests/round_trip.rs ================================================ use burn_core as burn; use crate::{ModuleSnapshot, SafetensorsStore}; use burn_core::module::{Module, Param}; use burn_nn::{Linear, LinearConfig}; use burn_tensor::backend::Backend; use burn_tensor::{Tensor, shape}; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] pub(super) struct ComplexModule { pub encoder: EncoderModule, pub decoder: DecoderModule, pub layers: Vec>, } #[derive(Module, Debug)] pub(super) struct EncoderModule { pub weight: Param>, pub bias: Param>, pub norm: Param>, } #[derive(Module, Debug)] pub(super) struct DecoderModule { pub projection: Linear, pub scale: Param>, } impl ComplexModule { pub fn new(device: &B::Device) -> Self { Self { encoder: EncoderModule { weight: Param::from_data( [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], device, ), bias: Param::from_data([0.1, 0.2, 0.3], device), norm: Param::from_data([1.0, 1.0, 1.0], device), }, decoder: DecoderModule { projection: LinearConfig::new(4, 2).with_bias(true).init(device), scale: Param::from_data([[0.5, 0.5], [0.5, 0.5]], device), }, layers: vec![ LinearConfig::new(3, 4).with_bias(false).init(device), LinearConfig::new(4, 3).with_bias(true).init(device), ], } } pub fn new_zeros(device: &B::Device) -> Self { Self { encoder: EncoderModule { weight: Param::from_tensor(Tensor::zeros([2, 2, 2], device)), bias: Param::from_tensor(Tensor::zeros([3], device)), norm: Param::from_tensor(Tensor::zeros([3], device)), }, decoder: DecoderModule { projection: LinearConfig::new(4, 2).with_bias(true).init(device), scale: Param::from_tensor(Tensor::zeros([2, 2], device)), }, layers: vec![ LinearConfig::new(3, 4).with_bias(false).init(device), LinearConfig::new(4, 3).with_bias(true).init(device), ], } } } #[test] fn complex_module_round_trip() { let device = Default::default(); let module1 = ComplexModule::::new(&device); let mut module2 = ComplexModule::::new_zeros(&device); // Save module1 using new store API let mut save_store = SafetensorsStore::from_bytes(None); module1.save_into(&mut save_store).unwrap(); // Load into module2 let mut load_store = SafetensorsStore::from_bytes(None); if let SafetensorsStore::Memory(ref mut p) = load_store && let SafetensorsStore::Memory(ref p_save) = save_store { // Get Arc and extract data let data_arc = p_save.data().unwrap(); p.set_data(data_arc.as_ref().clone()); } let result = module2.load_from(&mut load_store).unwrap(); assert!(result.is_success()); assert!(result.applied.len() > 5); assert_eq!(result.errors.len(), 0); // Verify data was imported correctly let module2_views = module2.collect(None, None, false); let encoder_weight = module2_views .iter() .find(|v| v.full_path() == "encoder.weight") .unwrap() .to_data() .unwrap(); assert_eq!(encoder_weight.shape, shape![2, 2, 2]); } ================================================ FILE: crates/burn-store/src/tensor_snapshot.rs ================================================ use alloc::rc::Rc; use alloc::string::String; use alloc::string::ToString; use alloc::vec::Vec; use burn_core::module::ParamId; use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape}; use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend}; use half::f16; /// Returns the byte size of a quantization parameter type. // TODO: Add `size_bytes()` method to `QuantParam` in cubecl and use it here. const fn quant_param_size(param: QuantParam) -> usize { match param { QuantParam::F32 => core::mem::size_of::(), QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::(), QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::(), } } /// Error type for TensorSnapshot operations #[derive(Debug, Clone)] pub enum TensorSnapshotError { /// I/O error occurred while loading tensor data IoError(String), /// Data corruption or invalid format DataError(String), /// Panic occurred while loading tensor data PanicError(String), } impl core::fmt::Display for TensorSnapshotError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::IoError(e) => write!(f, "I/O error: {}", e), Self::DataError(e) => write!(f, "Data error: {}", e), Self::PanicError(e) => write!(f, "Panic error: {}", e), } } } impl core::error::Error for TensorSnapshotError {} /// A lightweight snapshot of a tensor that can lazily produce TensorData. /// /// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting) /// and only materializes the actual data when `to_data()` is called. This allows /// efficient inspection of module structure without the overhead of copying all tensor data. /// /// The dtype and shape are cached for efficient access without requiring data materialization, /// which is particularly useful for serialization formats that need metadata upfront. pub struct TensorSnapshot { /// Function to get tensor data when needed (Rc allows cloning) data_fn: Rc Result>, /// Data type of the tensor (cached for efficient access) pub dtype: burn_tensor::DType, /// Shape of the tensor (cached for efficient access) pub shape: Shape, /// Path stack representing the module hierarchy pub path_stack: Option>, /// Container stack representing the container types at each level pub container_stack: Option>, /// Unique identifier for the tensor parameter pub tensor_id: Option, } impl TensorSnapshot { /// Create a new tensor snapshot from a float tensor pub fn from_float( tensor: &Tensor, path_stack: Vec, container_stack: Vec, tensor_id: ParamId, ) -> Self { let dtype = tensor.dtype(); let shape = tensor.shape(); let tensor = tensor.clone(); // Clone is cheap (reference counted) Self { data_fn: Rc::new(move || Ok(tensor.to_data())), dtype, shape, path_stack: Some(path_stack), container_stack: Some(container_stack), tensor_id: Some(tensor_id), } } /// Create a new tensor snapshot from an int tensor pub fn from_int( tensor: &Tensor, path_stack: Vec, container_stack: Vec, tensor_id: ParamId, ) -> Self { let dtype = tensor.dtype(); let shape = tensor.shape(); let tensor = tensor.clone(); // Clone is cheap (reference counted) Self { data_fn: Rc::new(move || Ok(tensor.to_data())), dtype, shape, path_stack: Some(path_stack), container_stack: Some(container_stack), tensor_id: Some(tensor_id), } } /// Create a new tensor snapshot from a bool tensor pub fn from_bool( tensor: &Tensor, path_stack: Vec, container_stack: Vec, tensor_id: ParamId, ) -> Self { let dtype = tensor.dtype(); let shape = tensor.shape(); let tensor = tensor.clone(); // Clone is cheap (reference counted) Self { data_fn: Rc::new(move || Ok(tensor.to_data())), dtype, shape, path_stack: Some(path_stack), container_stack: Some(container_stack), tensor_id: Some(tensor_id), } } /// Convert to TensorData (this is where actual data copy happens) #[cfg(feature = "std")] pub fn to_data(&self) -> Result { // Use AssertUnwindSafe since we're working with Rc which is not UnwindSafe std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else( |_| { Err(TensorSnapshotError::PanicError( "Panic occurred while loading tensor data".to_string(), )) }, ) } /// Convert to TensorData (this is where actual data copy happens) #[cfg(not(feature = "std"))] pub fn to_data(&self) -> Result { (self.data_fn)() // Can't catch panics in no-std, do it when core::panic::AssertUnwindSafe is available } /// Get the full path by joining the path stack pub fn full_path(&self) -> String { self.path_stack .as_ref() .map(|stack| stack.join(".")) .unwrap_or_default() } /// Get the full container path by joining the container stack pub fn container_path(&self) -> String { self.container_stack .as_ref() .map(|stack| stack.join(".")) .unwrap_or_default() } /// Get the module type (last Struct/Enum in the hierarchy) /// /// Returns the last user-defined module type, skipping primitive containers /// like "Vec", "Array". This is useful for determining which user-defined /// module a tensor belongs to. /// /// # Examples /// - `Linear.weight` → `Some("Struct:Linear")` /// - `Vec[0].weight` → `Some("Struct:Linear")` /// - `Linear.bias` (Optional) → `Some("Struct:Linear")` /// - `Vec[0]` (no module) → `None` pub fn module_type(&self) -> Option { self.container_stack.as_ref().and_then(|stack| { // Find the last user-defined type (Struct: or Enum:) stack .iter() .rev() .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:")) .cloned() }) } /// Get the immediate container type (last in the container stack) /// /// Returns the last element in the container stack, which could be a /// user-defined type ("Struct:", "Enum:") or a collection type ("Vec", "Array"). /// This is useful for understanding the full container hierarchy. /// /// # Examples /// - `Linear.weight` → `"Struct:Linear"` /// - `Vec[0].weight` → `"Struct:Linear"` (the Linear, not the Vec) /// - `Vec[0]` → `"Vec"` pub fn container_type(&self) -> String { self.container_stack .as_ref() .and_then(|stack| stack.last()) .cloned() .unwrap_or_else(|| "Unknown".to_string()) } /// Create a TensorSnapshot from a closure that produces TensorData /// This is used internally for lazy loading pub fn from_closure( data_fn: Rc Result>, dtype: burn_tensor::DType, shape: Shape, path_stack: Vec, container_stack: Vec, tensor_id: ParamId, ) -> Self { Self { data_fn, dtype, shape, path_stack: Some(path_stack), container_stack: Some(container_stack), tensor_id: Some(tensor_id), } } /// Create a TensorSnapshot from TensorData directly pub fn from_data( data: TensorData, path_stack: Vec, container_stack: Vec, tensor_id: ParamId, ) -> Self { let dtype = data.dtype; let shape = data.shape.clone(); Self { data_fn: Rc::new(move || Ok(data.clone())), dtype, shape, path_stack: Some(path_stack), container_stack: Some(container_stack), tensor_id: Some(tensor_id), } } /// Get the size of the tensor data in bytes without materializing it. /// /// For regular (non-quantized) types, this is simply `shape.product() * dtype.size()`. /// /// For quantized types (`QFloat`), this accounts for: /// - The quantized values (packed according to the quantization scheme) /// - Alignment padding (values are aligned to 4-byte boundary) /// - Quantization parameters (scale values appended to the data) pub fn data_len(&self) -> usize { const BITS_PER_BYTE: usize = 8; let num_elements: usize = self.shape.iter().product(); match self.dtype { DType::QFloat(scheme) => { // Calculate value bytes using scheme's packing information let num_storage_elements = num_elements.div_ceil(scheme.num_quants()); let value_bytes = num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE); // Calculate number of quantization parameters (scales) let num_params = params_shape(&self.shape, scheme.level).num_elements(); let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN; let scale_bytes = num_params * quant_param_size(scheme.param); aligned_value_bytes + scale_bytes } _ => num_elements * self.dtype.size(), } } /// Clone the data function for lazy composition pub fn clone_data_fn(&self) -> Rc Result> { self.data_fn.clone() } } impl Clone for TensorSnapshot { fn clone(&self) -> Self { // Clone lazily - keep the same data function Self { data_fn: self.data_fn.clone(), dtype: self.dtype, shape: self.shape.clone(), path_stack: self.path_stack.clone(), container_stack: self.container_stack.clone(), tensor_id: self.tensor_id, } } } impl core::fmt::Debug for TensorSnapshot { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("TensorSnapshot") .field("dtype", &self.dtype) .field("shape", &self.shape) .field("path_stack", &self.path_stack) .field("container_stack", &self.container_stack) .field("tensor_id", &self.tensor_id) .finish() } } #[cfg(all(test, feature = "std"))] mod tests { use super::*; type TestBackend = burn_ndarray::NdArray; use alloc::string::ToString; use burn_tensor::{BoolStore, DType, shape}; #[test] fn tensor_view_float() { let device = Default::default(); let tensor = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let snapshot = TensorSnapshot::from_float( &tensor, vec!["test".to_string(), "weight".to_string()], vec!["TestModule".to_string(), "Param".to_string()], ParamId::new(), ); // Test metadata access without materialization assert_eq!(snapshot.dtype, DType::F32); assert_eq!(snapshot.shape, shape![2, 2]); assert_eq!(snapshot.full_path(), "test.weight"); assert_eq!(snapshot.container_path(), "TestModule.Param"); // Test data materialization let data = snapshot.to_data().unwrap(); assert_eq!(data.shape, shape![2, 2]); assert_eq!(data.dtype, DType::F32); } #[test] fn tensor_view_int() { let device = Default::default(); let tensor = Tensor::::from_data([[1, 2], [3, 4]], &device); let snapshot = TensorSnapshot::from_int( &tensor, vec!["test".to_string(), "int".to_string()], vec!["TestModule".to_string(), "Param".to_string()], ParamId::new(), ); // Test metadata access without materialization // TestBackend uses I64 for integers assert_eq!(snapshot.dtype, DType::I64); assert_eq!(snapshot.shape, shape![2, 2]); let data = snapshot.to_data().unwrap(); assert_eq!(data.shape, shape![2, 2]); assert_eq!(data.dtype, DType::I64); } #[test] fn tensor_view_bool() { let device = Default::default(); let tensor = Tensor::::from_data([[true, false], [false, true]], &device); let snapshot = TensorSnapshot::from_bool( &tensor, vec!["test".to_string(), "bool".to_string()], vec!["TestModule".to_string(), "Param".to_string()], ParamId::new(), ); // Test metadata access without materialization assert_eq!(snapshot.dtype, DType::Bool(BoolStore::Native)); assert_eq!(snapshot.shape, shape![2, 2]); let data = snapshot.to_data().unwrap(); assert_eq!(data.shape, shape![2, 2]); assert_eq!(data.dtype, DType::Bool(BoolStore::Native)); } #[test] fn data_len() { let device = Default::default(); // Test F32 tensor (4 bytes per element) let tensor_f32 = Tensor::::from_data([[1.0, 2.0], [3.0, 4.0]], &device); let view_f32 = TensorSnapshot::from_float( &tensor_f32, vec!["test".to_string()], vec!["Module".to_string()], ParamId::new(), ); assert_eq!(view_f32.data_len(), 16); // 4 elements * 4 bytes // Test I64 tensor (8 bytes per element) - TestBackend uses I64 for Int let tensor_i64 = Tensor::::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device); let view_i64 = TensorSnapshot::from_int( &tensor_i64, vec!["test".to_string()], vec!["Module".to_string()], ParamId::new(), ); assert_eq!(view_i64.data_len(), 64); // 8 elements * 8 bytes (I64) // Test Bool tensor (1 byte per element) let tensor_bool = Tensor::::from_data([[true, false], [false, true]], &device); let view_bool = TensorSnapshot::from_bool( &tensor_bool, vec!["test".to_string()], vec!["Module".to_string()], ParamId::new(), ); assert_eq!(view_bool.data_len(), 4); // 4 elements * 1 byte } #[test] fn from_closure() { let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]); let dtype = data.dtype; let shape = data.shape.clone(); let snapshot = TensorSnapshot::from_closure( Rc::new(move || Ok(data.clone())), dtype, shape.clone(), vec!["model".to_string(), "layer".to_string()], vec!["Model".to_string(), "Layer".to_string()], ParamId::new(), ); // Test metadata access assert_eq!(snapshot.dtype, DType::F32); assert_eq!(snapshot.shape, shape![4]); assert_eq!(snapshot.full_path(), "model.layer"); assert_eq!(snapshot.data_len(), 16); // 4 * 4 bytes // Test data materialization let materialized = snapshot.to_data().unwrap(); assert_eq!(materialized.shape, shape![4]); } #[test] fn from_data() { let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]); let original_dtype = data.dtype; let original_shape = data.shape.clone(); let snapshot = TensorSnapshot::from_data( data, vec!["encoder".to_string(), "weight".to_string()], vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()], ParamId::new(), ); // Test metadata assert_eq!(snapshot.dtype, original_dtype); assert_eq!(snapshot.shape, original_shape); assert_eq!(snapshot.full_path(), "encoder.weight"); assert_eq!(snapshot.container_type(), "Struct:Dense"); assert_eq!(snapshot.data_len(), 24); // 6 * 4 bytes // Test data materialization let materialized = snapshot.to_data().unwrap(); assert_eq!(materialized.shape, original_shape); } #[test] #[cfg(feature = "std")] fn panic_catching_in_to_data() { use alloc::rc::Rc; // Create a TensorSnapshot with a closure that panics let snapshot = TensorSnapshot { data_fn: Rc::new(|| panic!("Test panic in data_fn")), dtype: DType::F32, shape: shape![2, 2], path_stack: Some(vec!["test".to_string()]), container_stack: Some(vec!["Test".to_string()]), tensor_id: Some(ParamId::new()), }; // When std is available, to_data should catch the panic and return an error let result = snapshot.to_data(); assert!(result.is_err()); match result { Err(TensorSnapshotError::PanicError(msg)) => { assert!(msg.contains("Panic occurred")); } _ => panic!("Expected PanicError with panic message"), } } #[test] fn error_propagation_in_closure() { use alloc::rc::Rc; // Create a snapshot with a closure that returns an error let snapshot = TensorSnapshot::from_closure( Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))), DType::F32, shape![2, 2], vec!["error_test".into()], vec![], ParamId::new(), ); // Should return an error when trying to get data let result = snapshot.to_data(); assert!(result.is_err()); match result { Err(TensorSnapshotError::IoError(msg)) => { assert!(msg.contains("Simulated IO error")); } _ => panic!("Expected IoError"), } } #[test] fn container_type_extraction() { let device = Default::default(); let tensor = Tensor::::from_data([1.0, 2.0, 3.0], &device); let snapshot = TensorSnapshot::from_float( &tensor, vec![ "model".to_string(), "layer1".to_string(), "weight".to_string(), ], vec![ "Struct:Model".to_string(), "Struct:Conv2d".to_string(), "Struct:Param".to_string(), ], ParamId::new(), ); assert_eq!(snapshot.container_type(), "Struct:Param"); assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string())); assert_eq!( snapshot.container_path(), "Struct:Model.Struct:Conv2d.Struct:Param" ); assert_eq!(snapshot.full_path(), "model.layer1.weight"); } #[test] fn container_type_vs_module_type() { let device = Default::default(); let tensor = Tensor::::from_data([1.0, 2.0, 3.0], &device); // Test case 1: Tensor inside a Vec // container_stack: ["Struct:Model", "Vec", "Struct:Linear"] let snapshot = TensorSnapshot::from_float( &tensor, vec![ "model".to_string(), "layers".to_string(), "0".to_string(), "weight".to_string(), ], vec![ "Struct:Model".to_string(), "Vec".to_string(), "Struct:Linear".to_string(), ], ParamId::new(), ); // container_type() returns the last element (Struct:Linear in this case) assert_eq!(snapshot.container_type(), "Struct:Linear"); // module_type() also returns Some(Struct:Linear) (skipping Vec) assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string())); // Test case 2: Tensor that's just in a Vec // container_stack: ["Vec"] let snapshot2 = TensorSnapshot::from_float( &tensor, vec!["data".to_string(), "0".to_string()], vec!["Vec".to_string()], ParamId::new(), ); // container_type() returns Vec assert_eq!(snapshot2.container_type(), "Vec"); // module_type() returns None (no Struct/Enum found) assert_eq!(snapshot2.module_type(), None); // Test case 3: Nested collections // container_stack: ["Struct:Model", "Vec", "Array", "Struct:Linear"] let snapshot3 = TensorSnapshot::from_float( &tensor, vec![ "model".to_string(), "layers".to_string(), "0".to_string(), "sublayers".to_string(), "1".to_string(), "weight".to_string(), ], vec![ "Struct:Model".to_string(), "Vec".to_string(), "Array".to_string(), "Struct:Linear".to_string(), ], ParamId::new(), ); // container_type() returns the immediate container assert_eq!(snapshot3.container_type(), "Struct:Linear"); // module_type() returns the last Struct/Enum assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string())); } } ================================================ FILE: crates/burn-store/src/traits.rs ================================================ use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::string::String; use alloc::vec::Vec; use super::applier::Applier; use super::apply_result::ApplyResult; use crate::collector::Collector; use crate::{ModuleAdapter, PathFilter, TensorSnapshot}; use burn_core::module::Module; use burn_tensor::backend::Backend; /// Extension trait for modules that provides tensor storage functionality. /// /// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module. /// Collection operations create lightweight tensor snapshots without immediately copying data. /// Apply operations apply tensor data from snapshots to the corresponding tensors in the module. pub trait ModuleSnapshot: Module { /// Collects tensor snapshots for inspection without copying data. /// /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data. /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`. /// /// # Arguments /// /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect. /// When `None`, all tensors are collected. /// * `adapter` - Optional adapter to transform tensors based on container types. /// Applied to all collected tensors before returning. /// * `skip_enum_variants` - Skip enum variant names when building paths. /// When true, paths will not include enum variant names (e.g., "feature.weight" /// instead of "feature.BaseConv.weight"). Useful when exporting to formats /// like PyTorch/SafeTensors that don't use enum variants. fn collect( &self, filter: Option, adapter: Option>, skip_enum_variants: bool, ) -> Vec { let mut collector = Collector::new(filter, adapter, skip_enum_variants); self.visit(&mut collector); collector.into_tensors() } /// Applies tensor snapshots to the module. /// /// This is the primary apply method that applies tensor data from `TensorSnapshot`s /// to the corresponding tensors in the module. The snapshots are typically obtained /// from `collect()` or loaded from storage. /// /// # Arguments /// /// * `snapshots` - A vector of TensorSnapshot objects /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply. /// When `None`, all available tensors are applied. /// * `adapter` - Optional adapter to transform tensors based on container types /// * `skip_enum_variants` - Skip enum variant names when matching tensor paths /// /// # Returns /// /// An [`ApplyResult`] containing information about applied, skipped, missing, /// and unused tensors, as well as any errors encountered. /// /// # Examples /// /// ```rust,ignore /// use burn_store::PathFilter; /// /// // Apply all tensors /// let result = model.apply(snapshots, None, None, false); /// /// // Apply only encoder tensors /// let filter = PathFilter::new().with_regex(r"^encoder\..*"); /// let result = model.apply(snapshots, Some(filter), None, false); /// /// // Apply with complex filter /// let filter = PathFilter::new() /// .with_regex(r"^encoder\..*") /// .with_regex(r"^decoder\..*") /// .with_full_path("head.weight"); /// let result = model.apply(snapshots, Some(filter), None, false); /// /// // Apply with enum variant skipping (for PyTorch models) /// let result = model.apply(snapshots, None, None, true); /// ``` fn apply( &mut self, snapshots: Vec, filter: Option, adapter: Option>, skip_enum_variants: bool, ) -> ApplyResult where Self: Sized, { let mut applier = Applier::new(snapshots, filter, adapter, skip_enum_variants); // Use unsafe to avoid cloning the entire module, which would double the memory usage // We read the module out, map it, then write it back // See https://github.com/tracel-ai/burn/issues/3754 unsafe { // Read the module out of self (moves it, leaving self in undefined state) let module = core::ptr::read(self as *const Self); // Map the module to create a new one with updated tensors let new_module = module.map(&mut applier); // Write the new module back to self core::ptr::write(self as *mut Self, new_module); } applier.into_result() } /// Saves tensor snapshots into a [`ModuleStore`]. /// /// This method allows using a `ModuleStore` implementation to handle the /// collection and writing logic in a configurable way. /// /// # Arguments /// /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors fn save_into

(&self, store: &mut P) -> Result<(), P::Error> where P: ModuleStore, { store.collect_from(self) } /// Loads tensor data from a [`ModuleStore`]. /// /// This method allows using a `ModuleStore` implementation to handle the /// loading and application logic in a configurable way. /// /// # Arguments /// /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors fn load_from

(&mut self, store: &mut P) -> Result where P: ModuleStore, { store.apply_to(self) } } /// A trait for handling module storage operations. /// /// `ModuleStore` provides a unified interface for saving and loading module /// tensor data with support for various storage formats and advanced features like filtering, /// remapping, and metadata handling. pub trait ModuleStore { /// The error type that can be returned during storage operations. /// /// This should be a format-specific error type that provides detailed /// information about what went wrong (e.g., I/O errors, format violations, /// unsupported tensor types). type Error: core::fmt::Debug + core::fmt::Display; /// Collect tensor data from a module and store it to storage. /// /// This method traverses the module structure, collects all tensor data /// according to the store's configuration (filters, remapping, etc.), /// and writes it to the underlying storage. /// /// # Arguments /// /// * `module` - The module to collect tensor data from. The module must /// implement `ModuleSnapshot` to provide tensor access. /// /// # Returns /// /// * `Ok(())` - If all tensors were successfully collected and stored /// * `Err(Self::Error)` - If an error occurred during collection or writing fn collect_from>( &mut self, module: &M, ) -> Result<(), Self::Error>; /// Load stored tensor data and apply it to a module. /// /// This method reads tensor data from storage and applies it to the provided /// module. The operation is flexible and can handle partial matches, missing /// tensors, and extra tensors in the storage. /// /// # Arguments /// /// * `module` - The module to apply tensor data to. The module must /// implement `ModuleSnapshot` to allow tensor updates. /// /// # Returns /// /// * `Ok(ApplyResult)` - Detailed information about the apply operation: /// - `applied`: List of successfully applied tensor names /// - `missing`: Tensors expected by the module but not found in storage /// - `skipped`: Tensors in storage that were not applied (filtered or not needed) /// - `errors`: Non-critical errors that occurred during apply /// * `Err(Self::Error)` - If a critical error prevented the apply operation fn apply_to>( &mut self, module: &mut M, ) -> Result; /// Get a single tensor snapshot by name. /// /// This method provides direct access to individual tensors in storage without /// requiring a module. The returned `TensorSnapshot` uses lazy loading - tensor /// data is only materialized when `to_data()` is called. /// /// **Note:** Key remapping is applied, so use the remapped name if configured. /// Filters are NOT applied - use `apply_to()` for filtered loading. /// /// Results are cached after the first call for efficient repeated access. /// /// # Arguments /// /// * `name` - The tensor name/path (e.g., "encoder.layer1.weight") /// /// # Returns /// /// * `Ok(Some(&TensorSnapshot))` - Reference to the tensor snapshot if found /// * `Ok(None)` - If no tensor with that name exists /// * `Err(Self::Error)` - If an error occurred accessing storage /// /// # Example /// /// ```rust,ignore /// let mut store = BurnpackStore::from_file("model.bpk"); /// if let Some(snapshot) = store.get_snapshot("encoder.weight")? { /// println!("Shape: {:?}", snapshot.shape); /// println!("Dtype: {:?}", snapshot.dtype); /// let data = snapshot.to_data()?; // Lazy load /// } /// ``` fn get_snapshot(&mut self, name: &str) -> Result, Self::Error>; /// Get all tensor snapshots from storage as an ordered map. /// /// This method returns all tensors in storage as lazy-loading snapshots, /// organized in a `BTreeMap` for efficient lookup by name. The map preserves /// alphabetical ordering of tensor names. /// /// **Note:** This returns ALL tensors in storage, regardless of any filter /// settings. Filters are only applied during `apply_to()`. Key remapping /// IS applied, so tensor names reflect any configured remapping. /// /// Results are cached after the first call for efficient repeated access. /// /// # Returns /// /// * `Ok(&BTreeMap)` - Reference to all tensor snapshots /// * `Err(Self::Error)` - If an error occurred accessing storage /// /// # Example /// /// ```rust,ignore /// let mut store = SafetensorsStore::from_file("model.safetensors"); /// let snapshots = store.get_all_snapshots()?; /// for (name, snapshot) in snapshots { /// println!("{}: {:?}", name, snapshot.shape); /// } /// ``` fn get_all_snapshots(&mut self) -> Result<&BTreeMap, Self::Error>; /// Get all tensor names/keys in storage. /// /// This method returns the names of all tensors in storage. /// Useful for inspecting storage contents or checking if specific tensors exist. /// /// **Note:** Returns ALL tensor names regardless of filter settings. /// Key remapping IS applied, so names reflect any configured remapping. /// /// # Returns /// /// * `Ok(Vec)` - All tensor names in storage /// * `Err(Self::Error)` - If an error occurred accessing storage /// /// # Example /// /// ```rust,ignore /// let mut store = PytorchStore::from_file("model.pth"); /// let keys = store.keys()?; /// println!("Tensors in file: {:?}", keys); /// ``` fn keys(&mut self) -> Result, Self::Error>; } // Blanket implementation for all modules impl> ModuleSnapshot for M {} ================================================ FILE: crates/burn-tch/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "LibTorch backend for the Burn framework using the tch bindings." documentation = "https://docs.rs/burn-tch" edition.workspace = true keywords = ["deep-learning", "machine-learning", "data"] license.workspace = true name = "burn-tch" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tch" version.workspace = true [lints] workspace = true [features] default = ["std"] std = ["burn-backend/std"] doc = ["tch/doc-only"] tracing = [ "burn-backend/tracing", ] [dependencies] burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } libc = { workspace = true } log = { workspace = true } tch = { workspace = true, features = ["download-libtorch"] } torch-sys = { workspace = true } # for build script lib dir detection [build-dependencies] cc = "1.2.56" [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-tch/README.md ================================================ # Burn Torch Backend [Burn](https://github.com/tracel-ai/burn) Torch backend [![Current Crates.io Version](https://img.shields.io/crates/v/burn-tch.svg)](https://crates.io/crates/burn-tch) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tch/blob/master/README.md) This crate provides a Torch backend for [Burn](https://github.com/tracel-ai/burn) utilizing the [`tch-rs`](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the [PyTorch](https://pytorch.org/) C++ API. The backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html) (multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS). ## Installation [`tch-rs`](https://github.com/LaurentMazare/tch-rs) requires the C++ PyTorch library (LibTorch) to be available on your system. By default, the CPU distribution is installed for LibTorch v2.9.0 as required by `tch-rs`.

CUDA To install the latest compatible CUDA distribution, set the `TORCH_CUDA_VERSION` environment variable before the `tch-rs` dependency is retrieved with `cargo`. ```shell export TORCH_CUDA_VERSION=cu128 ``` On Windows: ```powershell $Env:TORCH_CUDA_VERSION = "cu128" ``` > Note: `tch` doesn't expose the downloaded libtorch directory on Windows when using the automatic > download feature, so the `torch_cuda.dll` cannot be detected properly during build. In this case, > you can set the `LIBTORCH` environment variable to point to the `libtorch/` folder in `torch-sys` > `OUT_DIR` (or move the downloaded lib to a different folder and point to it). For example, running the validation sample for the first time could be done with the following commands: ```shell export TORCH_CUDA_VERSION=cu128 cargo run --bin cuda --release ``` **Important:** make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the [toolkit driver version table](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id4) or [minimum required driver version](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#minor-version-compatibility) (limited feature-set, might not work with all operations).

Once your installation is complete, you should be able to build/run your project. You can also validate your installation by running the appropriate `cpu`, `cuda` or `mps` sample as below. ```shell cargo run --bin cpu --release cargo run --bin cuda --release cargo run --bin mps --release ``` _Note: no MPS distribution is available for automatic download at this time, please check out the [manual instructions](#metal-mps)._ ### Manual Download To install `tch-rs` with a different LibTorch distribution, you will have to manually download the desired LibTorch distribution. The instructions are detailed in the sections below for each platform. | Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM | | :------------------------ | :----------------------------: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: | | [CPU](#cpu) | Yes | No | Yes | Yes | Yes | Yes | Yes | No | | [CUDA](#cuda) | Yes [[1]](#cpu-sup) | Yes | Yes | No | Yes | No | No | No | | [Metal (MPS)](#metal-mps) | No | Yes | No | Yes | No | No | No | No | | Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No | [1] The LibTorch CUDA distribution also comes with CPU support. #### CPU
🐧 Linux First, download the LibTorch CPU distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH ```

🍎 Mac First, download the LibTorch CPU distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `DYLD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH ```

🪟 Windows First, download the LibTorch CPU distribution. ```powershell wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip -OutFile libtorch.zip Expand-Archive libtorch.zip ``` Then, set the `LIBTORCH` environment variable and append the library to your path as with the PowerShell commands below before building `burn-tch` or a crate which depends on it. ```powershell $Env:LIBTORCH = "/absolute/path/to/libtorch/" $Env:Path += ";/absolute/path/to/libtorch/" ```

#### CUDA LibTorch 2.9.0 currently includes binary distributions with CUDA 12.6, 12.8 or 13.0 runtimes. The manual installation instructions are detailed below for CUDA 12.6, but can be applied to the other CUDA versions by replacing `cu126` with the corresponding version string (e.g., `cu130`).
🐧 Linux First, download the LibTorch CUDA 12.6 distribution. ```shell wget -O libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-shared-with-deps-2.9.0%2Bcu126.zip unzip libtorch.zip ``` Then, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables before building `burn-tch` or a crate which depends on it. ```shell export LIBTORCH=/absolute/path/to/libtorch/ export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH ``` **Note:** make sure your CUDA installation is in your `PATH` and `LD_LIBRARY_PATH`.

🪟 Windows First, download the LibTorch CUDA 12.6 distribution. ```powershell wget https://download.pytorch.org/libtorch/cu126/libtorch-win-shared-with-deps-2.9.0%2Bcu126.zip -OutFile libtorch.zip Expand-Archive libtorch.zip ``` Then, set the `LIBTORCH` environment variable and append the library to your path as with the PowerShell commands below before building `burn-tch` or a crate which depends on it. ```powershell $Env:LIBTORCH = "/absolute/path/to/libtorch/" $Env:Path += ";/absolute/path/to/libtorch/" ```

#### Metal (MPS) There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation. _Note: MPS acceleration is available on MacOS 12.3+._ ```shell pip install torch==2.9.0 numpy==1.26.4 setuptools export LIBTORCH_USE_PYTORCH=1 export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH ``` **Note:** if `venv` is used, it should be activated during coding and building, or the compiler may not work properly. ## Example Usage For a simple example, check out any of the test programs in [`src/bin/`](./src/bin/). Each program sets the device to use and performs a simple element-wise addition. For a more complete example using the `tch` backend, take a loot at the [Burn mnist example](https://github.com/tracel-ai/burn/tree/main/examples/mnist). ## Too many environment variables? Try `.cargo/config.toml` ([cargo book](https://doc.rust-lang.org/cargo/reference/config.html#env)). Instead of setting the environments in your shell, you can manually add them to your `.cargo/config.toml`: ```toml [env] LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib" LIBTORCH = "/absolute/path/to/libtorch/libtorch" ``` Or use bash commands below: ```bash mkdir .cargo cat < .cargo/config.toml [env] LD_LIBRARY_PATH = "/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH" LIBTORCH = "/absolute/path/to/libtorch/libtorch" EOF ``` This will automatically include the old `LD_LIBRARY_PATH` value in the new one. ================================================ FILE: crates/burn-tch/build.rs ================================================ // The LIBTORCH environment variable can be used to specify the directory // where libtorch has been installed. // When not specified this script downloads the cpu version for libtorch // and extracts it in OUT_DIR. // // On Linux, the TORCH_CUDA_VERSION environment variable can be used, // like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch. use std::path::{Path, PathBuf}; use std::{env, fs}; const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" import torch from torch.utils import cpp_extension print('LIBTORCH_VERSION:', torch.__version__.split('+')[0]) print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI) for include_path in cpp_extension.include_paths(): print('LIBTORCH_INCLUDE:', include_path) for library_path in cpp_extension.library_paths(): print('LIBTORCH_LIB:', library_path) "; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Os { Linux, Macos, Windows, } #[allow(dead_code)] #[derive(Debug, Clone)] struct SystemInfo { os: Os, cxx11_abi: String, libtorch_include_dirs: Vec, libtorch_lib_dir: PathBuf, } fn env_var_rerun(name: &str) -> Result { println!("cargo:rerun-if-env-changed={name}"); env::var(name) } impl SystemInfo { fn new() -> Option { let os = match env::var("CARGO_CFG_TARGET_OS") .expect("Unable to get TARGET_OS") .as_str() { "linux" => Os::Linux, "windows" => Os::Windows, "macos" => Os::Macos, os => panic!("unsupported TARGET_OS '{os}'"), }; // Locate the currently active Python binary, similar to: // https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547 let python_interpreter = match os { Os::Windows => PathBuf::from("python.exe"), Os::Linux | Os::Macos => { if env::var_os("VIRTUAL_ENV").is_some() { PathBuf::from("python") } else { PathBuf::from("python3") } } }; let mut libtorch_include_dirs = vec![]; let mut libtorch_lib_dir = None; let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() { let output = std::process::Command::new(&python_interpreter) .arg("-c") .arg(PYTHON_PRINT_PYTORCH_DETAILS) .output() .expect("error running python interpreter"); let mut cxx11_abi = None; for line in String::from_utf8_lossy(&output.stdout).lines() { match line.strip_prefix("LIBTORCH_CXX11: ") { Some("True") => cxx11_abi = Some("1".to_owned()), Some("False") => cxx11_abi = Some("0".to_owned()), _ => {} } if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { libtorch_include_dirs.push(PathBuf::from(path)) } if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { libtorch_lib_dir = Some(PathBuf::from(path)) } } match cxx11_abi { Some(cxx11_abi) => cxx11_abi, None => panic!("no cxx11 abi returned by python {output:?}"), } } else { let libtorch = Self::prepare_libtorch_dir(os)?; let includes = env_var_rerun("LIBTORCH_INCLUDE") .map(PathBuf::from) .unwrap_or_else(|_| libtorch.clone()); let lib = env_var_rerun("LIBTORCH_LIB") .map(PathBuf::from) .unwrap_or_else(|_| libtorch.clone()); libtorch_include_dirs.push(includes.join("include")); libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include")); if lib.ends_with("lib") { // DEP_TCH_LIBTORCH_LIB might already point to /lib libtorch_lib_dir = Some(lib); } else { libtorch_lib_dir = Some(lib.join("lib")); } env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()) }; let libtorch_lib_dir = libtorch_lib_dir?; Some(Self { os, cxx11_abi, libtorch_include_dirs, libtorch_lib_dir, }) } fn check_system_location(os: Os) -> Option { match os { Os::Linux => Path::new("/usr/lib/libtorch.so") .exists() .then(|| PathBuf::from("/usr")), _ => None, } } fn prepare_libtorch_dir(os: Os) -> Option { if let Ok(libtorch) = env_var_rerun("DEP_TCH_LIBTORCH_LIB") { Some(PathBuf::from(libtorch)) } else if let Ok(libtorch) = env_var_rerun("LIBTORCH") { Some(PathBuf::from(libtorch)) } else if let Some(pathbuf) = Self::check_system_location(os) { Some(pathbuf) } else { check_out_dir() } } fn make(&self, use_cuda: bool, use_hip: bool) { let cuda_dependency = if use_cuda || use_hip { "src/cuda_hack/dummy_cuda_dependency.cpp" } else { "src/cuda_hack/fake_cuda_dependency.cpp" }; println!("cargo:rerun-if-changed={cuda_dependency}"); match self.os { Os::Linux | Os::Macos => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) .flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display())) .flag("-std=c++17") .flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) .files(&[cuda_dependency]) .compile("burn-tch"); } Os::Windows => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) .flag("/std:c++17") .files(&[cuda_dependency]) .compile("burn-tch"); } }; } fn make_cpu() { let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp"; println!("cargo:rerun-if-changed={cuda_dependency}"); let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); match os.as_str() { "windows" => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .flag("/std:c++17") .files(&[cuda_dependency]) .compile("burn-tch"); } _ => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .flag("-std=c++17") .files(&[cuda_dependency]) .compile("tch"); } }; } } fn check_out_dir() -> Option { let out_dir = env_var_rerun("OUT_DIR").ok()?; let libtorch_dir = PathBuf::from(out_dir).join("libtorch"); libtorch_dir.exists().then_some(libtorch_dir) } fn main() { let system_info = SystemInfo::new(); let out_dir = env_var_rerun("OUT_DIR").expect("Failed to get out dir"); let mut gpu_found = false; let found_dir = system_info.is_some(); if let Some(system_info) = &system_info { let si_lib = &system_info.libtorch_lib_dir; let use_cuda = si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists(); let use_hip = si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists(); system_info.make(use_cuda, use_hip); gpu_found = use_cuda || use_hip; } else { SystemInfo::make_cpu(); } let check_file = PathBuf::from(out_dir).join("tch_gpu_check.rs"); if gpu_found { fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap(); } else { let message = if !found_dir { r#"Could not find libtorch dir. If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions. If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."# } else { "No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device" }; fs::write(check_file, format!("panic!(\"{message}\")")).unwrap(); } } ================================================ FILE: crates/burn-tch/src/backend.rs ================================================ use std::marker::PhantomData; use crate::IntoKind; use super::TchTensor; use super::element::TchElement; use burn_backend::backend::{Backend, DeviceId, DeviceOps, ExecutionError}; use burn_backend::ops::IntTensorOps; #[derive(Clone, Copy, Debug, PartialEq, Eq)] /// The device struct when using the `tch` backend. /// /// Note that you need to provide the device index when using Cuda. /// /// # Example /// /// ```no_run /// use burn_tch::LibTorchDevice; /// /// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU /// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU /// let device_cpu = LibTorchDevice::Cpu; // CPU /// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders /// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan /// ``` #[derive(Default)] pub enum LibTorchDevice { /// CPU device. #[default] Cpu, /// Cuda device with the given index. The index is the index of the Cuda device in the list of /// all Cuda devices found on the system. Cuda(usize), /// Metal Performance Shaders device. Mps, /// Vulkan device. Vulkan, } impl From for tch::Device { #[allow( unreachable_code, reason = "CUDA branch always panics if the library is missing" )] fn from(device: LibTorchDevice) -> Self { match device { LibTorchDevice::Cpu => tch::Device::Cpu, LibTorchDevice::Cuda(_num) => { include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs")); tch::Device::Cuda(_num) } LibTorchDevice::Mps => tch::Device::Mps, LibTorchDevice::Vulkan => tch::Device::Vulkan, } } } impl From for LibTorchDevice { fn from(device: tch::Device) -> Self { match device { tch::Device::Cpu => LibTorchDevice::Cpu, tch::Device::Cuda(num) => LibTorchDevice::Cuda(num), tch::Device::Mps => LibTorchDevice::Mps, tch::Device::Vulkan => LibTorchDevice::Vulkan, } } } impl burn_backend::Device for LibTorchDevice { fn from_id(device_id: DeviceId) -> Self { match device_id.type_id { 0 => Self::Cuda(device_id.index_id as usize), 1 => Self::Mps, 2 => Self::Cpu, 3 => Self::Vulkan, _ => LibTorchDevice::Cpu, } } fn to_id(&self) -> DeviceId { match self { LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32), LibTorchDevice::Mps => DeviceId::new(1, 0), LibTorchDevice::Cpu => DeviceId::new(2, 0), LibTorchDevice::Vulkan => DeviceId::new(3, 0), } } fn device_count(_type_id: u16) -> usize { // TODO: Somehow find the info using the tch API. 1 } } impl DeviceOps for LibTorchDevice {} /// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations. /// /// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but /// requires `LibTorch` to be installed correctly. The CPU version can be downloaded /// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment /// variable. For more complex configurations, check out the manual installation for /// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch). /// /// Refer to the [tch] crate for more information. #[derive(Clone, Copy, Default, Debug)] pub struct LibTorch { _e: PhantomData, } impl Backend for LibTorch { type Device = LibTorchDevice; type FloatTensorPrimitive = TchTensor; type FloatElem = E; type IntTensorPrimitive = TchTensor; type IntElem = i64; type BoolTensorPrimitive = TchTensor; type BoolElem = bool; type QuantizedTensorPrimitive = TchTensor; fn seed(_device: &Self::Device, seed: u64) { tch::manual_seed(seed as i64); } fn ad_enabled(_device: &Self::Device) -> bool { false } fn name(device: &Self::Device) -> String { match device { LibTorchDevice::Cpu => "libtorch", LibTorchDevice::Cuda(_) => "libtorch", LibTorchDevice::Mps => "libtorch", LibTorchDevice::Vulkan => "libtorch", } .to_string() } fn sync(device: &Self::Device) -> Result<(), ExecutionError> { match device { LibTorchDevice::Cpu => (), LibTorchDevice::Cuda(index) => { tch::Cuda::synchronize(*index as i64); } _ => { // When there is no explicit way to synchronize, we write and read one value to sync burn_backend::read_sync(Self::int_into_data(Self::int_zeros( [1].into(), device, E::dtype().into(), ))) .unwrap(); } }; Ok(()) } fn dtype_usage( _device: &Self::Device, dtype: burn_backend::DType, ) -> burn_backend::DTypeUsageSet { if dtype.try_into_kind().is_ok() { burn_backend::DTypeUsage::general() } else { burn_backend::DTypeUsageSet::empty() } } } ================================================ FILE: crates/burn-tch/src/bin/cpu.rs ================================================ use burn_backend::{TensorMetadata, ops::FloatTensorOps}; use burn_tch::{LibTorch, LibTorchDevice}; fn main() { type B = LibTorch; let device = LibTorchDevice::Cpu; // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device); let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into()); // Print the element-wise addition of the two tensors. println!("{}", B::float_add(tensor_1, tensor_2)); } ================================================ FILE: crates/burn-tch/src/bin/cuda.rs ================================================ use burn_backend::{TensorMetadata, ops::FloatTensorOps}; use burn_tch::{LibTorch, LibTorchDevice}; fn main() { assert!( tch::utils::has_cuda(), "Could not detect valid CUDA configuration" ); type B = LibTorch; let device = LibTorchDevice::Cuda(0); // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device); let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into()); // Print the element-wise addition of the two tensors. println!("{}", B::float_add(tensor_1, tensor_2)); } ================================================ FILE: crates/burn-tch/src/bin/mps.rs ================================================ use burn_backend::{TensorMetadata, ops::FloatTensorOps}; use burn_tch::{LibTorch, LibTorchDevice}; fn main() { assert!(tch::utils::has_mps(), "Could not detect MPS"); type B = LibTorch; let device = LibTorchDevice::Mps; // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device); let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into()); // Print the element-wise addition of the two tensors. println!("{}", B::float_add(tensor_1, tensor_2)); } ================================================ FILE: crates/burn-tch/src/cuda_hack/dummy_cuda_dependency.cpp ================================================ #include #include #include #include using namespace std; extern "C" { void dummy_cuda_dependency(); } struct cublasContext; namespace at { namespace cuda { cublasContext *getCurrentCUDABlasHandle(); int warp_size(); } // namespace cuda } // namespace at char *magma_strerror(int err); void dummy_cuda_dependency() { try { at::cuda::getCurrentCUDABlasHandle(); at::cuda::warp_size(); } catch (std::exception &e) { if (getenv("TCH_PRINT_CUDA_INIT_ERROR") != nullptr) { std::cerr << "error initializing cuda: " << e.what() << std::endl; } } } ================================================ FILE: crates/burn-tch/src/cuda_hack/fake_cuda_dependency.cpp ================================================ extern "C" { void dummy_cuda_dependency(); } void dummy_cuda_dependency() {} ================================================ FILE: crates/burn-tch/src/element.rs ================================================ use burn_backend::Element; use burn_backend::{bf16, f16}; /// The element type for the tch backend. pub trait TchElement: Element + tch::kind::Element { /// Returns the associated tensor kind for [`tch::kind::Element`]. fn kind() -> tch::Kind { Self::KIND } } impl TchElement for f64 {} impl TchElement for f32 {} impl TchElement for f16 {} impl TchElement for bf16 { fn kind() -> tch::Kind { let mut kind = ::KIND; // Incorrect kind mapping in tch definitions, force bfloat16 if matches!(Self::dtype(), burn_backend::DType::BF16) && kind == tch::Kind::Half { kind = tch::Kind::BFloat16 } kind } } impl TchElement for i64 {} impl TchElement for i32 {} impl TchElement for i16 {} impl TchElement for i8 {} impl TchElement for u8 {} impl TchElement for bool {} #[cfg(test)] mod tests { use super::*; #[test] fn test_elem_kinds() { assert_eq!(f64::kind(), tch::Kind::Double); assert_eq!(f32::kind(), tch::Kind::Float); assert_eq!(f16::kind(), tch::Kind::Half); assert_eq!(bf16::kind(), tch::Kind::BFloat16); assert_eq!(i64::kind(), tch::Kind::Int64); assert_eq!(i32::kind(), tch::Kind::Int); assert_eq!(i16::kind(), tch::Kind::Int16); assert_eq!(i8::kind(), tch::Kind::Int8); assert_eq!(bool::kind(), tch::Kind::Bool); } } ================================================ FILE: crates/burn-tch/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] #![allow(clippy::single_range_in_vec_init)] //! Burn Tch Backend mod backend; mod element; mod ops; mod tensor; pub use backend::*; pub use element::*; pub use tensor::*; ================================================ FILE: crates/burn-tch/src/ops/activation.rs ================================================ use crate::{LibTorch, TchTensor, element::TchElement}; use burn_backend::ops::ActivationOps; impl ActivationOps for LibTorch { fn relu(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()) } fn gelu(tensor: TchTensor) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.gelu_("none"), |tensor| tensor.gelu("none"), ) } fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none"); TchTensor::from_existing(tensor, storage) } fn sigmoid(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid()) } fn log_sigmoid(tensor: TchTensor) -> TchTensor { // NOTE: we don't override log_sigmoid_backward because Torch has a special backward // formula that uses a buffer with computed values from the forward pass // no in-place log_sigmoid_ let storage = tensor.storage.clone(); let tensor = tensor.tensor.log_sigmoid(); TchTensor::from_existing(tensor, storage) } } ================================================ FILE: crates/burn-tch/src/ops/base.rs ================================================ use burn_backend::{Shape, TensorMetadata}; use tch::Scalar; use crate::{LibTorchDevice, TchShape, TchTensor}; pub struct TchOps { // e: PhantomData, } impl TchOps { pub fn to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor { let device = (*device).into(); // We have to manually check if the device is the same, since when it's the case, we need to keep // the same storage reference and not create a new one. if tensor.tensor.device() == device { return tensor; } TchTensor::new(tensor.tensor.to(device)) } pub fn reshape(tensor: TchTensor, shape: Shape) -> TchTensor { let shape_tch: TchShape = shape.into(); TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) } pub fn repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { let mut dims = vec![1; tensor.shape().num_dims()]; dims[dim] = times as i64; let tensor = tch::Tensor::repeat(&tensor.tensor, dims); TchTensor::new(tensor) } pub fn slice_with_steps(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor { let storage = tensor.storage.clone(); let mut tensor = tensor.tensor.shallow_clone(); for (dim, slice) in slices.iter().enumerate() { let dim_i64 = dim as i64; // Convert slice to range using a dummy size (we'll use tensor dimensions) let dim_size = tensor.size()[dim]; let range = slice.to_range(dim_size as usize); let start = range.start as i64; let end = range.end as i64; let step = slice.step as i64; if step > 0 { // Forward stepping - use native slice tensor = tensor.slice(dim_i64, Some(start), Some(end), step); } else { // Negative stepping - we need to handle the semantics correctly // For negative steps, we iterate backwards from end-1 // PyTorch's negative step works differently than our semantics // We need to reverse the selected range // First get the slice with positive step tensor = tensor.slice(dim_i64, Some(start), Some(end), 1); // Then reverse it and apply the step if step == -1 { // Simple reversal tensor = tensor.flip([dim_i64]); } else { // Reverse and then take every nth element tensor = tensor.flip([dim_i64]); let abs_step = step.abs(); tensor = tensor.slice(dim_i64, None, None, abs_step); } } } TchTensor::partial(tensor, storage) } pub fn slice_assign( tensor: TchTensor, slices: &[burn_backend::Slice], value: TchTensor, ) -> TchTensor { // PyTorch's narrow operation only supports contiguous slices (step=1) // For non-unit steps, we use advanced indexing as a workaround let all_unit_steps = slices.iter().all(|s| s.step == 1); if all_unit_steps { // Fast path: use narrow and copy_ for unit steps let tch_shape = TchShape::from(tensor.shape()); // Copy the input tensor if we can't mutate it let tensor_original: TchTensor = tensor.unary_ops(|tensor| tensor, |tensor| tensor.copy()); let tensor_original = tensor_original.tensor; let mut tensor = tensor_original.view_(tch_shape.dims); for (i, slice) in slices.iter().enumerate().take(slices.len()) { // Convert Slice to range for narrow operation let dim_size = tensor.size()[i] as usize; let range = slice.to_range(dim_size); let start = range.start as i64; let length = (range.end - range.start) as i64; tensor = tensor.narrow(i as i64, start, length); } tensor.copy_(&value.tensor); TchTensor::new(tensor_original) } else { // Workaround for non-unit steps: use PyTorch's index_put operation // This generates explicit indices for the slice and uses advanced indexing let tensor_shape = tensor.shape(); let dims = tensor_shape.clone(); // Copy the tensor since we'll modify it let result_tensor = tensor.tensor.shallow_clone(); // Use advanced indexing to set the values Self::slice_assign_with_advanced_indexing(result_tensor, slices, value.tensor, &dims) } } /// Generate indices for a slice with potentially non-unit step. /// For negative steps, generates indices in reverse order. fn generate_slice_indices(slice: &burn_backend::Slice, dim_size: usize) -> Vec { let step = slice.step; let range = slice.to_range(dim_size); let mut indices = Vec::new(); if step > 0 { let mut idx = range.start as i64; while idx < range.end as i64 { indices.push(idx); idx += step as i64; } } else if step < 0 { // For negative steps, iterate backwards through the range let mut idx = (range.end - 1) as i64; while idx >= range.start as i64 { indices.push(idx); idx += step as i64; // step is negative, so this decreases } } indices } /// Implementation using advanced indexing for non-unit steps. /// Uses PyTorch's index_put operation to assign values at specific indices. fn slice_assign_with_advanced_indexing( mut tensor: tch::Tensor, slices: &[burn_backend::Slice], value: tch::Tensor, dims: &[usize], ) -> TchTensor { // Generate all index combinations for the sliced regions let mut index_sets: Vec> = Vec::new(); for (i, slice) in slices.iter().enumerate() { let dim_size = if i < dims.len() { dims[i] } else { 1 }; let indices = Self::generate_slice_indices(slice, dim_size); index_sets.push(indices); } // For unsliced dimensions, include all indices for &dim_size in dims.iter().skip(slices.len()) { let indices: Vec = (0..dim_size as i64).collect(); index_sets.push(indices); } // Convert index sets to tensors for index_put let mut final_indices = Vec::new(); let total_elements = index_sets.iter().map(|s| s.len()).product::(); // Build flattened index arrays for each dimension using cartesian product // This creates the index tensors needed for PyTorch's index_put operation for dim_idx in 0..index_sets.len() { let mut dim_indices = Vec::with_capacity(total_elements); let repeat = index_sets[dim_idx + 1..] .iter() .map(|s| s.len()) .product::() .max(1); let tile = index_sets[..dim_idx] .iter() .map(|s| s.len()) .product::() .max(1); for _ in 0..tile { for &idx in &index_sets[dim_idx] { for _ in 0..repeat { dim_indices.push(idx); } } } let indices_tensor = tch::Tensor::from_slice(&dim_indices).to_device(tensor.device()); final_indices.push(indices_tensor); } // PyTorch's index_put handles assignment correctly for negative steps // following NumPy semantics: values[i] goes to selected_indices[i] let value_flat = value.view(-1); // Use index_put to assign values - convert to Option let final_indices_opt: Vec> = final_indices.into_iter().map(Some).collect(); tensor = tensor.index_put(&final_indices_opt, &value_flat, false); TchTensor::new(tensor) } pub fn gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false); TchTensor::from_existing(tensor, storage) } pub fn scatter( dim: usize, tensor: TchTensor, indices: TchTensor, value: TchTensor, ) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor .tensor .scatter_add(dim as i64, &indices.tensor, &value.tensor); TchTensor::from_existing(tensor, storage) } pub fn index_select_dim(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor); TchTensor::from_existing(tensor, storage) } pub fn select_assign( tensor: TchTensor, dim: usize, indices: TchTensor, value: TchTensor, ) -> TchTensor { tensor.clone().unary_ops( |mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor), |tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor), ) } pub fn cat(tensors: Vec, dim: usize) -> TchTensor { let tensors: Vec = tensors .into_iter() .map(|t| t.tensor.shallow_clone()) .collect(); let tensor = tch::Tensor::cat(&tensors, dim as i64); TchTensor::new(tensor) } pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool), |lhs, rhs| lhs.eq_tensor(rhs), ) } pub fn equal_elem + Clone>(lhs: TchTensor, rhs: S) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool), |tensor| tensor.eq(rhs.clone().into()), ) } pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool), |lhs, rhs| lhs.greater_tensor(rhs), ) } pub fn greater_elem + Clone>(lhs: TchTensor, rhs: S) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool), |tensor| tensor.greater(rhs.clone().into()), ) } pub fn greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool), |lhs, rhs| lhs.greater_equal_tensor(rhs), ) } pub fn greater_equal_elem + Clone>(lhs: TchTensor, rhs: S) -> TchTensor { lhs.unary_ops( |mut tensor| { tensor .greater_equal_(rhs.clone().into()) .to_kind(tch::Kind::Bool) }, |tensor| tensor.greater_equal(rhs.clone().into()), ) } pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool), |lhs, rhs| lhs.less_tensor(rhs), ) } pub fn lower_elem + Clone>(lhs: TchTensor, rhs: S) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool), |tensor| tensor.less(rhs.clone().into()), ) } pub fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool), |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool), |lhs, rhs| lhs.less_equal_tensor(rhs), ) } pub fn lower_equal_elem + Clone>(lhs: TchTensor, rhs: S) -> TchTensor { lhs.unary_ops( |mut tensor| { tensor .less_equal_(rhs.clone().into()) .to_kind(tch::Kind::Bool) }, |tensor| tensor.less_equal(rhs.clone().into()), ) } pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_add_(rhs).unwrap(), |lhs, rhs| rhs.f_add_(lhs).unwrap(), |lhs, rhs| lhs.f_add(rhs).unwrap(), ) } pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_sub_(rhs).unwrap(), |lhs, rhs| lhs.f_sub(rhs).unwrap(), |lhs, rhs| lhs.f_sub(rhs).unwrap(), ) } pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_mul_(rhs).unwrap(), |lhs, rhs| rhs.f_mul_(lhs).unwrap(), |lhs, rhs| lhs.f_mul(rhs).unwrap(), ) } pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_div_(rhs).unwrap(), |lhs, rhs| lhs.f_div(rhs).unwrap(), |lhs, rhs| lhs.f_div(rhs).unwrap(), ) } pub fn remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_remainder_tensor_(rhs).unwrap(), |lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(), |lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(), ) } pub fn mean(tensor: TchTensor) -> TchTensor { // view as 1d tensor let tensor = tensor.tensor.mean(tensor.tensor.kind()).view(1); TchTensor::new(tensor) } pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor .tensor .mean_dim(Some([dim as i64].as_slice()), true, tensor.tensor.kind()), tensor.storage, ) } pub fn sum(tensor: TchTensor) -> TchTensor { // view as 1d tensor let tensor = tensor.tensor.sum(tensor.tensor.kind()).view(1); TchTensor::new(tensor) } pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor.tensor.sum_dim_intlist( Some([dim as i64].as_slice()), true, tensor.tensor.kind(), ), tensor.storage, ) } pub fn prod(tensor: TchTensor) -> TchTensor { // view as 1d tensor let tensor = tensor.tensor.prod(tensor.tensor.kind()).view(1); TchTensor::new(tensor) } pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor .tensor .prod_dim_int(dim as i64, true, tensor.tensor.kind()), tensor.storage, ) } pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()), tensor.storage, ) } pub fn cumprod(tensor: TchTensor, dim: usize) -> TchTensor { TchTensor::from_existing( tensor.tensor.cumprod(dim as i64, tensor.tensor.kind()), tensor.storage, ) } pub fn cummin(tensor: TchTensor, dim: usize) -> TchTensor { let (values, _indices) = tensor.tensor.cummin(dim as i64); TchTensor::from_existing(values, tensor.storage) } pub fn cummax(tensor: TchTensor, dim: usize) -> TchTensor { // cummax returns (values, indices) tuple in PyTorch, we only need values let (values, _indices) = tensor.tensor.cummax(dim as i64); TchTensor::from_existing(values, tensor.storage) } pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.argmax(dim as i64, true); TchTensor::from_existing(tensor, storage) } pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.argmin(dim as i64, true); TchTensor::from_existing(tensor, storage) } pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor { let storage = tensor.storage.clone(); let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true); TchTensor::from_existing(tensor, storage) } pub fn max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { let storage = tensor.storage.clone(); let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true); let tensor = TchTensor::from_existing(tensor, storage); let indices = TchTensor::new(indices); (tensor, indices) } pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor { let storage = tensor.storage.clone(); let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true); TchTensor::from_existing(tensor, storage) } pub fn min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { let storage = tensor.storage.clone(); let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true); let tensor = TchTensor::from_existing(tensor, storage); let indices = TchTensor::new(indices); (tensor, indices) } pub fn clamp_min + Clone + Copy>(tensor: TchTensor, min: S) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.clamp_min_(min), |tensor| tensor.clamp_min(min), ) } pub fn clamp_max + Clone + Copy>(tensor: TchTensor, max: S) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.clamp_max_(max), |tensor| tensor.clamp_max(max), ) } pub fn clamp + Clone + Copy>( tensor: TchTensor, min: S, max: S, ) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.clamp_(min, max), |tensor| tensor.clamp(min, max), ) } pub fn swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor { let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64); TchTensor::new(tensor) } pub fn permute(tensor: TchTensor, axes: &[usize]) -> TchTensor { let tensor = tensor .tensor .permute(axes.iter().map(|x| *x as i64).collect::>()); TchTensor::new(tensor) } pub fn flip(tensor: TchTensor, axes: &[usize]) -> TchTensor { let dims = axes.iter().map(|x| *x as i64).collect::>(); let tensor = tensor.tensor.flip(dims); TchTensor::new(tensor) } pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( tensor, exponent, |lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(), ) } pub fn sign(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign()) } pub fn expand(tensor: TchTensor, shape: Shape) -> TchTensor { let storage = tensor.storage.clone(); let broadcasted_tensor = tensor.tensor.broadcast_to(TchShape::from(shape).dims); TchTensor::from_existing(broadcasted_tensor, storage) } pub fn unfold(tensor: TchTensor, dim: usize, size: usize, step: usize) -> TchTensor { let storage = tensor.storage.clone(); let uf_tensor = tensor.tensor.unfold(dim as i64, size as i64, step as i64); TchTensor::from_existing(uf_tensor, storage) } pub fn sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.sort(dim as i64, descending).0) } pub fn sort_with_indices( tensor: TchTensor, dim: usize, descending: bool, ) -> (TchTensor, TchTensor) { let sorted = tensor.tensor.sort(dim as i64, descending); (TchTensor::new(sorted.0), TchTensor::new(sorted.1)) } pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(), |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(), |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), ) } pub fn bitwise_and_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(), |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(), ) } pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(), |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(), |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(), ) } pub fn bitwise_or_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(), |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(), ) } pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(), |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(), |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(), ) } pub fn bitwise_xor_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(), |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(), ) } pub fn bitwise_not(tensor: TchTensor) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.f_bitwise_not_().unwrap(), |tensor| tensor.f_bitwise_not().unwrap(), ) } pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(), |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), ) } pub fn bitwise_left_shift_scalar + Clone>( tensor: TchTensor, scalar: S, ) -> TchTensor { tensor.unary_ops( |mut tensor| { tensor .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into()) .unwrap() }, |tensor| { tensor .f_bitwise_left_shift_tensor_scalar(scalar.clone().into()) .unwrap() }, ) } pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(), |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), ) } pub fn bitwise_right_shift_scalar + Clone>( tensor: TchTensor, scalar: S, ) -> TchTensor { tensor.unary_ops( |mut tensor| { tensor .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into()) .unwrap() }, |tensor| { tensor .f_bitwise_right_shift_tensor_scalar(scalar.clone().into()) .unwrap() }, ) } pub fn atan2(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.f_atan2_(rhs).unwrap(), |lhs, rhs| lhs.f_atan2(rhs).unwrap(), |lhs, rhs| lhs.f_atan2(rhs).unwrap(), ) } } ================================================ FILE: crates/burn-tch/src/ops/bool_tensor.rs ================================================ use super::TchOps; use crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; use burn_backend::BoolStore; use burn_backend::ExecutionError; use burn_backend::Scalar; use burn_backend::tensor::BoolTensor; use burn_backend::tensor::IntTensor; use burn_backend::{Shape, TensorData, TensorMetadata, ops::BoolTensorOps}; impl BoolTensorOps for LibTorch { fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { burn_backend::DType::Bool(BoolStore::Native) => { TchTensor::from_data::(data, (*device).into()) } _ => unimplemented!("Unsupported dtype for `bool_from_data`"), } } fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { TchOps::repeat_dim(tensor, dim, times) } async fn bool_into_data(tensor: TchTensor) -> Result { let shape = tensor.shape(); let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); Ok(TensorData::new(values.unwrap(), shape)) } fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor { TchOps::to_device(tensor, device) } fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor { TchOps::reshape(tensor, shape) } fn bool_device(tensor: &TchTensor) -> LibTorchDevice { tensor.tensor.device().into() } fn bool_empty(shape: Shape, device: &LibTorchDevice) -> TchTensor { let tensor = tch::Tensor::empty( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), ); TchTensor::new(tensor) } fn bool_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor { let tensor = tch::Tensor::zeros( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), ); TchTensor::new(tensor) } fn bool_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor { let tensor = tch::Tensor::ones( TchShape::from(shape).dims, (tch::Kind::Bool, (*device).into()), ); TchTensor::new(tensor) } fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor { TchOps::slice_with_steps(tensor, slices) } fn bool_slice_assign( tensor: TchTensor, slices: &[burn_backend::Slice], value: TchTensor, ) -> TchTensor { TchOps::slice_assign(tensor, slices, value) } fn bool_cat(tensors: Vec, dim: usize) -> TchTensor { TchOps::cat(tensors, dim) } fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::equal(lhs, rhs) } fn bool_not(tensor: TchTensor) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool), |tensor| tensor.eq(0), ) } fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.logical_and_(rhs), |lhs, rhs| rhs.logical_and_(lhs), |lhs, rhs| lhs.logical_and(rhs), ) } fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( lhs, rhs, |lhs, rhs| lhs.logical_or_(rhs), |lhs, rhs| rhs.logical_or_(lhs), |lhs, rhs| lhs.logical_or(rhs), ) } fn bool_into_int(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.to_kind(tch::Kind::Int64); TchTensor::new(tensor) } fn bool_into_float(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.to_kind(E::kind()); TchTensor::new(tensor) } fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor { TchOps::swap_dims(tensor, dim1, dim2) } fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor { TchOps::permute(tensor, axes) } fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor { TchOps::flip(tensor, axes) } async fn bool_argwhere(tensor: TchTensor) -> TchTensor { TchTensor::new(tensor.tensor.argwhere()) } fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor { TchOps::index_select_dim(tensor, dim, indices) } fn bool_select_or( tensor: TchTensor, dim: usize, indices: TchTensor, value: TchTensor, ) -> TchTensor { TchOps::select_assign(tensor, dim, indices, value) } fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor { TchOps::expand(tensor, shape) } fn bool_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { TchOps::unfold(tensor, dim, size, step) } fn bool_mask_where( tensor: BoolTensor, mask: BoolTensor, value: BoolTensor, ) -> BoolTensor { TchTensor::binary_ops_tensor( tensor, value, |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), ) } fn bool_mask_fill( tensor: BoolTensor, mask: BoolTensor, value: Scalar, ) -> BoolTensor { tensor.unary_ops( |mut tensor| { tensor .f_masked_fill_(&mask.tensor, value.elem::()) .unwrap() }, |tensor| { tensor .f_masked_fill(&mask.tensor, value.elem::()) .unwrap() }, ) } fn bool_gather( dim: usize, tensor: BoolTensor, indices: IntTensor, ) -> BoolTensor { TchOps::gather(dim, tensor, indices) } fn bool_scatter_or( dim: usize, tensor: BoolTensor, indices: IntTensor, value: BoolTensor, ) -> BoolTensor { TchOps::scatter(dim, tensor, indices, value) } fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { TchOps::equal_elem(lhs, rhs.elem::()) } } ================================================ FILE: crates/burn-tch/src/ops/int_tensor.rs ================================================ use std::ops::Range; use burn_backend::{ Distribution, ExecutionError, IntDType, Scalar, Shape, TensorData, TensorMetadata, ops::{FloatTensorOps, IntTensorOps}, tensor::IntTensor, }; use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; use super::TchOps; impl IntTensorOps for LibTorch { fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { burn_backend::DType::I64 => TchTensor::from_data::(data, (*device).into()), _ => unimplemented!("Unsupported dtype for `int_from_data`"), } } fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { TchOps::repeat_dim(tensor, dim, times) } async fn int_into_data(tensor: TchTensor) -> Result { let shape = tensor.shape(); let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()])); let values: Result, tch::TchError> = tensor.tensor.shallow_clone().try_into(); Ok(TensorData::new(values.unwrap(), shape)) } fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor { TchOps::to_device(tensor, device) } fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor { TchOps::reshape(tensor, shape) } fn int_device(tensor: &TchTensor) -> LibTorchDevice { tensor.tensor.device().into() } fn int_empty(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor { let tensor = tch::Tensor::empty( TchShape::from(shape).dims, (dtype.into_kind(), (*device).into()), ); TchTensor::new(tensor) } fn int_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor { TchOps::slice_with_steps(tensor, slices) } fn int_slice_assign( tensor: TchTensor, slices: &[burn_backend::Slice], value: TchTensor, ) -> TchTensor { TchOps::slice_assign(tensor, slices, value) } fn int_cat(tensors: Vec, dim: usize) -> TchTensor { TchOps::cat(tensors, dim) } fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { let lhs = Self::int_into_float(lhs); let rhs = Self::int_into_float(rhs); let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap(); Self::float_into_int(TchTensor::new(out)) } fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::equal(lhs, rhs) } fn int_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::equal_elem(lhs, rhs.elem::()) } fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::greater(lhs, rhs) } fn int_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::greater_elem(lhs, rhs.elem::()) } fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::greater_equal(lhs, rhs) } fn int_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::greater_equal_elem(lhs, rhs.elem::()) } fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::lower(lhs, rhs) } fn int_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::lower_elem(lhs, rhs.elem::()) } fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::lower_equal(lhs, rhs) } fn int_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::lower_equal_elem(lhs, rhs.elem::()) } fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::add(lhs, rhs) } fn int_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.f_add_scalar_(rhs.elem::()).unwrap(), |tensor| tensor.f_add_scalar(rhs.elem::()).unwrap(), ) } fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::sub(lhs, rhs) } fn int_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.f_sub_scalar_(rhs.elem::()).unwrap(), |tensor| tensor.f_sub_scalar(rhs.elem::()).unwrap(), ) } fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::mul(lhs, rhs) } fn int_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { lhs.unary_ops( |mut tensor| tensor.f_mul_scalar_(rhs.elem::()).unwrap(), |tensor| tensor.f_mul_scalar(rhs.elem::()).unwrap(), ) } fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { let dtype = lhs.tensor.kind(); let copy = false; let non_blocking = true; let lhs: TchTensor = TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy)); let rhs: TchTensor = TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy)); let out = TchOps::div(lhs, rhs); TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy)) } fn int_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let dtype = lhs.tensor.kind(); let copy = false; let non_blocking = true; let lhs: TchTensor = TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy)); let out: TchTensor = lhs.unary_ops( |mut tensor| tensor.f_div_scalar_(rhs.elem::()).unwrap(), |tensor| tensor.f_div_scalar(rhs.elem::()).unwrap(), ); TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy)) } fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor { let dtype = lhs.tensor.kind(); let copy = false; let non_blocking = true; let lhs: TchTensor = TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy)); let rhs: TchTensor = TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy)); let out = TchOps::remainder(lhs, rhs); TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy)) } fn int_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { lhs.unary_ops( |tensor| tensor.f_remainder(rhs.elem::()).unwrap(), |tensor| tensor.f_remainder(rhs.elem::()).unwrap(), ) } fn int_zeros(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device))) } fn int_ones(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device))) } fn int_full( shape: Shape, fill_value: Scalar, device: &LibTorchDevice, dtype: IntDType, ) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); TchTensor::new(tch::Tensor::full( shape.dims, fill_value.elem::(), (dtype.into_kind(), device), )) } fn int_sum(tensor: TchTensor) -> TchTensor { TchOps::sum(tensor) } fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::sum_dim(tensor, dim) } fn int_prod(tensor: TchTensor) -> TchTensor { TchOps::prod(tensor) } fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::prod_dim(tensor, dim) } fn int_mean(tensor: TchTensor) -> TchTensor { let dtype = tensor.tensor.kind(); let tensor: TchTensor = TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor); TchTensor::new(output.tensor.to_dtype(dtype, true, false)) } fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { let dtype = tensor.tensor.kind(); let tensor: TchTensor = TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false)); let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor); TchTensor::new(output.tensor.to_dtype(dtype, true, false)) } fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cumsum(tensor, dim) } fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cumprod(tensor, dim) } fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cummin(tensor, dim) } fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cummax(tensor, dim) } fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor { TchOps::gather(dim, tensor, indices) } fn int_scatter_add( dim: usize, tensor: TchTensor, indices: TchTensor, value: TchTensor, ) -> TchTensor { TchOps::scatter(dim, tensor, indices, value) } fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor { TchOps::index_select_dim(tensor, dim, indices) } fn int_select_add( tensor: TchTensor, dim: usize, indices: TchTensor, value: TchTensor, ) -> TchTensor { TchOps::select_assign(tensor, dim, indices, value) } fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor { TchTensor::binary_ops_tensor( tensor, source, |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(), ) } fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor { let value = value.elem::(); tensor.unary_ops( |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), ) } fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::argmax(tensor, dim) } fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::argmin(tensor, dim) } fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::max_dim(tensor, dim) } fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { TchOps::max_dim_with_indices(tensor, dim) } fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::min_dim(tensor, dim) } fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { TchOps::min_dim_with_indices(tensor, dim) } fn int_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor { TchOps::clamp_min(tensor, min.elem::()) } fn int_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor { TchOps::clamp_max(tensor, max.elem::()) } fn int_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor { TchOps::clamp(tensor, min.elem::(), max.elem::()) } fn int_abs(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) } fn int_into_float(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.to_kind(E::kind()); TchTensor::new(tensor) } fn int_swap_dims(tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { TchOps::swap_dims(tensor, dim1, dim2) } fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor { match distribution { Distribution::Default => TchTensor::new(tch::Tensor::randint_low( 0, 255, shape.iter().map(|i| *i as i64).collect::>(), (tch::Kind::Int64, (*device).into()), )), Distribution::Bernoulli(prob) => { let mut tensor = TchTensor::empty::(shape, *device); tensor .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) .unwrap() } Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low( from as i64, to as i64, shape.iter().map(|i| *i as i64).collect::>(), (tch::Kind::Int64, (*device).into()), )), Distribution::Normal(mean, std) => { let mut tensor = TchTensor::empty::(shape, *device); tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() } } } fn int_arange(range: Range, device: &LibTorchDevice) -> TchTensor { let device: tch::Device = (*device).into(); let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device)); if range.start != 0 { tensor = tensor.f_add_scalar_(range.start).unwrap(); } TchTensor::new(tensor) } fn int_permute(tensor: IntTensor, axes: &[usize]) -> IntTensor { TchOps::permute(tensor, axes) } fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { TchOps::flip(tensor, axes) } fn int_sign(tensor: IntTensor) -> IntTensor { TchOps::sign(tensor) } fn int_expand(tensor: IntTensor, shape: Shape) -> IntTensor { TchOps::expand(tensor, shape) } fn int_sort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::sort(tensor, dim, descending) } fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { TchOps::bitwise_and(lhs, rhs) } fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { TchOps::bitwise_or(lhs, rhs) } fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { TchOps::bitwise_xor(lhs, rhs) } fn bitwise_not(tensor: IntTensor) -> IntTensor { TchOps::bitwise_not(tensor) } fn bitwise_and_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { TchOps::bitwise_and_scalar(lhs, rhs.elem::()) } fn bitwise_or_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { TchOps::bitwise_or_scalar(lhs, rhs.elem::()) } fn bitwise_xor_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { TchOps::bitwise_xor_scalar(lhs, rhs.elem::()) } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { TchOps::bitwise_left_shift(lhs, rhs) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { TchOps::bitwise_right_shift(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { TchOps::bitwise_left_shift_scalar(lhs, rhs.elem::()) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: Scalar) -> IntTensor { TchOps::bitwise_right_shift_scalar(lhs, rhs.elem::()) } fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc // Type promotion is not automatic on all backends so this behavior might differ let kind = dtype.into_kind(); if tensor.tensor.kind() == kind { tensor } else { TchTensor::new(tensor.tensor.to_kind(kind)) } } fn int_unfold( tensor: IntTensor, dim: usize, size: usize, step: usize, ) -> IntTensor { TchOps::unfold(tensor, dim, size, step) } } ================================================ FILE: crates/burn-tch/src/ops/mod.rs ================================================ mod activation; mod base; mod bool_tensor; mod int_tensor; mod module; mod qtensor; mod tensor; mod transaction; pub(crate) use base::*; ================================================ FILE: crates/burn-tch/src/ops/module.rs ================================================ use crate::{LibTorch, TchTensor, element::TchElement}; use burn_backend::{ TensorMetadata, ops::{ AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback, }, }; impl ModuleOps for LibTorch { fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor { // Workaround for MPS "Placeholder storage has not been allocated" error. // See: https://github.com/pytorch/pytorch/issues/123995 // MPS uses lazy allocation and the embedding operation (which uses index_select) // can fail if the tensors haven't been materialized yet. // We work around this by performing the embedding on CPU and transferring back to MPS. if matches!(weights.tensor.device(), tch::Device::Mps) { let cpu_weights = weights.tensor.to(tch::Device::Cpu); let cpu_indices = indices.tensor.to(tch::Device::Cpu); let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false) .to(tch::Device::Mps); return TchTensor::new(result); } let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false); TchTensor::new(tensor) } fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor { let [n_embedding, _d_model] = weights.shape().dims(); // Workaround for MPS "Placeholder storage has not been allocated" error. // See: https://github.com/pytorch/pytorch/issues/123995 if matches!(output.tensor.device(), tch::Device::Mps) { let cpu_output = output.tensor.to(tch::Device::Cpu); let cpu_indices = indices.tensor.to(tch::Device::Cpu); let result = tch::Tensor::embedding_backward( &cpu_output, &cpu_indices, n_embedding as i64, -1, false, false, ) .to(tch::Device::Mps); return TchTensor::new(result); } let tensor = tch::Tensor::embedding_backward( &output.tensor, &indices.tensor, n_embedding as i64, -1, false, false, ); TchTensor::new(tensor) } fn conv1d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvOptions<1>, ) -> TchTensor { let tensor = tch::Tensor::conv1d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.dilation.map(|i| i as i64), options.groups as i64, ); TchTensor::new(tensor) } fn conv2d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvOptions<2>, ) -> TchTensor { let tensor = tch::Tensor::conv2d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.dilation.map(|i| i as i64), options.groups as i64, ); TchTensor::new(tensor) } fn conv3d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvOptions<3>, ) -> TchTensor { let tensor = tch::Tensor::conv3d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.dilation.map(|i| i as i64), options.groups as i64, ); TchTensor::new(tensor) } fn deform_conv2d( _x: TchTensor, _offset: TchTensor, _weight: TchTensor, _mask: Option, _bias: Option, _options: DeformConvOptions<2>, ) -> TchTensor { unimplemented!("Torch bindings don't support deform_conv2d"); } fn deform_conv2d_backward( _x: TchTensor, _offset: TchTensor, _weight: TchTensor, _mask: Option, _bias: Option, _out_grad: TchTensor, _options: DeformConvOptions<2>, ) -> DeformConv2dBackward { unimplemented!("Torch bindings don't support deform_conv2d"); } fn conv_transpose1d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvTransposeOptions<1>, ) -> TchTensor { let tensor = tch::Tensor::conv_transpose1d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.padding_out.map(|i| i as i64), options.groups as i64, options.dilation.map(|i| i as i64), ); TchTensor::new(tensor) } fn conv_transpose2d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvTransposeOptions<2>, ) -> TchTensor { let tensor = tch::Tensor::conv_transpose2d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.padding_out.map(|i| i as i64), options.groups as i64, options.dilation.map(|i| i as i64), ); TchTensor::new(tensor) } fn conv_transpose3d( x: TchTensor, weight: TchTensor, bias: Option, options: ConvTransposeOptions<3>, ) -> TchTensor { let tensor = tch::Tensor::conv_transpose3d( &x.tensor, &weight.tensor, bias.map(|t| t.tensor), options.stride.map(|i| i as i64), options.padding.map(|i| i as i64), options.padding_out.map(|i| i as i64), options.groups as i64, options.dilation.map(|i| i as i64), ); TchTensor::new(tensor) } fn avg_pool1d( x: TchTensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> TchTensor { let tensor = tch::Tensor::avg_pool1d( &x.tensor, [kernel_size as i64], [stride as i64], [padding as i64], ceil_mode, count_include_pad, ); TchTensor::new(tensor) } fn avg_pool2d( x: TchTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> TchTensor { let tensor = tch::Tensor::avg_pool2d( &x.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], [stride[0] as i64, stride[1] as i64], [padding[0] as i64, padding[1] as i64], ceil_mode, count_include_pad, None, ); TchTensor::new(tensor) } fn avg_pool2d_backward( x: TchTensor, grad: TchTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> TchTensor { let tensor = tch::Tensor::avg_pool2d_backward( &x.tensor, &grad.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], [stride[0] as i64, stride[1] as i64], [padding[0] as i64, padding[1] as i64], ceil_mode, count_include_pad, None, ); TchTensor::new(tensor) } fn max_pool1d( x: TchTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> TchTensor { let tensor = tch::Tensor::max_pool1d( &x.tensor, kernel_size as i64, stride as i64, padding as i64, dilation as i64, ceil_mode, ); TchTensor::new(tensor) } fn max_pool1d_with_indices( x: TchTensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> MaxPool1dWithIndices { let (tensor, indices) = tch::Tensor::max_pool1d_with_indices( &x.tensor, kernel_size as i64, stride as i64, padding as i64, dilation as i64, ceil_mode, ); MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) } fn max_pool2d( x: TchTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> TchTensor { let tensor = tch::Tensor::max_pool2d( &x.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], [stride[0] as i64, stride[1] as i64], [padding[0] as i64, padding[1] as i64], [dilation[0] as i64, dilation[1] as i64], ceil_mode, ); TchTensor::new(tensor) } fn max_pool2d_with_indices( x: TchTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> MaxPool2dWithIndices { let (tensor, indices) = tch::Tensor::max_pool2d_with_indices( &x.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], [stride[0] as i64, stride[1] as i64], [padding[0] as i64, padding[1] as i64], [dilation[0] as i64, dilation[1] as i64], ceil_mode, ); MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) } fn max_pool2d_with_indices_backward( x: TchTensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, output_grad: TchTensor, indices: TchTensor, ) -> MaxPool2dBackward { let grad = tch::Tensor::max_pool2d_with_indices_backward( &x.tensor, &output_grad.tensor, [kernel_size[0] as i64, kernel_size[1] as i64], [stride[0] as i64, stride[1] as i64], [padding[0] as i64, padding[1] as i64], [dilation[0] as i64, dilation[1] as i64], ceil_mode, &indices.tensor, ); MaxPool2dBackward::new(TchTensor::new(grad)) } fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); TchTensor::new(tensor) } fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor); TchTensor::new(tensor) } fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor { let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64); TchTensor::new(tensor) } fn interpolate( x: TchTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> TchTensor { let output_size = output_size.map(|e| e as i64); let align_corners = options.align_corners; let tensor = match options.mode { InterpolateMode::Nearest => { tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None) } InterpolateMode::Bilinear => { tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None) } InterpolateMode::Bicubic => { tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None) } InterpolateMode::Lanczos3 => { panic!("lanczos3 interpolation is not supported by PyTorch/tch backend") } }; TchTensor::new(tensor) } fn interpolate_backward( x: TchTensor, grad: TchTensor, output_size: [usize; 2], options: InterpolateOptions, ) -> TchTensor { let output_size = output_size.map(|e| e as i64); let [n, c, h_in, w_in] = x.shape().dims(); let input_size = [n as i64, c as i64, h_in as i64, w_in as i64]; let align_corners = options.align_corners; let tensor = match options.mode { InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward( &grad.tensor, output_size, input_size, None, None, ), InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward( &grad.tensor, output_size, input_size, align_corners, None, None, ), InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward( &grad.tensor, output_size, input_size, align_corners, None, None, ), InterpolateMode::Lanczos3 => { panic!("lanczos3 interpolation backward is not supported by PyTorch/tch backend") } }; TchTensor::new(tensor) } fn attention( query: TchTensor, key: TchTensor, value: TchTensor, mask: Option, attn_bias: Option, options: AttentionModuleOptions, ) -> TchTensor { if attn_bias.is_some() { return attention_fallback::(query, key, value, mask, attn_bias, options); } TchTensor::new(tch::Tensor::scaled_dot_product_attention( &query.tensor, &key.tensor, &value.tensor, mask.map(|m| m.tensor), 0., options.is_causal, options.scale, false, )) } } ================================================ FILE: crates/burn-tch/src/ops/qtensor.rs ================================================ use burn_backend::{ ExecutionError, Shape, TensorData, ops::QTensorOps, quantization::{QuantScheme, QuantizationParametersPrimitive}, tensor::{Device, FloatTensor, IntTensor, QuantizedTensor}, }; use crate::{LibTorch, LibTorchDevice, TchElement}; impl QTensorOps for LibTorch { fn q_from_data(_data: TensorData, _device: &LibTorchDevice) -> QuantizedTensor { unimplemented!() } fn quantize( _tensor: FloatTensor, _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { unimplemented!() } fn quantize_dynamic( _tensor: FloatTensor, _scheme: &QuantScheme, ) -> QuantizedTensor { unimplemented!() } fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { unimplemented!() } fn q_device(_tensor: &QuantizedTensor) -> LibTorchDevice { unimplemented!() } fn q_to_device( _tensor: QuantizedTensor, _device: &Device, ) -> QuantizedTensor { unimplemented!() } fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } async fn q_into_data(_tensor: QuantizedTensor) -> Result { unimplemented!() } fn q_swap_dims( _tensor: QuantizedTensor, _dim1: usize, _dim2: usize, ) -> QuantizedTensor { unimplemented!() } fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { unimplemented!() } fn q_select( _tensor: QuantizedTensor, _dim: usize, _indices: IntTensor, ) -> QuantizedTensor { unimplemented!() } fn q_slice( _tensor: QuantizedTensor, _slices: &[burn_backend::Slice], ) -> QuantizedTensor { unimplemented!() } fn q_argmax(_tensor: QuantizedTensor, _dim: usize) -> IntTensor { unimplemented!() } fn q_argmin(_tensor: QuantizedTensor, _dim: usize) -> IntTensor { unimplemented!() } fn q_max_dim_with_indices( _tensor: QuantizedTensor, _dim: usize, ) -> (QuantizedTensor, IntTensor) { unimplemented!() } fn q_max_dim(_tensor: QuantizedTensor, _dim: usize) -> QuantizedTensor { unimplemented!() } fn q_min_dim(_tensor: QuantizedTensor, _dim: usize) -> QuantizedTensor { unimplemented!() } fn q_min_dim_with_indices( _tensor: QuantizedTensor, _dim: usize, ) -> (QuantizedTensor, IntTensor) { unimplemented!() } fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { unimplemented!() } fn q_sort( _tensor: QuantizedTensor, _dim: usize, _descending: bool, ) -> QuantizedTensor { unimplemented!() } fn q_sort_with_indices( _tensor: QuantizedTensor, _dim: usize, _descending: bool, ) -> (QuantizedTensor, IntTensor) { unimplemented!() } fn q_argsort( _tensor: QuantizedTensor, _dim: usize, _descending: bool, ) -> IntTensor { unimplemented!() } } ================================================ FILE: crates/burn-tch/src/ops/tensor.rs ================================================ use super::TchOps; use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement}; use burn_backend::backend::ExecutionError; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor}; use burn_backend::{ DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps, }; use burn_backend::{Scalar, bf16, f16}; impl FloatTensorOps for LibTorch { fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor { match data.dtype { DType::F64 => TchTensor::from_data::(data, (*device).into()), DType::F32 => TchTensor::from_data::(data, (*device).into()), DType::F16 => TchTensor::from_data::(data, (*device).into()), DType::BF16 => TchTensor::from_data::(data, (*device).into()), _ => unimplemented!("Unsupported dtype for `float_from_data`"), } } fn float_random( shape: Shape, distribution: Distribution, device: &LibTorchDevice, ) -> TchTensor { match distribution { Distribution::Default => { let mut tensor = TchTensor::empty::(shape, *device); tensor .mut_ops(|tensor| tensor.rand_like_out(tensor)) .unwrap() } Distribution::Bernoulli(prob) => { let mut tensor = TchTensor::empty::(shape, *device); tensor .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap()) .unwrap() } Distribution::Uniform(from, to) => { let mut tensor = TchTensor::empty::(shape, *device); tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap() } Distribution::Normal(mean, std) => { let mut tensor = TchTensor::empty::(shape, *device); tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap() } } } fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor { TchOps::repeat_dim(tensor, dim, times) } fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device))) } fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor { let shape = TchShape::from(shape); let device: tch::Device = (*device).into(); TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device))) } async fn float_into_data(tensor: TchTensor) -> Result { let shape = tensor.shape(); let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()])); Ok(match tensor.tensor.kind() { tch::Kind::Half => { let values = Vec::::try_from(&tensor).unwrap(); TensorData::new(values, shape) } tch::Kind::Float => { let values = Vec::::try_from(&tensor).unwrap(); TensorData::new(values, shape) } tch::Kind::Double => { let values = Vec::::try_from(&tensor).unwrap(); TensorData::new(values, shape) } tch::Kind::BFloat16 => { let values = Vec::::try_from(&tensor).unwrap(); TensorData::new(values, shape) } _ => panic!("Not a valid float kind"), }) } fn float_device(tensor: &TchTensor) -> LibTorchDevice { tensor.tensor.device().into() } fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor { TchOps::to_device(tensor, device) } fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor { let tensor = tch::Tensor::empty( TchShape::from(shape).dims, (dtype.into_kind(), (*device).into()), ); TchTensor::new(tensor) } fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::add(lhs, rhs) } fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( |mut tensor| tensor.f_add_scalar_(rhs).unwrap(), |tensor| tensor.f_add_scalar(rhs).unwrap(), ) } fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::sub(lhs, rhs) } fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(), |tensor| tensor.f_sub_scalar(rhs).unwrap(), ) } fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::mul(lhs, rhs) } fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(), |tensor| tensor.f_mul_scalar(rhs).unwrap(), ) } fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::div(lhs, rhs) } fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( |mut tensor| tensor.f_div_scalar_(rhs).unwrap(), |tensor| tensor.f_div_scalar(rhs).unwrap(), ) } fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::remainder(lhs, rhs) } fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor { let rhs: f64 = rhs.elem(); lhs.unary_ops( |tensor| tensor.f_remainder(rhs).unwrap(), |tensor| tensor.f_remainder(rhs).unwrap(), ) } fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor { let tensor = lhs.tensor.matmul(&rhs.tensor); TchTensor::new(tensor) } fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor { let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64); TchTensor::new(tensor) } fn float_recip(tensor: TchTensor) -> TchTensor { TchTensor::new(tensor.tensor.reciprocal()) } fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor { TchOps::swap_dims(tensor, dim1, dim2) } fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor { TchOps::reshape(tensor, shape) } fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor { TchOps::gather(dim, tensor, indices) } fn float_scatter_add( dim: usize, tensor: TchTensor, indices: TchTensor, value: TchTensor, ) -> TchTensor { TchOps::scatter(dim, tensor, indices, value) } fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor { TchOps::index_select_dim(tensor, dim, indices) } fn float_select_add( tensor: TchTensor, dim: usize, indices: TchTensor, value: TchTensor, ) -> TchTensor { TchOps::select_assign(tensor, dim, indices, value) } fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor { TchOps::slice_with_steps(tensor, slices) } fn float_slice_assign( tensor: TchTensor, slices: &[burn_backend::Slice], value: TchTensor, ) -> TchTensor { TchOps::slice_assign(tensor, slices, value) } fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor { let output = value.tensor.where_self(&mask.tensor, &tensor.tensor); TchTensor::new(output) } fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor { let value: f64 = value.elem(); tensor.unary_ops( |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(), |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(), ) } fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::equal(lhs, rhs) } fn float_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::equal_elem(lhs, rhs.elem::()) } fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::greater(lhs, rhs) } fn float_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::greater_elem(lhs, rhs.elem::()) } fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::greater_equal(lhs, rhs) } fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::greater_equal_elem(lhs, rhs.elem::()) } fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::lower(lhs, rhs) } fn float_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::lower_elem(lhs, rhs.elem::()) } fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::lower_equal(lhs, rhs) } fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor { TchOps::lower_equal_elem(lhs, rhs.elem::()) } fn float_mean(tensor: TchTensor) -> TchTensor { TchOps::mean(tensor) } fn float_sum(tensor: TchTensor) -> TchTensor { TchOps::sum(tensor) } fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::sum_dim(tensor, dim) } fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::mean_dim(tensor, dim) } fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cumsum(tensor, dim) } fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cumprod(tensor, dim) } fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cummin(tensor, dim) } fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::cummax(tensor, dim) } fn float_prod(tensor: TchTensor) -> TchTensor { TchOps::prod(tensor) } fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::prod_dim(tensor, dim) } fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::argmax(tensor, dim) } fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::argmin(tensor, dim) } fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::max_dim(tensor, dim) } fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { TchOps::max_dim_with_indices(tensor, dim) } fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::min_dim(tensor, dim) } fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) { TchOps::min_dim_with_indices(tensor, dim) } fn float_exp(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp()) } fn float_log(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log()) } fn float_log1p(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p()) } fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor { tensor.unary_ops( |mut tensor| tensor.f_pow_(value.elem::()).unwrap(), |tensor| tensor.pow_tensor_scalar(value.elem::()), ) } fn float_sqrt(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt()) } fn float_abs(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs()) } fn float_cos(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos()) } fn float_cosh(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh()) } fn float_sin(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin()) } fn float_sinh(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh()) } fn float_tan(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan()) } fn float_tanh(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) } fn float_acos(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos()) } fn float_acosh(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh()) } fn float_asin(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin()) } fn float_asinh(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh()) } fn float_atan(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan()) } fn float_atanh(tensor: FloatTensor) -> FloatTensor { tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh()) } fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { TchOps::atan2(lhs, rhs) } fn float_round(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round()) } fn float_floor(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor()) } fn float_ceil(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil()) } fn float_trunc(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc()) } fn float_erf(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) } fn float_cat(tensors: Vec, dim: usize) -> TchTensor { TchOps::cat(tensors, dim) } fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor { TchOps::clamp_min(tensor, min.elem::()) } fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor { TchOps::clamp_max(tensor, max.elem::()) } fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor { TchOps::clamp(tensor, min.elem::(), max.elem::()) } fn float_into_int(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.to_kind(tch::Kind::Int64); TchTensor::new(tensor) } fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor { TchOps::powf(lhs, rhs) } fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor { TchOps::permute(tensor, axes) } fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor { TchOps::flip(tensor, axes) } fn float_sign(tensor: TchTensor) -> TchTensor { TchOps::sign(tensor) } fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor { TchOps::expand(tensor, shape) } fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchOps::sort(tensor, dim, descending) } fn float_sort_with_indices( tensor: TchTensor, dim: usize, descending: bool, ) -> (TchTensor, TchTensor) { TchOps::sort_with_indices(tensor, dim, descending) } fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor { // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc // Type promotion is not automatic on all backends so this behavior might differ let kind = dtype.into_kind(); if tensor.tensor.kind() == kind { tensor } else { TchTensor::new(tensor.tensor.to_kind(kind)) } } fn float_unfold( tensor: FloatTensor, dim: usize, size: usize, step: usize, ) -> FloatTensor { TchOps::unfold(tensor, dim, size, step) } fn float_is_nan(tensor: FloatTensor) -> BoolTensor { TchTensor::new(tensor.tensor.isnan()) } fn float_is_inf(tensor: FloatTensor) -> BoolTensor { TchTensor::new(tensor.tensor.isinf()) } } ================================================ FILE: crates/burn-tch/src/ops/transaction.rs ================================================ use burn_backend::ops::TransactionOps; use crate::{LibTorch, TchElement}; impl TransactionOps for LibTorch {} ================================================ FILE: crates/burn-tch/src/tensor.rs ================================================ use crate::{LibTorchDevice, TchElement}; use burn_backend::{BoolStore, DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata}; use libc::c_void; use std::sync::Arc; /// A reference to a tensor storage. /// /// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe. #[allow(clippy::arc_with_non_send_sync)] pub type StorageRef = Arc<*mut c_void>; /// A reference to a tensor storage. #[derive(PartialEq, Debug, Clone)] pub enum Storage { /// When a tensor is a partial view of another tensor. View { /// Storage reference for the whole buffer. buffer_ref: StorageRef, /// Storage reference for the partial buffer. view_ref: StorageRef, }, /// When a tensor use all of its buffer. Owned { /// Storage reference for the whole buffer. buffer_ref: StorageRef, }, } impl Storage { /// Check if the storage can be used inplace. pub fn can_mut(&self) -> bool { match self { Storage::View { buffer_ref: start_ref, view_ref, } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1, Storage::Owned { buffer_ref: start_ref, } => Arc::strong_count(start_ref) == 1, } } /// Get the whole buffer reference. pub fn buffer_ref(&self) -> &StorageRef { match self { Storage::View { buffer_ref: start_ref, view_ref: _, } => start_ref, Storage::Owned { buffer_ref: start_ref, } => start_ref, } } } /// A tensor using the tch backend. #[derive(Debug, PartialEq)] pub struct TchTensor { /// Handle to the tensor. Call methods on this field. pub tensor: tch::Tensor, /// The tensor's storage pub storage: Storage, } impl TensorMetadata for TchTensor { fn dtype(&self) -> DType { match self.tensor.kind() { tch::Kind::Uint8 => DType::U8, tch::Kind::Int8 => DType::I8, tch::Kind::Int16 => DType::I16, tch::Kind::Int => DType::I32, tch::Kind::Int64 => DType::I64, tch::Kind::Half => DType::F16, tch::Kind::Float => DType::F32, tch::Kind::Double => DType::F64, tch::Kind::Bool => DType::Bool(BoolStore::Native), tch::Kind::BFloat16 => DType::BF16, // Complex and quantization types are not valid/implemented. _ => unimplemented!(), } } fn shape(&self) -> Shape { Shape::from(self.tensor.size()) } fn rank(&self) -> usize { self.tensor.dim() } } impl burn_backend::QTensorPrimitive for TchTensor { fn scheme(&self) -> &burn_backend::quantization::QuantScheme { unimplemented!("Quantization is not supported") } } impl core::fmt::Display for TchTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.tensor) } } pub(crate) trait IntoKind { fn try_into_kind(self) -> Result; fn into_kind(self) -> tch::Kind where Self: Sized, { self.try_into_kind().unwrap() } } impl IntoKind for IntDType { fn try_into_kind(self) -> Result { let dtype: DType = self.into(); dtype.try_into_kind() } } impl IntoKind for FloatDType { fn try_into_kind(self) -> Result { let dtype: DType = self.into(); dtype.try_into_kind() } } impl IntoKind for DType { fn try_into_kind(self) -> Result { match self { DType::F64 => Ok(tch::Kind::Double), DType::F32 => Ok(tch::Kind::Float), DType::Flex32 => Ok(tch::Kind::Float), DType::F16 => Ok(tch::Kind::Half), DType::BF16 => Ok(tch::Kind::BFloat16), DType::I64 => Ok(tch::Kind::Int64), DType::I32 => Ok(tch::Kind::Int), DType::I16 => Ok(tch::Kind::Int16), DType::I8 => Ok(tch::Kind::Int8), DType::U8 => Ok(tch::Kind::Uint8), DType::Bool(BoolStore::Native) => Ok(tch::Kind::Bool), other => Err(tch::TchError::Kind(format!("Unsupported dtype {other:?}"))), } } } impl TchTensor { /// Create a new tensor. /// /// Note that if the tensor was created from an operation that may reuse the same tensor /// storage as the parent, you should use [from_existing](TchTensor::from_existing) /// instead. pub fn new(tensor: tch::Tensor) -> Self { #[allow(clippy::arc_with_non_send_sync)] let storage = Storage::Owned { buffer_ref: Arc::new(tensor.data_ptr()), }; Self { tensor, storage } } /// Create a tensor that was created from an operation executed on a parent tensor. /// /// If the child tensor shared the same storage as its parent, it will be cloned, effectively /// tracking how much tensors point to the same memory space. pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self { let storage_child = tensor.data_ptr(); let mut is_a_new_tensor = true; match &storage_parent { Storage::View { buffer_ref: start_ref, view_ref, } => { if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() { is_a_new_tensor = false; } } Storage::Owned { buffer_ref: start_ref, } => { if storage_child == *start_ref.as_ref() { is_a_new_tensor = false; } } }; let storage = match is_a_new_tensor { true => Storage::Owned { #[allow(clippy::arc_with_non_send_sync)] buffer_ref: Arc::new(storage_child), }, false => storage_parent.clone(), }; Self { tensor, storage } } /// Create a tensor that uses a part of its parent tensor such as slice and narrow. pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self { let storage = Storage::View { buffer_ref: storage_parent.buffer_ref().clone(), #[allow(clippy::arc_with_non_send_sync)] view_ref: Arc::new(tensor.data_ptr()), }; Self { tensor, storage } } } // This is safe since we don't use autodiff from LibTorch. // Also, atomic reference counting is used to know if the tensor's data can be reused. // If there are multiple reference on the same tensor, it becomes read only. unsafe impl Send for TchTensor {} unsafe impl Sync for TchTensor {} impl TchTensor { /// Checks if the tensor can be mutated in-place. /// /// Returns `true` if the tensor's stride does not contain zero (no broadcasting) /// and the storage can be mutated. pub fn can_mut(&self) -> bool { let stride_contains_zero = self.tensor.stride().contains(&0); !stride_contains_zero && self.storage.can_mut() } /// Executes an operation on a tensor if the data can be reused. pub fn mut_ops tch::Tensor>( &mut self, func: F, ) -> Option { if !self.can_mut() { return None; } let data = self.storage.clone(); Some(TchTensor::from_existing(func(&mut self.tensor), data)) } /// Executes a unary operation, reusing the tensor data if possible. pub fn unary_ops(self, fown: FOwn, fref: FRef) -> TchTensor where FOwn: Fn(tch::Tensor) -> tch::Tensor, FRef: Fn(&tch::Tensor) -> tch::Tensor, { if !self.can_mut() { return TchTensor::from_existing(fref(&self.tensor), self.storage); } TchTensor::from_existing(fown(self.tensor), self.storage) } /// Executes a binary operation, reusing the tensor data if possible. pub fn binary_ops_tensor( mut lhs: Self, mut rhs: Self, flmut: FLMut, frmut: FRMut, fref: FRef, ) -> TchTensor where FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor, FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor, FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor, { let lhs_shape = lhs.shape(); let rhs_shape = rhs.shape(); // Both lhs and rhs are expected to have the same rank let d_out = lhs_shape.num_dims(); let mut out_shape = Shape::from(vec![1usize; d_out]); for i in 0..d_out { out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]); } let num_elements_out = out_shape.num_elements(); // Attempt to mutate lhs tensor if lhs_shape.num_elements() == num_elements_out && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) { return output; } // Attempt to mutate rhs tensor if rhs_shape.num_elements() == num_elements_out && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) { return output; } let storage = lhs.storage; let tensor = fref(&lhs.tensor, &rhs.tensor); TchTensor::from_existing(tensor, storage) } } impl Clone for TchTensor { fn clone(&self) -> Self { Self { tensor: self.tensor.shallow_clone(), storage: self.storage.clone(), } } } /// A shape that can be used by LibTorch. #[derive(Debug)] pub struct TchShape { /// The shape's dimensions. pub dims: Vec, } impl From for TchShape { fn from(shape: Shape) -> Self { TchShape { dims: shape.iter().map(|d| *d as i64).collect(), } } } impl From<&[usize]> for TchShape { fn from(shape: &[usize]) -> Self { TchShape { dims: shape.iter().map(|d| *d as i64).collect(), } } } impl TchTensor { /// Creates a new tensor from a shape and a device. /// /// # Arguments /// /// * `data` - The tensor's data. /// * `device` - The device on which the tensor will be allocated. /// /// # Returns /// /// A new tensor. pub fn from_data(data: TensorData, device: tch::Device) -> Self { let shape_tch = TchShape::from(data.shape.as_slice()); let tensor = tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device); Self::new(tensor) } } impl TchTensor { /// Creates an empty tensor from a shape and a device. /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// A new empty tensor. pub fn empty(shape: Shape, device: LibTorchDevice) -> Self { let shape_tch = TchShape::from(shape); let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into())); Self::new(tensor) } } // Adapted from `tch` to use patched `T::kind()` instead of `T::KIND` which is incorrect for bf16. // TODO: remove when fixed in `tch` release (https://github.com/LaurentMazare/tch-rs/pull/996). impl TryFrom<&TchTensor> for Vec { type Error = tch::TchError; fn try_from(tensor: &TchTensor) -> Result { let tensor = &tensor.tensor; let size = tensor.size(); if size.len() != 1 { Err(tch::TchError::Convert(format!( "Attempting to convert a Tensor with {} dimensions to flat vector", size.len() )))?; } let numel = size[0] as usize; let mut vec = vec![T::ZERO; numel]; // Adapted to use patched `T::kind()` instead // TODO: tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?; f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?; Ok(vec) } } unsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option { if !ptr.is_null() { unsafe { let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned(); libc::free(ptr as *mut libc::c_void); Some(str) } } else { None } } /// Copies `numel` elements from `self` to `dst`. fn f_copy_data( tensor: &mut tch::Tensor, dst: &mut [T], numel: usize, ) -> Result<(), tch::TchError> { if T::kind() != tensor.f_kind()? { return Err(tch::TchError::Kind(format!( "incoherent elt kind, {:?} != {:?}", tensor.f_kind(), T::kind() ))); } if dst.len() < numel { return Err(tch::TchError::Shape(format!("slice len < {numel}"))); } unsafe { torch_sys::at_copy_data( tensor.as_mut_ptr(), dst.as_mut_ptr() as *const c_void, numel, T::kind().elt_size_in_bytes(), ); match ptr_to_string(torch_sys::get_and_reset_last_err()) { None => Ok(()), Some(c_error) => Err(tch::TchError::Torch(c_error)), } } } #[cfg(test)] mod tests { use super::*; use burn_backend::ops::FloatTensorOps; use burn_backend::{Backend, quantization::QuantScheme, read_sync}; type B = crate::LibTorch; #[test] fn should_have_bf16_kind() { let data = TensorData::from([4.0, 4.0]); let tensor_1: TchTensor = B::float_from_data(data, &Default::default()); let tensor_2 = B::float_cast(tensor_1, DType::BF16.into()); assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16); let out = read_sync(B::float_into_data(tensor_2)).unwrap(); out.assert_eq(&TensorData::from([4.0, 4.0]), false); } #[test] fn should_support_dtypes() { let device = Default::default(); assert!(B::supports_dtype(&device, DType::F64)); assert!(B::supports_dtype(&device, DType::F32)); assert!(B::supports_dtype(&device, DType::Flex32)); assert!(B::supports_dtype(&device, DType::F16)); assert!(B::supports_dtype(&device, DType::BF16)); assert!(B::supports_dtype(&device, DType::I64)); assert!(B::supports_dtype(&device, DType::I32)); assert!(B::supports_dtype(&device, DType::I16)); assert!(B::supports_dtype(&device, DType::I8)); assert!(B::supports_dtype(&device, DType::U8)); assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); assert!(!B::supports_dtype(&device, DType::U64)); assert!(!B::supports_dtype(&device, DType::U32)); assert!(!B::supports_dtype(&device, DType::U16)); assert!(!B::supports_dtype( &device, DType::QFloat(QuantScheme::default()) )); } #[test] fn should_support_from_bf16() { let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16); let tensor_1: TchTensor = B::float_from_data(data, &Default::default()); let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16); let tensor_2 = B::float_from_data(data, &Default::default()); let tensor_3 = B::float_add(tensor_1, tensor_2); assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16); let out = read_sync(B::float_into_data(tensor_3)).unwrap(); out.assert_eq(&TensorData::from([[3.0], [3.0]]), false); } } unsafe extern "C" { /// Dummy function to get CUDA to link properly pub fn dummy_cuda_dependency(); } #[used] static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency]; ================================================ FILE: crates/burn-tensor/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Tensor library with user-friendly APIs and automatic differentiation support" documentation = "https://docs.rs/burn-tensor" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-tensor" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor" version.workspace = true [lints] workspace = true [features] default = ["std"] doc = ["default"] std = [ "num-traits/std", "burn-std/std", "burn-backend/std", "colored", ] tracing = [ "burn-std/tracing", "burn-backend/tracing", ] cubecl = ["burn-std/cubecl", "burn-backend/cubecl"] cubecl-cuda = ["burn-backend/cubecl-cuda"] cubecl-hip = ["burn-backend/cubecl-hip"] cubecl-wgpu = ["burn-backend/cubecl-wgpu"] cubecl-cpu = ["burn-backend/cubecl-cpu"] [dependencies] burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", default-features = false } burn-backend = { path = "../burn-backend", version = "=0.21.0-pre.2", default-features = false } colored = { workspace = true, optional = true } derive-new = { workspace = true } num-traits = { workspace = true } # Device hashbrown = { workspace = true } spin = { workspace = true } thiserror = { workspace = true } # Serialization serde = { workspace = true } [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic-util = { workspace = true } [dev-dependencies] serial_test = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] ================================================ FILE: crates/burn-tensor/README.md ================================================ # Burn Tensor > [Burn](https://github.com/tracel-ai/burn) Tensor Library [![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor.svg)](https://crates.io/crates/burn-tensor) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tensor/blob/master/README.md) This library provides the core abstractions required to run tensor operations with Burn. `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations. Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait. ================================================ FILE: crates/burn-tensor/src/device.rs ================================================ use alloc::format; use alloc::string::String; use burn_backend::{Backend, Device, DeviceId, DeviceOps}; use burn_std::stub::RwLock; use burn_std::{DType, FloatDType, IntDType}; #[cfg(target_has_atomic = "ptr")] use alloc::sync::Arc; #[cfg(not(target_has_atomic = "ptr"))] use portable_atomic_util::Arc; use thiserror::Error; use core::any::TypeId; #[cfg(feature = "std")] pub use std::collections::HashMap; #[cfg(feature = "std")] use std::sync::LazyLock; #[cfg(not(feature = "std"))] pub use hashbrown::HashMap; #[cfg(not(feature = "std"))] use spin::Lazy as LazyLock; /// Policy controlling default device behavior. /// /// This includes default data types used for tensor creation. #[derive(Debug, Clone, Copy, Default)] pub(crate) struct DevicePolicy { /// Default floating-point data type for tensor creation. float_dtype: Option, /// Default integer data type for tensor creation. int_dtype: Option, } impl DevicePolicy { /// Returns the default floating-point data type used for tensor creation. pub(crate) fn float_dtype(&self) -> Option { self.float_dtype } /// Returns the default integer data type used for tensor creation. pub(crate) fn int_dtype(&self) -> Option { self.int_dtype } /// Sets the default floating-point data type. pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) { self.float_dtype = Some(dtype); } /// Sets the default integer data type. pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) { self.int_dtype = Some(dtype); } } /// Key for the registry: physical device type + device id type RegistryKey = (DeviceId, TypeId); /// Global registry mapping devices to their policies. static REGISTRY: LazyLock>>> = LazyLock::new(|| RwLock::new(HashMap::new())); /// Device policy management for controlling default tensor creation behavior. /// /// # Policy Semantics /// /// Device policies use snapshot semantics: when you retrieve a policy with /// [`get_device_policy`], you get an immutable snapshot of the current configuration. /// Updates to the policy (via [`set_default_dtypes`], [`set_default_float_dtype`], etc.) /// only affect future policy retrievals, not existing references. /// /// This is intended for the common case where policies are set once during /// initialization and then read frequently during tensor creation. struct DevicePolicyRegistry; impl DevicePolicyRegistry { /// Get the policy for a physical device type and device id. /// /// If no policy exists yet, a default one is created and stored. fn get(device: &D) -> Arc { let key = Self::key(device); if let Some(policy) = REGISTRY.read().unwrap().get(&key) { return Arc::clone(policy); } let mut map = REGISTRY.write().unwrap(); Arc::clone( map.entry(key) .or_insert_with(|| Arc::new(DevicePolicy::default())), ) } /// Mutate the policy for a given device. fn update(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) { let key = Self::key(device); let mut map = REGISTRY.write().unwrap(); let policy = map .entry(key) .or_insert_with(|| Arc::new(DevicePolicy::default())); // Update the policy let policy_mut = Arc::make_mut(policy); update_fn(policy_mut); } /// Returns the device registry key. fn key(device: &D) -> RegistryKey { (device.to_id(), TypeId::of::()) } } /// Get the [`device`'s policy](DevicePolicy). /// /// Returns an immutable snapshot of the device's current policy. If the policy /// is updated after retrieval, this snapshot will not reflect those changes. pub(crate) fn get_device_policy(device: &D) -> Arc { DevicePolicyRegistry::get(device) } /// Errors that can occur during device-related operations. /// /// This covers errors related to hardware capability mismatches, such as /// requesting a data type not supported by the device, and configuration /// errors like attempting to change a policy in an invalid context. #[derive(Debug, Error)] pub enum DeviceError { /// Unsupported data type by the device. #[error("Device {device} does not support the requested data type {dtype:?}")] UnsupportedDType { /// The string representation of the device. device: String, /// The data type that caused the error. dtype: DType, }, // TODO: `InvalidContext` if a device policy cannot be changed after init / during training / etc. } impl DeviceError { /// Helper to create a [`DeviceError::UnsupportedDType`] from any device. pub fn unsupported_dtype(device: &D, dtype: DType) -> Self { Self::UnsupportedDType { device: format!("{device:?}"), dtype, } } } fn check_dtype_support( device: &B::Device, dtype: impl Into, ) -> Result<(), DeviceError> { let dtype = dtype.into(); // Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized // operations should not be used as default. if B::supports_dtype(device, dtype) { Ok(()) } else { Err(DeviceError::unsupported_dtype(device, dtype)) } } /// Sets the default data types for the device. /// /// This updates the device's default data types used for tensor creation. /// The policy should typically be set once during initialization and then /// remains global for all subsequent operations on that device. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{DType, Int, Tensor, set_default_dtypes}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Update the device policy /// set_default_dtypes::(&device, DType::F16, DType::I32); /// /// // All float tensors created after this will use F16 by default /// let tensor = Tensor::::zeros([2, 3], &device); /// // All int tensors created after this will use I32 default /// let tensor = Tensor::::zeros([2, 3], &device); /// } /// ``` pub fn set_default_dtypes( device: &B::Device, float_dtype: impl Into, int_dtype: impl Into, ) -> Result<(), DeviceError> { let float_dtype = float_dtype.into(); let int_dtype = int_dtype.into(); check_dtype_support::(device, float_dtype)?; check_dtype_support::(device, int_dtype)?; set_default_dtypes_unchecked(device, float_dtype, int_dtype); Ok(()) } /// Sets the default floating-point data type for the device. /// /// This updates the device's default data types used for tensor creation. /// The policy should typically be set once during initialization and then /// remains global for all subsequent operations on that device. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{DType, Tensor, set_default_float_dtype}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Update the device policy /// set_default_float_dtype::(&device, DType::F16); /// /// // All float tensors created after this will use F16 by default /// let tensor = Tensor::::zeros([2, 3], &device); /// } /// ``` pub fn set_default_float_dtype( device: &B::Device, dtype: impl Into, ) -> Result<(), DeviceError> { let dtype = dtype.into(); check_dtype_support::(device, dtype)?; set_default_float_dtype_unchecked(device, dtype); Ok(()) } /// Sets the default integer data type for the device. /// /// This updates the device's default data types used for tensor creation. /// The policy should typically be set once during initialization and then /// remains global for all subsequent operations on that device. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{DType, Int, Tensor, set_default_int_dtype}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Update the device policy /// set_default_int_dtype::(&device, DType::I32); /// /// // All int tensors created after this will use I32 default /// let tensor = Tensor::::zeros([2, 3], &device); /// } /// ``` pub fn set_default_int_dtype( device: &B::Device, dtype: impl Into, ) -> Result<(), DeviceError> { let dtype = dtype.into(); check_dtype_support::(device, dtype)?; set_default_int_dtype_unchecked(device, dtype); Ok(()) } // Unchecked versions fn set_default_dtypes_unchecked( device: &D, float_dtype: FloatDType, int_dtype: IntDType, ) { DevicePolicyRegistry::update(device, |p| { p.set_float_dtype(float_dtype); p.set_int_dtype(int_dtype); }); } fn set_default_float_dtype_unchecked(device: &D, dtype: FloatDType) { DevicePolicyRegistry::update(device, |p| { p.set_float_dtype(dtype); }); } fn set_default_int_dtype_unchecked(device: &D, dtype: IntDType) { DevicePolicyRegistry::update(device, |p| { p.set_int_dtype(dtype); }); } #[cfg(all(test, feature = "std"))] mod tests { use serial_test::serial; use super::*; fn clear_registry() { REGISTRY.write().unwrap().clear(); } #[derive(Clone, Debug, Default, PartialEq, new)] pub struct TestDeviceA { index: u32, } impl Device for TestDeviceA { fn from_id(device_id: DeviceId) -> Self { Self { index: device_id.index_id, } } fn to_id(&self) -> DeviceId { DeviceId { type_id: 0, index_id: self.index, } } fn device_count(_type_id: u16) -> usize { 1 } } impl DeviceOps for TestDeviceA {} #[derive(Clone, Debug, Default, PartialEq, new)] pub struct TestDeviceB { index: u32, } impl Device for TestDeviceB { fn from_id(device_id: DeviceId) -> Self { Self { index: device_id.index_id, } } fn to_id(&self) -> DeviceId { DeviceId { type_id: 0, index_id: self.index, } } fn device_count(_type_id: u16) -> usize { 1 } } impl DeviceOps for TestDeviceB {} #[test] #[serial] fn default_policy_is_created_and_shared() { clear_registry(); // reset registry for each test let device = TestDeviceA::new(0); let p1 = get_device_policy(&device); let p2 = get_device_policy(&device); assert!(Arc::ptr_eq(&p1, &p2)); // Not explicitly set assert!(p1.float_dtype().is_none()); assert!(p1.int_dtype().is_none()); assert!(p2.float_dtype().is_none()); assert!(p2.int_dtype().is_none()); } #[test] #[serial] fn updated_policy_is_shared() { clear_registry(); // reset registry for each test let device = TestDeviceA::new(0); // The device policy is meant to be set once at initialization set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32); let p1 = get_device_policy(&device); let p2 = get_device_policy(&device); assert!(Arc::ptr_eq(&p1, &p2)); assert_eq!(p1.float_dtype(), Some(FloatDType::BF16)); assert_eq!(p1.int_dtype(), Some(IntDType::I32)); assert_eq!(p2.float_dtype(), Some(FloatDType::BF16)); assert_eq!(p2.int_dtype(), Some(IntDType::I32)); } #[test] #[serial] fn policy_is_device_id_specific() { clear_registry(); // reset registry for each test let d1 = TestDeviceA::new(0); let d2 = TestDeviceA::new(1); set_default_float_dtype_unchecked(&d1, FloatDType::F16); let p1 = get_device_policy(&d1); let p2 = get_device_policy(&d2); assert!(!Arc::ptr_eq(&p1, &p2)); assert_eq!(p1.float_dtype(), Some(FloatDType::F16)); assert!(p1.int_dtype().is_none()); assert!(p2.float_dtype().is_none()); assert!(p2.int_dtype().is_none()); } #[test] #[serial] fn policy_is_device_type_specific() { clear_registry(); // reset registry for each test let d1 = TestDeviceA::new(0); let d2 = TestDeviceB::new(0); set_default_float_dtype_unchecked(&d2, FloatDType::F16); let p1 = get_device_policy(&d1); let p2 = get_device_policy(&d2); assert!(p1.float_dtype().is_none()); assert!(p1.int_dtype().is_none()); assert_eq!(p2.float_dtype(), Some(FloatDType::F16)); assert!(p2.int_dtype().is_none()); } #[test] #[serial] fn updating_policy_should_not_affect_snapshot() { clear_registry(); // reset registry for each test // The device policy is meant to be set once at initialization let device = TestDeviceA::new(0); let before = get_device_policy(&device); set_default_float_dtype_unchecked(&device, FloatDType::BF16); let after = get_device_policy(&device); assert!(!Arc::ptr_eq(&before, &after)); assert_eq!(after.float_dtype(), Some(FloatDType::BF16)); assert!(before.float_dtype().is_none()); } #[test] #[serial] fn set_default_dtypes_overwrites_fields() { clear_registry(); // reset registry for each test let device = TestDeviceA::new(0); set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64); let policy = get_device_policy(&device); assert_eq!(policy.float_dtype(), Some(FloatDType::F16)); assert_eq!(policy.int_dtype(), Some(IntDType::I64)); } } ================================================ FILE: crates/burn-tensor/src/lib.rs ================================================ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! This library provides the core abstractions required to run tensor operations with Burn. //! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations. //! Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait. #[macro_use] extern crate derive_new; extern crate alloc; mod tensor; pub(crate) use tensor::check::macros::check; pub use tensor::*; // Re-exported types pub use burn_backend::{AllocationProperty, Bytes, StreamId, bf16, f16, read_sync, try_read_sync}; mod device; pub use device::*; ================================================ FILE: crates/burn-tensor/src/tensor/activation/base.rs ================================================ use crate::backend::Backend; use crate::check::TensorCheck; use crate::{Tensor, TensorPrimitive, check, s}; /// Applies the rectified linear unit function element-wise /// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375). /// #[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")] #[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")] pub fn relu(tensor: Tensor) -> Tensor { tensor.relu() } /// Applies the leaky rectified linear unit function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\) $$ or $$ \text{LeakyReLU}(x) = \begin{cases} x & \text{if } x \geq 0 \newline \text{negative\\_slope} \cdot x & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`" )] pub fn leaky_relu( tensor: Tensor, negative_slope: f64, ) -> Tensor { Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu( tensor.primitive.tensor(), negative_slope.into(), ))) } /// Applies the Gaussian Error Linear Units function as described in the paper /// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf). /// #[cfg_attr( doc, doc = r#" $$ \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) $$ where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution. "# )] #[cfg_attr( not(doc), doc = r#" `GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))` where `Φ(x)` is the cumulative distribution function for the Gaussian distribution. "# )] pub fn gelu(tensor: Tensor) -> Tensor { Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor()))) } /// Applies the tanh-based approximate GELU function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{GELU\_approx}(x) = \frac{x}{2}\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right) $$ "# )] #[cfg_attr( not(doc), doc = "`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`" )] pub fn gelu_approximate(tensor: Tensor) -> Tensor { /// sqrt(2/π) precomputed as FRAC_2_SQRT_PI * FRAC_1_SQRT_2 const SQRT_2_OVER_PI: f64 = core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2; let x = tensor; let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715; let inner = inner * SQRT_2_OVER_PI; (x.clone() * (inner.tanh() + 1)) * 0.5 } /// Applies Parametric ReLu activation function as described in the paper /// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852). /// /// - The tensor is assumed to be of shape `[batch_size, channels, ...]`. /// - `alpha` is assumed to be of shape `[channels]` or `[1]`. /// #[cfg_attr( doc, doc = r#" $$ \text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\) $$ or $$ \text{PReLU}(x) = \begin{cases} x & \text{if } x \geq 0 \newline \alpha x & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")] pub fn prelu( tensor: Tensor, alpha: Tensor, ) -> Tensor { check!(TensorCheck::check_prelu_shape::( &tensor.shape(), &alpha.shape() )); let weight = if alpha.dims()[0] == 1 { // if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D alpha.reshape([1; D]) } else { // D>=2 because the case where D==1 and num_weights >1 is handled by check function // there is more than 1 weight and rank is more than 2 let num_weights = alpha.dims()[0]; let mut s = [1; D]; s[1] = num_weights; // reshape the weights to (1, channels,1 ...) alpha.reshape(s) }; Tensor::from_primitive(TensorPrimitive::Float(B::prelu( tensor.primitive.tensor(), weight.primitive.tensor(), ))) } /// Applies the softmax function on the input tensor along the given dimension. /// #[cfg_attr( doc, doc = r#" $$ \text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)} $$ "# )] #[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")] /// /// # Arguments /// - `dim`: the dimension along which Softmax will be computed. /// /// # Panics /// - If `dim` is outside [0, D) pub fn softmax(tensor: Tensor, dim: usize) -> Tensor { check!(TensorCheck::dim_ops::("softmax", dim)); let tensor = tensor.clone() - tensor.detach().max_dim(dim); let tensor = tensor.exp(); let tensor_tmp = tensor.clone().sum_dim(dim); tensor.div(tensor_tmp) } /// Applies the softmin function on the input tensor along the given dimension. /// #[cfg_attr( doc, doc = r#" $$ \text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)} $$ "# )] #[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")] /// /// # Arguments /// - `dim`: the dimension along which Softmax will be computed. /// /// # Panics /// - If `dim` is outside [0, D) pub fn softmin(tensor: Tensor, dim: usize) -> Tensor { check!(TensorCheck::dim_ops::("softmin", dim)); softmax(tensor.neg(), dim) } /// Applies the SoftPlus function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\) $$ "# )] #[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")] /// /// The SoftPlus function is a smooth approximation of the ReLU function. pub fn softplus(tensor: Tensor, beta: f64) -> Tensor { let tensor = (tensor.mul_scalar(beta).exp() + 1).log(); tensor.div_scalar(beta) } /// Applies the "quiet softmax" function on the input tensor along the given dimension. /// /// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html). /// /// This function is similar to the softmax function, but it allows for "no selection" when /// all the outputs are close to zero. /// #[cfg_attr( doc, doc = r#" $$ \text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)} $$ "# )] #[cfg_attr( not(doc), doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`" )] /// /// # Arguments /// - `dim`: the dimension along which Softmax will be computed. /// /// # Panics /// - If `dim` is outside [0, D) pub fn quiet_softmax(tensor: Tensor, dim: usize) -> Tensor { check!(TensorCheck::dim_ops::("softmax", dim)); let max_vals = tensor.clone().detach().max_dim(dim); let exp_x = (tensor - max_vals.clone()).exp(); let sum_exp = exp_x.clone().sum_dim(dim); exp_x.div(sum_exp + max_vals.neg().exp()) } /// Applies the log softmax function on the input tensor along the given dimension. /// #[cfg_attr( doc, doc = r#" $$ \text{log\\_softmax}\(x_i\) = \log\left(\text{softmax}\(x_i\)\right) = \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right) $$ "# )] #[cfg_attr( not(doc), doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`" )] /// /// # Arguments /// - `dim`: the dimension along which Softmax will be computed. /// /// # Panics /// - If `dim` is outside [0, D) pub fn log_softmax(tensor: Tensor, dim: usize) -> Tensor { check!(TensorCheck::dim_ops::("log softmax", dim)); let tensor = tensor.clone() - tensor.detach().max_dim(dim); let tensor_tmp = tensor.clone().exp().sum_dim(dim).log(); tensor.sub(tensor_tmp) } /// Applies the sigmoid function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{sigmoid}\(x\) = \sigma(x) = \frac{1}{1 + \exp(-x)} $$ "# )] #[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")] pub fn sigmoid(tensor: Tensor) -> Tensor { Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid( tensor.primitive.tensor(), ))) } /// Applies the hard sigmoid function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta)) $$ "# )] #[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")] pub fn hard_sigmoid( tensor: Tensor, alpha: f64, beta: f64, ) -> Tensor { Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid( tensor.primitive.tensor(), alpha.into(), beta.into(), ))) } /// Applies the log sigmoid function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right) $$ "# )] #[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")] pub fn log_sigmoid(tensor: Tensor) -> Tensor { Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid( tensor.primitive.tensor(), ))) } /// Applies the SiLU function (also known as the swish function) element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)} $$ "# )] #[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")] pub fn silu(tensor: Tensor) -> Tensor { tensor.clone().mul(sigmoid(tensor)) } /// Applies the hard swish function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5)) $$ "# )] #[cfg_attr( not(doc), doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`" )] pub fn hard_swish(tensor: Tensor) -> Tensor { tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5)) } /// Applies the Mish function as described in the paper in /// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). /// #[cfg_attr( doc, doc = r#" $$ \text{Mish}\(x\) = x \cdot \tanh(\text{Softplus}(x)) = \tanh\left(\log\(1 + \exp\(x\)\)\right) $$ "# )] #[cfg_attr( not(doc), doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`" )] pub fn mish(tensor: Tensor) -> Tensor { tensor.clone().mul(softplus(tensor, 1.0).tanh()) } /// Applies the tanh function element-wise. pub fn tanh(tensor: Tensor) -> Tensor { tensor.tanh() } /// Applies the Exponential Linear Unit function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{ELU}\(x\) = \begin{cases} x & \text{if } x > 0 \newline \alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0 \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`f(x) =`\n- `x for x > 0`\n- `alpha * (exp(x) - 1) for x <= 0`" )] pub fn elu(tensor: Tensor, alpha: f64) -> Tensor { let mask = tensor.clone().lower_equal_elem(0); let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha); tensor.mask_where(mask, scaled) } /// Applies the Continuously Differentiable Exponential Linear Unit function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{CELU}(x) = \begin{cases} x & \text{if } x \geq 0 \newline \alpha \cdot \left(\exp\left(\frac{x}{\alpha}\right) - 1\right) & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`" )] /// /// See also [CELU](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html) /// /// # Arguments /// - `alpha`: scaling parameter for the negative part. pub fn celu(tensor: Tensor, alpha: f64) -> Tensor { let mask = tensor.clone().lower_equal_elem(0); let scaled = tensor .clone() .div_scalar(alpha) .exp() .sub_scalar(1) .mul_scalar(alpha); tensor.mask_where(mask, scaled) } /// Applies the Scaled Exponential Linear Unit function element-wise /// as described in the paper [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515). /// #[cfg_attr( doc, doc = r#" $$ \text{SELU}\(x\) = \gamma \cdot \begin{cases} x & \text{if } x > 0 \newline \alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0 \end{cases} $$ where $\alpha \approx 1.6733$ and $\gamma \approx 1.0507$. "# )] #[cfg_attr( not(doc), doc = "`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`" )] pub fn selu(tensor: Tensor) -> Tensor { // Constants from the SELU paper / ONNX spec const ALPHA: f64 = 1.6732632423543772848170429916717_f64; const GAMMA: f64 = 1.0507009873554804934193349852946_f64; let mask = tensor.clone().greater_equal_elem(0.0); let positive = tensor.clone().mul_scalar(GAMMA); let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA); negative.mask_where(mask, positive) } /// Applies the thresholded rectified linear unit function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{ThresholdedReLU}(x) = \begin{cases} x & \text{if } x > \alpha \newline 0 & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr(not(doc), doc = "`f(x) =`\n- `x if x > alpha`\n- `0 otherwise`")] /// /// # Arguments /// - `alpha`: threshold value (default in ONNX is 1.0). pub fn thresholded_relu( tensor: Tensor, alpha: f64, ) -> Tensor { let mask = tensor.clone().lower_equal_elem(alpha); tensor.mask_fill(mask, 0) } /// Applies the gated linear unit function. /// /// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half. /// /// **Note**: /// * The size of the input tensor along `dim` must be divisible by 2. /// /// ### Arguments /// * `tensor` - The input tensor. /// /// ### Returns /// * A tensor with the same shape as the input, except the size along `dim` is halved. pub fn glu(tensor: Tensor, dim: usize) -> Tensor { // TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU. assert!( tensor.dims()[dim].is_multiple_of(2), "Input tensor along dimension {dim} must have an even size. N is divisible by 2." ); let new_len = tensor.dims()[dim] / 2; let a = tensor.clone().slice_dim(dim, s![0..new_len]); let b = tensor.slice_dim(dim, s![new_len..new_len * 2]); a.mul(sigmoid(b)) } /// Applies the Softsign function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{softsign}(x) = \frac{x}{1 + |x|} $$ "# )] #[cfg_attr(not(doc), doc = "`softsign(x_i) = x_i / (1 + |x_i|)`")] pub fn softsign(tensor: Tensor) -> Tensor { tensor.clone().div(tensor.abs() + 1) } /// Applies the HardShrink function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{hard\_shrink}(x) = \begin{cases} x & \text{if } x > \lambda \newline x & \text{if } x < -\lambda \newline 0 & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`" )] /// # Arguments /// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5. pub fn hard_shrink(tensor: Tensor, lambda: f64) -> Tensor { let mask = tensor.clone().abs().lower_equal_elem(lambda); tensor.mask_fill(mask, 0) } /// Applies the SoftShrink function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{soft\_shrink}(x) = \begin{cases} x - \lambda & \text{if } x > \lambda \newline x + \lambda & \text{if } x < -\lambda \newline 0 & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`" )] /// # Arguments /// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5. pub fn soft_shrink(tensor: Tensor, lambda: f64) -> Tensor { shrink(tensor, lambda, lambda) } /// Applies the Shrink function element-wise. /// #[cfg_attr( doc, doc = r#" $$ \text{shrink}(x) = \begin{cases} x - \text{bias} & \text{if } x > \lambda \newline x + \text{bias} & \text{if } x < -\lambda \newline 0 & \text{otherwise} \end{cases} $$ "# )] #[cfg_attr( not(doc), doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`" )] /// # Arguments /// - `lambda`: the lambda value for the Shrink formulation. /// - `bias`: the bias value for the Shrink formulation. pub fn shrink( tensor: Tensor, lambda: f64, bias: f64, ) -> Tensor { let abs_tensor = tensor.clone().abs(); let sign = tensor.clone().sign(); let shrunk = tensor.sub(sign.mul_scalar(bias)); let mask = abs_tensor.lower_equal_elem(lambda); shrunk.mask_fill(mask, 0) } ================================================ FILE: crates/burn-tensor/src/tensor/activation/mod.rs ================================================ mod base; pub use base::*; ================================================ FILE: crates/burn-tensor/src/tensor/api/autodiff.rs ================================================ pub use burn_backend::tensor::BasicAutodiffOps; use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend}; impl Tensor { /// Backward pass of the tensor. pub fn backward(&self) -> B::Gradients { B::backward(self.primitive.clone().tensor()) } /// Get the gradients of a tensor if it exist. /// /// Returns a new reference to the same tensor. Therefore the same grad tensor can /// be accessed multiple times. If you only need to get the gradients one time, /// consider using [grad_remove](Tensor::grad_remove) for better performance. pub fn grad(&self, grads: &B::Gradients) -> Option> { match &self.primitive { TensorPrimitive::Float(tensor) => B::grad(tensor, grads) .map(TensorPrimitive::Float) .map(Tensor::new), TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads) .map(TensorPrimitive::Float) .map(Tensor::new), } } /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result. pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option> { match &self.primitive { TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads) .map(TensorPrimitive::Float) .map(Tensor::new), TensorPrimitive::QFloat(_tensor) => { B::grad_remove(&self.primitive.clone().tensor(), grads) .map(TensorPrimitive::Float) .map(Tensor::new) } } } /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided /// gradient. pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor) { match &self.primitive { TensorPrimitive::Float(tensor) => { B::grad_replace(tensor, grads, grad.primitive.tensor()) } TensorPrimitive::QFloat(_tensor) => B::grad_replace( &self.primitive.clone().tensor(), grads, grad.primitive.tensor(), ), } } } impl> Tensor { /// Returns the inner tensor without the autodiff information. pub fn inner(self) -> Tensor { Tensor::new(K::inner(self.primitive)) } /// Convert a tensor to the autodiff backend. /// /// # Arguments /// /// * `inner` - The tensor to convert. /// /// # Returns /// /// The tensor converted to the autodiff backend. pub fn from_inner(inner: Tensor) -> Self { Self::new(K::from_inner(inner.primitive)) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/base.rs ================================================ #![allow(clippy::single_range_in_vec_init)] use crate::backend::ExecutionError; use crate::check::unwrap_shape_reshape; use burn_backend::Scalar; pub use burn_backend::tensor::BasicOps; use alloc::vec::Vec; use alloc::format; use alloc::string::String; use alloc::vec; use burn_std::{SliceOps, stub::RwLock}; use core::iter::repeat; use core::{fmt::Debug, ops::Range}; use serde::{Deserialize, Deserializer}; use crate::{AsIndex, Slice, SliceArg, wrap_index}; use crate::{ Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata, backend::Backend, check, }; use crate::{DType, Element}; use crate::{IndexingUpdateOp, TensorCreationOptions}; use crate::{cast::ToElement, check::TensorCheck}; use serde::{Serialize, Serializer}; /// A tensor with a given backend, shape and data type. /// /// # Indexing /// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types /// or [`select`](Tensor::select) for numeric types. /// /// ## Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// use burn_tensor::Int; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data( /// [ /// [3.0, 4.9, 2.0], /// [2.0, 1.9, 3.0], /// [6.0, 1.5, 7.0], /// [3.0, 4.9, 9.0], /// ], /// &device, /// ); /// /// // Slice the tensor to get the second and third rows: /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]] /// // The resulting tensor will have dimensions [2, 3]. /// let slice = tensor.clone().slice([1..3]); /// println!("{slice}"); /// /// // Slice the tensor to get the first two rows and the first 2 columns: /// // [[3.0, 4.9], [2.0, 1.9]] /// // The resulting tensor will have dimensions [2, 2]. /// let slice = tensor.clone().slice([0..2, 0..2]); /// println!("{slice}"); /// /// // Index the tensor along the dimension 1 to get the elements 0 and 2: /// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]] /// // The resulting tensor will have dimensions [4, 2] /// let indices = Tensor::::from_data([0, 2], &device); /// let indexed = tensor.select(1, indices); /// println!("{indexed}"); /// } /// ``` #[derive(new, Clone, Debug)] pub struct Tensor where B: Backend, K: TensorKind, { pub(crate) primitive: K::Primitive, } impl From for Tensor where B: Backend, K: BasicOps, T: Into, { fn from(value: T) -> Self { Tensor::from_data(value.into(), &Default::default()) } } impl Tensor where B: Backend, K: BasicOps, K::Elem: Element, { /// Executes an operation on the tensor and modifies its value. /// /// # Notes /// /// This won't necessarily reuse the same tensor data/buffer, but it should if there is /// no other reference pointing to the same tensor. /// /// Wrapping operations with inplace is not an optimization, it's mainly there if you /// want to mutate a tensor by using owned operations. A plausible usage would be to /// update the weights of a mutable model reference. pub fn inplace Self>(&mut self, func: F) { let mut tensor_owned = Tensor::empty([0; D], &self.device()); core::mem::swap(&mut tensor_owned, self); let mut tensor_new = func(tensor_owned); core::mem::swap(&mut tensor_new, self); } /// Converts the tensor into a primitive tensor. pub fn into_primitive(self) -> K::Primitive { self.primitive } /// Converts from a primitive tensor into a tensor. pub fn from_primitive(tensor: K::Primitive) -> Self { Self::new(tensor) } /// Returns the number of dimensions of the tensor. pub fn rank(&self) -> usize { self.primitive.rank() } /// Returns the tensor primitive data type. /// /// # Note /// Some element types are encoded in different primitive types depending on the backend /// (e.g., bool could be encoded as `u8` or `u32`). pub fn dtype(&self) -> DType { self.primitive.dtype() } /// Create an empty tensor of the given shape. /// /// # Arguments /// /// - `shape`: The shape of the tensor. /// - `device`: The device where the tensor will be created. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create an empty tensor with dimensions [2, 3, 4]. /// let tensor = Tensor::::empty([2, 3, 4], &device); /// } /// ``` pub fn empty>(shape: S, options: impl Into>) -> Self { let opt = options.into(); let shape = shape.into(); let dtype = opt.resolve_policy::(); check!(TensorCheck::creation_ops::("Empty", &shape)); Self::new(K::empty(shape, &opt.device, dtype)) } /// Create a tensor of the given shape where each element is zero. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::zeros(Shape::new([2, 3]), &device); /// println!("{tensor}"); /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] /// } /// ``` pub fn zeros>(shape: S, options: impl Into>) -> Self { let opt = options.into(); let shape = shape.into(); let dtype = opt.resolve_policy::(); check!(TensorCheck::creation_ops::("Zeros", &shape)); Self::new(K::zeros(shape, &opt.device, dtype)) } /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.zeros_like(); /// println!("{tensor}"); /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] /// } /// ``` pub fn zeros_like(&self) -> Self { Self::new(K::zeros(self.shape(), &self.device(), self.dtype())) } /// Create a tensor of the given shape where each element is one. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::ones(Shape::new([2, 3]), &device); /// println!("{tensor}"); /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] /// } /// ``` pub fn ones>(shape: S, options: impl Into>) -> Self { let opt = options.into(); let shape = shape.into(); let dtype = opt.resolve_policy::(); check!(TensorCheck::creation_ops::("Ones", &shape)); Self::new(K::ones(shape, &opt.device, dtype)) } /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.ones_like(); /// println!("{tensor}"); /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] /// } /// ``` pub fn ones_like(&self) -> Self { Self::new(K::ones(self.shape(), &self.device(), self.dtype())) } /// Create a tensor of the given shape where each element is equal to the provided value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::full(Shape::new([2, 3]), 5.0, &device); /// println!("{tensor}"); /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]] /// } /// ``` pub fn full, E: ElementConversion>( shape: S, fill_value: E, options: impl Into>, ) -> Self { let opt = options.into(); let shape = shape.into(); let dtype = opt.resolve_policy::(); check!(TensorCheck::creation_ops::("Full", &shape)); Self::new(K::full( shape, Scalar::new(fill_value, &dtype), &opt.device, dtype, )) } /// Returns a new tensor with the same shape, dtype, and device as the current tensor, /// filled with the provided value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.full_like(5.0); /// println!("{tensor}"); /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]] /// } /// ``` pub fn full_like(&self, fill_value: E) -> Self { let dtype = self.dtype(); Self::new(K::full( self.shape(), Scalar::new(fill_value, &dtype), &self.device(), dtype, )) } /// Returns the dimensions of the current tensor. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::ones([2, 3, 4], &device); /// let dims = tensor.dims(); // [2, 3, 4] /// println!("{dims:?}"); /// } /// ``` pub fn dims(&self) -> [usize; D] { Self::shape(self).dims() } /// Returns the shape of the current tensor. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::ones([2, 3, 4], &device); /// // Shape { dims: [2, 3, 4] } /// let shape = tensor.shape(); /// } /// ``` pub fn shape(&self) -> Shape { self.primitive.shape() } /// Reshape the tensor to have the given shape. /// /// The tensor has the same data and number of elements as the input. /// /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]` /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12]. /// /// A `0` in the shape instructs to keep the current dimension from the original tensor, /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4]. /// This is useful when reshaping tensors with unknown dimensions and combining with `-1` /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor /// with [1, 3, 4] dimensions to [1, 12]. /// /// # Arguments /// - `shape`: The new shape of the tensor. /// /// # Panics /// - If the tensor contains more than one `-1` in the shape. /// - If the tensor contains values that are not positive (other than -1). /// - If the shape does not match the number of elements of the original shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a tensor with dimensions [2, 3, 4] /// let tensor = Tensor::::ones([2, 3, 4], &device); /// // Reshape it to [2, 12], where 12 is inferred from the number of elements. /// let reshaped = tensor.reshape([2, -1]); /// println!("{reshaped}"); /// } /// ``` pub fn reshape>(self, shape: S) -> Tensor { // Convert reshape args to shape let shape = shape.into_shape::(self.shape()); Tensor::new(K::reshape(self.primitive, shape)) } /// Transpose the tensor. /// /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is /// applied on the last two dimensions. For example, the transpose of a tensor with shape /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`. /// /// See also [`permute`](Tensor::permute). /// /// # Arguments /// /// * `tensor` - The tensor to transpose. /// /// # Returns /// /// The transposed tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor of shape [2, 3] /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// /// // Transpose the tensor: /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]] /// // The resulting tensor will have dimensions [3, 2]. /// let transposed = tensor.transpose(); /// println!("{transposed}"); /// } /// ``` pub fn transpose(self) -> Tensor { Tensor::new(K::transpose(self.primitive)) } /// Alias for `transpose`. #[inline(always)] pub fn t(self) -> Tensor { self.transpose() } /// Swaps two dimensions of a tensor. /// /// This is a no-op when `dim1 == dim2`, assuming both are within bounds. /// /// # Arguments /// /// * `tensor` - The tensor to swap the dimensions of. /// * `dim1` - The first dimension to swap, supports negative indexing. /// * `dim2` - The second dimension to swap, supports negative indexing. /// /// # Returns /// /// The tensor with the dimensions swapped. /// /// # Panics /// /// When dimensions are out of bounds. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor of shape [2, 3] /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// /// // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`): /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]] /// // The resulting tensor will have dimensions [3, 2]. /// let swapped = tensor.swap_dims(0, -1); /// println!("{swapped}"); /// } /// ``` pub fn swap_dims(self, dim1: Dim1, dim2: Dim2) -> Tensor where Dim1: AsIndex, Dim2: AsIndex, { let dim1 = dim1.expect_dim_index(D); let dim2 = dim2.expect_dim_index(D); check!(TensorCheck::swap_dims::(dim1, dim2)); if dim1 == dim2 { self } else { Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) } } /// Permute the dimensions of the tensor. /// /// This is a no-op when the resolved `axes` match the current order. /// /// # Arguments /// /// * `axes` - The new order of the dimensions. The length of the axes /// must be equal to the number of dimensions of the tensor. /// The values must be unique and in the range of the number of dimensions. /// The values can be negative, in which case they are used as an offset from the end. /// /// # Returns /// /// The tensor with the dimensions permuted. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor of shape [3, 2] /// let tensor = Tensor::::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device); /// /// // Permute the dimensions 1 and 0: /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]] /// // The resulting tensor will have dimensions [3, 2]. /// let permuted = tensor.permute([1, 0]); /// println!("{permuted}"); /// } /// ``` pub fn permute(self, axes: [Dim; D]) -> Tensor where Dim: AsIndex, { let mut no_op = true; let mut fixed_axes = [0; D]; for (i, axis) in axes.into_iter().enumerate() { let dim = axis.expect_dim_index(D); no_op &= dim == i; fixed_axes[i] = dim; } if no_op { self } else { check!(TensorCheck::permute(fixed_axes)); Tensor::new(K::permute(self.primitive, &fixed_axes)) } } /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination. /// /// Other dimensions of input that are not explicitly moved remain in their original order and appear /// at the positions not specified in destination. /// /// # Arguments /// /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions. /// The values can be negative, in which case they are used as an offset from the end. /// /// * `dst` - Destination positions for each of the original dims. These must also be unique. /// /// # Panics /// /// - If the source and destination dimensions are not of the same length. /// - If the source and destination vectors contain duplicate values. /// - If the source and destination vectors contain values that are out of bounds. /// /// # Returns /// /// The tensor with the dimensions moved. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 3D tensor of shape [3, 2, 1] /// let tensor = Tensor::::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device); /// /// // Move the dimensions 0 and 1: /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]] /// // The resulting tensor will have dimensions [2, 3, 1]. /// let moved = tensor.movedim(1, 0); /// println!("{moved}"); /// } /// ``` /// /// # Note /// /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op /// for it pub fn movedim(self, src: S1, dst: S2) -> Tensor { let source_dims = src.into_dim_vec::(); let destination_dims = dst.into_dim_vec::(); check!(TensorCheck::movedim_args_length( &source_dims, &destination_dims )); let mut m = [-1; D]; for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) { m[d] = s as isize; } let mut axes: [isize; D] = [0; D]; let mut source_i = 0; for (dest_i, item) in axes.iter_mut().enumerate().take(D) { *item = if m[dest_i] != -1 { m[dest_i] } else { while source_dims.contains(&source_i) { source_i += 1; } let result = source_i as isize; source_i += 1; result }; } self.permute(axes) } /// Reverse the order of elements in the tensor along the given dimensions. /// /// # Arguments /// /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions. /// The values can be negative, in which case they are used as an offset from the end. /// /// # Returns /// /// The tensor with the axes flipped. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [4, 3] /// let tensor = Tensor::::from_data( /// [ /// [3.0, 4.9, 2.0], /// [2.0, 1.9, 3.0], /// [4.0, 5.9, 8.0], /// [1.4, 5.8, 6.0], /// ], /// &device, /// ); /// /// // Flip the elements in dimensions 0 and 1: /// // [[6.0, 5.8, 1.4], /// // [8.0, 5.9, 4.0], /// // [3.0, 1.9, 2.0], /// // [2.0, 4.9, 3.0]] /// // The resulting tensor will have dimensions [4, 3]. /// let flipped = tensor.flip([0, 1]); /// println!("{flipped}"); /// } /// ``` pub fn flip(self, axes: [isize; N]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; N] = [0; N]; for (i, &x) in axes.iter().enumerate() { transformed_axes[i] = if x < 0 { (D as isize + x) as usize } else { x as usize }; } // Check if the axes are valid check!(TensorCheck::flip(D, &transformed_axes)); Tensor::new(K::flip(self.primitive, &transformed_axes)) } /// Flatten the tensor along a given range of dimensions. /// /// This function collapses the specified range of dimensions into a single dimension, /// effectively flattening the tensor in that range. /// /// # Arguments /// /// - `start_dim`: The starting dimension of the range to be flattened, /// supports negative indexing. /// - `end_dim`: The ending dimension of the range to be flattened (inclusive), /// supports negative indexing. /// /// # Type Parameters /// /// - `D2`: The resulting number of dimensions in the flattened tensor. /// /// # Returns /// /// A new `Tensor` instance with the specified range of dimensions flattened. /// /// # Example /// /// ```rust /// /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 3D tensor with dimensions [2, 3, 4] /// let tensor = Tensor::::ones(Shape::new([2, 3, 4]), &device); /// /// // Flatten the tensor from dimensions 1 to 2 (inclusive). /// // The resulting tensor will have dimensions [2, 12] /// let flattened: Tensor = tensor.flatten(1, 2); /// println!("{flattened}"); /// } /// ``` pub fn flatten( self, start_dim: impl AsIndex, end_dim: impl AsIndex, ) -> Tensor { let start_dim = start_dim.expect_dim_index(D); let end_dim = end_dim.expect_dim_index(D); check!(TensorCheck::flatten::(start_dim, end_dim)); let new_shape = self.shape().flatten_dims(start_dim, end_dim); Tensor::new(K::reshape(self.primitive, new_shape)) } /// Squeeze the tensor along all dimensions, removing dimensions /// of size one, and effectively reducing the rank of the tensor. /// /// # Type Parameters /// /// - `D2`: The resulting number of dimensions in the squeezed tensor. /// /// # Returns /// /// A new `Tensor` instance with the specified dimension removed. /// /// # Example /// /// ```rust /// /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 4D tensor with dimensions [1, 3, 1, 3] /// let tensor = Tensor::::from_data( /// [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]], /// &device, /// ); /// /// // Squeeze the tensor dimensions. /// // The resulting tensor will have dimensions [3, 3]. /// let squeezed = tensor.squeeze::<2>(); /// println!("{squeezed}"); /// } /// ``` pub fn squeeze(self) -> Tensor { let new_dims = self .shape() .iter() .filter_map(|&dim| if dim == 1 { None } else { Some(dim) }) .collect::>(); check!(TensorCheck::squeeze_dims_len::(new_dims.len())); Tensor::new(K::reshape(self.primitive, new_dims.into())) } /// Squeeze the tensor along the given dimension, removing the specified dimension /// of size one, and effectively reducing the rank of the tensor by one. /// /// # Arguments /// /// - `dim`: The dimension to be squeezed. /// /// # Type Parameters /// /// - `D2`: The resulting number of dimensions in the squeezed tensor. /// /// # Panics /// /// If the size in the squeezed dimension is not 1. /// /// # Returns /// /// A new `Tensor` instance with the specified dimension removed. /// /// # Example /// /// ```rust /// /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 3D tensor with dimensions [3, 1, 3] /// let tensor = Tensor::::from_data( /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]], /// &device, /// ); /// /// // Squeeze the dimension 1. /// // The resulting tensor will have dimensions [3, 3]. /// let squeezed = tensor.squeeze_dim::<2>(1); /// println!("{squeezed}"); /// } /// ``` pub fn squeeze_dim(self, dim: usize) -> Tensor { check!(TensorCheck::squeeze::(dim, &self.shape())); let current_dims = self.shape(); let mut new_dims: [usize; D2] = [0; D2]; new_dims[..dim].copy_from_slice(¤t_dims[..dim]); new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]); check!(TensorCheck::squeeze_dims_len::(new_dims.len())); Tensor::new(K::reshape(self.primitive, new_dims.into())) } /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted /// from the back. /// /// # Arguments /// /// - `dims`: The dimension(s) to be squeezed. /// /// # Type Parameters /// /// - `D2`: The resulting number of dimensions in the squeezed tensor. /// /// # Returns /// /// A new `Tensor` instance with the specified dimensions removed. /// /// # Example /// /// ```rust /// /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 4D tensor with dimensions [2, 1, 4, 1] /// let tensor = Tensor::::ones(Shape::new([2, 1, 4, 1]), &device); /// /// // Squeeze the dimensions 1 and 3. /// // The resulting tensor will have dimensions [2, 4]. /// let squeezed: Tensor = tensor.squeeze_dims(&[1, 3]); /// println!("{squeezed}"); /// } /// ``` pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { let current_dims = self.shape(); let mut dim_indices: Vec; // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries if dims.is_empty() { dim_indices = current_dims .iter() .enumerate() .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None }) .collect(); } else { // If negative dims, count from the back dim_indices = dims .iter() .map(|&d| { if d < 0 { (current_dims.len() as isize + d) as usize } else { d as usize } }) .collect(); } // Sort indices and remove duplicates dim_indices.sort_unstable(); dim_indices.dedup(); // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions check!(TensorCheck::squeeze_dims_input::( &dim_indices, ¤t_dims )); // Calculate new dimensions let mut new_dims = Vec::new(); for (index, &dim_size) in current_dims.iter().enumerate() { // Exclude the dimension if it's explicitly marked for squeezing if dim_indices.contains(&index) { check!(TensorCheck::squeeze::(index, ¤t_dims)); continue; } new_dims.push(dim_size); } // Check that after squeezing, we still respect the D2 size check!(TensorCheck::squeeze_dims_len::(new_dims.len())); Tensor::new(K::reshape(self.primitive, new_dims.into())) } /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size. /// /// # Type Parameters /// /// - `D2`: The resulting number of dimensions in the unsqueezed tensor. /// /// # Panics /// /// If the output size `D2` is smaller than the current number of dimensions. /// /// # Returns /// /// A new `Tensor` instance with the specified dimensions added. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [3, 3] /// let tensor = Tensor::::ones(Shape::new([3, 3]), &device); /// // Unsqueeze the tensor up to 4 dimensions. /// // The resulting tensor will have dimensions [1, 1, 3, 3]. /// let unsqueezed = tensor.unsqueeze::<4>(); /// println!("{unsqueezed}"); /// } /// ``` pub fn unsqueeze(self) -> Tensor { check!(TensorCheck::unsqueeze::()); let mut dims = [1; D2]; let num_ones = D2 - D; let shape = self.shape(); dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]); let shape = Shape::new(dims); self.reshape(shape) } /// Creates a new tensor with a dimension of size one inserted at the specified position. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [3, 3] /// let tensor = Tensor::::ones(Shape::new([3, 3]), &device); /// // Unsqueeze the dimension 1. /// // The resulting tensor will have dimensions [3, 1, 3]. /// let unsqueezed: Tensor = tensor.unsqueeze_dim(1); /// println!("{unsqueezed}"); /// } /// ``` pub fn unsqueeze_dim(self, dim: usize) -> Tensor { check!(TensorCheck::unsqueeze_dim::(dim)); let mut dims = [1; D2]; let shape = self.shape(); dims[0..dim].copy_from_slice(&shape[0..dim]); if dim < D { dims[dim] = 1; dims[(dim + 1)..].copy_from_slice(&shape[dim..]); } else { dims[dim] = 1; } let shape = Shape::new(dims); self.reshape(shape) } /// Creates a new tensor with added dimensions of size one inserted at the specified indices. /// The indices can be negative, in which case they are counted from the last to the first dimension. /// the axes can contain duplicates, in which case the number of dimensions inserted at the index /// is the number of duplicates. /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// // Create a 3D tensor with dimensions [3, 4, 5] /// let tensor = Tensor::::ones(Shape::new([3, 4, 5]), &device); /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice. /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1]. /// let unsqueezed: Tensor = tensor.unsqueeze_dims(&[0, -1, -1]); /// println!("{unsqueezed}"); /// } /// ``` pub fn unsqueeze_dims(self, axes: &[impl AsIndex]) -> Tensor { let mut new_dims = [1; D2]; let old_dims = self.shape(); //for checking if the dimension is in the acceptable range //part 1: convert the negative indices to positive let mut neg_offset = D2; let mut dim_indices = axes .iter() .map(|d| { let d = d.as_index(); // check if the dimension is in the acceptable range check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d)); (if d < 0 { neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse) d + neg_offset as isize + 1 } else { d }) as usize }) .collect::>(); //sort the indices dim_indices.sort_unstable(); //Now use this to copy the chunks of the dims let mut prev_idx: usize = 0; let mut current_left_b: usize = 0; let mut current_right_b: usize = 0; let mut offset: usize = 0; dim_indices.iter().for_each(|d| { //check if there is space for at least one dimension if prev_idx < *d { current_right_b = *d - offset; //copy the chunks of the dims if current_right_b < D { new_dims[prev_idx..*d] .copy_from_slice(&old_dims[current_left_b..current_right_b]); } else { new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]); } prev_idx = *d + 1; //offset is equal to the number of extracted elements from the original shape offset += current_right_b - current_left_b; current_left_b = current_right_b; } else { //it's sorted so the only reason this would happen //is if multiple indices are the same prev_idx += 1; } }); //copy over anything past the index of the last new dimension if current_left_b < D { new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]); } //lastly, create the shape and reshape let shape = Shape::new(new_dims); self.reshape(shape) } /// Roll operation along a specific dimension; wrapping around the elements. /// /// ## Parameters /// /// - `shift`: The roll extent; supports negative values and wraps around. /// - `dim`: The dimension to roll; supports negative indexing. /// /// ## Returns /// /// A new tensor with the specified dimension rolled by the given shift amount. pub fn roll_dim(self, shift: Shift, dim: Dim) -> Self where Shift: AsIndex, Dim: AsIndex, { let dim = dim.expect_dim_index(D); let size = self.shape()[dim]; if size == 0 { // If the dimension is empty, return the tensor as is. return self; } let shift = wrap_index(shift, size); if shift == 0 { // If the shift is zero, return the tensor as is. return self; } self.unchecked_roll_dim(shift, dim) } /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts. /// /// ## Parameters /// /// - `shift`: The number of positions to shift; must be (0 < shift < size). /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape. /// /// ## Returns /// /// A new tensor with the specified dimension rolled by the given shift amount. #[inline(always)] fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self { #[cfg(debug_assertions)] { let size = self.shape()[dim]; assert!( 0 < shift && shift < size, "Expected: 0 < shift < size: found shift={shift}, size={size}", ); assert!( dim < self.shape().num_dims(), "Expected: dim < num_dims: found dim={dim}, num_dims={size}", ); } Tensor::cat( vec![ self.clone().slice_dim(dim, shift..), self.slice_dim(dim, ..shift), ], dim, ) } /// Roll operation. /// /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length. /// /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially. /// /// ## Parameters /// /// - `shifts`: A slice of shifts corresponding to each dimension; /// supports negative values and wraps around. /// - `dims`: A slice of dimensions to roll; supports negative indexing. /// /// ## Returns /// /// A new tensor with the specified dimensions rolled by the given shifts. pub fn roll(self, shifts: &[Shift], dims: &[Dim]) -> Self where Shift: AsIndex, Dim: AsIndex, { assert_eq!( dims.len(), shifts.len(), "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}", ); // This is a fair amount of complexity, which could be replaced // by a simple canonicalization of `dims` and wrapping of `shifts`. // The work is done here to ensure that any roll operation // which could be a no-op is a no-op; simplifying the accounting // needed by backend-specific implementations of the inner roll op. let item_count = dims.len(); let shape = self.shape(); // Accumulate the effective shifts for each dimension. let mut accumulated_shifts: Vec = vec![0; shape.len()]; for i in 0..item_count { let dim = dims[i].expect_dim_index(D); accumulated_shifts[dim] += shifts[i].as_index(); } // Do this after we've checked the validity of `dims` and `shifts`. if self.shape().num_elements() == 0 { // If the tensor is empty, return it as is. return self; } // Wrap the accumulated shifts, and filter out empty dimensions. let mut effective_dims: Vec = Vec::with_capacity(item_count); let mut effective_shifts: Vec = Vec::with_capacity(item_count); for dim in 0..shape.len() { // `wrap_index` should inline, and has a fast-exit path for zero shifts. let shift = wrap_index(accumulated_shifts[dim], shape[dim]); if shift == 0 { continue; } effective_dims.push(dim); effective_shifts.push(shift); } // If no shifts are needed, return the original tensor. if effective_shifts.is_empty() { return self; } // At this point: // - `dims` contains the effective dimensions to roll, in index order, // - `shifts` contains the effective usize shifts for each dimension. // - Every shift is non-zero, and less than the size of the corresponding dimension. self.unchecked_roll(&effective_shifts, &effective_dims) } /// `roll` internal implementation. /// /// ## Parameters /// /// - `shifts`: A slice of shifts corresponding to each dimension; /// must be non-empty, the same length as `dims`, and all ``1..``. /// - `dims`: A slice of dimensions to roll; must be non-empty; /// the same length as `shifts`, and must not contain repeats. /// /// ## Panics /// /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats. /// /// ## Returns /// /// A new tensor with the specified dimensions rolled by the given shifts. #[inline(always)] fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self { #[cfg(debug_assertions)] { assert!(!shifts.is_empty()); assert_eq!( shifts.len(), dims.len(), "Shifts and dimensions must align; found {} shifts and {} dims", shifts.len(), dims.len() ); let mut unique_dims = dims.to_vec(); unique_dims.dedup(); assert_eq!( unique_dims.len(), dims.len(), "Dimensions must not contain repeats; found {} unique dims and {} total dims", unique_dims.len(), dims.len() ) } let x = self.unchecked_roll_dim(shifts[0], dims[0]); if dims.len() == 1 { x } else { x.unchecked_roll(&shifts[1..], &dims[1..]) } } /// Returns a tensor containing the elements selected from the given slices. /// /// This method provides flexible tensor slicing with support for various range types, /// negative indices, and stepped slicing. The method accepts both single slices and /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns. /// /// # Arguments /// /// * `slices` - Can be: /// - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`) /// - An array of ranges (e.g., `[0..2, 1..4]`) /// - The [`s!`] macro output for advanced slicing with steps /// - a `&Vec` or `&[Slice]` /// /// # Behavior /// /// - Supports partial and full slicing in any number of dimensions /// - Handles negative indices by wrapping from the end (-1 is the last element) /// - Automatically clamps ranges that exceed tensor dimensions /// - Supports stepped slicing for selecting every nth element /// - Negative steps reverse the selection order /// /// # Panics /// /// - If the number of slices exceeds the tensor's dimensions /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step /// - If a step is zero /// /// # Examples /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, s}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Single dimension slicing - no brackets needed! /// let tensor = Tensor::::arange(0..10, &device); /// let slice = tensor.clone().slice(2..8); // Simple range /// assert_eq!(slice.into_data().to_vec::().unwrap(), vec![2, 3, 4, 5, 6, 7]); /// /// // Using s! macro for single dimension with step /// let slice = tensor.clone().slice(s![0..10;2]); // Every 2nd element /// assert_eq!(slice.into_data().to_vec::().unwrap(), vec![0, 2, 4, 6, 8]); /// /// // Reverse a dimension with negative step /// let slice = tensor.slice(s![..;-1]); // Reverse entire tensor /// assert_eq!(slice.into_data().to_vec::().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]); /// /// // Multi-dimensional slicing /// let tensor = Tensor::::ones(Shape::new([4, 6]), &device); /// /// // Array syntax for simple ranges /// let slice = tensor.clone().slice([1..3, 2..5]); /// assert_eq!(slice.dims(), [2, 3]); /// /// // Advanced multi-dimensional with s! macro /// let slice = tensor.clone().slice(s![0..4;2, ..;-1]); // Every 2nd row, reverse columns /// assert_eq!(slice.dims(), [2, 6]); /// /// // Complex 3D example with mixed slice types /// let tensor = Tensor::::ones(Shape::new([4, 6, 8]), &device); /// let slice = tensor.slice(s![1..3, ..;2, -3..]); // Rows 1-2, every 2nd col, last 3 depth /// assert_eq!(slice.dims(), [2, 3, 3]); /// /// // Using negative indices /// let tensor = Tensor::::ones(Shape::new([4, 6]), &device); /// let slice = tensor.slice(s![-2.., ..-1]); // Last 2 rows, all but last column /// assert_eq!(slice.dims(), [2, 5]); /// } /// ``` /// /// # See Also /// /// - [`s!`] - The recommended macro for creating complex slice specifications /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension /// /// [`s!`]: crate::s! pub fn slice(self, slices: S) -> Self where S: SliceArg, { let shape = self.shape(); let slices = slices.into_slices(&shape); // Validate slices check!(TensorCheck::slice::(&shape, &slices)); // Calculate output shape and check for empty slices let mut output_dims = shape.clone(); for (dim, slice) in slices.iter().enumerate() { output_dims[dim] = slice.output_size(shape[dim]); } // Return empty tensor if any dimension is 0 (empty slice) if output_dims.contains(&0) { return Self::empty(output_dims, &self.device()); } Self::new(K::slice(self.primitive, &slices)) } /// Assigns values to a slice of the tensor and returns the updated tensor. /// /// This method supports advanced slicing with steps, including negative steps for reverse /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro /// providing powerful syntax for complex patterns. /// /// # Arguments /// /// * `slices` - Slice specification (same format as `slice` method) /// * `values` - Tensor with values to assign (must match slice dimensions) /// /// # Panics /// /// - If slices exceed tensor dimensions /// - If values dimensions don't match the selected slice shape /// - If a step is zero /// /// # Examples /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, s}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Simple assignment to a sub-region /// let mut tensor = Tensor::::zeros([4, 6], &device); /// let values = Tensor::::ones([2, 3], &device); /// tensor = tensor.slice_assign([1..3, 2..5], values); /// // Now tensor[1..3, 2..5] contains ones /// /// // Single dimension assignment with step /// let mut tensor = Tensor::::zeros([10], &device); /// let values = Tensor::::ones([5], &device); /// tensor = tensor.slice_assign(s![0..10;2], values); /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] /// /// // Reverse assignment with negative step /// let mut tensor = Tensor::::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device); /// let values = Tensor::::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device); /// tensor = tensor.slice_assign(s![..;-1], values); /// // Assigns in reverse: [14, 13, 12, 11, 10] /// /// // Complex multi-dimensional assignment /// let mut tensor = Tensor::::zeros([4, 6, 8], &device); /// let values = Tensor::::ones([2, 3, 3], &device); /// tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values); /// // Assigns to every 2nd row, every 2nd column, last 3 in depth /// /// // Mixed syntax example /// let mut tensor = Tensor::::zeros([8, 8], &device); /// let pattern = Tensor::::ones([4, 4], &device); /// tensor = tensor.slice_assign(s![..;2, ..;2], pattern); /// // Creates a checkerboard pattern with ones /// } /// ``` /// /// # See Also /// /// - [`s!`] - The recommended macro for creating complex slice specifications /// - [`slice`](Self::slice) - Extract a slice from a tensor /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value /// /// [`s!`]: crate::s! pub fn slice_assign(self, slices: S, values: Self) -> Self where S: SliceArg, { let shape = self.shape(); let slices = slices.into_slices(&shape); // Check if any slice produces 0 elements (empty assignment). // Empty assignments are no-ops and would cause issues in backend implementations. let is_empty_assignment = slices .iter() .enumerate() .any(|(i, slice)| slice.output_size(shape[i]) == 0); if is_empty_assignment { return self; } check!(TensorCheck::slice_assign::( &shape, &values.shape(), &slices )); Self::new(K::slice_assign(self.primitive, &slices, values.primitive)) } /// Fills a slice of the tensor with a constant value and returns the updated tensor. /// /// Like other slice methods, accepts both single slices and arrays. However, this method /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign) /// with a constant tensor for stepped patterns. /// /// # Arguments /// /// * `slices` - Slice specification (same format as `slice` method, but no steps) /// * `value` - The value to fill the slice with /// /// # Panics /// /// - If slices exceed tensor dimensions /// - If any slice has a step != 1 (not yet supported) /// /// # Examples /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, s}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Simple fill for a single dimension /// let mut tensor = Tensor::::zeros([10], &device); /// tensor = tensor.slice_fill(2..5, 1.0); /// // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0] /// /// // Multi-dimensional fill /// let mut tensor = Tensor::::zeros([4, 6], &device); /// tensor = tensor.slice_fill([1..3, 2..5], -1.0); /// // Fills the rectangle at rows 1-2, columns 2-4 with -1 /// /// // Using negative indices /// let mut tensor = Tensor::::zeros([10], &device); /// tensor = tensor.slice_fill(-3.., 2.0); /// // Fills the last 3 elements with 2.0 /// /// // Complex multi-dimensional example /// let mut tensor = Tensor::::ones([4, 6, 8], &device); /// tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0); /// // Sets rows 1-2, all columns, last 2 in depth to 0 /// /// // Stepped slicing is supported /// let mut tensor = Tensor::::zeros([10], &device); /// tensor = tensor.slice_fill(s![0..10;2], 1.0); /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] /// } /// ``` /// /// # See Also /// /// - [`s!`] - The macro for creating slice specifications with steps /// - [`slice`](Self::slice) - Extract a slice from a tensor /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice /// /// [`s!`]: crate::s! pub fn slice_fill(self, slices: S, value: E) -> Self where S: SliceArg, { let shape = self.shape(); let slices = slices.into_slices(&shape); check!(TensorCheck::slice::(&shape, &slices)); let slice_shape = shape.slice(&slices).unwrap(); let value = Tensor::::from_data_dtype( [value.elem::()], &self.device(), self.dtype(), ); let value = value.expand(slice_shape); self.slice_assign(&slices, value) } /// Returns a new tensor with the specified dimension sliced. /// /// # Arguments /// /// * `dim`: The dimension to slice. /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`), /// slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into`. /// /// # Returns /// /// A new tensor with the specified dimension sliced. /// /// # Panics /// /// If the slice is out of bounds for the specified dimension. /// /// # Examples /// /// ```rust /// # use burn_tensor::{Tensor, s}; /// # use burn_tensor::backend::Backend; /// # /// # fn example() { /// # let device = B::Device::default(); /// let tensor = Tensor::::zeros([3, 4, 5], &device); /// /// // Simple range slicing /// let sliced = tensor.clone().slice_dim(1, 1..3); /// assert_eq!(sliced.shape().as_slice(), [3, 2, 5]); /// /// // Slicing with step - take every 2nd element /// let sliced = tensor.clone().slice_dim(2, s![0..5;2]); /// assert_eq!(sliced.shape().as_slice(), [3, 4, 3]); // Takes indices 0, 2, 4 /// /// // Reverse slicing with negative step /// let sliced = tensor.clone().slice_dim(1, s![..;-1]); /// assert_eq!(sliced.shape().as_slice(), [3, 4, 5]); // Reverses dimension 1 /// /// // Select from index 2 with step 3 /// let sliced = tensor.clone().slice_dim(0, s![2..;3]); /// assert_eq!(sliced.shape().as_slice(), [1, 4, 5]); // Takes only index 2 /// /// // Select single index (reduces dimension to size 1) /// let sliced = tensor.slice_dim(0, 1); /// assert_eq!(sliced.shape().as_slice(), [1, 4, 5]); /// # } /// ``` /// /// # See Also /// /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously /// - [`s!`] - The macro for creating complex slice specifications /// /// [`s!`]: crate::s! pub fn slice_dim(self, dim: usize, slice: S) -> Self where S: Into, { check!(TensorCheck::check_dim::(dim)); let slice: Slice = slice.into(); let mut slices = vec![Slice::full(); D]; slices[dim] = slice; self.slice(&slices) } /// Returns the device of the current tensor. pub fn device(&self) -> B::Device { K::device(&self.primitive) } /// Move the tensor to the given device. pub fn to_device(self, device: &B::Device) -> Self { Self::new(K::to_device(self.primitive, device)) } /// Select tensor elements along the given dimension corresponding to the given indices. /// /// # Arguments /// /// * `dim` - The dimension to select from. Supports negative indexing. /// * `indices` - The indices of the elements to select. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Int}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device); /// let indices = Tensor::::from_data([0], &device); /// let tensor = tensor.select(0, indices); /// println!("{tensor}"); /// // [[1.0, -2.0, 3.0]] /// } /// ``` pub fn select(self, dim: impl AsIndex, indices: Tensor) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::select::(dim)); Self::new(K::select(self.primitive, dim, indices.primitive)) } /// Assign the selected elements along the given dimension corresponding to the given indices /// from the value tensor to the original tensor using sum reduction. /// /// # Note /// For booleans, the sum operator is logical or. /// /// # Arguments /// /// * `dim` - The dimension along which to select. Supports negative indexing. /// * `indices` - The indices to select from the tensor. /// * `values` - The values to assign to the selected indices. /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add). /// /// # Example /// /// Example using a 3D tensor: /// /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0` /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1` /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2` /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)` /// /// # Warning /// /// Not all backends have runtime bound checks for the indices, so make sure they are valid. /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking. pub fn select_assign( self, dim: impl AsIndex, indices: Tensor, values: Tensor, update: IndexingUpdateOp, ) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::select_assign::( dim, &indices.shape(), &values.shape() )); Self::new(K::select_assign( self.primitive, dim, indices.primitive, values.primitive, update, )) } /// Update the given tensor with the value tensor where the mask is true. /// /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of /// a scalar. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, Bool}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let mask = Tensor::::from_data([[true, false, true], [false, true, false]], &device); /// let value = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor.mask_where(mask, value); /// println!("{tensor}"); /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]] /// } /// ``` pub fn mask_where(self, mask: Tensor, value: Self) -> Self { Self::new(K::mask_where( self.primitive, mask.primitive, value.primitive, )) } /// Update the given tensor with the value where the mask is true. /// /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of /// a tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, Bool}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let mask = Tensor::::from_data([[true, false, true], [false, true, false]], &device); /// let tensor = tensor.mask_fill(mask, 3.0); /// println!("{tensor}"); /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]] /// } /// ``` pub fn mask_fill(self, mask: Tensor, value: E) -> Self { let value = Scalar::new(value, &self.dtype()); Self::new(K::mask_fill(self.primitive, mask.primitive, value)) } /// Gather tensor elements corresponding to the given indices from the specified dim. /// /// Example using a 3D tensor: /// /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0` /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1` /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2` /// /// # Notes /// /// The index tensor should have the same shape as the original tensor except for the dim /// specified. /// /// # Warning /// Not all backends have runtime bound checks for the indices, so make sure the they are valid. /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking. pub fn gather(self, dim: usize, indices: Tensor) -> Self { check!(TensorCheck::gather::( dim, &self.shape(), &indices.shape() )); Self::new(K::gather(dim, self.primitive, indices.primitive)) } /// Assign the gathered elements corresponding to the given indices along the specified dimension /// from the value tensor to the original tensor using sum reduction. /// /// Example using a 3D tensor: /// /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0` /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1` /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2` /// /// # Arguments /// * `dim` - The axis along which to scatter elements. /// * `indices` - The indices of the elements to scatter. /// * `values` - The values to scatter into the tensor. /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add). /// /// # Notes /// /// The index tensor should have the same shape as the original tensor except for the specified /// dimension. The value and index tensors should have the same shape. /// /// Other references to the input tensor will not be modified by this operation. /// /// # Warning /// Not all backends have runtime bound checks for the indices, so make sure the they are valid. /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking. pub fn scatter( self, dim: usize, indices: Tensor, values: Self, update: IndexingUpdateOp, ) -> Self { check!(TensorCheck::scatter::( dim, &self.shape(), &indices.shape(), &values.shape() )); Self::new(K::scatter( dim, self.primitive, indices.primitive, values.primitive, update, )) } /// Converts the data of the current tensor. /// /// # Note /// /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple /// tensors at once. This may improve laziness, especially if executed on a different /// thread in native environments. pub fn into_data(self) -> TensorData { self.try_into_data().expect( "Error while reading data: use `try_into_data` instead to catch the error at runtime", ) } /// Converts the data of the current tensor and returns any error that might have occurred since the /// last time the device was synchronized. /// /// # Note /// /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple /// tensors at once. This may improve laziness, especially if executed on a different /// thread in native environments. pub fn try_into_data(self) -> Result { crate::try_read_sync(self.into_data_async()).expect( "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using into_data_async instead.", ) } /// Converts the data of the current tensor. /// /// # Note /// /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple /// tensors at once. This may improve laziness, especially if executed on a different /// thread in native environments. pub fn to_data(&self) -> TensorData { self.clone().into_data() } /// Returns the data of the current tensor. pub async fn into_data_async(self) -> Result { K::into_data_async(self.primitive).await } /// Returns the data of the current tensor. pub async fn to_data_async(&self) -> Result { self.clone().into_data_async().await } /// Create a tensor from the given data on the given device. pub fn from_data(data: T, device: &B::Device) -> Self where T: Into, { let data = data.into(); check!(TensorCheck::creation_ops::( "From Data", data.shape.as_slice() )); Self::new(K::from_data(data, device)) } /// Create a tensor from the given data on the given device enforcing the given data type. pub fn from_data_dtype(data: T, device: &B::Device, dtype: DType) -> Self where T: Into, { let data = data.into(); check!(TensorCheck::creation_ops::( "From Data", data.shape.as_slice() )); Self::new(K::from_data_dtype(data, device, dtype)) } /// Repeat the tensor along the given dimension. /// /// The output tensor has the same shape, except along the given dimension. /// /// # Arguments /// - `dim`: The dimension to repeat. /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor. /// /// # Returns /// /// A new tensor with the given dimension repeated `times` times. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [3, 2] /// let tensor = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// /// // Repeat the tensor along the dimension 0 twice. /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]] /// // The resulting tensor will have dimensions [6, 2]. /// let repeated = tensor.repeat_dim(0, 2); /// println!("{repeated}"); /// } /// ``` pub fn repeat_dim(self, dim: usize, times: usize) -> Self { if times > 0 { Self::new(K::repeat_dim(self.primitive, dim, times)) } else { let shape = self.shape().repeat(dim, times).unwrap(); Self::empty(shape, &self.device()) } } /// Repeat the tensor along the given dimensions. /// # Arguments /// - `sizes`: Borrowed slice of the number of times to repeat each dimension. /// /// # Returns /// /// A new tensor with the given dimensions repeated `times` times. /// /// # Panics /// /// If `sizes` contains more elements than the number of dimensions. /// /// # Example /// /// ```rust /// /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [3, 2] /// let tensor = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once. /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]] /// // The resulting tensor will have dimensions [6, 2]. /// let repeated = tensor.repeat(&[2, 1]); /// } /// ``` pub fn repeat(self, sizes: &[usize]) -> Self { if sizes.contains(&0) { let mut shape = self.shape(); for (dim, ×) in sizes.iter().enumerate() { shape = shape.repeat(dim, times).unwrap(); } return Self::empty(shape, &self.device()); } let mut tensor = self; for (dim, ×) in sizes.iter().enumerate() { if times > 1 { tensor = tensor.repeat_dim(dim, times); } } tensor } /// Applies element-wise equal comparison. /// /// # Returns /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let t1 = Tensor::::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// let t2 = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// // Compare the elements of the two 2D tensors with dimensions [3, 2]. /// // [[false, true], [true, true], [true, true]] /// let equal = t1.equal(t2); /// println!("{equal}"); /// } /// ``` pub fn equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); Tensor::new(K::equal(self.primitive, other.primitive)) } /// Applies element-wise non-equality comparison. /// /// # Returns /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let t1 = Tensor::::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// let t2 = Tensor::::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device); /// // Compare the elements of the two 2D tensors for inequality. /// // [[true, false], [false, false], [false, false]] /// let not_equal = t1.not_equal(t2); /// println!("{not_equal}"); /// } /// ``` pub fn not_equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); Tensor::new(K::not_equal(self.primitive, other.primitive)) } /// Applies element wise equal comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.equal_elem(3.0); /// println!("{tensor}"); /// // [[false, false, true], [false, false, false]] /// } /// ``` pub fn equal_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::equal_elem(self.primitive, other)) } /// Applies element wise non-equality comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.not_equal_elem(3.0); /// println!("{tensor}"); /// // [[true, true, false], [true, true, true]] /// } /// ``` pub fn not_equal_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::not_equal_elem(self.primitive, other)) } /// Concatenates all tensors into a new one along the given dimension. /// /// # Panics /// /// - If `dim` is higher than the rank. /// - If `tensors` is an empty vector. /// - If all tensors don't have the same shape (the dimension `dim` is ignored). /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let t1 = Tensor::::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device); /// let t2 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); /// /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1. /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]] /// // The resulting tensor will have shape [2, 7]. /// let concat = Tensor::cat(vec![t1, t2], 1); /// println!("{concat}"); /// } /// ``` pub fn cat(tensors: Vec, dim: usize) -> Self { check!(TensorCheck::cat(&tensors, dim)); // Filter out tensors with size 0 along the concatenation dimension. // Empty tensors don't contribute to the output and would cause issues // in backend implementations (e.g., division by zero in slice_assign). // Safety: TensorCheck::cat ensures tensors is non-empty let first_tensor = tensors.first().unwrap(); let device = first_tensor.device(); let mut shape = first_tensor.shape(); let non_empty_primitives: Vec<_> = tensors .into_iter() .filter(|t| t.shape()[dim] > 0) .map(|t| t.primitive) .collect(); // If all tensors were empty, return an empty tensor with size 0 on concat dim if non_empty_primitives.is_empty() { shape[dim] = 0; return Self::empty(shape, &device); } Self::new(K::cat(non_empty_primitives, dim)) } /// Concatenates all tensors into a new one along a new dimension. /// /// # Panics /// /// - If all tensors don't have the same shape. /// - If given dimension is not with range of 0..D2 /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let t1 = Tensor::::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device); /// let t2 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); /// let t3 = Tensor::::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device); /// /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0. /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]] /// // The resulting tensor will have shape [3, 2, 3]. /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0); /// println!("{stacked}"); /// } /// ``` pub fn stack(tensors: Vec>, dim: usize) -> Tensor { check!(TensorCheck::stack::(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); Tensor::::cat(tensors, dim) } /// Iterate over slices of tensors alongside a given dimension. /// /// # Panics /// /// If given dimension is greater than or equal to tensor rank. /// /// # Returns /// /// A tensor iterator. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device); /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0. /// let iter = tensor.iter_dim(0); /// for (i,tensor) in iter.enumerate() { /// println!("Tensor {}: {}", i, tensor); /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... } /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... } /// } /// } /// ``` pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); DimIter::new(self, dim) } /// Returns a new tensor with the given dimension narrowed to the given range. /// /// # Panics /// /// - If the dimension is greater than the number of dimensions of the tensor. /// - If the given range exceeds the number of elements on the given dimension. /// /// # Returns /// /// A new tensor with the given dimension narrowed to the given range. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [4, 3] /// let tensor = Tensor::::from_data( /// [ /// [3.0, 4.9, 2.0], /// [2.0, 1.9, 3.0], /// [6.0, 1.5, 7.0], /// [3.0, 4.9, 9.0], /// ], /// &device, /// ); /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1. /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]] /// // The resulting tensor will have dimensions [3, 3]. /// let narrowed = tensor.narrow(0, 1, 3); /// println!("{narrowed}"); /// } /// ``` pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); check!(TensorCheck::narrow(&self, dim, start, length)); let dims = self.dims(); let ranges: [Range; D] = dims .iter() .enumerate() .map(|(i, d)| { if i == dim { start..(start + length) } else { 0..*d } }) .collect::>() .try_into() .unwrap(); Self::slice(self, ranges) } /// Attempts to split the tensor into a specified number of chunks along a given dimension. /// May return less chunks than requested if the tensor size is not divisible by the number of chunks. /// /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size. /// Otherwise all chunks will be of equal size except for the last one. /// /// # Panics /// /// If the dimension is greater than the number of dimensions of the tensor. /// /// # Returns /// A vector of tensors. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [4, 3] /// let tensor = Tensor::::from_data( /// [ /// [3.0, 4.9, 2.0], /// [2.0, 1.9, 3.0], /// [6.0, 1.5, 7.0], /// [3.0, 4.9, 9.0], /// ], /// &device, /// ); /// // Split the tensor along the dimension 1 into 2 chunks. /// // The first chuck will have shape [4, 2]: /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]] /// // The second chunk will have shape [4, 1]: /// // [[2.0], [3.0], [7.0], [9.0]] /// let chunks = tensor.chunk(2, 1); /// println!("{chunks:?}"); /// } /// ``` pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); let size = self.shape()[dim]; if size < chunks { return (0..size) .map(|i| Self::narrow(self.clone(), dim, i, 1)) .collect(); } let mut tensors = Vec::with_capacity(chunks); let mut sum_chunk_size = 0; if size.is_multiple_of(chunks) { let chunk_size = size / chunks; for _ in 0..chunks { tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size)); sum_chunk_size += chunk_size; } } else { let chunk_size = (size / chunks) + 1; // assumes not divisible for _ in 0..chunks - 1 { tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size)); sum_chunk_size += chunk_size; } let remainder = size % chunk_size; tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder)); } tensors } /// Splits the tensor into chunks of a specified size along a given dimension. /// Each chunk is a view of the original tensor. /// /// If the tensor size along the given dimension is not divisible by `split_size`, /// then the last chunk will be smaller. /// /// # Panics /// /// If the specified dimension to split along is greater than the number of dimensions of the tensor. /// /// # Returns /// /// A vector of tensors. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 1D tensor with 5 elements /// let tensor = Tensor::::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device); /// // Split the tensor into chunks of size 2 along dimension 0 /// let chunks = tensor.split(2, 0); /// // The result is a vector of tensors: /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])] /// println!("{:?}", chunks); /// } /// ``` pub fn split(self, split_size: usize, dim: usize) -> Vec { check!(TensorCheck::split::(&self.shape(), split_size, dim)); let size = self.shape()[dim]; let mut tensors = Vec::new(); let mut start = 0; while start < size { let length = usize::min(split_size, size - start); tensors.push(Self::narrow(self.clone(), dim, start, length)); start += length; } tensors } /// Splits the tensor into chunks with the specified sizes along a given dimension. /// Each chunk is a view of the original tensor. /// /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes /// in `split_sizes` must equal the size of the tensor along the specified dimension. /// /// # Panics /// /// If the specified dimension to split along is greater than the number of dimensions of the tensor or /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`. /// /// # Returns /// /// A vector of tensors. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 1D tensor with 5 elements /// let tensor = Tensor::::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device); /// // Split the tensor into chunks with sizes [2, 3] along dimension 0 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0); /// // The result is a vector of tensors: /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])] /// println!("{:?}", chunks); /// } /// ``` pub fn split_with_sizes(self, split_sizes: Vec, dim: usize) -> Vec { check!(TensorCheck::split_with_sizes::( &self.shape(), &split_sizes, dim )); let mut tensors = Vec::new(); let mut start = 0; for length in split_sizes { if length == 0 { continue; } tensors.push(Self::narrow(self.clone(), dim, start, length)); start += length; } tensors } /// Tests if any element in the `tensor` evaluates to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported. /// /// # Returns /// /// A boolean tensor `Tensor` containing a single element, True if any element in the input tensor /// evaluates to True, False otherwise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_data([[true,false,true],[false,true,false]], &device); /// let tensor_two = Tensor::::from_data([[false,false,false],[false,false,false]], &device); /// /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True. /// let any_tensor = tensor.any(); /// println!("{}", any_tensor); /// // Tensor { data: [true], ... } /// /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True. /// let any_tensor_two = tensor_two.any(); /// println!("{}", any_tensor_two); /// // Tensor { data: [false], ... } /// } /// ``` pub fn any(self) -> Tensor { Tensor::new(K::any(self.primitive)) } /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same shape as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = /// Tensor::::from_data([[true, false, false], [false, true, false]], &device); /// // Check if any element in the tensor evaluates to True along the dimension 1. /// // [[true], [true]], /// let any_dim = tensor.clone().any_dim(1); /// println!("{any_dim}"); /// } /// ``` pub fn any_dim(self, dim: usize) -> Tensor { Tensor::new(K::any_dim(self.primitive, dim)) } /// Tests if all elements in the `tensor` evaluate to True. /// /// # Arguments /// /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported. /// /// # Returns /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = /// Tensor::::from_data([[true, false, true], [true, true, true]], &device); /// // Check if all elements in the tensor evaluate to True (which is not the case). /// // [false] /// let all = tensor.all(); /// println!("{all}"); /// } /// ``` pub fn all(self) -> Tensor { Tensor::new(K::all(self.primitive)) } /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// /// # Arguments /// /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported. /// * `dim` - The axis along which to test. /// /// # Returns /// /// A boolean tensor `Tensor` with the same shape as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = /// Tensor::::from_data([[true, true, false], [true, true, true]], &device); /// // Check if all elements in the tensor evaluate to True along the dimension 1. /// // [[true, true, false]] /// let all_dim = tensor.clone().all_dim(0); /// println!("{all_dim}"); /// } /// ``` pub fn all_dim(self, dim: usize) -> Tensor { Tensor::new(K::all_dim(self.primitive, dim)) } /// Convert the tensor into a scalar. /// /// # Panics /// /// - If the tensor doesn't have one element. /// - If the backend fails to read the tensor data synchronously. /// /// # Returns /// /// The scalar value of the tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_data([[3.0]], &device); /// // Convert the tensor with a single element into a scalar. /// let scalar = tensor.into_scalar(); /// println!("{scalar}"); /// } /// ``` pub fn into_scalar(self) -> K::Elem { crate::try_read_sync(self.into_scalar_async()) .expect( "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. Try into_scalar_async instead.", ) .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime") } /// Convert the tensor into a scalar and returns any error that might have occurred since the /// last time the device was synchronized. /// /// # Panics /// /// - If the tensor doesn't have one element. /// - If the backend fails to read the tensor data synchronously. /// /// # Returns /// /// The scalar value of the tensor. pub fn try_into_scalar(self) -> Result { crate::try_read_sync(self.into_scalar_async()).expect( "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. Try into_scalar_async instead.", ) } /// Convert the tensor into a scalar. /// /// # Panics /// /// If the tensor doesn't have one element. pub async fn into_scalar_async(self) -> Result { check!(TensorCheck::into_scalar::(&self.shape())); Ok(self.into_data_async().await?.iter().next().unwrap()) } /// Broadcast the tensor to the given shape. /// /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size /// (which can be inferred with `-1`). /// /// # Arguments /// /// * `shape` - The shape to broadcast the tensor to. /// Can contain -1 for dimensions that should be inferred. /// The number of elements in the shape must be greater or equal as /// the number of dimensions of the tensor. /// /// # Panics /// /// If the tensor cannot be broadcasted to the given shape. /// /// # Returns /// /// A new tensor with the given shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// // Create a 2D tensor with dimensions [3, 1] /// let tensor = Tensor::::from_data([[1.], [2.], [3.]], &device); /// // Expand the tensor to a new shape [3, 4] /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]] /// let expanded = tensor.expand([3, 4]); /// println!("{}", expanded); /// } /// ``` pub fn expand>(self, shape: S) -> Tensor { let shape = shape.into_shape(&self.shape()); check!(TensorCheck::expand::( "expand", &self.shape(), &shape, )); Tensor::::new(K::expand(self.primitive, shape)) } /// Unfold windows along a dimension. /// /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; /// where windows are advanced by `step` at each index. /// /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. /// /// The new view will have the unfolded dimension replaced by two dimensions; /// one in the position of the original dimension, with size equal to the number of windows, /// and one appended to the right-most position, with size equal to `size`. /// /// # Warning /// /// For the `ndarray` backend; this is not a view but a copy /// with duplicated data. /// /// # Arguments /// /// * `dim` - the dimension to unfold. /// * `size` - the size of each unfolded window. /// * `step` - the step between each window. /// /// # Returns /// /// A tensor view with the shape ``[pre=..., windows, post=..., size]``. pub fn unfold( self, dim: I, size: usize, step: usize, ) -> Tensor { let dim = dim.expect_dim_index(D); check!(TensorCheck::unfold::( "unfold", &self.shape(), dim, size, step, )); Tensor::::new(K::unfold(self.primitive, dim, size, step)) } } /// Iterator given by (Tensor::iter_dim). pub struct DimIter where B: Backend, K: BasicOps, { start: usize, end: usize, dim: usize, ranges: [Range; D], tensor: Tensor, } impl> Iterator for DimIter { type Item = Tensor; fn next(&mut self) -> Option { if self.start >= self.end { return None; } let mut ranges = self.ranges.clone(); ranges[self.dim] = self.start..(self.start + 1); let slice = self.tensor.clone().slice(ranges); self.start += 1; Some(slice) } } impl> DoubleEndedIterator for DimIter { fn next_back(&mut self) -> Option { if self.start >= self.end { return None; } let mut ranges = self.ranges.clone(); ranges[self.dim] = (self.end - 1)..self.end; let slice = self.tensor.clone().slice(ranges); self.end = self.end.saturating_sub(1); Some(slice) } } impl> DimIter { fn new(tensor: Tensor, dim: usize) -> Self { let dims = tensor.dims(); let ranges = dims .iter() .map(|&dim| 0..dim) .collect::>>(); let ranges: [Range; D] = ranges.try_into().unwrap(); Self { end: dims[dim], ranges, start: 0, dim, tensor, } } } impl Tensor where B: Backend, K: BasicOps, >::Elem: Debug, { #[inline] fn push_newline_indent(acc: &mut String, indent: usize) { acc.push('\n'); for _ in 0..indent { acc.push(' '); } } fn fmt_inner_tensor( &self, acc: &mut String, depth: usize, multi_index: &mut [usize], range: (usize, usize), precision: Option, ) { let (start, end) = range; for i in start..end { if i > 0 { acc.push_str(", "); } multi_index[depth] = i; let range: [Range; D] = core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1); let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async()); if let Some(Ok(data)) = data { let elem = data.iter::<>::Elem>().next().unwrap(); match (precision, K::name()) { (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")), (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())), _ => acc.push_str(&format!("{elem:?}")), } } else { acc.push_str(""); } } } fn fmt_outer_tensor( &self, acc: &mut String, depth: usize, multi_index: &mut [usize], print_options: &PrintOptions, summarize: bool, range: (usize, usize), ) { let (start, end) = range; for i in start..end { if i > start { acc.push(','); Self::push_newline_indent(acc, depth + 1); } acc.push('['); multi_index[depth] = i; self.display_recursive(acc, depth + 1, multi_index, print_options, summarize); acc.push(']'); } } /// Recursively formats the tensor data for display and appends it to the provided accumulator string. /// /// This function is designed to work with tensors of any dimensionality. /// It traverses the tensor dimensions recursively, converting the elements /// to strings and appending them to the accumulator string with the /// appropriate formatting. /// /// # Arguments /// /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output. /// * `depth` - The current depth of the tensor dimensions being processed. /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension. fn display_recursive( &self, acc: &mut String, depth: usize, multi_index: &mut [usize], print_options: &PrintOptions, summarize: bool, ) { let edge_items = print_options.edge_items; if depth == 0 { acc.push('['); } if depth == self.dims().len() - 1 { // if we are at the innermost dimension, just push its elements into the accumulator if summarize && self.dims()[depth] > 2 * edge_items { // print the starting `edge_items` elements self.fmt_inner_tensor( acc, depth, multi_index, (0, edge_items), print_options.precision, ); acc.push_str(", ..."); // print the last `edge_items` elements self.fmt_inner_tensor( acc, depth, multi_index, (self.dims()[depth] - edge_items, self.dims()[depth]), print_options.precision, ); } else { // print all the elements self.fmt_inner_tensor( acc, depth, multi_index, (0, self.dims()[depth]), print_options.precision, ); } } else { // otherwise, iterate through the current dimension and recursively display the inner tensors if summarize && self.dims()[depth] > 2 * edge_items { self.fmt_outer_tensor( acc, depth, multi_index, print_options, summarize, (0, edge_items), ); acc.push(','); Self::push_newline_indent(acc, depth + 1); acc.push_str("..."); Self::push_newline_indent(acc, depth + 1); self.fmt_outer_tensor( acc, depth, multi_index, print_options, summarize, (self.dims()[depth] - edge_items, self.dims()[depth]), ); } else { self.fmt_outer_tensor( acc, depth, multi_index, print_options, summarize, (0, self.dims()[depth]), ); } } if depth == 0 { acc.push(']'); } } } #[derive(Clone, Debug)] /// Options for Tensor pretty printing pub struct PrintOptions { /// number of elements to start summarizing tensor pub threshold: usize, /// number of starting elements and ending elements to display pub edge_items: usize, /// Precision for floating point numbers pub precision: Option, } static PRINT_OPTS: RwLock = RwLock::new(PrintOptions::const_default()); impl PrintOptions { /// Print options with default values pub const fn const_default() -> Self { Self { threshold: 1000, edge_items: 3, precision: None, } } } impl Default for PrintOptions { fn default() -> Self { Self::const_default() } } /// Set print options pub fn set_print_options(options: PrintOptions) { let mut print_opts = PRINT_OPTS.write().unwrap(); *print_opts = options; } /// Pretty print tensors impl core::fmt::Display for Tensor where B: Backend, B::IntElem: core::fmt::Display, K: BasicOps, >::Elem: Debug, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { writeln!(f, "Tensor {{")?; { // Do not lock the mutex for the whole function let mut po = { PRINT_OPTS.read().unwrap().clone() }; // Override the precision if it is set from the formatter // This will be possible when the tensor is printed using the `{:.*}` syntax if let Some(precision) = f.precision() { po.precision = Some(precision); } let mut acc = String::new(); let mut multi_index = vec![0; D]; let summarize = self.shape().num_elements() > po.threshold; self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize); writeln!(f, " data:")?; write!(f, "{acc}")?; writeln!(f, ",")?; } writeln!(f, " shape: {:?},", self.dims())?; writeln!(f, " device: {:?},", self.device())?; writeln!(f, " backend: {:?},", B::name(&self.device()))?; writeln!(f, " kind: {:?},", K::name())?; let dtype = self.primitive.dtype(); writeln!(f, " dtype: {:?},", dtype.name())?; write!(f, "}}") } } /// Trait used for movedim arguments pub trait MovedimArgs { /// Converts into a set of dimensions `Vec` for the `tensor.movedim()` function fn into_dim_vec(self) -> Vec; } impl MovedimArgs for Vec { fn into_dim_vec(self) -> Vec { let set = self .iter() .map(|&dim| { if dim < 0 { (D as i32 + dim) as usize } else { dim as usize } }) .collect::>(); check!(TensorCheck::movedim_args_vec::(&set)); set } } impl MovedimArgs for Vec { fn into_dim_vec(self) -> Vec { check!(TensorCheck::movedim_args_vec::(&self)); self } } impl MovedimArgs for usize { #[allow(clippy::vec_init_then_push)] fn into_dim_vec(self) -> Vec { check!(TensorCheck::movedim_args_usize::(self)); let mut set = Vec::with_capacity(1); set.push(self); set } } impl MovedimArgs for i32 { #[allow(clippy::vec_init_then_push)] fn into_dim_vec(self) -> Vec { check!(TensorCheck::movedim_args_i32::(self)); let dim = if self < 0 { (D as i32 + self) as usize } else { self as usize }; let mut set = Vec::with_capacity(1); set.push(dim); set } } /// Trait used for reshape arguments. pub trait ReshapeArgs: Debug { /// Converts to a shape. fn into_shape(self, source: Shape) -> Shape; } impl ReshapeArgs for [I; D2] { fn into_shape(self, source: Shape) -> Shape { unwrap_shape_reshape(source.reshape(self)) } } impl ReshapeArgs for Shape { fn into_shape(self, source: Shape) -> Shape { unwrap_shape_reshape(source.reshape(self)) } } /// Trait used for broadcast arguments. pub trait BroadcastArgs { /// Converts to a shape. fn into_shape(self, shape: &Shape) -> Shape; } impl BroadcastArgs for Shape { fn into_shape(self, _shape: &Shape) -> Shape { self } } impl BroadcastArgs for [E; D2] { // Passing -1 as the size for a dimension means not changing the size of that dimension. fn into_shape(self, shape: &Shape) -> Shape { if self.len() < shape.num_dims() { panic!("Broadcast arguments must be greater than the number of dimensions"); } // Zip the two shapes in reverse order and replace -1 with the actual dimension value. let new_shape: Vec<_> = self .iter() .rev() .map(|x| { let primitive = x.as_index(); if primitive < -1 || primitive == 0 { panic!("Broadcast arguments must be positive or -1"); } primitive }) .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s .map(|(x, &y)| if x == -1 { y } else { x as usize }) .collect::>() .into_iter() .rev() .collect(); if new_shape.contains(&0) { panic!("Cannot substitute -1 for a non-existing dimension"); } let new_shape: [usize; D2] = new_shape.try_into().unwrap(); Shape::from(new_shape) } } impl Serialize for Tensor where B: Backend, K: BasicOps, K::Elem: Debug + Copy + Serialize, { fn serialize(&self, serializer: S) -> Result { let data = self.to_data(); data.serialize(serializer) } } impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor where B: Backend, K: BasicOps, K::Elem: Debug + Copy + Deserialize<'de>, { fn deserialize>(deserializer: De) -> Result { let tensor = Tensor::from_data( TensorData::deserialize(deserializer)?, &::default(), ); Ok(tensor) } } #[cfg(test)] mod tests { use burn_std::SliceOps; use crate::{Shape, s}; #[test] fn slice_range_single_dim_leading() { let shape = Shape::new([8, 4]); // Half-open range let slices = shape.clone().into_slices([0..5]); assert_eq!(slices[0].to_range(8), 0..5); let slices = shape.clone().into_slices([-3..-1]); assert_eq!(slices[0].to_range(8), 5..7); // Inclusive range let slices = shape.clone().into_slices([0..=4]); assert_eq!(slices[0].to_range(8), 0..5); let slices = shape.clone().into_slices([-2..=-1]); assert_eq!(slices[0].to_range(8), 6..8); // Unbounded start let slices = shape.clone().into_slices([..3]); assert_eq!(slices[0].to_range(8), 0..3); let slices = shape.clone().into_slices([..-5]); assert_eq!(slices[0].to_range(8), 0..3); // Unbounded end let slices = shape.clone().into_slices([5..]); assert_eq!(slices[0].to_range(8), 5..8); let slices = shape.clone().into_slices([-3..]); assert_eq!(slices[0].to_range(8), 5..8); // Full range let slices = shape.into_slices([..]); assert_eq!(slices[0].to_range(8), 0..8); } #[test] fn test_negative_slice_indices() { use crate::Slice; // Test negative indices conversion let slice: Slice = (-3..-1).into(); assert_eq!(slice.start, -3); assert_eq!(slice.end, Some(-1)); // Test to_range conversion with size 8 let range = slice.to_range(8); assert_eq!(range, 5..7); // Test with shape slice let shape = Shape::new([8, 4]); let result = shape.clone().into_slices([-3..-1]); assert_eq!(result[0].to_range(8), 5..7); // Test more negative index cases let slice2: Slice = (-5..).into(); assert_eq!(slice2.to_range(10), 5..10); let slice3: Slice = (..-2).into(); assert_eq!(slice3.to_range(10), 0..8); // Test with s! macro - single dimension returns Slice directly let slice4 = s![-3..-1]; assert_eq!(slice4.start, -3); assert_eq!(slice4.end, Some(-1)); } #[test] fn slice_range_multi_dim() { let shape = Shape::new([8, 4]); // Multiple ways to provide ranges let slices = shape.clone().into_slices([0..5, 0..4]); assert_eq!(slices[0].to_range(8), 0..5); assert_eq!(slices[1].to_range(4), 0..4); let slices = shape.clone().into_slices([0.., 0..]); assert_eq!(slices[0].to_range(8), 0..8); assert_eq!(slices[1].to_range(4), 0..4); let slices = shape.clone().into_slices([0..=7, 0..=3]); assert_eq!(slices[0].to_range(8), 0..8); assert_eq!(slices[1].to_range(4), 0..4); let slices = shape.clone().into_slices([0..5, 0..3]); assert_eq!(slices[0].to_range(8), 0..5); assert_eq!(slices[1].to_range(4), 0..3); let slices = shape.into_slices([0.., 0..]); assert_eq!(slices[0].to_range(8), 0..8); assert_eq!(slices[1].to_range(4), 0..4); } #[test] fn slice_range_multi_dim_index() { let shape = Shape::new([8, 4]); // Indices (single integer) should also convert to correct range let slices = shape.clone().into_slices([0, 2]); assert_eq!(slices[0].to_range(8), 0..1); assert_eq!(slices[1].to_range(4), 2..3); let slices = shape.into_slices([-1, -1]); assert_eq!(slices[0].to_range(8), 7..8); assert_eq!(slices[1].to_range(4), 3..4); } #[test] fn slice_range_multi_dim_heterogeneous() { // Slice macro `s![]` can be used to provide different range types let shape = Shape::new([8, 4, 2]); let slice = s![0..5, .., -1]; let slices = shape.into_slices(slice); assert_eq!(slices[0].to_range(8), 0..5); assert_eq!(slices[1].to_range(4), 0..4); assert_eq!(slices[2].to_range(2), 1..2); let shape = Shape::new([8, 4, 2, 3]); let slice = s![..=4, 0..=3, .., -2..]; let slices = shape.into_slices(slice); assert_eq!(slices[0].to_range(8), 0..5); assert_eq!(slices[1].to_range(4), 0..4); assert_eq!(slices[2].to_range(2), 0..2); assert_eq!(slices[3].to_range(3), 1..3); let shape = Shape::new([3, 4]); let slice = s![1..-1, ..]; let slices = shape.into_slices(slice); assert_eq!(slices[0].to_range(3), 1..2); assert_eq!(slices[1].to_range(4), 0..4); } } ================================================ FILE: crates/burn-tensor/src/tensor/api/bool.rs ================================================ use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend}; use alloc::{vec, vec::Vec}; use crate::try_read_sync; /// The part of the tensor to keep when creating a triangular mask. enum TriPart { /// Upper triangular part. Upper, /// Lower triangular part. Lower, /// Diagonal part. Diagonal, } impl Tensor where B: Backend, { /// Create a boolean tensor from data on the given device. /// /// # Arguments /// /// * `data` - The tensor data. /// * `device` - The device on which the tensor will be allocated. /// /// # Returns /// /// A boolean tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_bool([[true, false], [false, true]].into(), &device); /// println!("{tensor}"); /// } /// ``` pub fn from_bool(data: TensorData, device: &B::Device) -> Self { Self::new(B::bool_from_data(data.convert::(), device)) } /// Convert the bool tensor into an int tensor. /// /// # Returns /// /// An integer tensor where `true` is converted to `1` and `false` to `0`. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let bool_tensor = Tensor::::from_bool([true, false, true].into(), &device); /// let int_tensor = bool_tensor.int(); /// println!("{int_tensor}"); // [1, 0, 1] /// } /// ``` pub fn int(self) -> Tensor { Tensor::new(B::bool_into_int(self.primitive)) } /// Convert the bool tensor into a float tensor. /// /// # Returns /// /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let bool_tensor = Tensor::::from_bool([true, false, true].into(), &device); /// let float_tensor = bool_tensor.float(); /// println!("{float_tensor}"); // [1.0, 0.0, 1.0] /// } /// ``` pub fn float(self) -> Tensor { Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive))) } /// Inverses boolean values. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_bool([[true, false], [false, true]].into(), &device); /// let inverted = tensor.bool_not(); /// println!("{inverted}"); // [[false, true], [true, false]] /// } /// ``` pub fn bool_not(self) -> Self { Tensor::new(B::bool_not(self.primitive)) } /// Performs logical and (`&&`) on two boolean tensors. /// /// # Arguments /// /// * `rhs` - The right-hand side tensor for the AND operation. /// /// # Returns /// /// A boolean tensor where each element is the result of `self[i] && rhs[i]`. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let a = Tensor::::from_bool([[true, true], [false, false]].into(), &device); /// let b = Tensor::::from_bool([[true, false], [true, false]].into(), &device); /// let result = a.bool_and(b); /// println!("{result}"); // [[true, false], [false, false]] /// } /// ``` pub fn bool_and(self, rhs: Tensor) -> Tensor { Tensor::new(B::bool_and(self.primitive, rhs.primitive)) } /// Performs logical or (`||`) on two boolean tensors. /// /// # Arguments /// /// * `rhs` - The right-hand side tensor for the OR operation. /// /// # Returns /// /// A boolean tensor where each element is the result of `self[i] || rhs[i]`. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let a = Tensor::::from_bool([[true, true], [false, false]].into(), &device); /// let b = Tensor::::from_bool([[true, false], [true, false]].into(), &device); /// let result = a.bool_or(b); /// println!("{result}"); // [[true, true], [true, false]] /// } /// ``` pub fn bool_or(self, rhs: Tensor) -> Tensor { Tensor::new(B::bool_or(self.primitive, rhs.primitive)) } /// Performs logical xor (`^`) on two boolean tensors. /// /// # Arguments /// /// * `rhs` - The right-hand side tensor for the XOR operation. /// /// # Returns /// /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`. /// Returns `true` when exactly one of the operands is `true`. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let a = Tensor::::from_bool([[true, true], [false, false]].into(), &device); /// let b = Tensor::::from_bool([[true, false], [true, false]].into(), &device); /// let result = a.bool_xor(b); /// println!("{result}"); // [[false, true], [true, false]] /// } /// ``` pub fn bool_xor(self, rhs: Tensor) -> Tensor { Tensor::new(B::bool_xor(self.primitive, rhs.primitive)) } /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors). /// /// # Returns /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_bool( /// [[true, false, true], [false, true, false], [false, true, false]].into(), /// &device, /// ); /// let indices = tensor.nonzero(); /// println!("{}", indices[0]); // [0, 0, 1, 2] /// println!("{}", indices[1]); // [0, 2, 1, 1] /// } /// ``` pub fn nonzero(self) -> Vec> { try_read_sync(self.nonzero_async()) .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.") } /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors). /// /// # Returns /// /// A vector of tensors, one for each dimension of the given tensor, containing the indices of /// the non-zero elements in that dimension. pub async fn nonzero_async(self) -> Vec> { let indices = self.argwhere_async().await; if indices.shape().num_elements() == 0 { // Return empty vec when all elements are zero return vec![]; } let dims = indices.shape(); indices .chunk(dims[1], 1) .into_iter() .map(|t| t.reshape(Shape::new([dims[0]]))) .collect() } /// Compute the indices of the elements that are true, grouped by element. /// /// # Returns /// /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the /// result contains the indices of a non-zero element. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_bool( /// [[true, false, true], [false, true, false], [false, true, false]].into(), /// &device, /// ); /// let indices = tensor.argwhere(); /// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]] /// } /// ``` pub fn argwhere(self) -> Tensor { try_read_sync(self.argwhere_async()) .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.") } /// Compute the indices of the elements that are true, grouped by element. /// /// # Returns /// /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the /// result contains the indices of a non-zero element. pub async fn argwhere_async(self) -> Tensor { Tensor::new(B::bool_argwhere(self.primitive).await) } /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to /// fill the specified area with a value. fn tri_mask>( shape: S, tri_part: TriPart, offset: i64, device: &B::Device, ) -> Self { let shape: Shape = shape.into(); let height = shape[D - 2]; let width = shape[D - 1]; // Generate row and column index tensors. let row_indices: Tensor = Tensor::arange(0..height as i64, device); let col_indices: Tensor = Tensor::arange(0..width as i64, device); // Prepare shapes for broadcasting. let mut row_shape = [1; D]; row_shape[D - 2] = height; let mut col_shape = [1; D]; col_shape[D - 1] = width; // Reshape for broadcasting. let row_broadcast: Tensor = row_indices.reshape(Shape::new(row_shape)); let col_broadcast = col_indices.reshape(Shape::new(col_shape)); // Broadcasting trick to create a matrix that facilitates comparison for mask generation. let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset); // Select the appropriate comparison function based on `tri_part`. let compare = match tri_part { TriPart::Upper => Tensor::greater_elem, TriPart::Lower => Tensor::lower_elem, TriPart::Diagonal => Tensor::not_equal_elem, }; // Generate and return the mask by applying the comparison to the matrix. compare(matrix, 0).unsqueeze() } /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified /// area with a value. /// /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix. /// /// # Arguments /// /// * `shape`: The shape of the matrix. /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift /// towards the upper triangle. /// * `device`: The device on which the tensor will be allocated. /// /// # Returns /// /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the /// upper triangle taking into account the specified `offset`. All other elements are `true`. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let mask = Tensor::::triu_mask([3, 3], 0, &Default::default()); /// println!("{mask}"); /// // [[false, false, false], /// // [true, false, false], /// // [true, true, false]] /// } /// ``` pub fn triu_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Upper, offset, device) } /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified /// area with a value. /// /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix. /// /// # Arguments /// /// * `shape`: The shape of the matrix. /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift /// towards the lower triangle. /// * `device`: The device on which the tensor will be allocated. /// /// # Returns /// /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the /// lower triangle taking into account the specified `offset`. All other elements are `true`. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let mask = Tensor::::tril_mask([3, 3], 0, &Default::default()); /// println!("{mask}"); /// // [[false, true, true], /// // [false, false, true], /// // [false, false, false]] /// } /// ``` pub fn tril_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Lower, offset, device) } /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified /// area with a value. /// /// This function generates a boolean tensor representing the mask of the diagonal of a matrix. /// /// # Arguments /// /// * `shape`: The shape of the matrix. /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift /// towards the upper triangle. /// * `device`: The device on which the tensor will be allocated. /// /// # Returns /// /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the /// diagonal. All other elements are `true`. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool}; /// /// fn example() { /// let mask = Tensor::::diag_mask([3, 3], 0, &Default::default()); /// println!("{mask}"); /// // [[false, true, true], /// // [true, false, true], /// // [true, true, false]] /// } /// ``` pub fn diag_mask>(shape: S, offset: i64, device: &B::Device) -> Self { Self::tri_mask(shape, TriPart::Diagonal, offset, device) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/cartesian_grid.rs ================================================ use crate::{Int, Shape, Tensor, backend::Backend}; use alloc::vec::Vec; /// Generates a cartesian grid for the given tensor shape on the specified device. /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element. /// /// # Arguments /// /// * `shape` - The shape specifying the dimensions of the tensor. /// * `device` - The device to create the tensor on. /// /// # Panics /// /// Panics if `D2` is not equal to `D+1`. /// /// # Examples /// /// ```rust /// use burn_tensor::Int; /// use burn_tensor::{backend::Backend, Shape, Tensor}; /// fn example() { /// let device = Default::default(); /// let result: Tensor = Tensor::::cartesian_grid([2, 3], &device); /// println!("{}", result); /// } /// ``` pub fn cartesian_grid, const D: usize, const D2: usize>( shape: S, device: &B::Device, ) -> Tensor { if D2 != D + 1 { panic!("D2 must equal D + 1 for Tensor::cartesian_grid") } let dims = shape.into(); let mut indices: Vec> = Vec::new(); for dim in 0..D { let dim_range: Tensor = Tensor::arange(0..dims[dim] as i64, device); let mut shape = [1; D]; shape[dim] = dims[dim]; let mut dim_range = dim_range.reshape(shape); for (i, &item) in dims.iter().enumerate() { if i == dim { continue; } dim_range = dim_range.repeat_dim(i, item); } indices.push(dim_range); } Tensor::stack::(indices, D) } ================================================ FILE: crates/burn-tensor/src/tensor/api/check.rs ================================================ use crate::ops::FloatElem; use crate::{BasicOps, Shape, Slice, Tensor, backend::Backend, cast::ToElement}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; use alloc::vec::Vec; use burn_backend::tensor::Ordered; /// The struct should always be used with the [check](crate::check) macro. /// /// This is a simple pub(crate) data structure that efficiently checks tensor operations and /// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter /// when a failed check is discovered since the program will panic. /// /// # Notes /// /// Failing tensor checks will always result in a panic. /// As mentioned in [The Rust Programming Language book](https://doc.rust-lang.org/book/ch09-03-to-panic-or-not-to-panic.html), /// when there is no way to recover, panic should be used instead of a result. /// /// Most users will unwrap the results anyway, which will worsen the clarity of the code. Almost /// all checks highlight programming errors, which means invalid programs that should be fixed. /// Checks are not the ideal way to help users write correct programs, but they are still better /// than backend errors. Other forms of compile-time validation could be developed, such as named /// tensors, but we have to carefully evaluate the ease of use of the Tensor API. Adding overly /// complex type validation checks might drastically worsen the API and result in harder-to-maintain /// programs. /// /// # Design /// /// Maybe the Backend API should return a result for each operation, which would allow handling /// all checks, even the ones that can't be efficiently checked before performing an operation, /// such as the `index_select` operation. The downside of that approach is that all backend /// implementation might re-implement the same checks, which may result in unnecessary code /// duplication. Maybe a combination of both strategies could help to cover all use cases. pub(crate) enum TensorCheck { Ok, Failed(FailedTensorCheck), } impl TensorCheck { /// Checks device and shape compatibility for element wise binary operations. pub(crate) fn binary_ops_ew>( ops: &str, lhs: &Tensor, rhs: &Tensor, ) -> Self { Self::Ok .binary_ops_device(ops, &lhs.device(), &rhs.device()) .binary_ops_ew_shape::(ops, &lhs.shape(), &rhs.shape()) } pub(crate) fn into_scalar(shape: &Shape) -> Self { let mut check = Self::Ok; if shape.num_elements() != 1 { check = check.register( "Into Scalar", TensorError::new("Only tensors with 1 element can be converted into scalar.") .details(format!( "Current tensor has {} elements", shape.num_elements() )), ); } check } pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; if dim >= D { check = check.register( ops, TensorError::new("Given dimension is higher than the tensor rank.") .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), ); } check } pub(crate) fn creation_ops(ops: &str, dims: &[usize]) -> Self { let mut check = Self::Ok; if D == 0 { check = check.register( ops, TensorError::new("Tried to create a 0-dim tensor, which is invalid.") .details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")), ); } if dims.len() != D { check = check.register( ops, TensorError::new("Given dimensions differ from the tensor rank.") .details(format!("Tensor rank: '{D}', given dimensions: '{dims:?}'.")), ); } check } pub(crate) fn narrow>( tensor: &Tensor, dim: usize, start: usize, length: usize, ) -> Self { let mut check = Self::Ok; if length == 0 { check = check.register( "Narrow", TensorError::new(format!( "Can't narrow at dimension {dim}, length must be greater than 0", )), ); } if start >= tensor.shape()[dim] { check = check.register( "Narrow", TensorError::new(format!( "Can't narrow at dimension {dim}, start exceeds the size of the tensor along \ this dimension (Size={})", tensor.shape()[dim] )), ); } if start + length > tensor.shape()[dim] { check = check.register( "Narrow", TensorError::new(format!( "Can't narrow at dimension {dim}, start + length exceeds the size of the tensor \ along this dimension (Size={})", tensor.shape()[dim] )), ); } check } pub(crate) fn movedim_args_usize(dim: usize) -> Self { let mut check = Self::Ok; if dim >= D { check = check.register( "Movedim", TensorError::new( "The given dimension exceeds the number of dimensions of the current tensor.", ) .details(format!( "Current tensor has {D} dimensions, but the given dimension is {dim}.", )), ); } check } pub(crate) fn movedim_args_i32(dim: i32) -> Self { let mut check = Self::Ok; if dim < -(D as i32) || dim >= D as i32 { check = check.register( "Movedim", TensorError::new( "The given dimension is out of bounds for the current tensor dimensions.", ) .details(format!( "Current tensor has {D} dimensions, but the given dimension is {dim}.", )), ); } check } pub(crate) fn movedim_args_vec(dims: &Vec) -> Self { let mut check = Self::Ok; // Check out of bounds if dims.iter().any(|&x| x >= D) { check = check.register( "Movedim", TensorError::new("The given dimensions are out of bounds.").details(format!( "Current tensor has {D} dimensions, but the given dimensions are {dims:?}.", )), ); } // Check there are no duplicates for (i, &dim_i) in dims.iter().enumerate() { for &dim_j in dims.iter().skip(i + 1) { if dim_i == dim_j { check = check.register( "Movedim", TensorError::new("The given dimensions contain duplicates.").details( format!( "The dimension {dim_i} is duplicated in the given dimensions {dims:?}.", ), ), ); } } } check } pub(crate) fn movedim_args_length( source_dims: &Vec, destination_dims: &Vec, ) -> Self { let mut check = Self::Ok; if source_dims.len() != destination_dims.len() { check = check.register( "Movedim", TensorError::new( "The number of dimensions in source and destination must be equal.", ) .details(format!( "Source dimensions: {source_dims:?}, Destination dimensions: {destination_dims:?}.", )), ) } check } pub(crate) fn flatten( start_dim: usize, end_dim: usize, ) -> Self { let mut check = Self::Ok; if start_dim > end_dim { check = check.register( "Flatten", TensorError::new(format!( "The start dim ({start_dim}) must be smaller than or equal to the end dim ({end_dim})" )), ); } if D2 > D1 { check = check.register( "Flatten", TensorError::new(format!( "Result dim ({D2}) must be smaller than or equal to ({D1})" )), ); } if D1 < end_dim + 1 { check = check.register( "Flatten", TensorError::new(format!( "The end dim ({end_dim}) must be smaller than the tensor dim ({D1})" )), ); } if (D2 as i32) < (D1 as i32 - (end_dim as i32 - start_dim as i32)) { check = check.register( "Flatten", TensorError::new(format!( "The destination dimension ({D2}) must be large enough to accommodate the \ flattening operation." )), ); } check } pub(crate) fn tri() -> Self { let mut check = Self::Ok; if D < 2 { check = check.register( "Tri", TensorError::new(format!( "The input tensor must have at least 2 dimensions, got {D}" )), ); } check } pub(crate) fn squeeze(dim: usize, tensor_dims: &[usize]) -> Self { let mut check = Self::Ok; // This should actually be to check that the dimension to squeeze // has a size of 1 if tensor_dims[dim] != 1 { check = check.register( "Squeeze", TensorError::new(format!( "Can't squeeze dimension {dim} because its size is not 1", )), ); } if dim >= tensor_dims.len() { check = check.register( "Squeeze", TensorError::new(format!( "Dimension index {dim} is out of bounds for tensor dimensions {tensor_dims:?}.", )), ); } check } pub(crate) fn squeeze_dims_input( dim_indices: &[usize], current_dims: &[usize], ) -> Self { let mut check = Self::Ok; if dim_indices.len() >= current_dims.len() { check = check.register( "Squeeze", TensorError::new("Attempted to squeeze too many dimensions!"), ); } check } pub(crate) fn squeeze_dims_len(new_dims_len: usize) -> Self { let mut check = Self::Ok; if new_dims_len == 0 { // 0-dim tensor not supported check = check.register( "Squeeze", TensorError::new( "Resulting dimensions cannot be zero. To remove specific singleton dimensions while preserving at least one, use `squeeze_dims` instead.".to_string() ), ); } if new_dims_len != D2 { check = check.register( "Squeeze", TensorError::new(format!( "Resulting dimensions {new_dims_len} do not match the required D2 size {D2}.", )), ); } check } pub(crate) fn unsqueeze() -> Self { let mut check = Self::Ok; if D2 < D1 { check = check.register( "Unsqueeze", TensorError::new(format!( "Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}", )), ); } check } pub(crate) fn unsqueeze_dim(dim: usize) -> Self { let mut check = Self::Ok; if D2 <= D1 { check = check.register( "Unsqueeze", TensorError::new(format!( "The unsqueezed rank must be greater than the input rank (D={D1}; D2={D2})", )), ); } if dim > D1 { check = check.register( "Unsqueeze", TensorError::new(format!( "Can't unsqueeze at dimension {dim}, exceeds tensor dimensions (D={D1})", )), ); } if dim >= D2 { check = check.register( "Unsqueeze", TensorError::new(format!( "Can't unsqueeze at dimension {dim}, exceeds output tensor dimensions (D2={D2})", )), ); } check } pub(crate) fn unsqueeze_dims(dim: isize) -> Self { let mut check = Self::Ok; let output_rank = D as isize; //contains is right exclusive, so this is to spec if !(-output_rank..output_rank).contains(&dim) { check = check.register( "Unsqueeze", TensorError::new(format!( "unsqueeze arg {dim} is out of range for the output tensor of rank {output_rank}", )), ); } check } pub(crate) fn one_hot_tensor>( index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; if index_tensor .clone() .greater_equal_elem(num_classes as i32) .any() .into_scalar() .to_bool() { check = check.register( "One Hot", TensorError::new(format!( "Can't create a one hot tensor from ({index_tensor:?}) containing indexes greater or equal to the number of classes ({num_classes})", )), ); } else if num_classes <= 1 { check = check.register( "One Hot", TensorError::new("Can't create a one hot tensor with less then 2 classes"), ) } check } pub(crate) fn one_hot_tensor_rank() -> Self { let mut check = Self::Ok; if D + 1 != D2 { check = check.register( "One Hot", TensorError::new( "The one-hot tensor rank must correspond to the rank of the tensor + 1", ) .details(format!("Expected D2={}, got {D2}", D + 1)), ); } check } pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; if dim1 > D || dim2 > D { check = check.register( "Swap Dims", TensorError::new("The swap dimensions must be smaller than the tensor dimension") .details(format!( "Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions." )), ); } check } pub(crate) fn permute(axes: [usize; D]) -> Self { let check = Self::Ok; // Check if the axes are within the tensor dimensions if let Some(axis) = axes.iter().find(|&x| *x >= D) { return check.register( "permute", TensorError::new("The axes must be smaller than the tensor dimension.") .details(format!("The '{axis}' axis is greater than {D} dimensions.")), ); } // Check if the axes are unique let mut seen = [false; D]; axes.iter().for_each(|&x| seen[x] = true); if seen.iter().any(|&x| !x) { return check.register( "permute", TensorError::new("The axes must be unique.") .details(format!("The axes '{axes:?}' are not unique.")), ); } check } pub(crate) fn flip(rank: usize, axes: &[usize]) -> Self { let check = Self::Ok; // Check if the axes are within the tensor dimensions if let Some(axis) = axes.iter().find(|&x| *x >= rank) { return check.register( "flip", TensorError::new("The axes must be smaller than the tensor dimension.").details( format!("The '{axis}' axis is greater than {rank} dimensions."), ), ); } // Check if the axes are unique let mut dedup = axes.to_vec(); dedup.sort_unstable(); dedup.dedup(); if dedup.len() != axes.len() { return check.register( "flip", TensorError::new("The axes must be unique.") .details(format!("The axes '{axes:?}' are not unique.")), ); } check } pub(crate) fn matmul( lhs: &Tensor, rhs: &Tensor, ) -> Self where K: BasicOps, { let mut check = Self::Ok; check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); if D < 2 { return check; } let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); let dim_lhs = shape_lhs[D - 1]; let dim_rhs = shape_rhs[D - 2]; if dim_lhs != dim_rhs { check = check.register( "Matmul", TensorError::new(format!( "The inner dimension of matmul should be the same, but got {dim_lhs} and \ {dim_rhs}." )) .details(format!( "Lhs shape {:?}, rhs shape {:?}.", shape_lhs, shape_rhs )), ); } check } pub(crate) fn cross( lhs: &Tensor, rhs: &Tensor, dim: usize, ) -> Self where K: BasicOps, { let mut check = Self::Ok; check = check.binary_ops_device("Cross", &lhs.device(), &rhs.device()); let shape_lhs = lhs.shape(); let shape_rhs = rhs.shape(); if dim >= D { check = check.register( "Cross", TensorError::new(format!( "Dimension {dim} is out of bounds for tensors with {D} dimensions." )), ); return check; } let dim_size_lhs = shape_lhs[dim]; let dim_size_rhs = shape_rhs[dim]; if dim_size_lhs != 3 || dim_size_rhs != 3 { check = check.register( "Cross", TensorError::new(format!( "Cross product requires dimension {dim} to have size 3, but got {dim_size_lhs} and {dim_size_rhs}." )), ); } // Check broadcastability of other dimensions for i in 0..D { if i != dim { let l = shape_lhs[i]; let r = shape_rhs[i]; if l != r && l != 1 && r != 1 { check = check.register( "Cross", TensorError::new(format!( "Tensors are not broadcastable along dimension {i}: {l} and {r}." )), ); } } } check } pub(crate) fn stack, const D2: usize>( tensors: &[Tensor], dim: usize, ) -> Self { let mut check = Self::Ok; if dim > D1 { check = check.register( "Stack", TensorError::new( "Can't stack tensors on a dim that exceeds the tensors dimension (inclusive)", ) .details(format!( "Trying to concatenate tensors with {D1} dimensions on axis {dim}." )), ); } if D1 == D2 { check = check.register( "Stack", TensorError::new(format!( "Can't stack tensors on existing dimension {dim}, the input and output ranks are the same (D={D1}; D2={D2}).\ If you want to concatenate the tensors along the specified dimension ({dim}), use `Tensor::cat` instead.", )), ); } if tensors.is_empty() { return check.register( "Stack", TensorError::new("Can't stack an empty list of tensors."), ); } let shape_reference = tensors.first().unwrap().shape(); for tensor in tensors { let shape = tensor.shape(); if shape_reference != shape { return check.register( "Stack", TensorError::new("Can't stack tensors with different shapes").details(format!( "Provided dimension ({dim}), tensors shapes: {:?}", tensors.iter().map(Tensor::shape).collect::>() )), ); } } check } pub(crate) fn cat>( tensors: &[Tensor], dim: usize, ) -> Self { let mut check = Self::Ok; if dim >= D { check = check.register( "Cat", TensorError::new( "Can't concatenate tensors on a dim that exceeds the tensors dimension", ) .details(format!( "Trying to concatenate tensors with {D} dimensions on axis {dim}." )), ); } if tensors.is_empty() { return check.register( "Cat", TensorError::new("Can't concatenate an empty list of tensors."), ); } let mut shape_reference = tensors.first().unwrap().shape(); shape_reference[dim] = 1; // We want to check every dims except the one where the // concatenation happens. for tensor in tensors { let mut shape = tensor.shape(); shape[dim] = 1; // Ignore the concatenate dim. if shape_reference != shape { return check.register( "Cat", TensorError::new( "Can't concatenate tensors with different shapes, except for the provided \ dimension", ) .details(format!( "Provided dimension ({dim}), tensors shapes: {:?}", tensors.iter().map(Tensor::shape).collect::>() )), ); } } check } pub(crate) fn slice(shape: &Shape, slices: &[Slice]) -> Self { let mut check = Self::Ok; let n_dims_tensor = R; let n_dims_slices = slices.len(); if n_dims_tensor < n_dims_slices { check = check.register( "Slice", TensorError::new( "The provided slices array has a higher number of dimensions than the current \ tensor.", ) .details(format!( "The slices array must be smaller or equal to the tensor number of \ dimensions. Tensor number of dimensions: {n_dims_tensor}, slices array \ length {n_dims_slices}." )), ); } for (i, slice) in slices.iter().enumerate().take(R) { let d_tensor = shape[i]; // Check the raw end value before conversion if let Some(end) = slice.end && end > 0 && end as usize > d_tensor { check = check.register( "Slice", TensorError::new( "The provided slice has a range that exceeds the current tensor \ size.", ) .details(format!( "The slice end index {} exceeds the size of the tensor ({}) at dimension {}. \ Tensor shape {:?}.", end, d_tensor, i, shape, )), ); } // Empty slices (start >= end) are allowed and produce a tensor with size 0 // in that dimension. This matches PyTorch behavior and is required for ONNX // compatibility where dynamic slice ranges may become empty at runtime. if slice.step() == 0 { check = check.register( "Slice", TensorError::new("The provided slice has a step of 0.").details(format!( "The slice at dimension '{i}' has a step of 0. Step must be non-zero.", )), ); } } check } pub(crate) fn slice_assign( shape: &Shape, shape_value: &Shape, slices: &[crate::Slice], ) -> Self { let mut check = Self::Ok; let n_dims_slices = slices.len(); if R < n_dims_slices { check = check.register( "Slice Assign", TensorError::new( "The provided slices array has a higher number of dimensions than the current \ tensor.", ) .details(format!( "The slices array must be smaller or equal to the tensor number of \ dimensions. Tensor number of dimensions: {R}, slices array length {n_dims_slices}." )), ); } for (i, slice) in slices.iter().enumerate().take(usize::min(R, n_dims_slices)) { let d_tensor = shape[i]; let d_tensor_value = shape_value[i]; let range = slice.to_range(d_tensor); if range.end > d_tensor { check = check.register( "Range Assign", TensorError::new( "The provided slice has a range that exceeds the current tensor \ size.", ) .details(format!( "The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \ Current tensor shape {:?}, value tensor shape {:?}.", range.start, range.end, d_tensor, i, shape, shape_value, )), ); } // Calculate the number of elements selected with the given step let num_elements = slice.output_size(d_tensor); if num_elements != d_tensor_value { check = check.register( "Slice Assign", TensorError::new( "The value tensor must match the amount of elements selected with the \ slices array", ) .details(format!( "The slice with range ({}..{}) and step {} selects {} elements but the value \ tensor has {} elements at dimension {}. Current tensor shape {:?}, value tensor \ shape {:?}.", range.start, range.end, slice.step, num_elements, d_tensor_value, i, shape, shape_value, )), ); } // Note: Empty slices (start >= end with positive step) are handled at the API level // by returning the original tensor unchanged, so we don't check for them here. } check } pub(crate) fn check_dim(dim: usize) -> Self { let mut check = Self::Ok; if dim >= D { check = check.register( "Check Dim", TensorError::new("The provided dimension exceeds the tensor dimensions.").details( format!("Tensor has {D} dimensions, but the provided dimension is {dim}."), ), ); } check } pub(crate) fn gather(dim: usize, shape: &Shape, shape_indices: &Shape) -> Self { Self::check_gather_scatter_indices::(Self::Ok, "Gather", dim, shape, shape_indices) } pub(crate) fn scatter( dim: usize, shape: &Shape, shape_indices: &Shape, shape_value: &Shape, ) -> Self { let ops = "Scatter"; let mut check = Self::check_gather_scatter_indices::(Self::Ok, ops, dim, shape, shape_indices); if shape_indices != shape_value { check = check.register( ops, TensorError::new( "Indices tensor shape should be the same as the value tensor shape." .to_string(), ) .details(format!( "The shape differs: {:?} != {:?}", shape_indices, shape_value )), ); } check } pub(crate) fn select(dim: usize) -> Self { Self::check_select_basic::(Self::Ok, "select", dim) } pub(crate) fn take(dim: usize) -> Self { let mut check = Self::check_select_basic::(Self::Ok, "Take", dim); // Calculate expected output dimensions // DO = D - 1 + DI (remove 1 dim, add DI dims) let expected_do = D + DI - 1; if DO != expected_do { check = check.register( "Take", TensorError::new("Output dimension mismatch").details(format!( "Expected output dimension {} (D={} + DI={} - 1) but got DO={}", expected_do, D, DI, DO )), ); } check } pub(crate) fn diag() -> Self { let mut check = Self::Ok; if D < 2 { check = check.register( "Diag", TensorError::new( "Diagonal operations require tensors with at least 2 dimensions.", ) .details(format!( "Got tensor with {D} dimensions, expected at least 2" )), ); } if DO != D - 1 { check = check.register( "Diag", TensorError::new("Output rank must be input rank minus 1 for diagonal") .details(format!("Expected output rank {}, got {DO}", D - 1)), ); } check } pub(crate) fn select_assign( dim: usize, shape_indices: &Shape, shape_value: &Shape, ) -> Self { let mut check = Self::check_select_basic::(Self::Ok, "Select Assign", dim); if shape_value[dim] != shape_indices[0] { check = check.register( "Select Assign", TensorError::new( format!( "Number of indices ({}) should be equal to value tensor dimensions {:?} on axis (dim={dim})", shape_indices[0], shape_value ), ) ); } check } fn check_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { if dim > D { check = check.register( ops, TensorError::new(format!( "Can't index a tensor with ({D}) dimensions on axis ({dim})" )), ); } check } fn check_gather_scatter_indices( mut check: Self, ops: &str, dim: usize, shape: &Shape, shape_indices: &Shape, ) -> Self { if dim > D { check = check.register( ops, TensorError::new(format!( "Can't index a tensor with ({D}) dimensions on axis ({dim})" )), ); } for i in 0..D { if i == dim { continue; } let tensor_dim_i = shape[i]; let indices_dim_i = shape_indices[i]; if tensor_dim_i != indices_dim_i { check = check.register( ops, TensorError::new( "The tensor shape should be the same as the index tensor shape." .to_string(), ) .details(format!( "The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}" )), ); } } check } pub(crate) fn check_prelu_shape( shape_tensor: &Shape, shape_weight: &Shape, ) -> Self { let mut check = Self::Ok; if shape_weight[0] == 1 { check } else if D >= 2 { let channels = shape_tensor[1]; let num_weights = shape_weight[0]; if channels != num_weights { check = check.register( "PReLu", TensorError::new( "Number of channels in input tensor and number of weights must be equal", ) .details(format!( "Got no. of channels: {channels}, no. of weights: {num_weights}", )), ); return check; } check } else { check = check.register( "PReLu", TensorError::new( "Number of channels in input tensor and number of weights must be equal", ) .details(format!( "Got no. of channels: 1, no. of weights: {}", shape_weight[0] )), ); check } } /// Checks aggregate dimension such as mean and sum. pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; if dim > D { check = check.register( ops, TensorError::new(format!( "Can't aggregate a tensor with ({D}) dimensions on axis ({dim})" )), ); } check } pub(crate) fn sort_dim(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; if dim > D { check = check.register( ops, TensorError::new(format!( "Can't sort a tensor with ({D}) dimensions on axis ({dim})" )), ); } check } pub(crate) fn split( tensor_dims: &[usize], split_size: usize, dim: usize, ) -> Self { let mut check = Self::Ok; let op = "split"; let tensor_rank = tensor_dims.len(); if dim >= tensor_rank { check = check.register( op, TensorError::new("Given dimension is greater than or equal to the tensor rank.") .details(format!("Tensor rank: '{D}', given dimension: '{dim}'")), ); } else { let tensor_size = tensor_dims[dim]; if split_size == 0 && tensor_size != 0 { check = check.register( op, TensorError::new("split_size must be greater than 0 unless the tensor size along the dimension is 0.") .details(format!("split_size: '{split_size}', tensor size along dim '{dim}': '{tensor_size}'.")), ); } } check } pub(crate) fn split_with_sizes( tensor_dims: &[usize], split_sizes: &[usize], dim: usize, ) -> Self { let mut check = Self::Ok; let op = "split_with_sizes"; let tensor_rank = tensor_dims.len(); if dim >= tensor_rank { check = check.register( op, TensorError::new("Given dimension is greater than or equal to the tensor rank.") .details(format!("Tensor rank: '{D}', given dimension: '{dim}'.")), ); } else { // Validate split_sizes add up to size of dimension to split along let tensor_size = tensor_dims[dim]; let total_split_size: usize = split_sizes.iter().sum(); if total_split_size != tensor_size { check = check.register( op, TensorError::new("The sum of split_sizes must equal the tensor size along the specified dimension.") .details(format!("Sum of split_sizes: '{total_split_size}', tensor size along dim '{dim}': '{tensor_size}'.")), ); } } check } /// The goal is to minimize the cost of checks when there are no error, but it's way less /// important when an error occurred, crafting a comprehensive error message is more important /// than optimizing string manipulation. fn register(self, ops: &str, error: TensorError) -> Self { let errors = match self { Self::Ok => vec![error], Self::Failed(mut failed) => { failed.errors.push(error); failed.errors } }; Self::Failed(FailedTensorCheck { ops: ops.to_string(), errors, }) } /// Checks if shapes are compatible for element wise operations supporting broadcasting. pub(crate) fn binary_ops_ew_shape( self, ops: &str, lhs: &Shape, rhs: &Shape, ) -> Self { let mut check = self; for i in 0..D { let d_lhs = lhs[i]; let d_rhs = rhs[i]; if d_lhs != d_rhs { let is_broadcast = d_lhs == 1 || d_rhs == 1; if is_broadcast { continue; } check = check.register( ops, TensorError::new("The provided tensors have incompatible shapes.").details( format!( "Incompatible size at dimension '{}' => '{} != {}', which can't be \ broadcasted. Lhs tensor shape {:?}, Rhs tensor shape {:?}.", i, d_lhs, d_rhs, lhs, rhs, ), ), ); } } check } /// Checks if tensor devices are equal. fn binary_ops_device( self, ops: &str, lhs: &Device, rhs: &Device, ) -> Self { match lhs != rhs { true => self.register( ops, TensorError::new("The provided tensors are not on the same device.").details( format!("Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.",), ), ), false => self, } } /// Checks if expand operation is possible for the given shapes. pub fn expand(ops: &str, shape: &Shape, to: &Shape) -> Self { let mut check = TensorCheck::Ok; let max_dims = core::cmp::max(D1, D2); // Calculate the starting indices for each shape array, ensuring alignment from the right. let start_index_shape = max_dims.saturating_sub(D1); let start_index_to = max_dims.saturating_sub(D2); for i in 0..max_dims { // Use 1 as the default dimension size for dimensions beyond the tensor's rank. let d_shape = if i >= start_index_shape { shape[i - start_index_shape] } else { 1 }; let d_to = if i >= start_index_to { to[i - start_index_to] } else { 1 }; if d_shape != d_to && d_shape != 1 && d_to != 1 { // Register an incompatibility error. check = check.register( ops, TensorError::new( "The provided tensor can't be broadcasted to the target shape.", ) .details(format!( "Incompatible size at dimension '{}' => '{} != {}', which can't be \ broadcasted. Tensor shape {:?}, Target shape {:?}.", max_dims - i - 1, d_shape, d_to, shape, to, )), ); break; // Incompatibility found, no need to check further. } } check } /// Checks if unfold operation is possible for the given shapes. pub fn unfold( ops: &str, _shape: &Shape, _dim: usize, _size: usize, _step: usize, ) -> Self { let mut check = TensorCheck::Ok; if D2 != D1 + 1 { check = check.register( ops, TensorError::new("The unfold rank is incompatible with the input tensor rank.") .details(format!( "The output rank '{D2}' != the input rank + 1 '{D1}'.", )), ); } check } /// Checks if input is compatible with convolution weights. pub fn conv( ops: &str, x: [usize; D1], weight: [usize; D2], groups: usize, ) -> Self { let mut check = TensorCheck::Ok; let channels = x[1]; let expected = weight[1] * groups; if channels != expected { check = check.register( ops, TensorError::new("Number of channels in input tensor and input channels of convolution must be equal.") .details(format!("got: {channels}, expected: {expected}")), ); } check } /// Checks if input is compatible with transposed convolution weights. pub fn conv_transpose( ops: &str, x: [usize; D1], weight: [usize; D2], ) -> Self { let mut check = TensorCheck::Ok; let channels = x[1]; let expected = weight[0]; if channels != expected { check = check.register( ops, TensorError::new("Number of channels in input tensor and input channels of convolution must be equal.") .details(format!("got: {channels}, expected: {expected}")), ); } check } /// Check if input is compatible with LU decomposition. pub fn is_square(ops: &str, shape: &Shape) -> Self { let mut check = TensorCheck::Ok; if shape[D - 1] != shape[D - 2] { check = check.register( ops, TensorError::new("The input tensor must be square.").details(format!( "Got tensor with shape {:?}, expected last two dimensions to be equal", shape )), ); } check } /// Check pivot is valid for LU decomposition. pub fn lu_decomposition_pivot(pivot: FloatElem) -> Self { let mut check = TensorCheck::Ok; if pivot.to_f64().abs() <= 1e-6 { check = check.register( "lu_decomposition", TensorError::new("LU decomposition requires a valid pivot.") .details(format!("Got pivot value too close to zero: {}", pivot)), ); } check } } pub(crate) struct FailedTensorCheck { ops: String, errors: Vec, } impl FailedTensorCheck { /// Format all the checks into a single message ready to be printed by a [panic](core::panic). pub(crate) fn format(self) -> String { self.errors.into_iter().enumerate().fold( format!( "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", self.ops ), |accum, (number, error)| accum + error.format(number + 1).as_str(), ) + "\n" } } struct TensorError { description: String, details: Option, } impl TensorError { pub(crate) fn new>(description: S) -> Self { TensorError { description: description.into(), details: None, } } pub(crate) fn details>(mut self, details: S) -> Self { self.details = Some(details.into()); self } fn format(self, number: usize) -> String { let mut message = format!("\n {number}. "); message += self.description.as_str(); message += " "; if let Some(details) = self.details { message += details.as_str(); message += " "; } message } } /// Module where we defined macros that can be used only in the project. pub(crate) mod macros { /// We use a macro for all checks, since the panic message file and line number will match the /// function that does the check instead of a generic error.rs crate private unrelated file /// and line number. macro_rules! check { ($check:expr) => { if let TensorCheck::Failed(check) = $check { core::panic!("{}", check.format()); } }; } pub(crate) use check; } pub(crate) fn unwrap_shape_reshape(result: Result) -> Shape { match result { Ok(shape) => shape, // `shape.reshape(new_shape)` should only return `MetadataError::Invalid`. Err(burn_std::MetadataError::Invalid { reason }) => { macros::check!({ TensorCheck::Ok.register("Reshape", crate::check::TensorError::new(reason)) }); unreachable!() } Err(e) => panic!("{e:?}"), } } #[cfg(test)] mod tests { use super::*; use macros::check; #[test] #[should_panic] fn index_range_exceed_dimension() { let slices = vec![Slice::from(0..2), Slice::from(0..4), Slice::from(1..8)]; check!(TensorCheck::slice::<3>(&Shape::new([3, 5, 7]), &slices)); } #[test] #[should_panic] fn index_range_exceed_number_of_dimensions() { let slices = vec![Slice::from(0..1), Slice::from(0..1), Slice::from(0..1)]; check!(TensorCheck::slice::<2>(&Shape::new([3, 5]), &slices)); } #[test] #[should_panic] fn binary_ops_shapes_no_broadcast() { check!(TensorCheck::binary_ops_ew_shape::<2>( TensorCheck::Ok, "TestOps", &Shape::new([3, 5]), &Shape::new([3, 6]) )); } #[test] fn binary_ops_shapes_with_broadcast() { check!(TensorCheck::binary_ops_ew_shape::<2>( TensorCheck::Ok, "Test", &Shape::new([3, 5]), &Shape::new([1, 5]) )); } #[test] #[should_panic] fn binary_ops_devices() { check!(TensorCheck::binary_ops_device( TensorCheck::Ok, "Test", &5, // We can pass anything that implements PartialEq as device &8 )); } #[test] #[should_panic] fn movedim_args_out_of_bounds() { check!(TensorCheck::movedim_args_usize::<3>(5)); } #[test] fn movedim_args_i32() { check!(TensorCheck::movedim_args_i32::<3>(-3)); } #[test] #[should_panic] fn movedim_args_too_negative() { check!(TensorCheck::movedim_args_i32::<3>(-4)); } #[test] #[should_panic] fn movedim_args_vec_out_of_bounds() { check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 3])); } #[test] #[should_panic] fn movedim_args_vec_duplicates() { check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 1])); } #[test] #[should_panic] fn movedim_args_length() { check!(TensorCheck::movedim_args_length( &vec![0, 1], &vec![0, 1, 2] )); } #[test] #[should_panic] fn unsqueeze_dim_same_rank() { check!(TensorCheck::unsqueeze_dim::<3, 3>(2)); } } ================================================ FILE: crates/burn-tensor/src/tensor/api/float.rs ================================================ use crate::AsIndex; use crate::FloatDType; use crate::Tensor; use crate::cast::ToElement; use crate::check; use crate::check::TensorCheck; use crate::ops::GridSampleOptions; use crate::quantization::{QuantScheme, QuantizationParameters}; use crate::tensor::backend::Backend; use crate::tensor::stats; use crate::tensor::{Distribution, TensorData}; use crate::{Bool, Int, TensorPrimitive}; use burn_backend::ElementConversion; use burn_backend::Scalar; use burn_backend::tensor::quantization::QuantizationParametersPrimitive; use core::f32; /// Default RTOL value for `is_close` and `all_close`. pub const DEFAULT_RTOL: f64 = 1e-5; /// Default ATOL value for `is_close` and `all_close`. pub const DEFAULT_ATOL: f64 = 1e-8; impl Tensor where B: Backend, { /// Applies element wise exponential operation. /// #[cfg_attr(doc, doc = "$y_i = e^{x_i}$")] #[cfg_attr(not(doc), doc = "`y = e^x`")] pub fn exp(self) -> Self { Self::new(TensorPrimitive::Float(B::float_exp( self.primitive.tensor(), ))) } /// Applies element wise natural log operation *ln*. /// #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = log(x_i)`")] pub fn log(self) -> Self { Self::new(TensorPrimitive::Float(B::float_log( self.primitive.tensor(), ))) } /// Applies the natural logarithm of one plus the input tensor, element-wise. /// #[cfg_attr(doc, doc = r#"$y_i = \log_e\(x_i + 1\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = log(x_i + 1)`")] pub fn log1p(self) -> Self { Self::new(TensorPrimitive::Float(B::float_log1p( self.primitive.tensor(), ))) } /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise. /// #[cfg_attr( doc, doc = r#" $y_i = \text{erf}\(x_i\)$ The error function is defined as: $$\text{erf}\(x\) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt$$ "# )] #[cfg_attr(not(doc), doc = "`y_i = erf(x_i)`")] pub fn erf(self) -> Self { Self::new(TensorPrimitive::Float(B::float_erf( self.primitive.tensor(), ))) } /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse) /// (or multiplicative inverse) element wise. /// #[cfg_attr(doc, doc = r#"$y_i = \frac{1}{x_i}$"#)] #[cfg_attr(not(doc), doc = "`y_i = 1/x_i`")] pub fn recip(self) -> Self { Self::new(TensorPrimitive::Float(B::float_recip( self.primitive.tensor(), ))) } /// Applies element wise square operation. /// #[cfg_attr(doc, doc = r#"$y_i = x_i * x_i$"#)] #[cfg_attr(not(doc), doc = "`y_i = x_i * x_i`")] pub fn square(self) -> Self { self.powi_scalar(2) } /// Applies element wise root square operation. /// #[cfg_attr(doc, doc = r#"$y_i = \sqrt{x_i}$"#)] #[cfg_attr(not(doc), doc = "`y_i = sqrt(x_i)`")] pub fn sqrt(self) -> Self { Self::new(TensorPrimitive::Float(B::float_sqrt( self.primitive.tensor(), ))) } /// Applies element wise cosine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \cos\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = cos(x_i)`")] pub fn cos(self) -> Self { Self::new(TensorPrimitive::Float(B::float_cos( self.primitive.tensor(), ))) } /// Applies element wise sine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \sin\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = sin(x_i)`")] pub fn sin(self) -> Self { Self::new(TensorPrimitive::Float(B::float_sin( self.primitive.tensor(), ))) } /// Applies element wise tangent operation. /// #[cfg_attr(doc, doc = r#"$y_i = \tan\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = tan(x_i)`")] pub fn tan(self) -> Self { Self::new(TensorPrimitive::Float(B::float_tan( self.primitive.tensor(), ))) } /// Applies element wise hyperbolic cosine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \cosh\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = cosh(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 2.0], &device); /// println!("{}", tensor.cosh()); // [1.0, 1.5430, 3.7621] /// } /// ``` pub fn cosh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_cosh( self.primitive.tensor(), ))) } /// Applies element wise hyperbolic sine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \sinh\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = sinh(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 2.0], &device); /// println!("{}", tensor.sinh()); // [0.0, -1.1752, 3.6269] /// } /// ``` pub fn sinh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_sinh( self.primitive.tensor(), ))) } /// Applies element wise hyperbolic tangent operation. /// #[cfg_attr(doc, doc = r#"$y_i = \tanh\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = tanh(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 2.0], &device); /// println!("{}", tensor.tanh()); // [0.0, -0.7616, 0.9640] /// } /// ``` pub fn tanh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_tanh( self.primitive.tensor(), ))) } /// Applies element wise inverse sine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 1.0], &device); /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708] /// } /// ``` pub fn asin(self) -> Self { Self::new(TensorPrimitive::Float(B::float_asin( self.primitive.tensor(), ))) } /// Applies element wise inverse hyperbolic sine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 1.0], &device); /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814] /// } /// ``` pub fn asinh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_asinh( self.primitive.tensor(), ))) } /// Applies element wise inverse cosine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 1.0], &device); /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0] /// } /// ``` pub fn acos(self) -> Self { Self::new(TensorPrimitive::Float(B::float_acos( self.primitive.tensor(), ))) } /// Applies element wise inverse hyperbolic cosine operation. /// #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([1.0, 2.0, 3.0], &device); /// println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627] /// } /// ``` pub fn acosh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_acosh( self.primitive.tensor(), ))) } /// Applies element wise inverse tangent operation. /// #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -1.0, 2.0], &device); /// println!("{}", tensor.sinh()); // [ 0.0, -0.7854, 1.1071] /// } /// ``` pub fn atan(self) -> Self { Self::new(TensorPrimitive::Float(B::float_atan( self.primitive.tensor(), ))) } /// Applies element wise inverse hyperbolic tangent operation. /// #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)] #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let tensor = Tensor::::from_data([0.0, -0.5, 0.5], &device); /// println!("{}", tensor.sinh()); // [ 0.0, -0.5493, 0.5493] /// } /// ``` pub fn atanh(self) -> Self { Self::new(TensorPrimitive::Float(B::float_atanh( self.primitive.tensor(), ))) } /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant. /// #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)] #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")] /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// /// let lhs = Tensor::::from_data([-2.0, 2.0, -2.0], &device); /// let rhs = Tensor::::from_data([1.0, -1.0, -1.0], &device); /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344] /// } /// ``` pub fn atan2(self, other: Self) -> Self { Self::new(TensorPrimitive::Float(B::float_atan2( self.primitive.tensor(), other.primitive.tensor(), ))) } /// Converts each of the elements of the input tensor from angles in degrees to radians. /// /// # Example /// ```ignore /// let tensor_in_radians = tensor.deg2rad(); /// ``` pub fn deg2rad(self) -> Self { self.mul_scalar(f32::consts::PI / 180.0) } /// Converts each of the elements of the input tensor from angles in radians to degrees. /// /// # Example /// ```ignore /// let tensor_in_degrees = tensor.rad2deg(); /// ``` pub fn rad2deg(self) -> Self { self.mul_scalar(180.0 / f32::consts::PI) } /// Applies element wise round operation. /// /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even) /// strategy, with halfway cases rounded to the nearest even integer value. pub fn round(self) -> Self { Self::new(TensorPrimitive::Float(B::float_round( self.primitive.tensor(), ))) } /// Applies element wise floor operation. pub fn floor(self) -> Self { Self::new(TensorPrimitive::Float(B::float_floor( self.primitive.tensor(), ))) } /// Applies element wise ceil operation. pub fn ceil(self) -> Self { Self::new(TensorPrimitive::Float(B::float_ceil( self.primitive.tensor(), ))) } /// Create a tensor from floats (f32) on a given device. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = B::Device::default(); /// let _ = Tensor::::from_floats([1.0, 2.0], &device); /// let _ = Tensor::::from_floats([[1.0, 2.0], [3.0, 4.0]], &device); /// } /// ``` pub fn from_floats>(floats: A, device: &B::Device) -> Self { Self::from_data(floats.into().convert::(), device) } /// Returns a new tensor with the same shape and device as the current tensor and the data /// cast to Integer. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = Default::default(); /// let float_tensor = Tensor::::from_floats([1.0, 2.0], &device); /// let int_tensor = float_tensor.int(); /// } /// ``` pub fn int(self) -> Tensor { Tensor::new(B::float_into_int(self.primitive.tensor())) } /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random /// values sampled from the given distribution. pub fn random_like(&self, distribution: Distribution) -> Self { Self::new(TensorPrimitive::Float(B::float_random( self.shape(), distribution, &self.device(), ))) .cast(self.dtype()) } /// Calculate the variance along the given dimension. pub fn var(self, dim: usize) -> Self { stats::var(self, dim) } /// Calculate the variance along the given dimension without applying the Bessel’s correction. pub fn var_bias(self, dim: usize) -> Self { stats::var_bias(self, dim) } /// Calculate the variance along the given dimension and also returns the mean. pub fn var_mean(self, dim: usize) -> (Self, Self) { let mean = self.clone().mean_dim(dim); let var = stats::var_with_mean(self, mean.clone(), dim); (var, mean) } /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean. pub fn var_mean_bias(self, dim: usize) -> (Self, Self) { let mean = self.clone().mean_dim(dim); let var = stats::var_with_mean_bias(self, mean.clone(), dim); (var, mean) } /// Returns the median value along the specified dimension. /// /// The median is not unique for input tensors with an even number of elements /// in the reduced dimension. In this case, the lower of the two medians is returned, /// following PyTorch's behavior. /// /// # Note /// /// The current implementation performs a full sort along the specified dimension, /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back /// to CPU for the sort operation, which may result in slower performance compared /// to native GPU operations. /// /// # Arguments /// /// - `dim` - The dimension along which to compute the median. /// /// # Returns /// /// - A tensor containing the median values along the specified dimension. /// /// # Example 1 /// /// ```ignore /// // Assuming backend B /// let device = B::Device::default(); /// let tensor = Tensor::::from_data( /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]], /// &device, /// ); /// /// // Median along dimension 0: /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0] /// let median = tensor.median(0); /// // Result: [[1.0, 4.0, 3.0, 2.0]] /// /// // Median along dimension 1: /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0] /// let median = tensor.median(1); /// // Result: [[2.0], [6.0]] /// ``` /// /// # Example 2 /// /// The median across all elements can be calculated as follows: /// /// ```ignore /// // D is the number of dimensions of the tensor /// let flattened_tensor: Tensor = tensor.flatten(0, D - 1); /// /// // Calculate median for dim 0 since the tensor has become 1 dimensional /// let median = flattened_tensor.median(0); /// // Result: [4.0] /// ``` pub fn median(self, dim: usize) -> Self { // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl // instead of leveraging a full sort to get the median. stats::median(self, dim) } /// Returns the median value along the specified dimension and its index. /// /// The median is not unique for input tensors with an even number of elements /// in the reduced dimension. In this case, the lower of the two medians is returned, /// following PyTorch's behavior. /// /// # Note /// /// The current implementation performs a full sort along the specified dimension, /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back /// to CPU for the sort operation, which may result in slower performance compared /// to native GPU operations. /// /// # Arguments /// /// - `dim` - The dimension along which to compute the median. /// /// # Returns /// /// A tuple containing: /// - A tensor with the median values. /// - A tensor with the indices of the median values in the original tensor. /// /// # Example /// /// ```ignore /// // Assuming backend B /// let device = B::Device::default(); /// let tensor = Tensor::::from_data( /// [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]], /// &device, /// ); /// /// // Median along dimension 1: /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0] /// let (values, indices) = tensor.median_with_indices(1); /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor) /// ``` pub fn median_with_indices(self, dim: usize) -> (Self, Tensor) { // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl // instead of leveraging a full sort to get the median. stats::median_with_indices(self, dim) } /// Converts a tensor to the specified floating point data type. /// /// This is always a no-op when casting to the current dtype. /// /// # Warning /// Most backends don't have automatic type promotion at this time, so make sure that all tensors /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops). pub fn cast>(self, dtype: F) -> Tensor { let dtype = dtype.into(); let self_type: FloatDType = self.dtype().into(); if dtype == self_type { // no-op. return self; } Tensor::new(TensorPrimitive::Float(B::float_cast( self.primitive.tensor(), dtype, ))) } /// Detach the current tensor from the autodiff graph. /// /// This function does nothing when autodiff is not enabled. /// This can be used in batchers or elsewhere to ensure that previous operations are not /// considered in the autodiff graph. pub fn detach(self) -> Self { Self::new(TensorPrimitive::Float(B::float_detach( self.primitive.tensor(), ))) } /// Mark the tensor to keep gradients during the backward pass. /// /// This function does nothing when autodiff is not enabled. pub fn require_grad(self) -> Self { self.set_require_grad(true) } /// Returns true if the tensor requires gradients during the backward pass. pub fn is_require_grad(&self) -> bool { match &self.primitive { TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor), TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor), } } /// Mark the tensor as tracked or untracked depending on the require_grad argument. /// When tracked, the gradients will be available after the backward pass. /// /// This function does nothing when autodiff is not enabled. pub fn set_require_grad(self, require_grad: bool) -> Self { let primitive = match self.primitive { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad)) } TensorPrimitive::QFloat(tensor) => { TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad)) } }; Self::new(primitive) } /// Applies the relu function to the tensor. pub(crate) fn relu(self) -> Self { Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor()))) } /// Calculate covaraince matrix between different entries alongside a given dimension. /// /// # Arguments /// /// * `size` - The size of the square matrix. /// * `correction_factor` - Is usually 1 for samples and 0 for population. pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor { let n = self.dims()[dim]; let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0); centered .clone() .transpose() .matmul(centered) .div_scalar(n as f32 - correction_factor as f32) } /// Convert the tensor to a lower precision data type based on the quantization scheme. /// /// # Arguments /// /// * `scheme` - The quantization scheme. /// * `qparams` - The pre-computed quantization parameters. /// /// # Returns /// /// The quantized tensor. pub fn quantize( self, scheme: &QuantScheme, qparams: QuantizationParameters, ) -> Tensor { Tensor::new(TensorPrimitive::QFloat(B::quantize( self.primitive.tensor(), scheme, QuantizationParametersPrimitive { scales: qparams.scales.primitive.tensor(), }, ))) } /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. /// /// # Arguments /// /// * `scheme` - The quantization scheme. /// /// # Returns /// /// The quantized tensor. /// /// # Notes /// This uses [min-max calibration](crate::quantization::Calibration::MinMax). pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor { Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( self.primitive.tensor(), scheme, ))) } /// Convert the tensor back to a higher precision data type. /// /// If the tensor is not quantized, its value is simply returned. /// /// # Returns /// /// The dequantized tensor. pub fn dequantize(self) -> Tensor { Tensor::new(TensorPrimitive::Float(self.primitive.tensor())) } /// Checks element wise if the tensor is close to another tensor. /// /// The tolerance is defined by the following equation: /// /// ```text /// abs(a - b) <= (atol + rtol * abs(b)) /// /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance, /// and `atol` is the absolute tolerance. /// ``` /// /// # Arguments /// /// * `other` - The tensor to compare with. /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`. /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensors. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor1.is_close(tensor2, None, None); /// println!("{tensor}"); /// // [[true, true, true], [true, true, true]] /// } /// ``` pub fn is_close(self, other: Self, rtol: Option, atol: Option) -> Tensor { let rtol = rtol.unwrap_or(DEFAULT_RTOL); let atol = atol.unwrap_or(DEFAULT_ATOL); // check finite difference is close let is_close_finite_val = self .clone() .sub(other.clone()) .abs() .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol)) .bool_and(self.clone().is_finite()) .bool_and(other.clone().is_finite()); // check if both are infinite and have same sign let inf_same_sign = self .clone() .is_finite() .bool_not() .bool_and(other.clone().is_finite().bool_not()) .bool_and(self.equal(other)); is_close_finite_val.bool_or(inf_same_sign) } /// Checks if all elements are close to another tensor. /// /// The tolerance is defined by the following equation: /// /// ```text /// /// abs(a - b) <= (atol + rtol * abs(b)) /// /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance, /// and `atol` is the absolute tolerance. /// /// ``` /// /// # Arguments /// /// * `other` - The tensor to compare with. /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`. /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`. /// /// # Returns /// /// A boolean scalar. /// /// # Remarks /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let result = tensor1.all_close(tensor2, None, None); /// println!("{}", result); /// // true /// } /// ``` pub fn all_close(self, other: Self, rtol: Option, atol: Option) -> bool { self.is_close(other, rtol, atol) .all() .into_scalar() .to_bool() } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// /// # Returns /// /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.is_nan(); /// println!("{tensor}"); /// // [[false, true, false], [false, false, false]] /// } /// ``` pub fn is_nan(self) -> Tensor { Tensor::new(B::float_is_nan(self.primitive.tensor())) } /// Checks if the tensor contains any NaN values. /// /// # Returns /// /// A boolean tensor with a single element indicating whether the tensor contains any NaN values. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device); /// let tensor = tensor.contains_nan(); /// println!("{tensor}"); /// // [true] /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.contains_nan(); /// println!("{tensor}"); /// // [false] /// } /// ``` pub fn contains_nan(self) -> Tensor { // Summing the tensor will result in NaN if the tensor contains any NaN values // This is faster than checking each element individually // because it rolls up the NaN values into a single value let sum = self.sum(); sum.is_nan() } /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF). /// /// # Returns /// /// A boolean tensor where `true` indicates that the value is infinite /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device); /// let tensor = tensor.is_finite(); /// println!("{tensor}"); /// // [[false, true, false], [false, false, false]] /// } /// ``` pub fn is_inf(self) -> Tensor { Tensor::new(B::float_is_inf(self.primitive.tensor())) } /// Returns a new tensor with boolean elements indicating whether each element of the input is finite /// /// # Returns /// /// A boolean tensor where `true` indicates that the value is finite and `false` indicates /// either INF, -INF or NAN /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Bool, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device); /// let tensor = tensor.is_finite(); /// println!("{tensor}"); /// // [[true, false, true], [false, true, true]] /// } /// ``` pub fn is_finite(self) -> Tensor { self.clone() .is_nan() .bool_not() .bool_and(self.is_inf().bool_not()) } /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values, /// using the given locations in [-1, 1]. /// /// # Arguments /// /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right /// * `options` - Grid sampling options (mode, padding_mode, align_corners) /// /// # Returns /// /// A tensor with shape (N, C, H_out, W_out) /// /// # Example /// /// ```ignore /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; /// /// // Default options (bilinear, zeros padding, align_corners=false) /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default()); /// /// // Custom options /// let options = GridSampleOptions::new(InterpolateMode::Bilinear) /// .with_padding_mode(GridSamplePaddingMode::Border) /// .with_align_corners(true); /// let output = tensor.grid_sample_2d(grid, options); /// ``` pub fn grid_sample_2d( self, grid: Tensor, options: impl Into, ) -> Tensor { Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d( self.primitive.tensor(), grid.primitive.tensor(), options.into(), ))) } /// Computes the cross product of `self` and another tensor along a given dimension. /// /// Both `self` and `other` **must have size 3** along the specified `dim`, /// because the cross product is only defined in three-dimensional space. /// /// # Arguments /// /// * `other` - The other tensor to take the cross product with. /// * `dim` - The dimension along which to compute the cross product. /// /// # Returns /// /// A tensor containing the cross product of `self` and `other` along `dim`. pub fn cross(self, other: Tensor, dim: Dim) -> Tensor { let dim = dim.expect_dim_index(D); check!(TensorCheck::cross(&self, &other, dim)); Tensor::new(TensorPrimitive::Float(B::float_cross( self.primitive.tensor(), other.primitive.tensor(), dim, ))) } /// Applies element wise power operation with a float Tensor /// /// # Arguments /// /// * `other` - The tensor to apply the power operation with. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.powf(tensor2); /// println!("{tensor}"); /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]] /// } /// ``` pub fn powf(self, other: Self) -> Self { let primitive = match (self.primitive, other.primitive) { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_powf(lhs, rhs)) } (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs), (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_powf(B::dequantize(lhs), rhs)) } (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => { TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs))) } }; Tensor::new(primitive) } /// Applies element wise power operation with a float scalar /// /// # Arguments /// /// * `other` - The scalar to apply the power operation with. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.powf_scalar(2.0); /// println!("{tensor}"); /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]] /// } /// ``` pub fn powf_scalar(self, other: E) -> Self { let rhs = Scalar::new(other, &self.dtype()); let primitive = match self.primitive { TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)), TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs), }; Tensor::new(primitive) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/fmod.rs ================================================ use crate::{Float, Tensor, backend::Backend}; impl Tensor where B: Backend, { /// Computes the floating-point remainder of dividing `self` by `other`. /// /// The result has the same sign as `self` and magnitude less than `other`. /// This is equivalent to the IEEE 754 remainder operation. /// /// # Special Cases (IEEE 754 compliant) /// /// - If `self` is ±∞ and `other` is not NaN, NaN is returned /// - If `other` is ±0 and `self` is not NaN, NaN is returned /// - If `other` is ±∞ and `self` is finite, `self` is returned /// - If either argument is NaN, NaN is returned /// /// # Arguments /// /// * `other` - The divisor tensor. Must have the same shape as `self`. /// /// # Returns /// /// A tensor with the same shape where each element is the floating-point remainder. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = B::Device::default(); /// let dividend = Tensor::::from_data([5.3, -5.3, 5.3, -5.3], &device); /// let divisor = Tensor::::from_data([2.0, 2.0, -2.0, -2.0], &device); /// let result = dividend.fmod(divisor); /// /// // Result: [1.3, -1.3, 1.3, -1.3] /// } /// ``` pub fn fmod(self, other: Self) -> Self { // Normal case: fmod(x, y) = x - y * trunc(x / y) let quotient = self.clone().div(other.clone()); let truncated = quotient.trunc(); let product = other.clone() * truncated.clone(); // When divisor is infinity and dividend is finite: // - quotient is 0, truncated is 0 // - but 0 * infinity = NaN, which is wrong // We need to handle this case by replacing NaN with 0 when appropriate // Check if the product is NaN due to 0 * inf let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf()); let zero_tensor = self.clone().mul_scalar(0.0); let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor); self - corrected_product } /// Computes the floating-point remainder of dividing `self` by a scalar. /// /// The result has the same sign as `self` and magnitude less than the scalar. /// /// # Special Cases (IEEE 754 compliant) /// /// - If `self` is ±∞ and scalar is not NaN, NaN is returned /// - If scalar is ±0 and `self` is not NaN, NaN is returned /// - If scalar is ±∞ and `self` is finite, `self` is returned /// - If either argument is NaN, NaN is returned /// /// # Arguments /// /// * `scalar` - The scalar divisor. /// /// # Returns /// /// A tensor with the same shape where each element is the floating-point remainder. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([5.3, -5.3, 7.5, -7.5], &device); /// let result = tensor.fmod_scalar(2.0); /// /// // Result: [1.3, -1.3, 1.5, -1.5] /// } /// ``` pub fn fmod_scalar(self, scalar: f32) -> Self { // Normal case: fmod(x, y) = x - y * trunc(x / y) let quotient = self.clone().div_scalar(scalar); let truncated = quotient.trunc(); let product = truncated.mul_scalar(scalar); // Handle the special case where scalar is infinity // When scalar is ±∞ and self is finite, quotient is 0, truncated is 0 // but 0 * infinity = NaN, which is wrong - it should be 0 if scalar.is_infinite() { // For finite values, fmod(x, ±∞) = x // For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic) return self; } self - product } } ================================================ FILE: crates/burn-tensor/src/tensor/api/int.rs ================================================ use burn_backend::Scalar; use crate::{ Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend, cartesian_grid, }; use core::ops::Range; impl Tensor where B: Backend, { /// Returns a new integer tensor on the specified device. /// /// # Arguments /// /// * `range` - The range of values to generate. /// * `device` - The device to create the tensor on. pub fn arange(range: Range, device: &B::Device) -> Self { Tensor::new(B::int_arange(range, device)) } /// Returns a new integer tensor on the specified device. /// /// # Arguments /// /// * `range` - The range of values to generate. /// * `step` - The step between each value. pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } } impl Tensor where B: Backend, { /// Create a tensor from integers (i32), placing it on a given device. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Int}; /// /// fn example() { /// let device = B::Device::default(); /// let _x: Tensor = Tensor::from_ints([1, 2], &device); /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]], &device); /// } /// ``` pub fn from_ints>(ints: A, device: &B::Device) -> Self { Self::from_data(ints.into().convert::(), device) } /// Returns a new tensor with the same shape and device as the current tensor and the data /// cast to Float. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let int_tensor = Tensor::::arange(0..5, &device); /// let float_tensor = int_tensor.float(); /// } /// ``` pub fn float(self) -> Tensor { Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive))) } /// Generates a cartesian grid for the given tensor shape on the specified device. /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element. /// /// # Arguments /// /// * `shape` - The shape specifying the dimensions of the tensor. /// * `device` - The device to create the tensor on. /// /// # Panics /// /// Panics if `D2` is not equal to `D+1`. /// /// # Examples /// /// ```rust /// use burn_tensor::Int; /// use burn_tensor::{backend::Backend, Shape, Tensor}; /// fn example() { /// let device = Default::default(); /// let result: Tensor = Tensor::::cartesian_grid([2, 3], &device); /// println!("{}", result); /// } /// ``` pub fn cartesian_grid, const D2: usize>( shape: S, device: &B::Device, ) -> Tensor { cartesian_grid::(shape, device) } /// Applies the bitwise logical and operation with each bit representing the integer. pub fn bitwise_and(self, other: Self) -> Self { Self::new(B::bitwise_and(self.primitive, other.primitive)) } /// Applies the bitwise logical or operation with another tensor. pub fn bitwise_or(self, other: Self) -> Self { Self::new(B::bitwise_or(self.primitive, other.primitive)) } /// Applies the bitwise logical xor operation with another tensor. pub fn bitwise_xor(self, other: Self) -> Self { Self::new(B::bitwise_xor(self.primitive, other.primitive)) } /// Applies the bitwise logical not operation. pub fn bitwise_not(self) -> Self { Self::new(B::bitwise_not(self.primitive)) } /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor. pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(B::bitwise_and_scalar(self.primitive, other)) } /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor. pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(B::bitwise_or_scalar(self.primitive, other)) } /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor. pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(B::bitwise_xor_scalar(self.primitive, other)) } /// Applies the bitwise left shift operation with the integers in the tensor. pub fn bitwise_left_shift(self, other: Self) -> Self { Self::new(B::bitwise_left_shift(self.primitive, other.primitive)) } /// Applies the bitwise right shift operation with the integers in the tensor. pub fn bitwise_right_shift(self, other: Self) -> Self { Self::new(B::bitwise_right_shift(self.primitive, other.primitive)) } /// Applies the bitwise left shift operation with the scalar. pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(B::bitwise_left_shift_scalar(self.primitive, other)) } /// Applies the bitwise right shift operation with the scalar. pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) } /// Converts a tensor to the specified integer data type. /// /// This is always a no-op when casting to the current dtype. /// /// # Warning /// Most backends don't have automatic type promotion at this time, so make sure that all tensors /// have the same integer data type for operations multiple input tensors (e.g., binary ops). pub fn cast>(self, dtype: F) -> Tensor { let dtype = dtype.into(); let self_dtype: IntDType = self.dtype().into(); if dtype == self_dtype { // no-op. return self; } Tensor::new(B::int_cast(self.primitive, dtype)) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/mod.rs ================================================ pub(crate) mod check; mod autodiff; mod base; mod bool; mod cartesian_grid; mod float; mod fmod; mod int; mod numeric; mod options; mod orderable; mod pad; pub use pad::IntoPadding; mod take; mod transaction; mod trunc; pub use autodiff::*; pub use base::*; pub use cartesian_grid::cartesian_grid; pub use float::{DEFAULT_ATOL, DEFAULT_RTOL}; pub use numeric::*; pub use options::*; pub use transaction::*; pub use burn_backend::tensor::IndexingUpdateOp; ================================================ FILE: crates/burn-tensor/src/tensor/api/numeric.rs ================================================ use burn_backend::Scalar; pub use burn_backend::tensor::Numeric; use crate::alloc::borrow::ToOwned; use alloc::vec::Vec; use crate::IndexingUpdateOp; use crate::{ AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend, check, check::TensorCheck, }; impl Tensor where B: Backend, K: Numeric, K::Elem: Element, { /// Applies element wise addition operation. /// /// `y = x2 + x1` /// /// # Arguments /// /// * `other` - The tensor to add. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1 + tensor2; /// println!("{tensor}"); /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]] /// } /// ``` #[allow(clippy::should_implement_trait)] pub fn add(self, other: Self) -> Self { check!(TensorCheck::binary_ops_ew("Add", &self, &other)); Self::new(K::add(self.primitive, other.primitive)) } /// Applies element wise addition operation with a scalar. /// /// `y = x + s` /// /// # Arguments /// /// * `other` - The scalar to add, element wise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let scalar = 2.0; /// let tensor = tensor + scalar; /// println!("{tensor}"); /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]] /// } /// ``` pub fn add_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::add_scalar(self.primitive, other)) } /// Applies element wise subtraction operation. /// /// `y = x2 - x1` /// /// # Arguments /// /// * `other` - The tensor to subtract. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1 - tensor2; /// println!("{tensor}"); /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]] /// } /// ``` #[allow(clippy::should_implement_trait)] pub fn sub(self, other: Self) -> Self { check!(TensorCheck::binary_ops_ew("Sub", &self, &other)); Self::new(K::sub(self.primitive, other.primitive)) } /// Applies element wise subtraction operation with a scalar. /// /// `y = x - s` /// /// # Arguments /// /// * `other` - The scalar to subtract, element wise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let scalar = 2.0; /// let tensor = tensor - scalar; /// println!("{tensor}"); /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]] /// } /// ``` pub fn sub_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::sub_scalar(self.primitive, other)) } /// Applies element wise division operation. /// /// `y = x2 / x1` /// /// # Arguments /// /// * `other` - The tensor to divide. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1 / tensor2; /// println!("{tensor}"); /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]] /// } /// ``` #[allow(clippy::should_implement_trait)] pub fn div(self, other: Self) -> Self { check!(TensorCheck::binary_ops_ew("Div", &self, &other)); Self::new(K::div(self.primitive, other.primitive)) } /// Applies element wise division operation with a scalar. /// /// `y = x / s` /// /// # Arguments /// /// * `other` - The scalar to divide, element wise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let scalar = 2.0; /// let tensor = tensor / scalar; /// println!("{tensor}"); /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]] /// } /// ``` pub fn div_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::div_scalar(self.primitive, other)) } /// Applies element wise the remainder operation with a scalar. /// /// `y = x2 % x1` pub fn remainder(self, other: Self) -> Self { Self::new(K::remainder(self.primitive, other.primitive)) } /// Applies element wise the remainder operation with a scalar. /// /// `y = x % s` /// /// # Arguments /// /// * `other` - The scalar to divide, element wise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let scalar = 2.0; /// let tensor = tensor1 % scalar; /// println!("{tensor}"); /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]] /// } /// ``` pub fn remainder_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::remainder_scalar(self.primitive, other)) } /// Applies element wise multiplication operation. /// /// `y = x2 * x1` /// /// # Arguments /// /// * `other` - The tensor to multiply. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1 * tensor2; /// println!("{tensor}"); /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]] /// } /// ``` #[allow(clippy::should_implement_trait)] pub fn mul(self, other: Self) -> Self { check!(TensorCheck::binary_ops_ew("Mul", &self, &other)); Self::new(K::mul(self.primitive, other.primitive)) } /// Applies element wise multiplication operation with a scalar. /// /// `y = x * s` /// /// # Arguments /// /// * `other` - The scalar to multiply, element wise. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let scalar = 2.0; /// let tensor = tensor * scalar; /// println!("{tensor}"); /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]] /// } /// ``` pub fn mul_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::mul_scalar(self.primitive, other)) } /// Switch sign of each element in the tensor. /// /// `y = -x` /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = -tensor; /// println!("{tensor}"); /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]] /// } /// ``` #[allow(clippy::should_implement_trait)] pub fn neg(self) -> Self { Self::new(K::neg(self.primitive)) } /// Returns the signs of the elements of the input tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.sign(); /// println!("{tensor}"); /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]] /// } /// ``` pub fn sign(self) -> Self { Self::new(K::sign(self.primitive)) } /// Aggregate all elements in the tensor with the mean operation. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.mean(); /// println!("{tensor}"); /// // [3.6666667] /// } /// ``` pub fn mean(self) -> Tensor { Tensor::new(K::mean(self.primitive)) } /// Aggregate all elements in the tensor with the sum operation. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.sum(); /// println!("{tensor}"); /// // [22.0] /// } /// ``` pub fn sum(self) -> Tensor { Tensor::new(K::sum(self.primitive)) } /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the mean operation. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements; /// supports negative indexing. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.clone().mean_dim(0); /// println!("{tensor}"); /// // [[3.0, 3.5, 4.5]] /// let tensor = tensor.clone().mean_dim(1); /// println!("{tensor}"); /// // [[0.6666667], [6.6666665]] /// } /// ``` pub fn mean_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Mean", dim)); Self::new(K::mean_dim(self.primitive, dim)) } /// Aggregate all elements along the given *axes* /// in the tensor with the mean operation. /// /// # Arguments /// /// * `dims` - the dimensions to aggregate; supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[2.0, 4.0], [6.0, -4.0]], &device); /// let tensor = tensor.clone().mean_dims(&[0, 1]); /// println!("{tensor}"); /// // [[2.0]] /// } /// ``` pub fn mean_dims(self, dims: &[I]) -> Self { dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim)) } /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the sum operation. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements; /// supports negative indexing. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.clone().sum_dim(0); /// println!("{tensor}"); /// // [[6.0, 7.0, 9.0]] /// let tensor = tensor.clone().sum_dim(1); /// println!("{tensor}"); /// // [[2.0], [20.0]] /// } /// ``` pub fn sum_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Sum", dim)); Self::new(K::sum_dim(self.primitive, dim)) } /// Aggregate all elements along the given *axes* /// in the tensor with the sum operation. /// /// # Arguments /// /// * `dims` - the dimensions to aggregate; supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.clone().sum_dims(&[0, 1]); /// println!("{tensor}"); /// // [[27]] /// } /// ``` pub fn sum_dims(self, dims: &[I]) -> Self { dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim)) } /// Aggregate and squeeze along the given dimensions. /// /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)`` /// /// # Arguments /// /// * `dims` - the dimensions to aggregate; supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([ /// [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], /// [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]], /// ], &device); /// let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]); /// println!("{tensor}"); /// // [20.0, 16.0, 21.0] /// } /// ``` pub fn sum_dims_squeeze(self, dims: &[I]) -> Tensor { // TODO: remove idims when squeeze_dims uses AsIndex. let idims = dims .iter() .map(|&dim| (dim.expect_dim_index(D)) as isize) .collect::>(); self.sum_dims(dims).squeeze_dims::(&idims) } /// Aggregate all elements in the tensor with the product operation. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.prod(); /// println!("{tensor}"); /// // [-1620.0] /// } /// ``` pub fn prod(self) -> Tensor { Tensor::new(K::prod(self.primitive)) } /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the product operation. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements, /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimension will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.clone().prod_dim(0); /// println!("{tensor}"); /// // [[5.0, -18.0, 18.0]] /// let tensor = tensor.clone().prod_dim(1); /// println!("{tensor}"); /// // [[-6.0], [270.0]] /// } /// ``` pub fn prod_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Prod", dim)); Self::new(K::prod_dim(self.primitive, dim)) } /// Aggregate all elements along the given *axes* /// in the tensor with the prod operation. /// /// # Arguments /// /// * `dims` - the dimensions to aggregate, supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.clone().sum_dims(&[0, 1]); /// println!("{tensor}"); /// // [[-1620.0]] /// } /// ``` pub fn prod_dims(self, dims: &[I]) -> Self { dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim)) } /// Computes the cumulative sum of elements along the given *dimension* or *axis*. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to compute the cumulative sum. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); /// let result = tensor.clone().cumsum(0); /// println!("{result}"); /// // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]] /// let result = tensor.cumsum(1); /// println!("{result}"); /// // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]] /// } /// ``` pub fn cumsum(self, dim: usize) -> Self { check!(TensorCheck::aggregate_dim::("CumSum", dim)); Self::new(K::cumsum(self.primitive, dim)) } /// Computes the cumulative product of elements along the given *dimension* or *axis*. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to compute the cumulative product. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); /// let result = tensor.clone().cumprod(0); /// println!("{result}"); /// // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]] /// let result = tensor.cumprod(1); /// println!("{result}"); /// // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]] /// } /// ``` pub fn cumprod(self, dim: usize) -> Self { check!(TensorCheck::aggregate_dim::("CumProd", dim)); Self::new(K::cumprod(self.primitive, dim)) } /// Apply element wise absolute value operation. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device); /// let tensor = tensor.abs(); /// println!("{tensor}"); /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]] /// } /// ``` /// /// # Notes /// /// For signed integer dtypes, this operation uses two's-complement wraparound semantics, similar to /// `x.wrapping_abs()`. For example, `abs(i64::MIN) == i64::MIN`. pub fn abs(self) -> Self { Self::new(K::abs(self.primitive)) } /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, /// the other elements of the result tensor out are set to 0. /// /// See also [`triu_mask`](Tensor::triu_mask). /// /// # Arguments /// /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift /// towards the upper triangle. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints( /// [ /// [1, 2, 3], /// [4, 5, 6], /// [7, 8, 9] /// ], /// &device /// ); /// let tensor = tensor.triu(1); /// println!("{tensor}"); /// // [ /// // [0, 2, 3], /// // [0, 0, 6], /// // [0, 0, 0] /// // ] /// } /// ``` pub fn triu(self, diagonal: i64) -> Self { check!(TensorCheck::tri::<{ D }>()); // last two dimensions let shape = &self.shape()[D - 2..].to_owned(); let mask = Tensor::::triu_mask(shape, diagonal, &self.device()).unsqueeze(); self.mask_fill(mask, 0) } /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, /// the other elements of the result tensor out are set to 0. /// /// See also [`tril_mask`](Tensor::tril_mask). /// /// # Arguments /// /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift /// towards the upper triangle. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints( /// [ /// [1, 2, 3], /// [4, 5, 6], /// [7, 8, 9] /// ], /// &device /// ); /// /// let tensor = tensor.tril(-1); /// println!("{tensor}"); /// // [ /// // [0, 0, 0], /// // [4, 0, 0], /// // [7, 8, 0] /// // ] /// } /// ``` pub fn tril(self, diagonal: i64) -> Self { check!(TensorCheck::tri::<{ D }>()); // last two dimensions let shape = &self.shape()[D - 2..].to_owned(); let mask = Tensor::::tril_mask(shape, diagonal, &self.device()).unsqueeze(); self.mask_fill(mask, 0) } /// Applies element wise power operation with a integer Tensor /// /// # Arguments /// /// * `other` - The tensor to apply the power operation with. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, Int}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_ints([[1, -2, 3], [5, 9, 6]], &device); /// let tensor2 = Tensor::::from_ints([[2, 3, 4], [1, 2, 3]], &device); /// let tensor = tensor1.powi(tensor2); /// println!("{tensor}"); /// // [[1, -8, 81], [5, 81, 216]] /// } /// ``` pub fn powi(self, other: Self) -> Self { Self::new(K::powi(self.primitive, other.primitive)) } /// Applies element wise power operation with a integer scalar /// /// # Arguments /// /// * `other` - The scalar to apply the power operation with. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, Int}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_ints([[1, -2, 3], [5, 9, 6]], &device); /// let tensor = tensor.powi_scalar(2); /// println!("{tensor}"); /// /// // [[1, 4, 9], [25, 81, 36]] /// let tensor = Tensor::::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device); /// let tensor = tensor.powi_scalar(2); /// println!("{tensor}"); /// // [[2.25, 4., 9.], [25., 81., 36.]] /// } /// ``` pub fn powi_scalar(self, other: E) -> Self { let other = Scalar::new(other, &self.dtype()); Self::new(K::powi_scalar(self.primitive, other)) } /// Converts the tensor to a boolean tensor by checking if the elements are non-zero. /// /// # Returns /// /// A boolean tensor with the same shape as the input tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device); /// let tensor = tensor.bool(); /// println!("{tensor}"); /// // [ /// // [true, true, true], /// // [false, true, true] /// // ] /// } pub fn bool(self) -> Tensor { self.not_equal_elem(0) } /// Create a random tensor of the given shape on the given device where each element is /// sampled from the given distribution. /// /// See also [`random_like`](Tensor::random_like). /// /// # Arguments /// /// * `shape` - The shape of the tensor. /// * `distribution` - The distribution to sample from. /// * `device` - The device to create the tensor on. /// /// # Returns /// /// A new tensor with the given shape and elements sampled from the given distribution. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape, Distribution}; /// /// fn example() { /// let device = B::Device::default(); /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0 /// let tensor = Tensor::::random(Shape::new([2, 3]), distribution, &device); /// println!("{tensor}"); /// // [ /// // [0.08347523, 0.70498955, 0.60332155], /// // [0.08173251, 0.18028641, 0.97942924] /// // ] /// } /// ``` pub fn random>( shape: S, distribution: Distribution, device: &B::Device, ) -> Self { Self::new(K::random(shape.into(), distribution, device)) } /// Applies the matrix multiplication operation. /// /// ```math /// C = AB /// ``` /// /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general /// matmul, which is often faster. pub fn matmul(self, other: Self) -> Self { check!(TensorCheck::matmul(&self, &other)); if D >= 3 { let batch_index = D - 3; let vector_index = D - 2; let lhs_dims = &self.shape()[batch_index..D]; let rhs_dims = &other.shape()[batch_index..D]; if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims) && k1 == k2 { return Tensor::new(K::matmul( self.swap_dims(batch_index, vector_index).primitive, other.primitive, )) .swap_dims(batch_index, vector_index); } } Tensor::new(K::matmul(self.primitive, other.primitive)) } } impl Tensor where B: Backend, K: Numeric, K::Elem: Element, { /// Calculates the dot product with another tensor. /// /// `y = x2.dot(x1)` /// /// # Arguments /// /// * `other` - The tensor to compute dot product with. /// /// # Notes /// /// Both tensors must have the same number of elements. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([1.0, 2.0], &device); /// let tensor2 = Tensor::::from_data([-2.0, 3.0], &device); /// let tensor = tensor1.dot(tensor2); /// println!("{tensor}"); /// // [4] /// } /// ``` pub fn dot(self, other: Self) -> Self { self.mul(other).sum() } } impl Tensor where B: Backend, K: Numeric, K::Elem: Element, { /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere. /// /// # Arguments /// /// * `size` - The size of the square matrix. pub fn eye(size: usize, device: &B::Device) -> Self { let indices = Tensor::::arange(0..size as i64, device).unsqueeze::<2>(); let ones = Self::ones([1, size], device); let zeros = Self::zeros([size, size], device); zeros.scatter(0, indices, ones, IndexingUpdateOp::Add) } } // Tensor + tensor impl> core::ops::Add for Tensor where K::Elem: Element, { type Output = Self; fn add(self, rhs: Self) -> Self::Output { Self::add(self, rhs) } } // Tensor + scalar impl> core::ops::Add for Tensor where K::Elem: Element, { type Output = Self; fn add(self, other: E) -> Self::Output { Tensor::add_scalar(self, other) } } // Scalar + tensor macro_rules! impl_tensor_scalar_add { ($($t:ty),*) => { $( impl> core::ops::Add> for $t where K::Elem: Element, { type Output = Tensor; fn add(self, tensor: Tensor) -> Self::Output { Tensor::add_scalar(tensor, self) } } )* } } impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64); // Tensor - tensor impl> core::ops::Sub for Tensor where K::Elem: Element, { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { Tensor::sub(self, rhs) } } // Tensor - scalar impl> core::ops::Sub for Tensor where K::Elem: Element, { type Output = Self; fn sub(self, other: E) -> Self::Output { Tensor::sub_scalar(self, other) } } // Scalar - tensor macro_rules! impl_tensor_scalar_sub { ($($t:ty),*) => { $( impl> core::ops::Sub> for $t where K::Elem: Element, { type Output = Tensor; fn sub(self, tensor: Tensor) -> Self::Output { Tensor::add_scalar(Tensor::neg(tensor), self) } } )* } } impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64); // Tensor / tensor impl> core::ops::Div for Tensor where K::Elem: Element, { type Output = Self; fn div(self, rhs: Self) -> Self::Output { Tensor::div(self, rhs) } } // Tensor / scalar impl> core::ops::Div for Tensor where K::Elem: Element, { type Output = Self; fn div(self, other: E) -> Self::Output { Tensor::div_scalar(self, other) } } // Scalar / tensor (float only) macro_rules! impl_tensor_scalar_div { ($($t:ty),*) => { $( impl core::ops::Div> for $t { type Output = Tensor; fn div(self, tensor: Tensor) -> Self::Output { tensor.recip().mul_scalar(self) } } )* } } impl_tensor_scalar_div!(f32, f64); // Tensor % tensor. impl> core::ops::Rem for Tensor where K::Elem: Element, { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { Tensor::remainder(self, rhs) } } // Tensor % scalar. impl> core::ops::Rem for Tensor where K::Elem: Element, { type Output = Self; fn rem(self, other: E) -> Self::Output { Tensor::remainder_scalar(self, other) } } // Tensor * tensor. impl> core::ops::Mul for Tensor where K::Elem: Element, { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { Tensor::mul(self, rhs) } } // Tensor * scalar. impl> core::ops::Mul for Tensor where K::Elem: Element, { type Output = Self; fn mul(self, other: E) -> Self::Output { Tensor::mul_scalar(self, other) } } macro_rules! impl_tensor_scalar_mul { ($($t:ty),*) => { $( impl> core::ops::Mul> for $t where K::Elem: Element, { type Output = Tensor; fn mul(self, other: Tensor) -> Self::Output { Tensor::mul_scalar(other, self) } } )* } } impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64); impl core::ops::Neg for Tensor where B: Backend, K: Numeric, K::Elem: Element, { type Output = Self; fn neg(self) -> Self::Output { Tensor::neg(self) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/options.rs ================================================ use burn_backend::{ Backend, Element, tensor::{BasicOps, Device}, }; use burn_std::DType; use crate::get_device_policy; /// Options for tensor creation. /// /// This struct allows specifying the `device` and overriding the data type when creating a tensor. /// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used. #[derive(Debug, Clone)] pub struct TensorCreationOptions { /// Device where the tensor will be created. pub device: Device, /// Optional data type. /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes). pub dtype: Option, } impl Default for TensorCreationOptions { /// Returns new options with the backend's default device. fn default() -> Self { Self::new(Default::default()) } } impl TensorCreationOptions { /// Create new options with a specific device. /// /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation. pub fn new(device: Device) -> Self { Self { device, dtype: None, } } /// Set the tensor creation data type. pub fn with_dtype(mut self, dtype: DType) -> Self { self.dtype = Some(dtype); self } /// Set the tensor creation device. pub fn with_device(mut self, device: Device) -> Self { self.device = device; self } /// Create options with backend's default device and float dtype. pub fn float() -> Self { Self::default().with_dtype(::dtype()) } /// Create options with backend's default device and int dtype. pub fn int() -> Self { Self::default().with_dtype(::dtype()) } /// Create options with backend's default device and bool dtype. pub fn bool() -> Self { Self::default().with_dtype(::dtype()) } /// Returns the tensor data type, or a provided default if not set. /// /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`. pub fn dtype_or(&self, dtype: DType) -> DType { self.dtype.unwrap_or(dtype) } /// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes). pub(crate) fn resolve_policy>(&self) -> DType { let dtype = K::Elem::dtype(); let kind_name = K::name(); // TODO: tensor kind enum? self.dtype.unwrap_or_else(|| { let policy = get_device_policy(&self.device); if dtype.is_float() && kind_name == "Float" && let Some(float_dtype) = policy.float_dtype() { float_dtype.into() } else if (dtype.is_int() || dtype.is_uint()) && kind_name == "Int" && let Some(int_dtype) = policy.int_dtype() { int_dtype.into() } else { // If policy was not explicitly set, use the fallback dtype (default backend elem type) dtype } }) } } impl From<&Device> for TensorCreationOptions { /// Convenience conversion from a reference to a device. /// /// Example: /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::TensorCreationOptions; /// /// fn example(device: B::Device) { /// let options: TensorCreationOptions = (&device).into(); /// } /// ``` fn from(device: &Device) -> Self { TensorCreationOptions::new(device.clone()) } } impl From<(&Device, DType)> for TensorCreationOptions { /// Convenience conversion for a specified `(&device, dtype)` tuple. fn from(args: (&Device, DType)) -> Self { TensorCreationOptions::new(args.0.clone()).with_dtype(args.1) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/orderable.rs ================================================ use burn_backend::{ Backend, ElementConversion, Scalar, tensor::{Bool, IndexingUpdateOp, Int, Ordered}, }; use burn_std::AsIndex; use crate::check; use crate::{Tensor, check::TensorCheck}; impl Tensor where B: Backend, K: Ordered, { /// Sort the elements by value in ascending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Returns /// /// A new tensor with the elements sorted in ascending order along the given dimension. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let tensor = tensor.sort(0); /// println!("{tensor}"); /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]] /// let tensor = tensor.sort(1); /// println!("{tensor}"); /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]] /// } /// ``` pub fn sort(self, dim: usize) -> Self { check!(TensorCheck::sort_dim::("Sort", dim)); Tensor::new(K::sort(self.primitive, dim, /*descending*/ false)) } /// Sort the elements by value in descending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Returns /// /// A new tensor with the elements sorted in descending order along the given dimension. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let tensor = tensor.sort_descending(0); /// println!("{tensor}"); /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]] /// let tensor = tensor.sort_descending(1); /// println!("{tensor}"); /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]] /// } /// ``` pub fn sort_descending(self, dim: usize) -> Self { check!(TensorCheck::sort_dim::("Sort", dim)); Tensor::new(K::sort(self.primitive, dim, /*descending*/ true)) } /// Sort the elements by value in ascending order along a given dimension. /// Also returns the indices. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Returns /// /// A tuple containing the sorted tensor and the indices tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let (tensor, indices) = tensor.sort_with_indices(0); /// println!("{tensor}"); /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]] /// println!("{}", indices); /// // [[1, 0, 0], [0, 1, 1]] /// } /// ``` pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor) { check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ false); (Tensor::new(values), Tensor::new(indices)) } /// Sort the elements by value in descending order along a given dimension. /// Also returns the indices. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let (tensor, indices) = tensor.sort_descending_with_indices(0); /// println!("{tensor}"); /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]] /// println!("{}", indices); /// // [[0, 1, 1], [1, 0, 0]] /// } /// ``` pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor) { check!(TensorCheck::sort_dim::("Sort_with_indices", dim)); let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true); (Tensor::new(values), Tensor::new(indices)) } /// Returns the indices that sort the elements by value in ascending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let tensor = tensor.argsort(0); /// println!("{tensor}"); /// // [[1, 0, 0], [0, 1, 1]] /// } /// ``` pub fn argsort(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Argsort", dim)); Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false)) } /// Returns the indices that sort the elements by value in descending order along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). /// /// # Arguments /// /// * `dim` - The dimension to sort along. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let tensor = tensor.argsort_descending(0); /// println!("{tensor}"); /// // [[0, 1, 1], [1, 0, 0]] /// let tensor = tensor.argsort_descending(1); /// println!("{tensor}"); /// // [[0, 2, 1], [2, 0, 1]] /// } /// ``` pub fn argsort_descending(self, dim: usize) -> Tensor { check!(TensorCheck::sort_dim::("Argsort", dim)); Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true)) } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// /// # Arguments /// /// * `k` - The number of elements to return. /// /// # Returns /// /// A new tensor with the `k` largest elements along the given dimension. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let tensor = tensor.topk(2, 0); /// println!("{tensor}"); /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]] /// let tensor = tensor.topk(1, 1); /// println!("{tensor}"); /// // [[12.0], [6.0]] /// } /// ``` pub fn topk(self, k: usize, dim: usize) -> Self { let k_indices = Tensor::arange(0..k as i64, &self.device()); self.sort_descending(dim).select(dim, k_indices) } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. /// /// # Arguments /// /// * `k` - The number of elements to return. /// * `dim` - The dimension to sort along. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// let (tensor, indices) = tensor.topk_with_indices(2, 0); /// println!("{tensor}"); /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]] /// println!("{}", indices); /// // [[0, 1, 1], [1, 0, 0]] /// let (tensor, indices) = tensor.topk_with_indices(1, 1); /// println!("{tensor}"); /// // [[12.0], [6.0]] /// println!("{indices}"); /// // [[0], [2]] /// } /// ``` pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); let (values, indices) = self.sort_descending_with_indices(dim); ( values.select(dim, k_indices.clone()), indices.select(dim, k_indices), ) } /// Create a one hot tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example(){ /// let device = Default::default(); /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); /// let one_hot: Tensor = indices.one_hot(4); /// println!("{}", one_hot.to_data()); /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] /// } /// ``` pub fn one_hot(self, num_classes: usize) -> Tensor { check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); self.one_hot_fill(num_classes, 1.0, 0.0, -1) } /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. /// /// # Arguments /// /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. /// * `on_value`: The value to assign for active positions (corresponding to indices). /// * `off_value`: The value to assign for inactive positions. /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. /// /// # Returns /// /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. /// /// # Example /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Float}; /// fn example>>() { /// let device = B::Device::default(); /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); /// // One-hot encoding /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); /// println!("{tensor}"); /// // [[[5.0, 0.0, 0.0], /// // [0.0, 0.0, 5.0]], /// // [[0.0, 5.0, 0.0], /// // [0.0, 0.0, 5.0]]] /// } /// ``` pub fn one_hot_fill( self, num_classes: usize, on_value: f32, off_value: f32, axis: i64, ) -> Tensor { check!(TensorCheck::one_hot_tensor_rank::()); // Initialize shape from the current tensor dimensions and prepare for modification let mut shape = self.shape(); let device = self.device(); let rank = self.dims().len(); // Adjust negative axis to a positive index let axis = if axis < 0 { axis + rank as i64 + 1 } else { axis }; // Ensure axis is within valid range if axis < 0 || axis > rank as i64 { panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); } // Convert the input tensor to integer indices let indices: Tensor = Tensor::from_data(self.to_data().convert::(), &device); // Insert the new dimension for the one-hot representation shape.insert(axis as usize, num_classes); // Adjust indices to valid range and handle invalid indices let adjusted_indices = indices .clone() .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices // Unsqueeze the indices tensor along the specified axis let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); // Initialize the output tensor with the off_value let output = Tensor::full(shape.clone(), off_value, &device); // Prepare scatter tensor for on_value and off_value adjustments let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); // Scatter on_value at the appropriate indices to create the one-hot representation output.scatter( axis as usize, indices_unsqueezed, scatter_on_values, IndexingUpdateOp::Add, ) } /// Applies element wise greater comparison and returns a boolean tensor. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.greater(tensor2); /// println!("{tensor}"); /// // [[false, false, false], [true, true, true]] /// } /// ``` pub fn greater(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Greater", &self, &other)); Tensor::new(K::greater(self.primitive, other.primitive)) } /// Applies element wise greater-equal comparison and returns a boolean tensor. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.greater_equal(tensor2); /// println!("{tensor}"); /// // [[true, false, false], [true, true, true]] /// } /// ``` pub fn greater_equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other)); Tensor::new(K::greater_equal(self.primitive, other.primitive)) } /// Applies element wise lower comparison and returns a boolean tensor. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.lower(tensor2); /// println!("{tensor}"); /// // [[false, true, true], [false, false, false]] /// } /// ``` pub fn lower(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Lower", &self, &other)); Tensor::new(K::lower(self.primitive, other.primitive)) } /// Applies element wise lower-equal comparison and returns a boolean tensor. /// /// # Panics /// /// If the two tensors don't have the same shape. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.lower_equal(tensor2); /// println!("{tensor}"); /// // [[true, true, true], [false, false, false]] /// } /// ``` pub fn lower_equal(self, other: Self) -> Tensor { check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other)); Tensor::new(K::lower_equal(self.primitive, other.primitive)) } /// Applies greater than `other` comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.greater_elem(3.0); /// println!("{tensor}"); /// // [[false, false, true], [true, true, true]] /// } /// ``` pub fn greater_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::greater_elem(self.primitive, other)) } /// Applies greater-equal than `other` comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.greater_equal_elem(3.0); /// println!("{tensor}"); /// // [[false, false, true], [true, true, true]] /// } /// ``` pub fn greater_equal_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::greater_equal_elem(self.primitive, other)) } /// Applies lower than `other` comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.lower_elem(3.0); /// println!("{tensor}"); /// // [[true, true, false], [false, false, false]] /// } /// ``` pub fn lower_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::lower_elem(self.primitive, other)) } /// Applies lower-equal than `other` comparison and returns a boolean tensor. /// /// # Arguments /// /// * `other` - The element to compare. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.lower_equal_elem(3.0); /// println!("{tensor}"); /// // [[true, true, true], [false, false, false]] /// } /// ``` pub fn lower_equal_elem(self, other: E) -> Tensor { let other = Scalar::new(other, &self.dtype()); Tensor::new(K::lower_equal_elem(self.primitive, other)) } /// Applies the argmax function along the given dimension and returns an integer tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); /// let tensor = tensor.argmax(1); /// println!("{:?}", tensor.shape()); /// // Shape { dims: [2, 1, 3] } /// } /// ``` pub fn argmax(self, dim: usize) -> Tensor { Tensor::new(K::argmax(self.primitive, dim)) } /// Find the maximum value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.max(); /// println!("{tensor}"); /// // [9.0] /// } /// ``` pub fn max(self) -> Tensor { Tensor::new(K::max(self.primitive)) } /// Find the maximum value along the given dimension. /// /// Also returns the indices. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let (tensor, index) = tensor.max_dim_with_indices(0); /// // [[5.0, 9.0, 6.0]] /// println!("{tensor}"); /// // [[1, 1, 1]] /// println!("{index}"); /// } /// ``` pub fn max_dim_with_indices(self, dim: I) -> (Self, Tensor) { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Max", dim)); let (tensor, index) = K::max_dim_with_indices(self.primitive, dim); let tensor = Tensor::new(tensor); let index = Tensor::new(index); (tensor, index) } /// Find the maximum absolute value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device); /// let tensor = tensor.max_abs(); /// println!("{tensor}"); /// // [7.0] /// } /// ``` pub fn max_abs(self) -> Tensor { Tensor::new(K::max_abs(self.primitive)) } /// Finds the maximum pair wise values with another tensor. /// /// # Arguments /// /// * `other` - Other tensor to find maximum elements with /// /// # Returns /// /// A tensor with the same shape as the input tensors containing the maximum value found /// in the input tensors. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.max_pair(tensor2); /// println!("{tensor}"); /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]] /// } /// ``` pub fn max_pair(self, other: Self) -> Self { let mask = self.clone().lower(other.clone()); self.mask_where(mask, other) } /// Find the maximum absolute value along the given dimension. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements, /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimension will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.max_dim(0); /// println!("{tensor}"); /// // [[5.0, 9.0, 6.0]] /// } /// ``` pub fn max_abs_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("MaxAbs", dim)); Tensor::new(K::max_abs_dim(self.primitive, dim)) } /// Find the maximum absolute value along the given dimensions. /// /// # Arguments /// /// * `dims` - The dimensions or axes along which to aggregate the elements, /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.max_abs_dims(&[0, 1]); /// println!("{tensor}"); /// // [[9.0]] /// } /// ``` pub fn max_abs_dims(self, dims: &[I]) -> Self { dims.iter() .fold(self, |tensor, &dim| tensor.max_abs_dim(dim)) } /// Applies the argmin function along the given dimension and returns an integer tensor. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::ones(Shape::new([2, 3, 3]), &device); /// let tensor = tensor.argmin(1); /// println!("{:?}", tensor.shape()); /// // Shape { dims: [2, 1, 3] } /// } /// ``` pub fn argmin(self, dim: usize) -> Tensor { Tensor::new(K::argmin(self.primitive, dim)) } /// Find the minimum value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.min(); /// println!("{tensor}"); /// // [-2.0] /// } /// ``` pub fn min(self) -> Tensor { Tensor::new(K::min(self.primitive)) } /// Find the minimum value along the given dimension. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements; /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimension will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.min_dim(0); /// println!("{tensor}"); /// // [[1.0, -2.0, 3.0]] /// } /// ``` pub fn min_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Min", dim)); Tensor::new(K::min_dim(self.primitive, dim)) } /// Find the minimum value along the given dimensions. /// /// # Arguments /// /// * `dims` - The dimensions or axes along which to aggregate the elements; /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.min_dims(&[0, 1]); /// println!("{tensor}"); /// // [[-2.0]] /// } /// ``` pub fn min_dims(self, dims: &[I]) -> Self { dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim)) } /// Find the minimum value along the given dimension. /// /// Also returns the indices. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let (tensor, index) = tensor.min_dim_with_indices(0); /// println!("{tensor}"); /// // [[5.0, -2.0, 3.0]] /// println!("{}", index); /// // [[1, 0, 0]] /// } /// ``` pub fn min_dim_with_indices(self, dim: I) -> (Self, Tensor) { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Min", dim)); let (tensor, index) = K::min_dim_with_indices(self.primitive, dim); let tensor = Tensor::new(tensor); let index = Tensor::new(index); (tensor, index) } /// Finds the minimum pair wise values with another tensor. /// /// # Arguments /// /// * `other` - Other tensor to find minimum elements with /// /// # Returns /// /// A tensor with the same shape as the input tensors containing the minimum value found /// between each element of the two source tensors. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor1 = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor2 = Tensor::::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device); /// let tensor = tensor1.min_pair(tensor2); /// println!("{tensor}"); /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]] /// } pub fn min_pair(self, other: Self) -> Self { let mask = other.clone().lower(self.clone()); self.mask_where(mask, other) } /// Clamp element wise between the given min and max values. /// /// # Arguments /// /// * `min` - The minimum value. /// * `max` - The maximum value. /// /// # Returns /// /// A new tensor with the values clamped between the given min and max values. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints( /// [ /// [1, 2, 3], /// [4, 5, 6], /// [7, 8, 9] /// ], /// &device); /// let tensor = tensor.clamp(2, 6); /// println!("{tensor}"); /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]] /// } /// ``` pub fn clamp(self, min: E, max: E) -> Self { let dtype = self.dtype(); Self::new(K::clamp( self.primitive, Scalar::new(min, &dtype), Scalar::new(max, &dtype), )) } /// Clamp element wise under a minimum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `min` - The minimum value. /// /// # Returns /// /// A new tensor with the values clamped under the given min value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints( /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], /// &device); /// let tensor = tensor.clamp_min(4); /// println!("{tensor}"); /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]] /// } /// ``` pub fn clamp_min(self, min: E) -> Self { let min = Scalar::new(min, &self.dtype()); Self::new(K::clamp_min(self.primitive, min)) } /// Clamp element wise over a maximum value. /// /// # Arguments /// /// * `tensor` - The tensor to clamp. /// * `max` - The maximum value. /// /// # Returns /// /// A new tensor with the values clamped over the given max value. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Int, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor = Tensor::::from_ints( /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]], /// &device); /// let tensor = tensor.clamp_max(5); /// println!("{tensor}"); /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]] /// } /// ``` pub fn clamp_max(self, max: E) -> Self { let max = Scalar::new(max, &self.dtype()); Self::new(K::clamp_max(self.primitive, max)) } /// Computes the cumulative minimum of elements along the given *dimension* or *axis*. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to compute the cumulative minimum. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device); /// let result = tensor.clone().cummin(0); /// println!("{result}"); /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]] /// let result = tensor.cummin(1); /// println!("{result}"); /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]] /// } /// ``` pub fn cummin(self, dim: usize) -> Self { check!(TensorCheck::aggregate_dim::("CumMin", dim)); Self::new(K::cummin(self.primitive, dim)) } /// Computes the cumulative maximum of elements along the given *dimension* or *axis*. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to compute the cumulative maximum. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device); /// let result = tensor.clone().cummax(0); /// println!("{result}"); /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]] /// let result = tensor.cummax(1); /// println!("{result}"); /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]] /// } /// ``` pub fn cummax(self, dim: usize) -> Self { check!(TensorCheck::aggregate_dim::("CumMax", dim)); Self::new(K::cummax(self.primitive, dim)) } /// Find the maximum value along the given dimension. /// /// # Arguments /// /// * `dim` - The dimension or axis along which to aggregate the elements; /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimension will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.max_dim(0); /// println!("{tensor}"); /// // [[5.0, 9.0, 6.0]] /// } /// ``` pub fn max_dim(self, dim: I) -> Self { let dim = dim.expect_dim_index(D); check!(TensorCheck::aggregate_dim::("Max", dim)); Tensor::new(K::max_dim(self.primitive, dim)) } /// Find the maximum value along the given dimensions. /// /// # Arguments /// /// * `dims` - The dimensions or axis along which to aggregate the elements; /// supports negative indexing. /// /// # Returns /// /// The returned tensor will have the same rank, /// but the aggregated dimensions will have size 1. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device); /// let tensor = tensor.max_dims(&[0, 1]); /// println!("{tensor}"); /// // [[9.0]] /// } /// ``` pub fn max_dims(self, dims: &[I]) -> Self { dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim)) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/pad.rs ================================================ use alloc::vec::Vec; use core::ops::Range; use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode}; use super::Numeric; /// Trait for types that can be used as padding specifications. /// /// Padding is specified as `(before, after)` pairs per dimension, returned as a /// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided, /// they apply to the **last** N dimensions (earlier dimensions are left unpadded). pub trait IntoPadding { /// Converts into a fixed-size array of `(before, after)` padding pairs. fn into_padding(self) -> [(usize, usize); D]; } impl IntoPadding for [(usize, usize); N] { fn into_padding(self) -> [(usize, usize); D] { assert!( N <= D, "Padding has {} pairs but tensor only has {} dimensions", N, D ); let mut result = [(0usize, 0usize); D]; let offset = D - N; for (i, pair) in self.into_iter().enumerate() { result[offset + i] = pair; } result } } /// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions. /// /// Equivalent to `[(top, bottom), (left, right)]`. impl IntoPadding for (usize, usize, usize, usize) { fn into_padding(self) -> [(usize, usize); D] { let (left, right, top, bottom) = self; let mut result = [(0usize, 0usize); D]; result[D - 2] = (top, bottom); result[D - 1] = (left, right); result } } impl IntoPadding for &[(usize, usize)] { fn into_padding(self) -> [(usize, usize); D] { assert!( self.len() <= D, "Padding has {} pairs but tensor only has {} dimensions", self.len(), D ); let mut result = [(0usize, 0usize); D]; let offset = D - self.len(); for (i, &pair) in self.iter().enumerate() { result[offset + i] = pair; } result } } impl IntoPadding for Vec<(usize, usize)> { fn into_padding(self) -> [(usize, usize); D] { assert!( self.len() <= D, "Padding has {} pairs but tensor only has {} dimensions", self.len(), D ); let mut result = [(0usize, 0usize); D]; let offset = D - self.len(); for (i, pair) in self.into_iter().enumerate() { result[offset + i] = pair; } result } } /// Helper to build a range array for slice_assign, selecting a portion of one dimension. fn build_slice_ranges( dims: [usize; D], target_dim: usize, start: usize, len: usize, ) -> [Range; D] { dims.iter() .enumerate() .map(|(i, &size)| { if i == target_dim { start..start + len } else { 0..size } }) .collect::>>() .try_into() .unwrap() } impl Tensor where B: Backend, K: Numeric, K::Elem: Element, { /// Pads the tensor using the specified padding mode. /// /// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions /// are provided, they apply to the **last** N dimensions (unspecified leading dimensions /// are left unpadded). /// /// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted, /// which pads the last two dimensions. /// /// # Arguments /// /// * `padding` - Padding specification. Accepts: /// - `[(before, after); N]` fixed-size array of pairs (N <= D) /// - `&[(before, after)]` slice of pairs per dimension /// - `Vec<(before, after)>` vector of pairs /// - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility /// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`. /// /// # Returns /// /// A new tensor with the specified padding applied. /// /// # Panics /// /// - Panics if more padding pairs are provided than tensor dimensions. /// - `Reflect` mode panics if padding exceeds `dimension_size - 1`. /// - `Edge` mode panics if padding is applied to a zero-sized dimension. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Shape}; /// use burn_tensor::ops::PadMode; /// /// fn example>>() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); /// /// // Constant padding with value 0.0 (backward-compatible tuple) /// let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0)); /// /// // Pad arbitrary dimensions with slice of (before, after) pairs /// let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0)); /// /// // Pad only the last dimension /// let padded = tensor.pad([(1, 1)], PadMode::Reflect); /// } /// ``` pub fn pad(self, padding: impl IntoPadding, mode: impl Into) -> Self { let pairs = padding.into_padding(); match mode.into() { PadMode::Constant(value) => pad_constant(self, &pairs, value), PadMode::Reflect => pad_reflect(self, &pairs), PadMode::Edge => pad_edge(self, &pairs), } } } /// Pad with a constant value. fn pad_constant( tensor: Tensor, padding: &[(usize, usize); D], value: E, ) -> Tensor where B: Backend, K: Numeric, K::Elem: Element, E: ElementConversion, { let mut padded_dims: [usize; D] = tensor.dims(); for (i, &(before, after)) in padding.iter().enumerate() { padded_dims[i] += before + after; } let ranges: [Range; D] = padded_dims .iter() .enumerate() .map(|(i, &dim)| { let (before, after) = padding[i]; before..dim - after }) .collect::>>() .try_into() .unwrap(); let padded_tensor = Tensor::full(padded_dims, value, &tensor.device()); padded_tensor.slice_assign(ranges, tensor) } /// Pad using reflection at the boundaries (excluding edge values). /// /// For ONNX "reflect" mode: mirrors from index 1, not index 0. /// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]` fn pad_reflect( tensor: Tensor, padding: &[(usize, usize); D], ) -> Tensor where B: Backend, K: Numeric, K::Elem: Element, { let dims = tensor.dims(); for (i, &(before, after)) in padding.iter().enumerate() { if before > 0 || after > 0 { assert!( before < dims[i] && after < dims[i], "Reflect padding ({}, {}) must be less than dimension {} size ({})", before, after, i, dims[i] ); } } let mut result = tensor; for (i, &(before, after)) in padding.iter().enumerate() { if before > 0 || after > 0 { result = pad_reflect_dim(result, i, before, after); } } result } /// Helper to pad a single dimension using reflection. fn pad_reflect_dim( tensor: Tensor, dim: usize, pad_before: usize, pad_after: usize, ) -> Tensor where B: Backend, K: Numeric, K::Elem: Element, { let dims = tensor.dims(); let dim_size = dims[dim]; // Calculate output dimensions let mut output_dims = dims; output_dims[dim] += pad_before + pad_after; // Create output tensor and place original in the center let output = Tensor::zeros(output_dims, &tensor.device()); let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size); let mut output = output.slice_assign(original_range, tensor.clone()); // Assign reflected "before" padding (e.g., top or left) // Reflect excludes the edge, so we take indices [1..pad_before+1] and flip if pad_before > 0 { let before_slice = tensor.clone().narrow(dim, 1, pad_before); let before_flipped = before_slice.flip([dim as isize]); let before_range = build_slice_ranges(output_dims, dim, 0, pad_before); output = output.slice_assign(before_range, before_flipped); } // Assign reflected "after" padding (e.g., bottom or right) // Take indices [dim_size - pad_after - 1..dim_size - 1] and flip if pad_after > 0 { let start = dim_size - pad_after - 1; let after_slice = tensor.narrow(dim, start, pad_after); let after_flipped = after_slice.flip([dim as isize]); let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after); output = output.slice_assign(after_range, after_flipped); } output } /// Pad by replicating edge values. /// /// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]` fn pad_edge( tensor: Tensor, padding: &[(usize, usize); D], ) -> Tensor where B: Backend, K: Numeric, K::Elem: Element, { let dims = tensor.dims(); for (i, &(before, after)) in padding.iter().enumerate() { if before > 0 || after > 0 { assert!( dims[i] > 0, "Cannot apply edge padding to zero-sized dimension {}", i ); } } let mut result = tensor; for (i, &(before, after)) in padding.iter().enumerate() { if before > 0 || after > 0 { result = pad_edge_dim(result, i, before, after); } } result } /// Helper to pad a single dimension by replicating edge values. fn pad_edge_dim( tensor: Tensor, dim: usize, pad_before: usize, pad_after: usize, ) -> Tensor where B: Backend, K: Numeric, K::Elem: Element, { let dims = tensor.dims(); let dim_size = dims[dim]; // Calculate output dimensions let mut output_dims = dims; output_dims[dim] += pad_before + pad_after; // Create output tensor and place original in the center let output = Tensor::zeros(output_dims, &tensor.device()); let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size); let mut output = output.slice_assign(original_range, tensor.clone()); // Assign "before" padding by repeating the first element if pad_before > 0 { let first_slice = tensor.clone().narrow(dim, 0, 1); let before_pad = first_slice.repeat_dim(dim, pad_before); let before_range = build_slice_ranges(output_dims, dim, 0, pad_before); output = output.slice_assign(before_range, before_pad); } // Assign "after" padding by repeating the last element if pad_after > 0 { let last_slice = tensor.narrow(dim, dim_size - 1, 1); let after_pad = last_slice.repeat_dim(dim, pad_after); let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after); output = output.slice_assign(after_range, after_pad); } output } ================================================ FILE: crates/burn-tensor/src/tensor/api/take.rs ================================================ use crate::{AsIndex, BasicOps, Int, Tensor, backend::Backend, check, check::TensorCheck}; use alloc::vec::Vec; impl Tensor where B: Backend, K: BasicOps, { /// Takes elements from the tensor along the given dimension using indices of any dimensionality. /// /// This behaves like numpy's take function. When indices is multi-dimensional, /// the output shape will be: input.shape\[:dim\] + indices.shape + input.shape\[dim+1:\] /// /// # Arguments /// /// * `dim` - The dimension along which to select elements. Supports negative indexing. /// * `indices` - The indices of elements to select. Can be any dimensionality. /// Must be valid indices in the range [0, dim_size). /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::{Tensor, Int}; /// /// fn example() { /// let device = B::Device::default(); /// /// // Example with 1D indices /// let tensor = Tensor::::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); /// let indices = Tensor::::from_data([2, 0, 1], &device); /// let result: Tensor = tensor.clone().take::<1, 2>(-1, indices); // -1 refers to last dimension /// println!("{result}"); /// // [[3.0, 1.0, 2.0], [6.0, 4.0, 5.0]] /// /// // Example with 2D indices - output will have +1 dimension (2D -> 3D) /// let indices_2d = Tensor::::from_data([[0, 2], [1, 0]], &device); /// let result: Tensor = tensor.take::<2, 3>(1, indices_2d); /// println!("{result}"); /// // [[[1.0, 3.0], [2.0, 1.0]], [[4.0, 6.0], [5.0, 4.0]]] /// } /// ``` pub fn take( self, dim: impl AsIndex, indices: Tensor, ) -> Tensor { let dim = dim.expect_dim_index(D); check!(TensorCheck::take::(dim)); // Store the indices shape for reshaping later let indices_shape = indices.shape(); let indices_dims = indices_shape.clone(); // Flatten indices to 1D for processing let indices_flat = indices.reshape([indices_shape.num_elements()]); // Perform the selection with the flattened indices let selected = self.select(dim, indices_flat); // Build the output shape // Output shape = input.shape[:dim] + indices.shape + input.shape[dim+1:] let selected_shape = selected.shape(); let mut new_shape = Vec::with_capacity(DO); // Add dimensions before the selected dimension for i in 0..dim { new_shape.push(selected_shape[i]); } // Add all indices dimensions for &idx_dim in indices_dims.iter() { new_shape.push(idx_dim); } // Add dimensions after the selected dimension for i in (dim + 1)..D { new_shape.push(selected_shape[i]); } // Verify we have the correct number of dimensions assert_eq!( new_shape.len(), DO, "Internal error: shape calculation resulted in {} dims but expected {}", new_shape.len(), DO ); // Convert to fixed-size array for reshape let mut shape_array = [0; DO]; for (i, &s) in new_shape.iter().enumerate() { shape_array[i] = s; } selected.reshape(shape_array) } } ================================================ FILE: crates/burn-tensor/src/tensor/api/transaction.rs ================================================ use super::{BasicOps, Tensor}; use crate::{ TensorData, backend::{Backend, ExecutionError}, ops::TransactionPrimitive, }; use alloc::vec::Vec; #[derive(Default)] /// A transaction can [read](Self::register) multiple tensors at once with a single operation improving /// compute utilization with optimized laziness. /// /// # Example /// /// ```rust,ignore /// let [output_data, loss_data, targets_data] = Transaction::default() /// .register(output) /// .register(loss) /// .register(targets) /// .execute() /// .try_into() /// .expect("Correct amount of tensor data"); /// ``` pub struct Transaction { op: TransactionPrimitive, } impl Transaction { /// Add a [tensor](Tensor) to the transaction to be read. pub fn register>(mut self, tensor: Tensor) -> Self { K::register_transaction(&mut self.op, tensor.into_primitive()); self } /// Executes the transaction synchronously and returns the [data](TensorData) in the same order /// in which they were [registered](Self::register). pub fn execute(self) -> Vec { burn_std::future::block_on(self.execute_async()) .expect("Error while reading data: use `try_execute` to handle error at runtime") } /// Executes the transaction synchronously and returns the [data](TensorData) in the same /// order in which they were [registered](Self::register). /// /// # Returns /// /// Any error that might have occurred since the last time the device was synchronized. pub fn try_execute(self) -> Result, ExecutionError> { burn_std::future::block_on(self.execute_async()) } /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order /// in which they were [registered](Self::register). pub async fn execute_async(self) -> Result, ExecutionError> { self.op.execute_async().await } } ================================================ FILE: crates/burn-tensor/src/tensor/api/trunc.rs ================================================ use crate::{Float, Tensor, TensorPrimitive, backend::Backend}; impl Tensor where B: Backend, { /// Truncates the tensor element-wise, rounding toward zero. /// /// This function returns a new tensor with the same shape as the input tensor, /// where each element is truncated toward zero. For positive values, this is /// equivalent to floor, and for negative values, it's equivalent to ceil. /// /// # Special Cases (IEEE 754 compliant) /// /// - `trunc(±0)` returns ±0 (preserves sign of zero) /// - `trunc(±∞)` returns ±∞ /// - `trunc(NaN)` returns NaN /// /// # Returns /// /// A tensor with the same shape where each element has been truncated toward zero. /// /// # Example /// /// ```rust /// use burn_tensor::backend::Backend; /// use burn_tensor::Tensor; /// /// fn example() { /// let device = B::Device::default(); /// let tensor = Tensor::::from_data([2.3, -1.7, 0.5, -0.5, 3.9], &device); /// let truncated = tensor.trunc(); /// /// // Result: [2.0, -1.0, 0.0, -0.0, 3.0] /// } /// ``` pub fn trunc(self) -> Self { Self::new(TensorPrimitive::Float(B::float_trunc( self.primitive.tensor(), ))) } } ================================================ FILE: crates/burn-tensor/src/tensor/grid/affine_grid.rs ================================================ use crate::ElementConversion; use crate::backend::Backend; use crate::s; use crate::tensor::{Int, Tensor}; use alloc::vec; /// Generate a tensor with homogeonous coordinates of each element's /// transformed location /// /// /// See: /// - [torch.nn.functional.affine_grid](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html) /// /// * `transform` - Transformation with shape (batch_size, 2, 3) /// * `dims` - dimensions as (batch_size, channels, height, width) /// /// # Returns /// /// Tensor with shape (batch_size, height, width, 2), where dim 2 is (x, y) /// All coordinates are broadcast on the batch dim pub fn affine_grid_2d(transform: Tensor, dims: [usize; 4]) -> Tensor { let [batch_size, _c, height, width] = dims; let device = &transform.device(); let x = Tensor::::arange(0..width as i64, device) .reshape([1, width]) .expand([height, width]); let y = Tensor::::arange(0..height as i64, device) .reshape([height, 1]) .expand([height, width]); // from ints (0..(width-1)) and (0..(height-1)), to (-1.0..1.0) let x = x .float() .div_scalar(((width - 1) as f32 / 2.0).elem::()) .sub_scalar((1_f32).elem::()); let y = y .float() .div_scalar(((height - 1) as f32 / 2.0).elem::()) .sub_scalar((1_f32).elem::()); // Broadcast to batch dimension let x = x.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W] let y = y.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W] // Apply affine transform let a_11 = transform.clone().slice(s![.., 0, 0]); let a_12 = transform.clone().slice(s![.., 0, 1]); let trans_x = transform.clone().slice(s![.., 0, 2]); let a_21 = transform.clone().slice(s![.., 1, 0]); let a_22 = transform.clone().slice(s![.., 1, 1]); let trans_y = transform.slice(s![.., 1, 2]); let grid_x = a_11.mul(x.clone()).add(a_12.mul(y.clone())).add(trans_x); let grid_y = a_21.mul(x).add(a_22.mul(y)).add(trans_y); Tensor::stack(vec![grid_x, grid_y], 3) } ================================================ FILE: crates/burn-tensor/src/tensor/grid/meshgrid.rs ================================================ use crate::backend::Backend; use crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos}; use crate::tensor::{BasicOps, Tensor}; use alloc::vec::Vec; /// Return a collection of coordinate matrices for coordinate vectors. /// /// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates /// in one dimension across an N-dimensional grid. /// /// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`: /// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension. /// * In `Dense` mode, output tensors will be expanded to the full grid shape. /// /// Based upon `options.indexing`, the generated coordinate tensors will use either: /// * `Matrix` indexing, where dimensions are in the same order as their cardinality. /// * `Cartesian` indexing; where the first two dimensions are swapped. /// /// See: /// - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html) /// - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html) /// /// # Arguments /// /// * `tensors` - A slice of 1D tensors /// * `options` - the options. /// /// # Returns /// /// A vector of N N-dimensional tensors representing the grid coordinates. pub fn meshgrid( tensors: &[Tensor; N], options: O, ) -> [Tensor; N] where K: BasicOps, O: Into, { let options = options.into(); let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1; let dense = options.sparsity == GridSparsity::Dense; let grid_shape: [usize; N] = tensors .iter() .map(|t| t.dims()[0]) .collect::>() .try_into() .unwrap(); tensors .iter() .enumerate() .map(|(i, tensor)| { let mut coord_tensor_shape = [1; N]; coord_tensor_shape[i] = grid_shape[i]; // Reshape the tensor to have singleton dimensions in all but the i-th dimension let mut tensor = tensor.clone().reshape(coord_tensor_shape); if dense { tensor = tensor.expand(grid_shape); } if swap_dims { tensor = tensor.swap_dims(0, 1); } tensor }) .collect::>() .try_into() .unwrap() } /// Return a coordinate matrix for a given set of 1D coordinate tensors. /// /// Equivalent to stacking a dense matrix `meshgrid`, /// where the stack is along the first or last dimension. /// /// # Arguments /// /// * `tensors`: A slice of 1D tensors. /// * `index_pos`: The position of the index in the output tensor. /// /// # Returns /// /// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``, /// of coordinates, indexed on the first or last dimension. pub fn meshgrid_stack( tensors: &[Tensor; D], index_pos: IndexPos, ) -> Tensor where K: BasicOps, { assert_eq!(D2, D + 1, "D2 ({D2}) != D ({D}) + 1"); let xs: Vec> = meshgrid(tensors, GridOptions::default()) .into_iter() .collect(); let dim = match index_pos { IndexPos::First => 0, IndexPos::Last => D, }; Tensor::stack(xs, dim) } ================================================ FILE: crates/burn-tensor/src/tensor/grid/mod.rs ================================================ mod affine_grid; mod meshgrid; pub use meshgrid::*; pub use affine_grid::*; /// Enum to specify index cardinal layout. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub enum GridIndexing { /// Dimensions are in the same order as the cardinality of the inputs. /// Equivalent to "ij" indexing in NumPy and PyTorch. #[default] Matrix, /// The same as Matrix, but the first two dimensions are swapped. /// Equivalent to "xy" indexing in NumPy and PyTorch. Cartesian, } /// Enum to specify grid sparsity mode. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub enum GridSparsity { /// The grid is fully expanded to the full cartesian product shape. #[default] Dense, /// The grid is sparse, expanded only at the cardinal dimensions. Sparse, } /// Grid policy options. #[derive(new, Default, Debug, Copy, Clone)] pub struct GridOptions { /// Indexing mode. pub indexing: GridIndexing, /// Sparsity mode. pub sparsity: GridSparsity, } impl From for GridOptions { fn from(value: GridIndexing) -> Self { Self { indexing: value, ..Default::default() } } } impl From for GridOptions { fn from(value: GridSparsity) -> Self { Self { sparsity: value, ..Default::default() } } } /// Enum to specify the index dimension position. #[derive(Default, Debug, Copy, Clone)] pub enum IndexPos { /// The index is in the first dimension. #[default] First, /// The index is in the last dimension. Last, } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/cosine_similarity.rs ================================================ use crate::ElementConversion; use crate::backend::Backend; use crate::tensor::Tensor; use super::l2_norm; /// Default epsilon value to avoid division by zero pub const DEFAULT_EPSILON: f64 = 1e-8; /// Computes the cosine similarity between two tensors along a specified dimension. /// /// Calculates the cosine of the angle between inputs as their dot product divided /// by the product of their L2 norms. /// /// # Arguments /// /// * `x1` - First input tensor /// * `x2` - Second input tensor /// * `dim` - Dimension along which to compute the similarity /// (negative indices allowed: -1 for last dimension) /// * `eps` - Small value to avoid division by zero (default: 1e-8) /// /// # Returns /// /// Tensor containing the cosine similarity between x1 and x2 pub fn cosine_similarity( x1: Tensor, x2: Tensor, dim: i32, eps: Option, ) -> Tensor { let eps = eps.unwrap_or_else(|| B::FloatElem::from_elem(DEFAULT_EPSILON)); // Convert negative dimension to positive let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize; // Compute dot product: sum(x1 * x2) along the specified dimension let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx); // Compute L2 norms: ||x1|| and ||x2|| let norm_x1 = l2_norm(x1, dim_idx); let norm_x2 = l2_norm(x2, dim_idx); // Calculate the denominator (product of the norms) with epsilon to avoid division by zero let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps); // Return the cosine similarity (dot product divided by the product of norms) dot_product / denominator } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/diag.rs ================================================ use crate::backend::Backend; use crate::check; use crate::check::TensorCheck; use crate::tensor::{Int, Tensor}; use crate::{BasicOps, TensorKind}; /// Returns the diag of a matrix. /// /// For batched inputs, returns of each matrix in the batch independently. /// /// The diag operation extracts the diagonal elements of the last two dimensions, /// treating them as the matrix dimensions, while preserving all leading batch dimensions. /// /// # Arguments /// /// * `tensor` - The input tensor with at least 2 dimensions. /// /// # Returns /// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input. pub fn diag( tensor: Tensor, ) -> Tensor where K: TensorKind + BasicOps, { check!(TensorCheck::diag::()); let shape = tensor.shape(); let rows = shape[D - 2]; let cols = shape[D - 1]; let diag_len = rows.min(cols); let device = tensor.device(); // create the indices for the diag let mut flat_shape = shape.clone(); flat_shape[D - 2] = rows * cols; flat_shape[D - 1] = 1; let flat: Tensor = tensor.reshape(flat_shape); let range = Tensor::::arange(0..diag_len as i64, &device); let step_tensor = Tensor::::from_data([cols as i64 + 1], &device); let indices = range * step_tensor; flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1) } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/lu_decomposition.rs ================================================ use crate::{ Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s, tensor::Tensor, }; /// Performs PLU decomposition of a square matrix. /// /// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`, /// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`. /// The permutation vector `p` represents the row swaps made during the decomposition process. /// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used /// during the elimination process below the diagonal. The upper triangular matrix `U` contains /// the resulting upper triangular form of the matrix after the elimination process. /// /// # Arguments /// * `tensor` - A square matrix to decompose, represented as a 2D tensor. /// /// # Returns /// A tuple containing: /// - A 2D tensor representing the combined `L` and `U` matrices. /// - A 1D tensor representing the permutation vector `p`. /// /// # Panics and numerical issues /// - The function will panic if the input matrix is singular or near-singular. /// - The function will panic if the input matrix is not square. /// # Performance note (synchronization / device transfers) /// This function may involve multiple synchronizations and device transfers, especially /// when determining pivot elements and performing row swaps. This can impact performance, pub fn lu_decomposition(tensor: Tensor) -> (Tensor, Tensor) { check!(TensorCheck::is_square::<2>( "lu_decomposition", &tensor.shape() )); let dims = tensor.shape().dims::<2>(); let n = dims[0]; let mut permutations = Tensor::arange(0..n as i64, &tensor.device()); let mut tensor = tensor; for k in 0..n { // Find the pivot row let p = tensor .clone() .slice(s![k.., k]) .abs() .argmax(0) .into_scalar() .to_usize() + k; let max = tensor.clone().slice(s![p, k]).abs(); // Avoid division by zero let pivot = max.into_scalar(); check!(TensorCheck::lu_decomposition_pivot::(pivot)); if p != k { tensor = swap_slices(tensor, s![k, ..], s![p, ..]); permutations = swap_slices(permutations, s![k], s![p]); } // Normalize k-th column under the diagonal if k < n - 1 { let a_kk = tensor.clone().slice(s![k, k]); let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk; tensor = tensor.slice_assign(s![(k + 1).., k], column); } // Update the trailing submatrix for i in (k + 1)..n { // a[i, k+1..] -= a[i, k] * a[k, k+1..] let a_ik = tensor.clone().slice(s![i, k]); let row_k = tensor.clone().slice(s![k, (k + 1)..]); let update = a_ik * row_k; let row_i = tensor.clone().slice(s![i, (k + 1)..]); tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update); } } (tensor, permutations) } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/matvec.rs ================================================ use crate::Numeric; use crate::backend::Backend; use crate::tensor::{BasicOps, Shape, Tensor}; /// Performs matrix-vector multiplication with optional batch dimensions. /// /// The `matrix` tensor is expected to have rank `DM` with the last two dimensions representing /// the matrix rows and columns. The `vector` tensor should have rank `DV = DM - 1`, sharing /// broadcast-compatible batch dimensions and matching the last dimension of the matrix. /// /// # Panics /// /// * If the matrix rank is lower than 2. /// * If the vector rank isn't one less than the matrix rank. /// * If batch dimensions differ between the operands. /// * If the inner dimensions are incompatible for multiplication. pub fn matvec( matrix: Tensor, vector: Tensor, ) -> Tensor where K: BasicOps + Numeric, { assert!( DM >= 2, "matvec expects the matrix to be at least rank 2 (got {DM})" ); assert!( DM == DV + 1, "matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})", ); let matrix_dims = matrix.shape().dims::(); let vector_dims = vector.shape().dims::(); // Validate batch dimensions (all leading dimensions prior to the matrix axes). let batch_rank = DM.saturating_sub(2); if batch_rank > 0 { let matrix_batch = Shape::from(&matrix_dims[..batch_rank]); let vector_batch = Shape::from(&vector_dims[..batch_rank]); assert!( matrix_batch.broadcast(&vector_batch).is_ok(), "Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}", &matrix_dims[..batch_rank], &vector_dims[..batch_rank] ); } let matrix_inner = matrix_dims[DM - 1]; let vector_inner = vector_dims[DV - 1]; assert!( matrix_inner == vector_inner, "Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries", ); let vector_expanded = vector.unsqueeze_dim::(DV); matrix.matmul(vector_expanded).squeeze_dim::(DM - 1) } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/mod.rs ================================================ mod cosine_similarity; mod diag; mod lu_decomposition; mod matvec; mod outer; mod trace; mod vector_norm; pub use cosine_similarity::*; pub use diag::*; pub use lu_decomposition::*; pub use matvec::*; pub use outer::*; pub use trace::*; pub use vector_norm::*; use crate::{BasicOps, SliceArg, Tensor, TensorKind, backend::Backend}; /// Swaps two slices of a tensor. /// # Arguments /// * `tensor` - The input tensor. /// * `slices1` - The first slice to swap. /// * `slices2` - The second slice to swap. /// # Returns /// A new tensor with the specified slices swapped. /// # Notes /// This method will be useful for matrix factorization algorithms. fn swap_slices( tensor: Tensor, slices1: S, slices2: S, ) -> Tensor where S: SliceArg + Clone, K: TensorKind + BasicOps, { let temporary = tensor.clone().slice(slices1.clone()); let tensor = tensor .clone() .slice_assign(slices1, tensor.slice(slices2.clone())); tensor.slice_assign(slices2, temporary) } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/outer.rs ================================================ use crate::backend::Backend; use crate::tensor::{BasicOps, Tensor}; use crate::{AsIndex, Numeric}; /// Computes the outer product for the last columns of 2 tensors. /// /// See also: [`outer_dim`]. /// /// # Arguments /// - `lhs`: the "row" tensor, with shape ``[..., i]``. /// - `rhs`: the "col" tensor, with shape ``[..., j]``. /// /// # Returns /// /// A tensor of rank `R = D + 1`, where: /// /// `` /// result[..., i, j] = lhs[..., i] * rhs[..., j] /// `` pub fn outer( lhs: Tensor, rhs: Tensor, ) -> Tensor where K: BasicOps + Numeric, { outer_dim(lhs, rhs, -1) } /// Computes the outer product along a specific dimension, broadcasting over others. /// /// For the given `dim`, computes the outer product of elements along that dimension, /// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``. /// /// # Arguments /// /// - `lhs`: left operand, the "row" tensor, with size `M` at dimension `dim`. /// - `rhs`: right operand, the "col" tensor, with size `N` at dimension `dim`. /// - `dim`: dimension to compute the outer product along (supports negative indexing). /// /// # Returns /// /// A tensor of rank `R = D + 1`, where: /// /// `` /// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...] /// `` // // Notes: // - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant // than broadcasted elemwise multiply; benchmarking needed to confirm. pub fn outer_dim( lhs: Tensor, rhs: Tensor, dim: Dim, ) -> Tensor where K: BasicOps + Numeric, { assert_eq!( R, D + 1, "`outer` with D={D} expects R={} (got R={R})", D + 1 ); let dim = dim.expect_dim_index(D); // (..., i, 1, ...) let x = lhs.unsqueeze_dim::(dim + 1); // (..., 1, j, ...) let y = rhs.unsqueeze_dim::(dim); // (..., i, j, ...) x * y } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/trace.rs ================================================ use super::diag; use crate::backend::Backend; use crate::tensor::Tensor; /// Computes the trace of a matrix. /// /// For batched inputs, computes the trace of each matrix in the batch independently. /// /// The trace operation sums the diagonal elements of the last two dimensions, /// treating them as the matrix dimensions, while preserving all leading batch dimensions. /// /// # Arguments /// /// * `tensor` - The input tensor with at least 2 dimensions. /// /// # Returns /// /// A tensor of rank `D - 1`, where the last dimension contains the sum along the diagonals /// of the input. pub fn trace(tensor: Tensor) -> Tensor { let diag_tensor = diag::<_, D, DO, _>(tensor); diag_tensor.sum_dim(DO - 1) } ================================================ FILE: crates/burn-tensor/src/tensor/linalg/vector_norm.rs ================================================ use burn_backend::tensor::Ordered; use crate::backend::Backend; use crate::tensor::{BasicOps, Tensor}; use crate::{ElementConversion, Numeric}; #[allow(unused_imports)] use num_traits::float::Float; /// Specifies the type of norm to compute. #[derive(Debug, Clone, Copy, PartialEq)] pub enum Norm { /// L0 norm (count of non-zero elements) L0, /// L1 norm (sum of absolute values) L1, /// L2 norm (Euclidean norm) L2, /// L:INFINITY norm (maximum absolute value) LInf, /// L:NEG_INFINITY norm (minimum absolute value) LNegInf, /// Lp norm (generalized norm) Lp(f64), } impl Norm { /// Get the exponent of the norm. pub fn to_exponent(self) -> f64 { use Norm::*; match self { L0 => 0.0, L1 => 1.0, L2 => 2.0, LInf => f64::INFINITY, LNegInf => f64::NEG_INFINITY, Lp(p) => p, } } } impl From for Norm { fn from(value: u32) -> Self { use Norm::*; match value { 0 => L0, 1 => L1, 2 => L2, u32::MAX => LInf, _ => Lp(value as f64), } } } impl From for Norm { fn from(value: i32) -> Self { use Norm::*; match value { 0 => L0, 1 => L1, 2 => L2, i32::MAX => LInf, i32::MIN => LNegInf, _ => Lp(value as f64), } } } impl From for Norm { fn from(value: f32) -> Self { use Norm::*; match value { 0.0 => L0, 1.0 => L1, 2.0 => L2, f32::INFINITY => LInf, f32::NEG_INFINITY => LNegInf, _ => Lp(value as f64), } } } impl From for Norm { fn from(value: f64) -> Self { use Norm::*; match value { 0.0 => L0, 1.0 => L1, 2.0 => L2, f64::INFINITY => LInf, f64::NEG_INFINITY => LNegInf, _ => Lp(value), } } } /// Computes the vector norm of a tensor along a specified dimension. /// /// Generic dispatch wrapper over specialized / optimized norms. /// /// See: /// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html) /// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html) /// /// # Arguments /// /// * `x` - The input tensor. /// * `norm` - The selected norm. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The vector norm of the input tensor. pub fn vector_norm( x: Tensor, norm: impl Into, dim: usize, ) -> Tensor { lp_norm(x, norm.into().to_exponent(), dim) } /// Computes the general ``L(p)`` norm of a tensor along a specified dimension. /// /// Uses the specialized implementations for: /// * 0.0 /// * 1.0 /// * 2.0 /// * 2 * N for integral N, /// * f64::INFINITY, /// * f64::NEG_INFINITY, /// /// # Arguments /// /// * `x` - The input tensor. /// * `p` - The exponent of the Lp norm. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The ``L(p)`` norm of the input tensor. pub fn lp_norm(x: Tensor, p: f64, dim: usize) -> Tensor { match p { 0.0 => l0_norm(x, dim), 1.0 => l1_norm(x, dim), 2.0 => l2_norm(x, dim), p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim), f64::INFINITY => max_abs_norm(x, dim), f64::NEG_INFINITY => min_abs_norm(x, dim), _ => lp_norm_base(x, p, dim), } } /// Normalize a tensor versus its `vector_norm`. /// /// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``. /// /// # Arguments /// /// * `x` - The input tensor. /// * `norm` - The selected norm. /// * `dim` - The dimension to compute the norm over. /// * `eps` - The epsilon for the norm. /// /// # Returns /// /// The normalized tensor. pub fn vector_normalize( x: Tensor, norm: impl Into, dim: usize, eps: E, ) -> Tensor { let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps); x / norm } /// Computes the L0 norm of a tensor along a specified dimension. /// /// # Arguments /// /// * `x` - The input tensor. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The L0 norm of the input tensor. pub fn l0_norm(x: Tensor, dim: usize) -> Tensor where K: BasicOps + Numeric, { x.zeros_like() .mask_fill(x.not_equal_elem(0), 1) .sum_dim(dim) } /// Computes the L1 norm of a tensor along a specified dimension. /// /// This is a convenience function that wraps `vector_norm` with `p = 1.0`. /// /// # Arguments /// /// * `x` - The input tensor. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The L1 norm of the input tensor. pub fn l1_norm(x: Tensor, dim: usize) -> Tensor where K: BasicOps + Numeric, { x.abs().sum_dim(dim) } /// Computes the L2 norm of a tensor along a specified dimension. /// /// # Arguments /// /// * `x` - The input tensor. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The L2 norm of the input tensor. pub fn l2_norm(x: Tensor, dim: usize) -> Tensor { x.square().sum_dim(dim).sqrt() } fn is_even_integer(x: f64) -> bool { x.fract() == 0.0 && (x as i64) % 2 == 0 } /// Computes ``L(2*n)`` for even integer ``n``. /// /// This lets us skip the abs. fn lp_signed_norm(x: Tensor, p: u32, dim: usize) -> Tensor { x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64)) } /// Computes the general ``L(p)`` using the generalized method. /// /// This uses no specialized implementations and cannot handle: /// * 0.0 /// * f64::INFINITY, /// * f64::NEG_INFINITY, fn lp_norm_base(x: Tensor, p: f64, dim: usize) -> Tensor { x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p) } /// Computes the L:INFINITY norm of a tensor along a specified dimension. /// /// # Arguments /// /// * `x` - The input tensor. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The L:INFINITY norm of the input tensor. pub fn max_abs_norm( x: Tensor, dim: usize, ) -> Tensor where K: Ordered, { x.max_abs_dim(dim) } /// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension. /// /// # Arguments /// /// * `x` - The input tensor. /// * `dim` - The dimension to compute the norm over. /// /// # Returns /// /// The L:NEG_INFINITY norm of the input tensor. pub fn min_abs_norm( x: Tensor, dim: usize, ) -> Tensor where K: Ordered, { x.abs().min_dim(dim) } ================================================ FILE: crates/burn-tensor/src/tensor/loss/mod.rs ================================================ use crate::backend::Backend; use crate::{Tensor, activation}; /// Computes the log softmax cross entropy between logits and target probabilities. /// /// # Arguments /// /// * `logits` - The logits. /// * `target_probs` - The target probabilities. /// /// # Returns /// /// The log softmax cross entropy. pub fn cross_entropy_with_logits( logits: Tensor, target_probs: Tensor, ) -> Tensor { let tensor = activation::log_softmax(logits, D - 1); let tensor = tensor.mul(target_probs); let tensor = tensor.sum_dim(D - 1); tensor.mean().neg() } ================================================ FILE: crates/burn-tensor/src/tensor/mod.rs ================================================ pub(crate) mod stats; mod api; pub use api::*; // Re-exported types pub use burn_backend::{ BoolDType, BoolStore, DType, DataError, FloatDType, IntDType, TensorData, TensorMetadata, TensorPrimitive, Tolerance, distribution::*, element::*, indexing::*, ops::TransactionPrimitive, shape::*, slice::*, tensor::{Bool, Float, Int, TensorKind}, }; /// The activation module. pub mod activation; /// The backend module. pub mod backend { pub use burn_backend::backend::*; } /// The container module. pub mod container { pub use burn_backend::tensor::TensorContainer; } /// The grid module. pub mod grid; /// The linalg module. pub mod linalg; /// The loss module. pub mod loss; /// The neural network module. pub mod module; /// Operations on tensors module. pub mod ops { pub use burn_backend::backend::ops::*; pub use burn_backend::tensor::{ BoolElem, BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor, }; } /// Tensor quantization module. pub mod quantization; #[cfg(feature = "std")] pub use report::*; #[cfg(feature = "std")] mod report; pub use ops::Device; // Re-export device so that it's available from `burn_tensor::Device`. ================================================ FILE: crates/burn-tensor/src/tensor/module.rs ================================================ use crate::{ Bool, Int, Tensor, TensorPrimitive, backend::Backend, check, check::TensorCheck, ops::{ AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode, PaddedConvOptions, UnfoldOptions, }, }; use super::ops::DeformConvOptions; /// Applies the [embedding module](crate::ops::ModuleOps::embedding). pub fn embedding(weights: Tensor, indices: Tensor) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::embedding( weights.primitive.tensor(), indices.primitive, ))) } /// Applies a [1D convolution](crate::ops::ModuleOps::conv1d). /// /// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for /// asymmetric padding. When asymmetric padding is specified, an explicit pad /// operation is applied before the convolution backend op. pub fn conv1d( x: Tensor, weight: Tensor, bias: Option>, options: impl Into>, ) -> Tensor where B: Backend, { let padded_options = options.into(); check!(TensorCheck::conv( "conv1d", x.dims(), weight.dims(), padded_options.options.groups, )); if let Some(padding_end) = padded_options.padding_end { let left = padded_options.options.padding[0]; let right = padding_end[0]; // For 1D (NCL format), pad the length dimension let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0)); let zero_options = ConvOptions::new( padded_options.options.stride, [0], padded_options.options.dilation, padded_options.options.groups, ); Tensor::new(TensorPrimitive::Float(B::conv1d( padded.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), zero_options, ))) } else { Tensor::new(TensorPrimitive::Float(B::conv1d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), padded_options.options, ))) } } /// Applies a [2D convolution](crate::ops::ModuleOps::conv2d). /// /// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for /// asymmetric padding. When asymmetric padding is specified, an explicit pad /// operation is applied before the convolution backend op. pub fn conv2d( x: Tensor, weight: Tensor, bias: Option>, options: impl Into>, ) -> Tensor where B: Backend, { let padded_options = options.into(); check!(TensorCheck::conv( "conv2d", x.dims(), weight.dims(), padded_options.options.groups, )); if let Some(padding_end) = padded_options.padding_end { let top = padded_options.options.padding[0]; let left = padded_options.options.padding[1]; let bottom = padding_end[0]; let right = padding_end[1]; // For 2D (NCHW format), pad height and width let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0)); let zero_options = ConvOptions::new( padded_options.options.stride, [0, 0], padded_options.options.dilation, padded_options.options.groups, ); Tensor::new(TensorPrimitive::Float(B::conv2d( padded.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), zero_options, ))) } else { Tensor::new(TensorPrimitive::Float(B::conv2d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), padded_options.options, ))) } } /// Applies a [3D convolution](crate::ops::ModuleOps::conv3d). /// /// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for /// asymmetric padding. Asymmetric 3D padding is not yet supported. pub fn conv3d( x: Tensor, weight: Tensor, bias: Option>, options: impl Into>, ) -> Tensor where B: Backend, { let padded_options = options.into(); check!(TensorCheck::conv( "conv3d", x.dims(), weight.dims(), padded_options.options.groups, )); if padded_options.is_asymmetric() { panic!("Asymmetric padding is not yet supported for conv3d"); } Tensor::new(TensorPrimitive::Float(B::conv3d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), padded_options.options, ))) } /// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d). pub fn deform_conv2d( x: Tensor, offset: Tensor, weight: Tensor, mask: Option>, bias: Option>, options: DeformConvOptions<2>, ) -> Tensor where B: Backend, { check!(TensorCheck::conv( "deform_conv2d", x.dims(), weight.dims(), options.weight_groups, )); Tensor::new(TensorPrimitive::Float(B::deform_conv2d( x.primitive.tensor(), offset.primitive.tensor(), weight.primitive.tensor(), mask.map(|m| m.primitive.tensor()), bias.map(|b| b.primitive.tensor()), options, ))) } /// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d). pub fn conv_transpose1d( x: Tensor, weight: Tensor, bias: Option>, options: ConvTransposeOptions<1>, ) -> Tensor where B: Backend, { check!(TensorCheck::conv_transpose( "conv_transpose1d", x.dims(), weight.dims(), )); Tensor::new(TensorPrimitive::Float(B::conv_transpose1d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), options, ))) } /// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d). pub fn conv_transpose2d( x: Tensor, weight: Tensor, bias: Option>, options: ConvTransposeOptions<2>, ) -> Tensor where B: Backend, { check!(TensorCheck::conv_transpose( "conv_transpose2d", x.dims(), weight.dims(), )); Tensor::new(TensorPrimitive::Float(B::conv_transpose2d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), options, ))) } /// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d). pub fn conv_transpose3d( x: Tensor, weight: Tensor, bias: Option>, options: ConvTransposeOptions<3>, ) -> Tensor where B: Backend, { check!(TensorCheck::conv_transpose( "conv_transpose3d", x.dims(), weight.dims(), )); Tensor::new(TensorPrimitive::Float(B::conv_transpose3d( x.primitive.tensor(), weight.primitive.tensor(), bias.map(|b| b.primitive.tensor()), options, ))) } /// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d). pub fn unfold4d(x: Tensor, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::unfold4d( x.primitive.tensor(), kernel_size, options, ))) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d( x: Tensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::max_pool1d( x.primitive.tensor(), kernel_size, stride, padding, dilation, ceil_mode, ))) } /// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d). pub fn max_pool2d( x: Tensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::max_pool2d( x.primitive.tensor(), kernel_size, stride, padding, dilation, ceil_mode, ))) } /// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d). pub fn avg_pool2d( x: Tensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], count_include_pad: bool, ceil_mode: bool, ) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::avg_pool2d( x.primitive.tensor(), kernel_size, stride, padding, count_include_pad, ceil_mode, ))) } /// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d). pub fn avg_pool1d( x: Tensor, kernel_size: usize, stride: usize, padding: usize, count_include_pad: bool, ceil_mode: bool, ) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::avg_pool1d( x.primitive.tensor(), kernel_size, stride, padding, count_include_pad, ceil_mode, ))) } /// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d). pub fn max_pool1d_with_indices( x: Tensor, kernel_size: usize, stride: usize, padding: usize, dilation: usize, ceil_mode: bool, ) -> (Tensor, Tensor) where B: Backend, { let output = B::max_pool1d_with_indices( x.primitive.tensor(), kernel_size, stride, padding, dilation, ceil_mode, ); ( Tensor::new(TensorPrimitive::Float(output.output)), Tensor::new(output.indices), ) } /// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices). pub fn max_pool2d_with_indices( x: Tensor, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2], ceil_mode: bool, ) -> (Tensor, Tensor) where B: Backend, { let output = B::max_pool2d_with_indices( x.primitive.tensor(), kernel_size, stride, padding, dilation, ceil_mode, ); ( Tensor::new(TensorPrimitive::Float(output.output)), Tensor::new(output.indices), ) } /// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d). pub fn adaptive_avg_pool2d(x: Tensor, output_size: [usize; 2]) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d( x.primitive.tensor(), output_size, ))) } /// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d). pub fn adaptive_avg_pool1d(x: Tensor, output_size: usize) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d( x.primitive.tensor(), output_size, ))) } /// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate). pub fn interpolate( x: Tensor, output_size: [usize; 2], options: InterpolateOptions, ) -> Tensor where B: Backend, { Tensor::new(TensorPrimitive::Float(B::interpolate( x.primitive.tensor(), output_size, options, ))) } /// Applies a linear transformation to the input tensor using the given weight and bias. /// /// ```math /// y = x @ weight + [bias] /// ``` /// /// # Arguments: /// /// - `input` is the input tensor, ``[..., d_input]``. /// - `weight` is the weight tensor, ``[d_input, d_output]``. /// - `bias` is the bias tensor (optional), ``[d_output]``. /// /// # Returns: /// /// The transformed tensor, ``[..., d_output]``. /// /// # Compatibility /// /// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not /// transpose the weight matrix. In PyTorch, the weight matrix is transposed before /// multiplication: /// /// ```math /// y = x @ weight^T + [bias] /// ``` pub fn linear( input: Tensor, weight: Tensor, bias: Option>, ) -> Tensor { if D == 1 { // Insert and remove an extra batch dimension for the batch matmul to work. let input = input.unsqueeze::<2>(); let output = linear(input, weight, bias); return output.squeeze_dim(0); } // Perform broadcasting // // Important to be done before doing operations to easily fuse. let weight = weight.unsqueeze::(); let bias = bias.map(|bias| bias.unsqueeze::()); let output = input.matmul(weight); match bias { Some(bias) => output.add(bias), None => output, } } /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V, /// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`). /// Optionally applies masking, additive bias, causal masking, and softcap. /// /// # Arguments /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]` /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]` /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]` /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`, /// where `true` indicates positions to mask (i.e. set to -inf before softmax). /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]` /// added to the attention scores before softmax (e.g. ALiBi, relative position biases). /// - `options`: Additional attention options (custom scale, softcap, causal masking). /// /// # Returns /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]` /// representing the attended context per head. /// /// # Note /// This implementation does not support dropout and is intended for inference or /// use cases where dropout is not needed. pub fn attention( query: Tensor, key: Tensor, value: Tensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> Tensor { Tensor::new(TensorPrimitive::Float(B::attention( query.primitive.tensor(), key.primitive.tensor(), value.primitive.tensor(), mask.map(|mask| mask.primitive), attn_bias.map(|bias| bias.primitive.tensor()), options, ))) } /// Exports attention fallback to test backend's attention against. pub fn attention_fallback( query: Tensor, key: Tensor, value: Tensor, mask: Option>, attn_bias: Option>, options: AttentionModuleOptions, ) -> Tensor { Tensor::new(TensorPrimitive::Float( crate::ops::attention::attention_fallback::( query.primitive.tensor(), key.primitive.tensor(), value.primitive.tensor(), mask.map(|mask| mask.primitive), attn_bias.map(|bias| bias.primitive.tensor()), options, ), )) } ================================================ FILE: crates/burn-tensor/src/tensor/quantization.rs ================================================ use crate::{Tensor, TensorPrimitive, backend::Backend}; use burn_backend::tensor::quantization; // We re-export those types. pub use burn_backend::{QTensorPrimitive, quantization::*}; /// The tensor quantization parameters. pub type QuantizationParameters = QParams>; /// The observed input calibration range. #[derive(Clone, Debug)] pub struct CalibrationRange { /// Minimum observed value(s). pub min: Tensor, /// Maximum observed value(s). pub max: Tensor, } /// Compute the quantization range mapping. pub fn compute_range( scheme: &QuantScheme, tensor: &Tensor, calibration: &Calibration, ) -> CalibrationRange { let (min, max) = match &tensor.primitive { TensorPrimitive::Float(tensor) => { quantization::compute_range::(scheme, tensor.clone(), calibration) } TensorPrimitive::QFloat(_) => unreachable!(), }; CalibrationRange { min: Tensor::from_primitive(TensorPrimitive::Float(min)), max: Tensor::from_primitive(TensorPrimitive::Float(max)), } } /// Compute the quantization parameters. pub fn compute_q_params( scheme: &QuantScheme, range: CalibrationRange, ) -> QuantizationParameters { match (range.min.primitive, range.max.primitive) { (TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => { let qparams = quantization::compute_q_params::(scheme, min, max); QuantizationParameters { scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)), } } _ => unreachable!(), } } ================================================ FILE: crates/burn-tensor/src/tensor/report.rs ================================================ use super::{Tensor, backend::Backend}; use colored::*; /// Checks the closeness of two tensors and prints the results. /// /// Compares tensors by checking the absolute difference between each element. /// Prints the percentage of elements within specified tolerances. /// /// # Arguments /// /// * `output` - The output tensor. /// * `expected` - The expected tensor. /// /// # Example /// /// ```no_run /// use burn_tensor::backend::Backend; /// use burn_tensor::{check_closeness, Tensor}; /// /// fn example() { /// let device = Default::default(); /// let tensor1 = Tensor::::from_floats( /// [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1], /// &device, /// ); /// let tensor2 = Tensor::::from_floats( /// [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004], /// &device, /// ); /// check_closeness(&tensor1, &tensor2); ///} /// ``` /// /// # Output /// /// ```text /// Tensor Closeness Check Results: /// =============================== /// Epsilon: 1e-1 /// Close elements: 10/10 (100.00%) /// [PASS] All elements are within tolerance /// /// Epsilon: 1e-2 /// Close elements: 10/10 (100.00%) /// [PASS] All elements are within tolerance /// /// Epsilon: 1e-3 /// Close elements: 9/10 (90.00%) /// [WARN] Most elements are within tolerance /// /// Epsilon: 1e-4 /// Close elements: 6/10 (60.00%) /// [FAIL] Significant differences detected /// /// Epsilon: 1e-5 /// Close elements: 5/10 (50.00%) /// [FAIL] Significant differences detected /// /// Epsilon: 1e-6 /// Close elements: 5/10 (50.00%) /// [FAIL] Significant differences detected /// /// Epsilon: 1e-7 /// Close elements: 5/10 (50.00%) /// [FAIL] Significant differences detected /// /// Epsilon: 1e-8 /// Close elements: 5/10 (50.00%) /// [FAIL] Significant differences detected /// /// Closeness check complete. /// ``` pub fn check_closeness(output: &Tensor, expected: &Tensor) { println!("{}", "Tensor Closeness Check Results:".bold()); println!("==============================="); for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() { println!("{} {:e}", "Epsilon:".bold(), epsilon); let close = output .clone() .is_close(expected.clone(), Some(*epsilon), Some(*epsilon)); let data = close.clone().into_data(); let num_elements = data.num_elements(); // Count the number of elements that are close (true) let count = data.iter::().filter(|x| *x).count(); let percentage = (count as f64 / num_elements as f64) * 100.0; println!(" Close elements: {count}/{num_elements} ({percentage:.2}%)"); if percentage == 100.0 { println!(" {} All elements are within tolerance", "[PASS]".green()); } else if percentage >= 90.0 { println!(" {} Most elements are within tolerance", "[WARN]".yellow()); } else { println!(" {} Significant differences detected", "[FAIL]".red()); } println!(); } println!("{}", "Closeness check complete.".bold()); } ================================================ FILE: crates/burn-tensor/src/tensor/stats/mod.rs ================================================ use crate::{Tensor, backend::Backend}; use burn_backend::tensor::Int; pub fn var(tensor: Tensor, dim: usize) -> Tensor { let mean = tensor.clone().mean_dim(dim); var_with_mean(tensor, mean, dim) } pub fn var_with_mean( tensor: Tensor, mean: Tensor, dim: usize, ) -> Tensor { let n = tensor.shape()[dim] - 1; var_with_mean_n(tensor, mean, dim, n) } pub fn var_bias(tensor: Tensor, dim: usize) -> Tensor { let mean = tensor.clone().mean_dim(dim); var_with_mean_bias(tensor, mean, dim) } pub fn var_with_mean_bias( tensor: Tensor, mean: Tensor, dim: usize, ) -> Tensor { let n = tensor.shape()[dim]; var_with_mean_n(tensor, mean, dim, n) } pub fn var_with_mean_n( tensor: Tensor, mean: Tensor, dim: usize, n: usize, ) -> Tensor { tensor.sub(mean).square().sum_dim(dim).div_scalar(n as f32) } pub fn median(tensor: Tensor, dim: usize) -> Tensor { let total_elem_numbers = tensor.dims()[dim]; let sorted_tensor = tensor.sort(dim); // Following the PyTorch behavior: // - Odd count: the median // - Even count: the lower of the two median elements // // Example: // - 5 elements: (5 - 1) / 2 = 4 / 2 = 2 // - 4 elements: (4 - 1) / 2 = 3 / 2 = 1 let median_index = (total_elem_numbers - 1) / 2; sorted_tensor.narrow(dim, median_index, 1) } pub fn median_with_indices( tensor: Tensor, dim: usize, ) -> (Tensor, Tensor) { let total_elem_numbers = tensor.dims()[dim]; let (sorted_tensor, indices) = tensor.sort_with_indices(dim); // Following the PyTorch behavior: // - Odd count: the median // - Even count: the lower of the two median elements // // Example: // - 5 elements: (5 - 1) / 2 = 4 / 2 = 2 // - 4 elements: (4 - 1) / 2 = 3 / 2 = 1 let median_index = (total_elem_numbers - 1) / 2; let median_values = sorted_tensor.narrow(dim, median_index, 1); let median_indices = indices.narrow(dim, median_index, 1); (median_values, median_indices) } ================================================ FILE: crates/burn-tensor-testgen/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] description = "Test generation crate for burn-tensor" edition.workspace = true license.workspace = true name = "burn-tensor-testgen" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor-testgen" version.workspace = true [lints] workspace = true [lib] proc-macro = true [dependencies] proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } ================================================ FILE: crates/burn-tensor-testgen/README.md ================================================ # Burn Tensor Test Generation > [Burn](https://github.com/tracel-ai/burn) tensor test generation [![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor-testgen.svg)](https://crates.io/crates/burn-tensor-testgen) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tensor-testgen/blob/master/README.md) ================================================ FILE: crates/burn-tensor-testgen/src/lib.rs ================================================ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue, parse_macro_input}; // Define a structure to parse the attribute arguments struct AttributeArgs { args: Punctuated, } impl Parse for AttributeArgs { fn parse(input: ParseStream) -> syn::Result { Ok(AttributeArgs { args: Punctuated::parse_terminated(input)?, }) } } #[allow(clippy::test_attr_in_doctest)] /// **This is only meaningful when the `reason` is specific and clear.** /// /// A proc macro attribute that adds panic handling to test functions. /// /// # Usage /// ```rust, ignore /// #[might_panic(reason = "expected panic message prefix")] /// #[test] /// fn test_that_might_panic() { /// // test code that might panic (with acceptable reason) /// } /// ``` /// /// # Behavior /// - If the test does not panic, it passes. /// - If the test panics with a message starting with the expected prefix, the failure is ignored. /// - If the test panics with a different message, the test fails. /// /// # Note /// This proc macro uses [`std::panic::catch_unwind`]. As such, it does not work in a no-std environment. /// Make sure it is feature gated when an `"std"` feature is available. #[proc_macro_attribute] pub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream { // Parse the attribute arguments let args = parse_macro_input!(args as AttributeArgs); let input_fn = parse_macro_input!(input as ItemFn); // Extract the expected panic reason let mut expected_reason = None; for arg in args.args.iter() { if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg && path.is_ident("reason") && let Expr::Lit(lit) = value && let Lit::Str(ref lit_str) = lit.lit { expected_reason = Some(lit_str.value()); } } let expected_reason = match expected_reason { Some(reason) => reason, None => { return syn::Error::new( proc_macro2::Span::call_site(), "The #[might_panic] attribute requires a 'reason' parameter", ) .to_compile_error() .into(); } }; let fn_name = &input_fn.sig.ident; let fn_vis = &input_fn.vis; let fn_generics = &input_fn.sig.generics; let fn_block = &input_fn.block; let fn_attrs = input_fn .attrs .iter() .filter(|attr| !attr.path().is_ident("test")) .collect::>(); // Create a wrapped test function let wrapper_name = format_ident!("{}_might_panic", fn_name); quote! { #(#fn_attrs)* #fn_vis fn #fn_name #fn_generics() { #fn_block } #[test] #fn_vis fn #wrapper_name #fn_generics() { use std::panic::{self, AssertUnwindSafe}; use std::sync::{Arc, Mutex, OnceLock}; let get_msg = |p: &(dyn std::any::Any + Send)| -> String { p.downcast_ref::().cloned() .or_else(|| p.downcast_ref::<&str>().map(|s| s.to_string())) .unwrap_or_else(|| "Unknown panic".to_string()) }; // An append-only list of all panic messages across the entire process. // This is required because cubecl's `CallError` hides the original panic message // occurring in the device threads. // // A global log also prevents parallel tests from overwriting each other's panic hooks. static PANIC_LOG: OnceLock>> = OnceLock::new(); let log = PANIC_LOG.get_or_init(|| Mutex::new(Vec::new())); static HOOK: OnceLock<()> = OnceLock::new(); HOOK.get_or_init(|| { let prev = panic::take_hook(); panic::set_hook(Box::new(move |info| { if let Ok(mut v) = log.lock() { v.push(get_msg(info.payload())); } prev(info); })); }); // We only care about panics that occur during this test's execution window, so // we start at the number of panics logged before this test starts. let start_idx = log.lock().unwrap().len(); let result = panic::catch_unwind(AssertUnwindSafe(|| #fn_name())); if let Err(e) = result { let main_msg = get_msg(&*e); let panic_logs = log.lock().unwrap(); let window = &panic_logs[start_idx..]; let matched = window.iter().chain(std::iter::once(&main_msg)) .any(|m| m.contains(#expected_reason)); if !matched { let all = window.iter().chain(std::iter::once(&main_msg)) .map(|m| format!("- {m}")).collect::>().join("\n"); panic!("\nTest '{}' failed.\nExpected: '{}'\nFound:\n{}\n", stringify!(#fn_name), #expected_reason, all); } else { let all = window.iter().chain(std::iter::once(&main_msg)) .map(|m| format!("- {m}")).collect::>().join("\n"); println!("\nTest '{}' failed.\nExpected: '{}'\nFound:\n{}\n", stringify!(#fn_name), #expected_reason, all); } } } } .into() } ================================================ FILE: crates/burn-train/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] categories = ["science"] description = "Training crate for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-train" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-train" documentation = "https://docs.rs/burn-train" version.workspace = true [lints] workspace = true [features] default = ["sys-metrics", "tui", "rl"] doc = ["default"] vision = ["burn-nn", "burn-store/pytorch", "burn-std/network", "dirs"] tracing = [ "burn-core/tracing", "burn-optim/tracing", "burn-collective?/tracing", ] sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] rl = ["burn-rl"] # Distributed Data Parallel ddp = ["burn-collective", "burn-optim/collective"] [dependencies] burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [ "dataset", "std", ], default-features = false } burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [ "std", ], default-features = false } burn-rl = { path = "../burn-rl", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true } burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } dirs = { workspace = true, optional = true } log = { workspace = true } tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } tracing-core = { workspace = true } # System Metrics nvml-wrapper = { workspace = true, optional = true } sysinfo = { workspace = true, optional = true } systemstat = { workspace = true, optional = true } # Text UI ratatui = { workspace = true, optional = true, features = [ "all-widgets", "crossterm", ] } # Utilities derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } rstest.workspace = true thiserror.workspace = true rand.workspace = true [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] ================================================ FILE: crates/burn-train/README.md ================================================ # Burn Train This crate should be used with [burn](https://github.com/tracel-ai/burn). [![Current Crates.io Version](https://img.shields.io/crates/v/burn-train.svg)](https://crates.io/crates/burn-train) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-train/blob/master/README.md) ================================================ FILE: crates/burn-train/src/checkpoint/async_checkpoint.rs ================================================ use super::{Checkpointer, CheckpointerError}; use crate::Interrupter; use burn_core::{record::Record, tensor::backend::Backend}; use std::sync::mpsc; enum Message { Restore( usize, B::Device, mpsc::SyncSender>, Option, ), Save(usize, R, Option), Delete(usize, Option), End, } #[derive(new)] struct CheckpointerThread { checkpointer: C, receiver: mpsc::Receiver>, } impl CheckpointerThread where C: Checkpointer, R: Record, B: Backend, { fn run(self) { for item in self.receiver.iter() { match item { Message::Restore(epoch, device, callback, interrupter) => { let record = self.checkpointer.restore(epoch, &device); callback.send(record).unwrap_or_else(|err| { interrupter.map_or_else( || { panic!( "Error when sending response through callback channel: {err}" ) }, |int| int.stop(Some(&err.to_string())), ) }); } Message::Save(epoch, state, interrupter) => { self.checkpointer.save(epoch, state).unwrap_or_else(|err| { interrupter.map_or_else( || panic!("Error when saving the state: {err}"), |int| int.stop(Some(&err.to_string())), ) }); } Message::Delete(epoch, interrupter) => { self.checkpointer.delete(epoch).unwrap_or_else(|err| { interrupter.map_or_else( || panic!("Error when deleting the state: {err}"), |int| int.stop(Some(&err.to_string())), ) }); } Message::End => { return; } }; } } } /// Async checkpointer. pub struct AsyncCheckpointer { sender: mpsc::SyncSender>, handler: Option>, interrupter: Option, } impl AsyncCheckpointer where R: Record + 'static, B: Backend, { /// Create a new async checkpointer. /// /// # Arguments /// /// * `checkpointer` - The checkpointer. /// /// # Returns /// /// The async checkpointer. pub fn new(checkpointer: C) -> Self where C: Checkpointer + Send + 'static, { // Only on checkpoint can be done in advance. let (sender, receiver) = mpsc::sync_channel(0); let thread = CheckpointerThread::new(checkpointer, receiver); let handler = Some(std::thread::spawn(move || thread.run())); Self { sender, handler, interrupter: None, } } /// Assign a handle used to interrupt training in case of checkpointing error. pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self { self.interrupter = Some(interrupter); self } } impl Checkpointer for AsyncCheckpointer where R: Record + 'static, B: Backend, { fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { self.sender .send(Message::Save(epoch, record, self.interrupter.clone())) .expect("Can send message to checkpointer thread."); Ok(()) } fn restore(&self, epoch: usize, device: &B::Device) -> Result { let (sender, receiver) = mpsc::sync_channel(1); self.sender .send(Message::Restore( epoch, device.clone(), sender, self.interrupter.clone(), )) .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; if let Ok(record) = receiver.recv() { return record; }; Err(CheckpointerError::Unknown("Channel error.".to_string())) } fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { self.sender .send(Message::Delete(epoch, self.interrupter.clone())) .map_err(|e| CheckpointerError::Unknown(e.to_string()))?; Ok(()) } } impl Drop for AsyncCheckpointer where B: Backend, { fn drop(&mut self) { self.sender .send(Message::End) .expect("Can send the end message to the checkpointer thread."); let handler = self.handler.take(); if let Some(handler) = handler { handler .join() .expect("The checkpointer thread should stop."); } } } ================================================ FILE: crates/burn-train/src/checkpoint/base.rs ================================================ use burn_core::{ record::{Record, RecorderError}, tensor::backend::Backend, }; use thiserror::Error; /// The error type for checkpointer. #[derive(Error, Debug)] pub enum CheckpointerError { /// IO error. #[error("I/O Error: `{0}`")] IOError(std::io::Error), /// Recorder error. #[error("Recorder error: `{0}`")] RecorderError(RecorderError), /// Other errors. #[error("Unknown error: `{0}`")] Unknown(String), } /// The trait for checkpointer. pub trait Checkpointer: Send + Sync where R: Record, B: Backend, { /// Save the record. /// /// # Arguments /// /// * `epoch` - The epoch. /// * `record` - The record. fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>; /// Delete the record at the given epoch if present. fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>; /// Restore the record. /// /// # Arguments /// /// * `epoch` - The epoch. /// * `device` - The device used to restore the record. /// /// # Returns /// /// The record. fn restore(&self, epoch: usize, device: &B::Device) -> Result; } ================================================ FILE: crates/burn-train/src/checkpoint/file.rs ================================================ use std::path::{Path, PathBuf}; use super::{Checkpointer, CheckpointerError}; use burn_core::{ record::{FileRecorder, Record}, tensor::backend::Backend, }; /// The file checkpointer. pub struct FileCheckpointer { directory: PathBuf, name: String, recorder: FR, } impl FileCheckpointer { /// Creates a new file checkpointer. /// /// # Arguments /// /// * `recorder` - The file recorder. /// * `directory` - The directory to save the checkpoints. /// * `name` - The name of the checkpoint. pub fn new(recorder: FR, directory: impl AsRef, name: &str) -> Self { let directory = directory.as_ref(); std::fs::create_dir_all(directory).ok(); Self { directory: directory.to_path_buf(), name: name.to_string(), recorder, } } fn path_for_epoch(&self, epoch: usize) -> PathBuf { self.directory.join(format!("{}-{}", self.name, epoch)) } } impl Checkpointer for FileCheckpointer where R: Record, FR: FileRecorder, B: Backend, { fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> { let file_path = self.path_for_epoch(epoch); log::trace!("Saving checkpoint {} to {}", epoch, file_path.display()); self.recorder .record(record, file_path) .map_err(CheckpointerError::RecorderError)?; Ok(()) } fn restore(&self, epoch: usize, device: &B::Device) -> Result { let file_path = self.path_for_epoch(epoch); log::info!( "Restoring checkpoint {} from {}", epoch, file_path.display() ); let record = self .recorder .load(file_path, device) .map_err(CheckpointerError::RecorderError)?; Ok(record) } fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> { let file_to_remove = format!( "{}.{}", self.path_for_epoch(epoch).display(), FR::file_extension(), ); if std::path::Path::new(&file_to_remove).exists() { log::trace!("Removing checkpoint {file_to_remove}"); std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?; } Ok(()) } } ================================================ FILE: crates/burn-train/src/checkpoint/mod.rs ================================================ mod async_checkpoint; mod base; mod file; mod strategy; pub use async_checkpoint::*; pub use base::*; pub use file::*; pub use strategy::*; ================================================ FILE: crates/burn-train/src/checkpoint/strategy/base.rs ================================================ use std::ops::DerefMut; use crate::metric::store::EventStoreClient; /// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer). #[derive(Clone, PartialEq, Debug)] pub enum CheckpointingAction { /// Delete the given epoch. Delete(usize), /// Save the current record. Save, } /// Define when checkpoint should be saved and deleted. pub trait CheckpointingStrategy: Send { /// Based on the epoch, determine if the checkpoint should be saved. fn checkpointing( &mut self, epoch: usize, collector: &EventStoreClient, ) -> Vec; } // We make dyn box implement the checkpointing strategy so that it can be used with generic, but // still be dynamic. impl CheckpointingStrategy for Box { fn checkpointing( &mut self, epoch: usize, collector: &EventStoreClient, ) -> Vec { self.deref_mut().checkpointing(epoch, collector) } } ================================================ FILE: crates/burn-train/src/checkpoint/strategy/composed.rs ================================================ use crate::metric::store::EventStoreClient; use super::{CheckpointingAction, CheckpointingStrategy}; use std::collections::HashSet; /// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an /// epoch to be deleted. pub struct ComposedCheckpointingStrategy { strategies: Vec>, deleted: Vec>, } /// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones. #[derive(Default)] pub struct ComposedCheckpointingStrategyBuilder { strategies: Vec>, } impl ComposedCheckpointingStrategyBuilder { /// Add a new [checkpointing strategy](CheckpointingStrategy). #[allow(clippy::should_implement_trait)] pub fn add(mut self, strategy: S) -> Self where S: CheckpointingStrategy + 'static, { self.strategies.push(Box::new(strategy)); self } /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). pub fn build(self) -> ComposedCheckpointingStrategy { ComposedCheckpointingStrategy::new(self.strategies) } } impl ComposedCheckpointingStrategy { fn new(strategies: Vec>) -> Self { Self { deleted: strategies.iter().map(|_| HashSet::new()).collect(), strategies, } } /// Create a new builder which help compose multiple /// [checkpointing strategies](CheckpointingStrategy). pub fn builder() -> ComposedCheckpointingStrategyBuilder { ComposedCheckpointingStrategyBuilder::default() } } impl CheckpointingStrategy for ComposedCheckpointingStrategy { fn checkpointing( &mut self, epoch: usize, collector: &EventStoreClient, ) -> Vec { let mut saved = false; let mut actions = Vec::new(); let mut epochs_to_check = Vec::new(); for (i, strategy) in self.strategies.iter_mut().enumerate() { let actions = strategy.checkpointing(epoch, collector); // We assume that the strategy would not want the current epoch to be saved. // So we flag it as deleted. if actions.is_empty() { self.deleted .get_mut(i) .expect("As many 'deleted' as 'strategies'.") .insert(epoch); } for action in actions { match action { CheckpointingAction::Delete(epoch) => { self.deleted .get_mut(i) .expect("As many 'deleted' as 'strategies'.") .insert(epoch); epochs_to_check.push(epoch); } CheckpointingAction::Save => saved = true, } } } if saved { actions.push(CheckpointingAction::Save); } for epoch in epochs_to_check.into_iter() { let mut num_true = 0; for i in 0..self.strategies.len() { if self .deleted .get(i) .expect("Ad many 'deleted' as 'strategies'.") .contains(&epoch) { num_true += 1; } } if num_true == self.strategies.len() { actions.push(CheckpointingAction::Delete(epoch)); for i in 0..self.strategies.len() { self.deleted .get_mut(i) .expect("As many 'deleted' as 'strategies'.") .remove(&epoch); } } } actions } } #[cfg(test)] mod tests { use super::*; use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; #[test] fn should_delete_when_both_deletes() { let store = EventStoreClient::new(LogEventStore::default()); let mut strategy = ComposedCheckpointingStrategy::builder() .add(KeepLastNCheckpoints::new(1)) .add(KeepLastNCheckpoints::new(2)) .build(); assert_eq!( vec![CheckpointingAction::Save], strategy.checkpointing(1, &store) ); assert_eq!( vec![CheckpointingAction::Save], strategy.checkpointing(2, &store) ); assert_eq!( vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], strategy.checkpointing(3, &store) ); } } ================================================ FILE: crates/burn-train/src/checkpoint/strategy/lastn.rs ================================================ use super::CheckpointingStrategy; use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient}; /// Keep the last N checkpoints. /// /// Very useful when training, minimizing disk space while ensuring that the training can be /// resumed even if something goes wrong. #[derive(new)] pub struct KeepLastNCheckpoints { num_keep: usize, } impl CheckpointingStrategy for KeepLastNCheckpoints { fn checkpointing( &mut self, epoch: usize, _store: &EventStoreClient, ) -> Vec { let mut actions = vec![CheckpointingAction::Save]; if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) && epoch > 0 { actions.push(CheckpointingAction::Delete(epoch)); } actions } } #[cfg(test)] mod tests { use super::*; use crate::metric::store::LogEventStore; #[test] fn should_always_delete_lastn_epoch_if_higher_than_one() { let mut strategy = KeepLastNCheckpoints::new(2); let store = EventStoreClient::new(LogEventStore::default()); assert_eq!( vec![CheckpointingAction::Save], strategy.checkpointing(1, &store) ); assert_eq!( vec![CheckpointingAction::Save], strategy.checkpointing(2, &store) ); assert_eq!( vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], strategy.checkpointing(3, &store) ); } } ================================================ FILE: crates/burn-train/src/checkpoint/strategy/metric.rs ================================================ use super::CheckpointingStrategy; use crate::{ checkpoint::CheckpointingAction, metric::{ Metric, MetricName, store::{Aggregate, Direction, EventStoreClient, Split}, }, }; /// Keep the best checkpoint based on a metric. pub struct MetricCheckpointingStrategy { current: Option, aggregate: Aggregate, direction: Direction, split: Split, name: MetricName, } impl MetricCheckpointingStrategy { /// Create a new metric checkpointing strategy. pub fn new(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self where M: Metric, { Self { current: None, name: metric.name(), aggregate, direction, split, } } } impl CheckpointingStrategy for MetricCheckpointingStrategy { fn checkpointing( &mut self, epoch: usize, store: &EventStoreClient, ) -> Vec { let best_epoch = match store.find_epoch(&self.name, self.aggregate, self.direction, &self.split) { Some(epoch_best) => epoch_best, None => epoch, }; let mut actions = Vec::new(); if let Some(current) = self.current && current != best_epoch { actions.push(CheckpointingAction::Delete(current)); } if best_epoch == epoch { actions.push(CheckpointingAction::Save); } self.current = Some(best_epoch); actions } } #[cfg(test)] mod tests { use crate::{ EventProcessorTraining, TestBackend, logger::InMemoryMetricLogger, metric::{ LossMetric, processor::{ MetricsTraining, MinimalEventProcessor, test_utils::{end_epoch, process_train}, }, store::LogEventStore, }, }; use super::*; use std::sync::Arc; #[test] fn always_keep_the_best_epoch() { let loss = LossMetric::::new(); let mut store = LogEventStore::default(); let mut strategy = MetricCheckpointingStrategy::new( &loss, Aggregate::Mean, Direction::Lowest, Split::Train, ); let mut metrics = MetricsTraining::::default(); // Register an in memory logger. store.register_logger(InMemoryMetricLogger::default()); // Register the loss metric. metrics.register_train_metric_numeric(loss); let store = Arc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); processor.process_train(crate::LearnerEvent::Start); // Two points for the first epoch. Mean 0.75 let mut epoch = 1; process_train(&mut processor, 1.0, epoch); process_train(&mut processor, 0.5, epoch); end_epoch(&mut processor, epoch); // Should save the current record. assert_eq!( vec![CheckpointingAction::Save], strategy.checkpointing(epoch, &store) ); // Two points for the second epoch. Mean 0.4 epoch += 1; process_train(&mut processor, 0.5, epoch); process_train(&mut processor, 0.3, epoch); end_epoch(&mut processor, epoch); // Should save the current record and delete the previous one. assert_eq!( vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], strategy.checkpointing(epoch, &store) ); // Two points for the last epoch. Mean 2.0 epoch += 1; process_train(&mut processor, 1.0, epoch); process_train(&mut processor, 3.0, epoch); end_epoch(&mut processor, epoch); // Should not delete the previous record, since it's the best one, and should not save a // new one. assert!(strategy.checkpointing(epoch, &store).is_empty()); } } ================================================ FILE: crates/burn-train/src/checkpoint/strategy/mod.rs ================================================ mod base; mod composed; mod lastn; mod metric; pub use base::*; pub use composed::*; pub use lastn::*; pub use metric::*; ================================================ FILE: crates/burn-train/src/components.rs ================================================ use crate::{InferenceStep, TrainStep}; use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend}; use burn_optim::{Optimizer, lr_scheduler::LrScheduler}; use std::marker::PhantomData; /// Components used for a model to learn, grouped in one trait. pub trait LearningComponentsTypes { /// The backend used for training. type Backend: AutodiffBackend; /// The learning rate scheduler used for training. type LrScheduler: LrScheduler + 'static; /// The model to train. type TrainingModel: TrainStep + AutodiffModule + core::fmt::Display + 'static; /// The non-autodiff type of the model. type InferenceModel: InferenceStep; /// The optimizer used for training. type Optimizer: Optimizer + 'static; } /// Concrete type that implements the [LearningComponentsTypes](LearningComponentsTypes) trait. pub struct LearningComponentsMarker { _backend: PhantomData, _lr_scheduler: PhantomData, _model: PhantomData, _optimizer: PhantomData, } impl LearningComponentsTypes for LearningComponentsMarker where B: AutodiffBackend, LR: LrScheduler + 'static, M: TrainStep + AutodiffModule + core::fmt::Display + 'static, M::InnerModule: InferenceStep, O: Optimizer + 'static, { type Backend = B; type LrScheduler = LR; type TrainingModel = M; type InferenceModel = M::InnerModule; type Optimizer = O; } /// The training backend. pub type TrainingBackend = ::Backend; /// The inference backend. pub(crate) type InferenceBackend = <::Backend as AutodiffBackend>::InnerBackend; /// The model used for training. pub type TrainingModel = ::TrainingModel; /// The non-autodiff model. pub(crate) type InferenceModel = ::InferenceModel; /// Type for training input. pub(crate) type TrainingModelInput = <::TrainingModel as TrainStep>::Input; /// Type for inference input. pub(crate) type InferenceModelInput = <::InferenceModel as InferenceStep>::Input; /// Type for training output. pub(crate) type TrainingModelOutput = <::TrainingModel as TrainStep>::Output; /// Type for inference output. pub(crate) type InferenceModelOutput = <::InferenceModel as InferenceStep>::Output; ================================================ FILE: crates/burn-train/src/evaluator/base.rs ================================================ use crate::{ AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep, Interrupter, LearnerSummaryConfig, evaluator::components::EvaluatorComponentTypes, metric::processor::{EvaluatorEvent, EventProcessorEvaluation}, renderer::{EvaluationName, MetricsRenderer}, }; use burn_core::{data::dataloader::DataLoader, module::Module}; use std::sync::Arc; pub(crate) type TestBackend = ::Backend; pub(crate) type TestInput = <::Model as InferenceStep>::Input; pub(crate) type TestOutput = <::Model as InferenceStep>::Output; pub(crate) type TestLoader = Arc, TestInput>>; /// Evaluates a model on a specific dataset. pub struct Evaluator { pub(crate) model: EC::Model, pub(crate) interrupter: Interrupter, pub(crate) event_processor: AsyncProcessorEvaluation>>, /// Config for creating a summary of the evaluation pub summary: Option, } impl Evaluator { /// Run the evaluation on the given dataset. /// /// The data will be stored and displayed under the provided name. pub fn eval( mut self, name: S, dataloader: TestLoader, ) -> Box { // Move dataloader to the model device let dataloader = dataloader.to_device(self.model.devices().first().unwrap()); let name = EvaluationName::new(name); let mut iterator = dataloader.iter(); let mut iteration = 0; self.event_processor.process_test(EvaluatorEvent::Start); while let Some(item) = iterator.next() { let progress = iterator.progress(); iteration += 1; let item = self.model.step(item); let item = EvaluationItem::new(item, progress, Some(iteration)); self.event_processor .process_test(EvaluatorEvent::ProcessedItem(name.clone(), item)); if self.interrupter.should_stop() { log::info!("Testing interrupted."); break; } } let summary = self.summary.and_then(|summary| { summary .init() .map(|summary| summary.with_model(self.model.to_string())) .ok() }); self.event_processor .process_test(EvaluatorEvent::End(summary)); self.event_processor.renderer() } } ================================================ FILE: crates/burn-train/src/evaluator/builder.rs ================================================ use crate::{ ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, InferenceStep, Interrupter, LearnerSummaryConfig, TestOutput, evaluator::components::{EvaluatorComponentTypes, EvaluatorComponentTypesMarker}, logger::FileMetricLogger, metric::{ Adaptor, ItemLazy, Metric, Numeric, processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation}, store::{EventStoreClient, LogEventStore}, }, renderer::{MetricsRenderer, default_renderer}, }; use burn_core::{module::Module, prelude::Backend}; use std::{ collections::BTreeSet, path::{Path, PathBuf}, sync::Arc, }; /// Struct to configure and create an [evaluator](Evaluator). /// /// The generics components of the builder should probably not be set manually, as they are /// optimized for Rust type inference. pub struct EvaluatorBuilder { tracing_logger: Option>, event_store: LogEventStore, summary_metrics: BTreeSet, renderer: Option>, interrupter: Interrupter, metrics: MetricsEvaluation>, directory: PathBuf, summary: bool, } impl EvaluatorBuilder> where B: Backend, M: Module + InferenceStep + core::fmt::Display + 'static, { /// Creates a new evaluator builder. /// /// # Arguments /// /// * `directory` - The directory to save the checkpoints. pub fn new(directory: impl AsRef) -> Self { let directory = directory.as_ref().to_path_buf(); let log_file = directory.join("evaluation.log"); Self { tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))), event_store: LogEventStore::default(), summary_metrics: Default::default(), renderer: None, interrupter: Interrupter::new(), summary: false, metrics: MetricsEvaluation::default(), directory, } } } impl EvaluatorBuilder { /// Registers [numeric](crate::metric::Numeric) test [metrics](Metric). pub fn metrics>(self, metrics: Me) -> Self { metrics.register(self) } /// Registers text [metrics](Metric). pub fn metrics_text>(self, metrics: Me) -> Self { metrics.register(self) } /// By default, Rust logs are captured and written into /// `evaluation.log`. If disabled, standard Rust log handling /// will apply. pub fn with_application_logger( mut self, logger: Option>, ) -> Self { self.tracing_logger = logger; self } /// Register a [numeric](crate::metric::Numeric) test [metric](Metric). pub fn metric_numeric(mut self, metric: Me) -> Self where Me: Metric + Numeric + 'static, as ItemLazy>::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_test_metric_numeric(metric); self } /// Register a text test [metric](Metric). pub fn metric(mut self, metric: Me) -> Self where Me: Metric + 'static, as ItemLazy>::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_test_metric(metric); self } /// Replace the default CLI renderer with a custom one. /// /// # Arguments /// /// * `renderer` - The custom renderer. pub fn renderer(mut self, renderer: Box) -> Self { self.renderer = Some(renderer); self } /// Enable the evaluation summary report. /// /// The summary will be displayed at the end of `.eval()`. pub fn summary(mut self) -> Self { self.summary = true; self } /// Builds the evaluator. #[allow(clippy::type_complexity)] pub fn build(mut self, model: EC::Model) -> Evaluator { let renderer = self .renderer .unwrap_or_else(|| default_renderer(self.interrupter.clone(), None)); self.event_store .register_logger(FileMetricLogger::new_eval(self.directory.clone())); let event_store = Arc::new(EventStoreClient::new(self.event_store)); let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new( self.metrics, renderer, event_store, )); let summary = if self.summary { Some(LearnerSummaryConfig { directory: self.directory, metrics: self.summary_metrics.into_iter().collect::>(), }) } else { None }; Evaluator { model, interrupter: self.interrupter, event_processor, summary, } } } /// Trait to fake variadic generics. pub trait EvalMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: EvaluatorBuilder) -> EvaluatorBuilder; } /// Trait to fake variadic generics. pub trait EvalTextMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: EvaluatorBuilder) -> EvaluatorBuilder; } macro_rules! gen_tuple { ($($M:ident),*) => { impl<$($M,)* EC: EvaluatorComponentTypes> EvalTextMetricRegistration for ($($M,)*) where $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: EvaluatorBuilder, ) -> EvaluatorBuilder { let ($($M,)*) = self; $(let builder = builder.metric($M);)* builder } } impl<$($M,)* EC: EvaluatorComponentTypes> EvalMetricRegistration for ($($M,)*) where $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + $crate::metric::Numeric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: EvaluatorBuilder, ) -> EvaluatorBuilder { let ($($M,)*) = self; $(let builder = builder.metric_numeric($M);)* builder } } }; } gen_tuple!(M1); gen_tuple!(M1, M2); gen_tuple!(M1, M2, M3); gen_tuple!(M1, M2, M3, M4); gen_tuple!(M1, M2, M3, M4, M5); gen_tuple!(M1, M2, M3, M4, M5, M6); ================================================ FILE: crates/burn-train/src/evaluator/components.rs ================================================ use crate::InferenceStep; use burn_core::{module::Module, prelude::Backend}; use std::marker::PhantomData; /// All components necessary to evaluate a model grouped in one trait. pub trait EvaluatorComponentTypes { /// The backend in used for the evaluation. type Backend: Backend; /// The model to evaluate. type Model: Module + InferenceStep + core::fmt::Display + 'static; } /// A marker type used to provide [evaluation components](EvaluatorComponentTypes). pub struct EvaluatorComponentTypesMarker { _p: PhantomData<(B, M)>, } impl EvaluatorComponentTypes for EvaluatorComponentTypesMarker where B: Backend, M: Module + InferenceStep + core::fmt::Display + 'static, { type Backend = B; type Model = M; } ================================================ FILE: crates/burn-train/src/evaluator/mod.rs ================================================ mod base; mod builder; pub(crate) mod components; pub use base::*; pub use builder::*; ================================================ FILE: crates/burn-train/src/learner/application_logger.rs ================================================ use std::path::{Path, PathBuf}; use tracing_core::{Level, LevelFilter}; use tracing_subscriber::filter::filter_fn; use tracing_subscriber::prelude::*; use tracing_subscriber::{Layer, registry}; /// This trait is used to install an application logger. pub trait ApplicationLoggerInstaller { /// Install the application logger. fn install(&self) -> Result<(), String>; } /// This struct is used to install a local file application logger to output logs to a given file path. pub struct FileApplicationLoggerInstaller { path: PathBuf, } impl FileApplicationLoggerInstaller { /// Create a new file application logger. pub fn new(path: impl AsRef) -> Self { Self { path: path.as_ref().to_path_buf(), } } } impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller { fn install(&self) -> Result<(), String> { let path = Path::new(&self.path); let writer = tracing_appender::rolling::never( path.parent().unwrap_or_else(|| Path::new(".")), path.file_name().unwrap_or_else(|| { panic!("The path '{}' to point to a file.", self.path.display()) }), ); let layer = tracing_subscriber::fmt::layer() .with_ansi(false) .with_writer(writer) .with_filter(LevelFilter::INFO) .with_filter(filter_fn(|m| { if let Some(path) = m.module_path() { // The wgpu crate is logging too much, so we skip `info` level. if path.starts_with("wgpu") && *m.level() >= Level::INFO { return false; } } true })); if registry().with(layer).try_init().is_err() { return Err("Failed to install the file logger.".to_string()); } let hook = std::panic::take_hook(); let file_path = self.path.to_owned(); std::panic::set_hook(Box::new(move |info| { log::error!("PANIC => {info}"); eprintln!( "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \ '{}'\n=============", file_path.display() ); hook(info); })); Ok(()) } } ================================================ FILE: crates/burn-train/src/learner/base.rs ================================================ use crate::LearningComponentsMarker; use crate::checkpoint::{ AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy, }; use crate::components::{LearningComponentsTypes, TrainingBackend}; use crate::metric::store::EventStoreClient; use crate::{ CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput, TrainingModelOutput, }; use burn_core::module::{AutodiffModule, Module}; use burn_core::prelude::Backend; use burn_core::tensor::Device; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::lr_scheduler::LrScheduler; use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; /// The record of the learner's model. pub type LearnerModelRecord = <::TrainingModel as Module>>::Record; /// The record of the optimizer. pub type LearnerOptimizerRecord = <::Optimizer as Optimizer< ::TrainingModel, TrainingBackend, >>::Record; /// The record of the LR scheduler. pub type LearnerSchedulerRecord = <::LrScheduler as LrScheduler>::Record>; /// Learner struct encapsulating all components necessary to train a Neural Network model. pub struct Learner { pub(crate) model: LC::TrainingModel, optim: LC::Optimizer, lr_scheduler: LC::LrScheduler, lr: f64, } impl Clone for Learner { fn clone(&self) -> Self { Self { model: self.model.clone(), optim: self.optim.clone(), lr_scheduler: self.lr_scheduler.clone(), lr: self.lr, } } } impl Learner> where B: AutodiffBackend, LR: LrScheduler + 'static, M: TrainStep + AutodiffModule + core::fmt::Display + 'static, M::InnerModule: InferenceStep, O: Optimizer + 'static, { /// Create a learner. pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self { Self { model, optim, lr_scheduler, lr: 0.0, } } } impl Learner { /// Fork the learner's model to the given device. pub fn fork(&mut self, device: & as Backend>::Device) { self.model = self.model().fork(device); } /// Returns the current model. pub fn model(&self) -> LC::TrainingModel { self.model.clone() } /// Returns the current learning rate. pub fn lr_current(&self) -> f64 { self.lr } /// Executes a step of the learning rate scheduler. pub fn lr_step(&mut self) { self.lr = self.lr_scheduler.step(); } /// Runs a step of the model for training, which executes the forward and backward passes. /// /// # Arguments /// /// * `item` - The input for the model. /// /// # Returns /// /// The output containing the model output and the gradients. pub fn train_step(&self, item: TrainingModelInput) -> TrainOutput> { self.model.step(item) } /// Optimize the current module with the provided gradients and learning rate. /// /// # Arguments /// /// * `optim`: Optimizer used for learning. /// * `lr`: The learning rate used for this step. /// * `grads`: The gradients of each parameter in the current model. pub fn optimizer_step(&mut self, grads: GradientsParams) { self.model = self.model().optimize(&mut self.optim, self.lr, grads); } /// Optimize the current module with the provided gradients and learning rate. /// /// # Arguments /// /// * `optim`: Optimizer used for learning. /// * `lr`: The learning rate used for this step. /// * `grads`: Multiple gradients associated to each parameter in the current model. pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) { self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads); } /// Load the module state from a [record](LearnerModelRecord). pub fn load_model(&mut self, record: LearnerModelRecord) { self.model = self.model.clone().load_record(record); } /// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord). pub fn load_optim(&mut self, record: LearnerOptimizerRecord) { self.optim = self.optim.clone().load_record(record); } /// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord). pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord) { self.lr_scheduler = self.lr_scheduler.clone().load_record(record); } } #[derive(new)] /// Used to create, delete, or load checkpoints of the training process. pub struct LearningCheckpointer { model: AsyncCheckpointer, LC::Backend>, optim: AsyncCheckpointer, LC::Backend>, lr_scheduler: AsyncCheckpointer, LC::Backend>, strategy: Box, } impl LearningCheckpointer { /// Create checkpoint for the training process. pub fn checkpoint(&mut self, learner: &Learner, epoch: usize, store: &EventStoreClient) { let actions = self.strategy.checkpointing(epoch, store); for action in actions { match action { CheckpointingAction::Delete(epoch) => { self.model .delete(epoch) .expect("Can delete model checkpoint."); self.optim .delete(epoch) .expect("Can delete optimizer checkpoint."); self.lr_scheduler .delete(epoch) .expect("Can delete learning rate scheduler checkpoint."); } CheckpointingAction::Save => { self.model .save(epoch, learner.model.clone().into_record()) .expect("Can save model checkpoint."); self.optim .save(epoch, learner.optim.to_record()) .expect("Can save optimizer checkpoint."); self.lr_scheduler .save(epoch, learner.lr_scheduler.to_record()) .expect("Can save learning rate scheduler checkpoint."); } } } } /// Load a training checkpoint. pub fn load_checkpoint( &self, mut learner: Learner, device: &Device, epoch: usize, ) -> Learner { let record = self .model .restore(epoch, device) .expect("Can load model checkpoint."); learner.load_model(record); let record = self .optim .restore(epoch, device) .expect("Can load optimizer checkpoint."); learner.load_optim(record); let record = self .lr_scheduler .restore(epoch, device) .expect("Can load learning rate scheduler checkpoint."); learner.load_scheduler(record); learner } } /// Cloneable reference to an early stopping strategy pub(crate) type EarlyStoppingStrategyRef = Box; #[derive(Clone, Default)] /// A handle that allows aborting the training/evaluation process early. pub struct Interrupter { state: Arc, message: Arc>>, } impl Interrupter { /// Create a new instance. pub fn new() -> Self { Self::default() } /// Notify the learner that it should stop. /// # Arguments /// * `reason` - A string describing the reason the training was stopped. pub fn stop(&self, reason: Option<&str>) { self.state.store(true, Ordering::Relaxed); reason.inspect(|r| { let mut message = self.message.lock().unwrap(); *message = Some(String::from(*r)); }); } /// Reset the interrupter. pub fn reset(&self) { self.state.store(false, Ordering::Relaxed); } /// True if .stop() has been called. pub fn should_stop(&self) -> bool { self.state.load(Ordering::Relaxed) } /// Get the message associated with the interrupt. pub fn get_message(&self) -> Option { let message = self.message.lock().unwrap(); message.clone() } } ================================================ FILE: crates/burn-train/src/learner/classification.rs ================================================ use crate::metric::{ AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput, TopKAccuracyInput, processor::ItemLazy, }; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor, Transaction}; use burn_ndarray::NdArray; /// Simple classification output adapted for multiple metrics. /// /// Supported metrics: /// - Accuracy /// - AUROC /// - TopKAccuracy /// - Perplexity /// - Precision (via ConfusionStatsInput) /// - Recall (via ConfusionStatsInput) /// - FBetaScore (via ConfusionStatsInput) /// - Loss. #[derive(new)] pub struct ClassificationOutput { /// The loss. pub loss: Tensor, /// The class logits or probabilities. Shape: \[batch_size, num_classes\]. pub output: Tensor, /// The ground truth class index for each sample. Shape: \[batch_size\]. pub targets: Tensor, } impl ItemLazy for ClassificationOutput { type ItemSync = ClassificationOutput; fn sync(self) -> Self::ItemSync { let [output, loss, targets] = Transaction::default() .register(self.output) .register(self.loss) .register(self.targets) .execute() .try_into() .expect("Correct amount of tensor data"); let device = &Default::default(); ClassificationOutput { output: Tensor::from_data(output, device), loss: Tensor::from_data(loss, device), targets: Tensor::from_data(targets, device), } } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> AccuracyInput { AccuracyInput::new(self.output.clone(), self.targets.clone()) } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> AurocInput { AurocInput::new(self.output.clone(), self.targets.clone()) } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> LossInput { LossInput::new(self.loss.clone()) } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> TopKAccuracyInput { TopKAccuracyInput::new(self.output.clone(), self.targets.clone()) } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> PerplexityInput { PerplexityInput::new(self.output.clone(), self.targets.clone()) } } impl Adaptor> for ClassificationOutput { fn adapt(&self) -> ConfusionStatsInput { let [_, num_classes] = self.output.dims(); if num_classes > 1 { ConfusionStatsInput::new( self.output.clone(), self.targets.clone().one_hot(num_classes).bool(), ) } else { ConfusionStatsInput::new( self.output.clone(), self.targets.clone().unsqueeze_dim(1).bool(), ) } } } /// Multi-label classification output adapted for multiple metrics. /// /// Supported metrics: /// - HammingScore /// - Precision (via ConfusionStatsInput) /// - Recall (via ConfusionStatsInput) /// - FBetaScore (via ConfusionStatsInput) /// - Loss #[derive(new)] pub struct MultiLabelClassificationOutput { /// The loss. pub loss: Tensor, /// The label logits or probabilities. Shape: \[batch_size, num_classes\]. pub output: Tensor, /// The ground truth labels. Shape: \[batch_size, num_classes\]. pub targets: Tensor, } impl ItemLazy for MultiLabelClassificationOutput { type ItemSync = MultiLabelClassificationOutput; fn sync(self) -> Self::ItemSync { let [output, loss, targets] = Transaction::default() .register(self.output) .register(self.loss) .register(self.targets) .execute() .try_into() .expect("Correct amount of tensor data"); let device = &Default::default(); MultiLabelClassificationOutput { output: Tensor::from_data(output, device), loss: Tensor::from_data(loss, device), targets: Tensor::from_data(targets, device), } } } impl Adaptor> for MultiLabelClassificationOutput { fn adapt(&self) -> HammingScoreInput { HammingScoreInput::new(self.output.clone(), self.targets.clone()) } } impl Adaptor> for MultiLabelClassificationOutput { fn adapt(&self) -> LossInput { LossInput::new(self.loss.clone()) } } impl Adaptor> for MultiLabelClassificationOutput { fn adapt(&self) -> ConfusionStatsInput { ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool()) } } ================================================ FILE: crates/burn-train/src/learner/early_stopping.rs ================================================ use crate::metric::{ Metric, MetricName, store::{Aggregate, Direction, EventStoreClient, Split}, }; /// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow. #[derive(Clone)] pub enum StoppingCondition { /// When no improvement has happened since the given number of epochs. NoImprovementSince { /// The number of epochs allowed to worsen before it gets better. n_epochs: usize, }, } /// A strategy that checks if the training should be stopped. pub trait EarlyStoppingStrategy: Send { /// Update its current state and returns if the training should be stopped. fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; } /// A helper trait to provide type-erased cloning. pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send { /// Clone into a boxed trait object. fn clone_box(&self) -> Box; } /// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that /// already implements your strategy + `Clone` + `Send` + `'static`. impl CloneEarlyStoppingStrategy for T where T: EarlyStoppingStrategy + Clone + Send + 'static, { fn clone_box(&self) -> Box { Box::new(self.clone()) } } /// Now you can `impl Clone` for the boxed trait object. impl Clone for Box { fn clone(&self) -> Box { self.clone_box() } } /// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected /// during training or validation. #[derive(Clone)] pub struct MetricEarlyStoppingStrategy { condition: StoppingCondition, metric_name: MetricName, aggregate: Aggregate, direction: Direction, split: Split, best_epoch: usize, best_value: f64, warmup_epochs: Option, } impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy { fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { let current_value = match store.find_metric(&self.metric_name, epoch, self.aggregate, &self.split) { Some(value) => value, None => { log::warn!("Can't find metric for early stopping."); return false; } }; let is_best = match self.direction { Direction::Lowest => current_value < self.best_value, Direction::Highest => current_value > self.best_value, }; if is_best { log::info!( "New best epoch found {} {}: {}", epoch, self.metric_name, current_value ); self.best_value = current_value; self.best_epoch = epoch; return false; } if let Some(warmup_epochs) = self.warmup_epochs && epoch <= warmup_epochs { return false; } match self.condition { StoppingCondition::NoImprovementSince { n_epochs } => { let should_stop = epoch - self.best_epoch >= n_epochs; if should_stop { log::info!( "Stopping training loop, no improvement since epoch {}, {}: {}, current \ epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value ); } should_stop } } } } impl MetricEarlyStoppingStrategy { /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected /// during training or validation. /// /// # Notes /// /// The metric should be registered for early stopping to work, otherwise no data is collected. pub fn new( metric: &Me, aggregate: Aggregate, direction: Direction, split: Split, condition: StoppingCondition, ) -> Self { let init_value = match direction { Direction::Lowest => f64::MAX, Direction::Highest => f64::MIN, }; Self { metric_name: metric.name(), condition, aggregate, direction, split, best_epoch: 1, best_value: init_value, warmup_epochs: None, } } /// Get the warmup period. /// /// Early stopping will not trigger during the warmup epochs. pub fn warmup_epochs(&self) -> Option { self.warmup_epochs } /// Set the warmup epochs. /// /// Early stopping will not trigger during the warmup epochs. /// /// # Arguments /// - `warmup`: the number of warmup epochs, or None. pub fn with_warmup_epochs(self, warmup: Option) -> Self { Self { warmup_epochs: warmup, ..self } } } #[cfg(test)] mod tests { use std::sync::Arc; use crate::{ EventProcessorTraining, TestBackend, logger::InMemoryMetricLogger, metric::{ LossMetric, processor::{ MetricsTraining, MinimalEventProcessor, test_utils::{end_epoch, process_train}, }, store::LogEventStore, }, }; use super::*; #[test] fn never_early_stop_while_it_is_improving() { test_early_stopping( None, 1, &[ (&[0.5, 0.3], false, "Should not stop first epoch"), (&[0.4, 0.3], false, "Should not stop when improving"), (&[0.3, 0.3], false, "Should not stop when improving"), (&[0.2, 0.3], false, "Should not stop when improving"), ], ); } #[test] fn early_stop_when_no_improvement_since_two_epochs() { test_early_stopping( None, 2, &[ (&[1.0, 0.5], false, "Should not stop first epoch"), (&[0.5, 0.3], false, "Should not stop when improving"), ( &[1.0, 3.0], false, "Should not stop first time it gets worse", ), ( &[1.0, 2.0], true, "Should stop since two following epochs didn't improve", ), ], ); } #[test] fn early_stopping_with_warmup() { test_early_stopping( Some(3), 2, &[ (&[1.0, 0.5], false, "Should not stop during warmup"), (&[1.0, 0.5], false, "Should not stop during warmup"), (&[1.0, 0.5], false, "Should not stop during warmup"), ( &[1.0, 0.5], true, "Should stop when not improving after warmup", ), ], ) } #[test] fn early_stop_when_stays_equal() { test_early_stopping( None, 2, &[ (&[0.5, 0.3], false, "Should not stop first epoch"), ( &[0.5, 0.3], false, "Should not stop first time it stars the same", ), ( &[0.5, 0.3], true, "Should stop since two following epochs didn't improve", ), ], ); } fn test_early_stopping(warmup: Option, n_epochs: usize, data: &[(&[f64], bool, &str)]) { let loss = LossMetric::::new(); let mut early_stopping = MetricEarlyStoppingStrategy::new( &loss, Aggregate::Mean, Direction::Lowest, Split::Train, StoppingCondition::NoImprovementSince { n_epochs }, ) .with_warmup_epochs(warmup); let mut store = LogEventStore::default(); let mut metrics = MetricsTraining::::default(); store.register_logger(InMemoryMetricLogger::default()); metrics.register_train_metric_numeric(loss); let store = Arc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); let mut epoch = 1; processor.process_train(crate::LearnerEvent::Start); for (points, should_start, comment) in data { for point in points.iter() { process_train(&mut processor, *point, epoch); } end_epoch(&mut processor, epoch); assert_eq!( *should_start, early_stopping.should_stop(epoch, &store), "{comment}" ); epoch += 1; } } } ================================================ FILE: crates/burn-train/src/learner/mod.rs ================================================ #[cfg(feature = "rl")] mod rl; #[cfg(feature = "rl")] pub use rl::*; mod application_logger; mod base; mod classification; mod early_stopping; mod regression; mod sequence; mod summary; mod supervised; mod train_val; pub use application_logger::*; pub use base::*; pub use classification::*; pub use early_stopping::*; pub use regression::*; pub use sequence::*; pub use summary::*; pub use supervised::*; pub use train_val::*; ================================================ FILE: crates/burn-train/src/learner/regression.rs ================================================ use crate::metric::processor::ItemLazy; use crate::metric::{Adaptor, LossInput}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Tensor, Transaction}; use burn_ndarray::NdArray; /// Regression output adapted for the loss metric. #[derive(new)] pub struct RegressionOutput { /// The loss. pub loss: Tensor, /// The predicted values. Shape: \[batch_size, num_targets\]. pub output: Tensor, /// The ground truth values. Shape: \[batch_size, num_targets\]. pub targets: Tensor, } impl Adaptor> for RegressionOutput { fn adapt(&self) -> LossInput { LossInput::new(self.loss.clone()) } } impl ItemLazy for RegressionOutput { type ItemSync = RegressionOutput; fn sync(self) -> Self::ItemSync { let [output, loss, targets] = Transaction::default() .register(self.output) .register(self.loss) .register(self.targets) .execute() .try_into() .expect("Correct amount of tensor data"); let device = &Default::default(); RegressionOutput { output: Tensor::from_data(output, device), loss: Tensor::from_data(loss, device), targets: Tensor::from_data(targets, device), } } } ================================================ FILE: crates/burn-train/src/learner/rl/checkpointer.rs ================================================ use burn_core::tensor::Device; use burn_rl::{Policy, PolicyLearner, PolicyState}; use crate::RLAgentRecord; use crate::{ RLComponentsTypes, RLPolicyRecord, checkpoint::Checkpointer, checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy}, metric::store::EventStoreClient, }; #[derive(new)] /// Used to create, delete, or load checkpoints of the training process. pub struct RLCheckpointer { policy: AsyncCheckpointer, RLC::Backend>, learning_agent: AsyncCheckpointer, RLC::Backend>, strategy: Box, } impl RLCheckpointer { /// Create checkpoint for the training process. pub fn checkpoint( &mut self, policy: &RLC::PolicyState, learning_agent: &RLC::LearningAgent, epoch: usize, store: &EventStoreClient, ) { let actions = self.strategy.checkpointing(epoch, store); for action in actions { match action { CheckpointingAction::Delete(epoch) => { self.policy .delete(epoch) .expect("Can delete policy checkpoint."); self.learning_agent .delete(epoch) .expect("Can delete learning agent checkpoint.") } CheckpointingAction::Save => { self.policy .save(epoch, policy.clone().into_record()) .expect("Can save policy checkpoint."); self.learning_agent .save(epoch, learning_agent.record()) .expect("Can save learning agent checkpoint."); } } } } /// Load a training checkpoint. pub fn load_checkpoint( &self, learning_agent: RLC::LearningAgent, device: &Device, epoch: usize, ) -> RLC::LearningAgent { let record = self .policy .restore(epoch, device) .expect("Can load model checkpoint."); let policy = learning_agent.policy().load_record(record); let record = self .learning_agent .restore(epoch, device) .expect("Can load learning agent checkpoint."); let mut learning_agent = learning_agent.load_record(record); learning_agent.update_policy(policy); learning_agent } } ================================================ FILE: crates/burn-train/src/learner/rl/components.rs ================================================ use std::marker::PhantomData; use burn_core::tensor::backend::AutodiffBackend; use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState}; use crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent}; /// All components used by the reinforcement learning paradigm, grouped in one trait. pub trait RLComponentsTypes { /// The backend used for training. type Backend: AutodiffBackend; /// The learning environment. type Env: Environment + 'static; /// Specifies how to initialize the environment. type EnvInit: EnvironmentInit + Send + 'static; /// The type of the environment state. type State: Into<>::Observation> + Clone + Send + 'static; /// The type of the environment action. type Action: From<>::Action> + Into<>::Action> + Clone + Send + 'static; /// The policy used to take actions in the environment. type Policy: Policy< Self::Backend, Observation = Self::PolicyObs, ActionDistribution = Self::PolicyAD, Action = Self::PolicyAction, ActionContext = Self::ActionContext, PolicyState = Self::PolicyState, > + Send + 'static; /// The policy's observation type. type PolicyObs: Clone + Send + Batchable + 'static; /// The policy's action distribution type. type PolicyAD: Clone + Send + Batchable; /// The policy's action type. type PolicyAction: Clone + Send + Batchable; /// Additional data as context for an agent's action. type ActionContext: ItemLazy + Clone + Send + 'static; /// The state of the parameterized policy. type PolicyState: Clone + Send + PolicyState + 'static; /// The learning agent. type LearningAgent: PolicyLearner< Self::Backend, TrainContext = Self::TrainingOutput, InnerPolicy = Self::Policy, > + Send + 'static; /// The output data of a training step. type TrainingOutput: ItemLazy + Clone + Send; } /// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait. pub struct RLComponentsMarker { _backend: PhantomData, _env: PhantomData, _env_init: PhantomData, _agent: PhantomData, } impl RLComponentsTypes for RLComponentsMarker where B: AutodiffBackend, E: Environment + 'static, EI: EnvironmentInit + Send + 'static, A: PolicyLearner + Send + 'static, A::TrainContext: ItemLazy + Clone + Send, A::InnerPolicy: Policy + Send, >::Observation: Batchable + Clone + Send, >::ActionDistribution: Batchable + Clone + Send, >::Action: Batchable + Clone + Send, >::ActionContext: ItemLazy + Clone + Send + 'static, >::PolicyState: Clone + Send, E::State: Into<>::Observation> + Clone + Send + 'static, E::Action: From<>::Action> + Into<>::Action> + Clone + Send + 'static, { type Backend = B; type Env = E; type EnvInit = EI; type LearningAgent = A; type Policy = A::InnerPolicy; type PolicyObs = >::Observation; type PolicyAD = >::ActionDistribution; type PolicyAction = >::Action; type ActionContext = >::ActionContext; type PolicyState = >::PolicyState; type TrainingOutput = A::TrainContext; type State = E::State; type Action = E::Action; } pub(crate) type RlPolicy = <::LearningAgent as PolicyLearner< ::Backend, >>::InnerPolicy; /// The event processor type for reinforcement learning. pub type RLEventProcessorType = AsyncProcessorTraining< RLEvent<::TrainingOutput, ::ActionContext>, AgentEvaluationEvent<::ActionContext>, >; /// The record of the policy. pub type RLPolicyRecord = <<::Policy as Policy< ::Backend, >>::PolicyState as PolicyState<::Backend>>::Record; /// The record of the learning agent. pub type RLAgentRecord = <::LearningAgent as PolicyLearner< ::Backend, >>::Record; ================================================ FILE: crates/burn-train/src/learner/rl/env_runner/async_runner.rs ================================================ use rand::prelude::SliceRandom; use std::{ sync::mpsc::{Receiver, Sender}, thread::spawn, }; use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device}; use burn_rl::EnvironmentInit; use burn_rl::Policy; use burn_rl::Transition; use burn_rl::{AsyncPolicy, Environment}; use crate::{ AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining, Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory, RlPolicy, TimeStep, Trajectory, }; enum RequestMessage { Step(), Episode(), } /// Configuration for an async agent/environment loop. pub struct AsyncAgentEnvLoopConfig { /// If the loop is used for evaluation (as opposed to training). pub eval: bool, /// If the agent should take action deterministically. pub deterministic: bool, /// An arbitrary ID for the loop. pub id: usize, } /// An asynchronous agent/environement interface. pub struct AgentEnvAsyncLoop { eval: bool, agent: AsyncPolicy>, transition_receiver: Receiver>, trajectory_receiver: Receiver>, request_sender: Sender, } impl AgentEnvAsyncLoop { /// Create a new asynchronous runner. /// /// # Arguments /// /// * `env_init` - A function returning an environment instance. /// * `agent` - An [AsyncPolicy](AsyncPolicy) taking actions in the loop. /// * `config` - An [AsyncAgentEnvLoopConfig](AsyncAgentEnvLoopConfig). /// * `transition_sender` - Optional sender for transitions if you want to drive the requests from outside of the loop instance. /// * `trajectory_sender` - Optional sender for trajectories if you want to drive the requests from outside of the loop instance. /// /// # Returns /// /// An async Agent/Environement loop. pub fn new( env_init: RLC::EnvInit, agent: AsyncPolicy>, config: AsyncAgentEnvLoopConfig, transition_device: &Device, transition_sender: Option>>, trajectory_sender: Option>>, ) -> Self { let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel(); let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel(); let (request_sender, request_receiver) = std::sync::mpsc::channel(); let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender); let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender); let device = transition_device.clone(); let mut loop_agent = agent.clone(); let eval = config.eval; let mut current_steps = vec![]; let mut current_reward = 0.0; let mut step_num = 0; spawn(move || { let mut env = env_init.init(); env.reset(); let mut request_episode = false; loop { let state = env.state(); let (action, context) = loop_agent.action(state.clone().into(), config.deterministic); let env_action = RLC::Action::from(action); let step_result = env.step(env_action.clone()); current_reward += step_result.reward; step_num += 1; let transition = Transition::new( state.clone(), step_result.next_state, env_action, Tensor::from_data([step_result.reward], &device), Tensor::from_data( [(step_result.done || step_result.truncated) as i32 as f64], &device, ), ); if !request_episode { loop_agent.decrement_agents(1); let request = match request_receiver.recv() { Ok(req) => req, Err(err) => { log::error!("Error in env runner : {}", err); break; } }; loop_agent.increment_agents(1); match request { RequestMessage::Step() => (), RequestMessage::Episode() => request_episode = true, } } let time_step = TimeStep { env_id: config.id, transition, done: step_result.done, ep_len: step_num, cum_reward: current_reward, action_context: context[0].clone(), }; current_steps.push(time_step.clone()); if !request_episode && let Err(err) = loop_transition_sender.send(time_step) { log::error!("Error in env runner : {}", err); break; } if step_result.done || step_result.truncated { if request_episode { request_episode = false; loop_trajectory_sender .send(Trajectory { timesteps: current_steps.clone(), }) .expect("Can send trajectory to main thread."); } current_steps.clear(); env.reset(); current_reward = 0.; step_num = 0; } } }); Self { eval, agent, transition_receiver, trajectory_receiver, request_sender, } } } impl AgentEnvLoop for AgentEnvAsyncLoop where BT: Backend, RLC: RLComponentsTypes, { fn run_steps( &mut self, num_steps: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec> { let mut items = vec![]; for _ in 0..num_steps { self.request_sender .send(RequestMessage::Step()) .expect("Can request transitions."); let transition = self .transition_receiver .recv() .expect("Can receive transitions."); items.push(transition.clone()); if !self.eval { progress.items_processed += 1; processor.process_train(RLEvent::TimeStep(EvaluationItem::new( transition.action_context, progress.clone(), None, ))); if transition.done { processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( EpisodeSummary { episode_length: transition.ep_len, cum_reward: transition.cum_reward, }, progress.clone(), None, ))); } } if interrupter.should_stop() { break; } } items } fn run_episodes( &mut self, num_episodes: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, _progress: &mut Progress, ) -> Vec> { let mut items = vec![]; self.agent.increment_agents(1); for episode_num in 0..num_episodes { self.request_sender .send(RequestMessage::Episode()) .expect("Can request episodes."); let trajectory = self .trajectory_receiver .recv() .expect("Main thread can receive trajectory."); for (i, step) in trajectory.timesteps.iter().enumerate() { // TODO : clean this. if self.eval { processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( step.action_context.clone(), Progress::new(i, i), None, ))); if step.done { processor.process_valid(AgentEvaluationEvent::EpisodeEnd( EvaluationItem::new( EpisodeSummary { episode_length: step.ep_len, cum_reward: step.cum_reward, }, Progress::new(episode_num + 1, num_episodes), None, ), )); } } else { processor.process_train(RLEvent::TimeStep(EvaluationItem::new( step.action_context.clone(), Progress::new(i, i), None, ))); if step.done { processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( EpisodeSummary { episode_length: step.ep_len, cum_reward: step.cum_reward, }, Progress::new(episode_num + 1, num_episodes), None, ))); } } } items.push(trajectory); if interrupter.should_stop() { break; } } self.agent.decrement_agents(1); items } fn update_policy(&mut self, update: RLC::PolicyState) { self.agent.update(update); } fn policy(&self) -> RLC::PolicyState { self.agent.state() } } /// An asynchronous runner for multiple agent/environement interfaces. pub struct MultiAgentEnvLoop { num_envs: usize, eval: bool, agent: AsyncPolicy, transition_receiver: Receiver>, trajectory_receiver: Receiver>, request_senders: Vec>, } impl MultiAgentEnvLoop { /// Create a new asynchronous runner for multiple agent/environement interfaces. pub fn new( num_envs: usize, env_init: RLC::EnvInit, agent: AsyncPolicy, eval: bool, deterministic: bool, device: &Device, ) -> Self { let (transition_sender, transition_receiver) = std::sync::mpsc::channel(); let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel(); let mut request_senders = vec![]; // Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps. agent.increment_agents(num_envs); for i in 0..num_envs { let config = AsyncAgentEnvLoopConfig { eval, deterministic, id: i, }; let runner = AgentEnvAsyncLoop::::new( env_init.clone(), agent.clone(), config, &device.clone(), Some(transition_sender.clone()), Some(trajectory_sender.clone()), ); request_senders.push(runner.request_sender.clone()); } // Double batching : The environments are always one step ahead. request_senders.iter().for_each(|s| { s.send(RequestMessage::Step()) .expect("Main thread can send step requests.") }); Self { num_envs, eval, agent: agent.clone(), transition_receiver, trajectory_receiver, request_senders, } } } impl AgentEnvLoop for MultiAgentEnvLoop where BT: Backend, RLC: RLComponentsTypes, { fn run_steps( &mut self, num_steps: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec> { let mut items = vec![]; for _ in 0..num_steps { let transition = self .transition_receiver .recv() .expect("Can receive transitions."); items.push(transition.clone()); self.request_senders[transition.env_id] .send(RequestMessage::Step()) .expect("Main thread can request steps."); if !self.eval { progress.items_processed += 1; processor.process_train(RLEvent::TimeStep(EvaluationItem::new( transition.action_context, progress.clone(), None, ))); if transition.done { processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( EpisodeSummary { episode_length: transition.ep_len, cum_reward: transition.cum_reward, }, progress.clone(), None, ))); } } if interrupter.should_stop() { break; } } items } fn update_policy(&mut self, update: RLC::PolicyState) { self.agent.update(update); } fn run_episodes( &mut self, num_episodes: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, _progress: &mut Progress, ) -> Vec> { // Send `num_episodes` initial requests. let mut idx = vec![]; if num_episodes < self.num_envs { let mut rng = rand::rng(); let mut vec: Vec = (0..self.num_envs).collect(); vec.shuffle(&mut rng); idx = vec.into_iter().take(num_episodes).collect(); } else { idx = (0..self.num_envs).collect(); } let num_requests = self.num_envs.min(num_episodes); idx.into_iter().for_each(|i| { self.request_senders[i] .send(RequestMessage::Episode()) .expect("Main thread can request steps."); }); let mut items = vec![]; for episode_num in 0..num_episodes { let trajectory = self .trajectory_receiver .recv() .expect("Can receive trajectory."); items.push(trajectory.clone()); if items.len() + num_requests <= num_episodes { self.request_senders[trajectory.timesteps[0].env_id] .send(RequestMessage::Episode()) .expect("Main thread can request steps."); } for (i, step) in trajectory.timesteps.iter().enumerate() { if self.eval { processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( step.action_context.clone(), Progress::new(i, i), None, ))); if step.done { processor.process_valid(AgentEvaluationEvent::EpisodeEnd( EvaluationItem::new( EpisodeSummary { episode_length: step.ep_len, cum_reward: step.cum_reward, }, Progress::new(episode_num + 1, num_episodes), None, ), )); } } else { processor.process_train(RLEvent::TimeStep(EvaluationItem::new( step.action_context.clone(), Progress::new(i, i), None, ))); if step.done { processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( EpisodeSummary { episode_length: step.ep_len, cum_reward: step.cum_reward, }, Progress::new(episode_num + 1, num_episodes), None, ))); } } } if interrupter.should_stop() { break; } } items } fn policy(&self) -> RLC::PolicyState { self.agent.state() } } #[cfg(test)] #[allow(clippy::needless_range_loop)] mod tests { use burn_core::data::dataloader::Progress; use burn_rl::AsyncPolicy; use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig; use crate::learner::rl::env_runner::base::AgentEnvLoop; use crate::learner::tests::{MockPolicyState, MockProcessor}; use crate::{ AgentEnvAsyncLoop, TestBackend, learner::tests::{MockEnvInit, MockPolicy, MockRLComponents}, }; use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop}; fn setup_async_loop( state: usize, eval: bool, deterministic: bool, ) -> AgentEnvAsyncLoop { let env_init = MockEnvInit; let agent = MockPolicy(state); let config = AsyncAgentEnvLoopConfig { eval, deterministic, id: 0, }; AgentEnvAsyncLoop::::new( env_init, AsyncPolicy::new(1, agent), config, &Default::default(), None, None, ) } fn setup_multi_loop( num_envs: usize, autobatch_size: usize, state: usize, eval: bool, deterministic: bool, ) -> MultiAgentEnvLoop { let env_init = MockEnvInit; let agent = MockPolicy(state); MultiAgentEnvLoop::::new( num_envs, env_init, AsyncPolicy::new(autobatch_size, agent), eval, deterministic, &Default::default(), ) } #[test] fn test_policy_async_loop() { let runner = setup_async_loop(1000, false, false); let policy_state = runner.policy(); assert_eq!(policy_state.0, 1000); } #[test] fn test_update_policy_async_loop() { let mut runner = setup_async_loop(0, false, false); runner.update_policy(MockPolicyState(1)); assert_eq!(runner.policy().0, 1); } #[test] fn run_steps_returns_requested_number_async_loop() { let mut runner = setup_async_loop(0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), 1); let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), 8); } #[test] fn run_episodes_returns_requested_number_async_loop() { let mut runner = setup_async_loop(0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), 1); assert_ne!(trajectories[0].timesteps.len(), 0); let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), 8); for i in 0..8 { assert_ne!(trajectories[i].timesteps.len(), 0); } } #[test] fn test_policy_multi_loop() { let runner = setup_multi_loop(4, 4, 1000, false, false); let policy_state = runner.policy(); assert_eq!(policy_state.0, 1000); } #[test] fn test_update_policy_multi_loop() { let mut runner = setup_multi_loop(4, 4, 0, false, false); runner.update_policy(MockPolicyState(1)); assert_eq!(runner.policy().0, 1); } #[test] fn run_steps_returns_requested_number_multi_loop() { fn run_test(num_envs: usize, autobatch_size: usize) { let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; // Kickstart tests by running some steps to make sure it's not a double batching edge case success. let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), 8); for i in 0..16 { let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), i); } } // num_envs == autobatch_size run_test(1, 1); run_test(4, 4); // num_envs < autobatch_size run_test(1, 2); run_test(1, 3); run_test(2, 3); run_test(2, 4); run_test(5, 19); // num_envs > autobatch_size run_test(2, 1); run_test(8, 1); run_test(3, 2); run_test(8, 2); run_test(8, 3); run_test(8, 7); } #[test] fn run_episodes_returns_requested_number_multi_loop() { fn run_test(num_envs: usize, autobatch_size: usize) { let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; // Kickstart tests by running some episodes to make sure it's not a double batching edge case success. let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), 8); for j in 0..8 { assert_ne!(trajectories[j].timesteps.len(), 0); } for i in 0..16 { let trajectories = runner.run_episodes(i, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), i); for j in 0..i { assert_ne!(trajectories[j].timesteps.len(), 0); } } } // num_envs == autobatch_size run_test(1, 1); run_test(4, 4); // num_envs < autobatch_size run_test(1, 2); run_test(1, 3); run_test(2, 3); run_test(2, 4); run_test(5, 19); // num_envs > autobatch_size run_test(2, 1); run_test(8, 1); run_test(3, 2); run_test(8, 2); run_test(8, 3); run_test(8, 7); } } ================================================ FILE: crates/burn-train/src/learner/rl/env_runner/base.rs ================================================ use std::marker::PhantomData; use burn_core::data::dataloader::Progress; use burn_core::{Tensor, prelude::Backend}; use burn_rl::Policy; use burn_rl::Transition; use burn_rl::{Environment, EnvironmentInit}; use crate::RLEvent; use crate::{ AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining, RLEventProcessorType, }; use crate::{Interrupter, RLComponentsTypes}; /// A trajectory, i.e. a list of ordered [TimeStep](TimeStep). #[derive(Clone, new)] pub struct Trajectory { /// A list of ordered [TimeStep](TimeStep)s. pub timesteps: Vec>, } /// A timestep debscribing an iteration of the state/decision process. #[derive(Clone)] pub struct TimeStep { /// The environment id. pub env_id: usize, /// The [burn_rl::Transition](burn_rl::Transition). pub transition: Transition, /// True if the environment reaches a terminal state. pub done: bool, /// The running length of the current episode. pub ep_len: usize, /// The running cumulative reward. pub cum_reward: f64, /// The action's context for this timestep. pub action_context: C, } pub(crate) type RLTimeStep = TimeStep< B, ::State, ::Action, ::ActionContext, >; pub(crate) type RLTrajectory = Trajectory< B, ::State, ::Action, ::ActionContext, >; /// Trait for a structure that implements an agent/environement interface. pub trait AgentEnvLoop { /// Run a certain number of timesteps. /// /// # Arguments /// /// * `num_steps` - The number of time_steps to run. /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining). /// * `interrupter` - An [crate::Interrupter](crate::Interrupter). /// * `num_steps` - The number of time_steps to run. /// * `progress` - A mutable reference to the learning progress. /// /// # Returns /// /// A list of ordered timesteps. fn run_steps( &mut self, num_steps: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec>; /// Run a certain number of episodes. /// /// # Arguments /// /// * `num_episodes` - The number of episodes to run. /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining). /// * `interrupter` - An [crate::Interrupter](crate::Interrupter). /// * `progress` - A mutable reference to the learning progress. /// /// # Returns /// /// A list of ordered timesteps. fn run_episodes( &mut self, num_episodes: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec>; /// Update the runner's agent. fn update_policy(&mut self, update: RLC::PolicyState); /// Get the state of the runner's agent. fn policy(&self) -> RLC::PolicyState; } /// A simple, synchronized agent/environement interface. pub struct AgentEnvBaseLoop { env: RLC::Env, eval: bool, agent: RLC::Policy, deterministic: bool, current_reward: f64, run_num: usize, step_num: usize, _backend: PhantomData, } impl AgentEnvBaseLoop { /// Create a new base runner. pub fn new( env_init: RLC::EnvInit, agent: RLC::Policy, eval: bool, deterministic: bool, ) -> Self { let mut env = env_init.init(); env.reset(); Self { env, eval, agent: agent.clone(), deterministic, current_reward: 0.0, run_num: 0, step_num: 0, _backend: PhantomData, } } } impl AgentEnvLoop for AgentEnvBaseLoop where BT: Backend, RLC: RLComponentsTypes, { fn run_steps( &mut self, num_steps: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec> { let mut items = vec![]; let device = Default::default(); for _ in 0..num_steps { let state = self.env.state(); let (action, context) = self.agent.action(state.clone().into(), self.deterministic); let step_result = self.env.step(RLC::Action::from(action.clone())); self.current_reward += step_result.reward; self.step_num += 1; let transition = Transition::new( state.clone(), step_result.next_state, RLC::Action::from(action), Tensor::from_data([step_result.reward], &device), Tensor::from_data( [(step_result.done || step_result.truncated) as i32 as f64], &device, ), ); items.push(TimeStep { env_id: 0, transition, done: step_result.done, ep_len: self.step_num, cum_reward: self.current_reward, action_context: context[0].clone(), }); if !self.eval { progress.items_processed += 1; processor.process_train(RLEvent::TimeStep(EvaluationItem::new( context[0].clone(), progress.clone(), None, ))); if step_result.done { processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new( EpisodeSummary { episode_length: self.step_num, cum_reward: self.current_reward, }, progress.clone(), None, ))); } } if interrupter.should_stop() { break; } if step_result.done || step_result.truncated { self.env.reset(); self.current_reward = 0.; self.step_num = 0; self.run_num += 1; } } items } fn update_policy(&mut self, update: RLC::PolicyState) { self.agent.update(update); } fn run_episodes( &mut self, num_episodes: usize, processor: &mut RLEventProcessorType, interrupter: &Interrupter, progress: &mut Progress, ) -> Vec> { self.env.reset(); let mut items = vec![]; for ep in 0..num_episodes { let mut steps = vec![]; loop { let step = self.run_steps(1, processor, interrupter, progress)[0].clone(); steps.push(step.clone()); if self.eval { processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new( step.action_context.clone(), Progress::new(steps.len() + 1, steps.len() + 1), None, ))); if step.done { processor.process_valid(AgentEvaluationEvent::EpisodeEnd( EvaluationItem::new( EpisodeSummary { episode_length: step.ep_len, cum_reward: step.cum_reward, }, Progress::new(ep + 1, num_episodes), None, ), )); } } if interrupter.should_stop() || step.done { break; } } items.push(Trajectory::new(steps)); if interrupter.should_stop() { break; } } items } fn policy(&self) -> RLC::PolicyState { self.agent.state() } } #[cfg(test)] #[allow(clippy::needless_range_loop)] mod tests { use crate::{AsyncProcessorTraining, TestBackend}; use crate::learner::tests::{ MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents, }; use super::*; fn setup( state: usize, eval: bool, deterministic: bool, ) -> AgentEnvBaseLoop { let env_init = MockEnvInit; let agent = MockPolicy(state); AgentEnvBaseLoop::::new(env_init, agent, eval, deterministic) } #[test] fn test_policy_returns_agent_state() { let runner = setup(1000, false, false); let policy_state = runner.policy(); assert_eq!(policy_state.0, 1000); } #[test] fn test_update_policy() { let mut runner = setup(0, false, false); runner.update_policy(MockPolicyState(1)); assert_eq!(runner.policy().0, 1); } #[test] fn run_steps_returns_requested_number() { let mut runner = setup(0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), 1); let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress); assert_eq!(steps.len(), 8); } #[test] fn run_episodes_returns_requested_number() { let mut runner = setup(0, false, false); let mut processor = AsyncProcessorTraining::new(MockProcessor); let interrupter = Interrupter::new(); let mut progress = Progress { items_processed: 0, items_total: 1, }; let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), 1); assert_ne!(trajectories[0].timesteps.len(), 0); let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress); assert_eq!(trajectories.len(), 8); for i in 0..8 { assert_ne!(trajectories[i].timesteps.len(), 0); } } } ================================================ FILE: crates/burn-train/src/learner/rl/env_runner/mod.rs ================================================ mod async_runner; mod base; pub use async_runner::*; pub use base::*; #[cfg(test)] pub(crate) mod tests { use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyState}; use crate::tests::TestAutodiffBackend; use crate::{ AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLComponentsTypes, RLEvent, }; use burn_rl::{LearnerTransitionBatch, PolicyLearner, RLTrainOutput, StepResult}; /// Mock policy for testing /// /// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution) /// containing a list of 0s of the same length as the observation. /// /// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockPolicyAction](MockPolicyAction) with a list of actions of the same length as the observation. /// The actions are all 1 if the call is requested as deterministic, or else 0. #[derive(Clone)] pub(crate) struct MockPolicy(pub usize); impl Policy for MockPolicy { type Observation = MockObservation; type ActionDistribution = MockActionDistribution; type Action = MockPolicyAction; type ActionContext = MockActionContext; type PolicyState = MockPolicyState; fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution { let mut dists = vec![]; for _ in obs.0 { dists.push(MockActionDistribution(vec![0.])); } MockActionDistribution::batch(dists) } fn action( &mut self, obs: Self::Observation, deterministic: bool, ) -> (Self::Action, Vec) { let mut actions = vec![]; let mut contexts = vec![]; for _ in obs.0 { if deterministic { actions.push(MockPolicyAction(vec![1])); } else { actions.push(MockPolicyAction(vec![0])); } contexts.push(MockActionContext); } (MockPolicyAction::batch(actions), contexts) } fn update(&mut self, update: Self::PolicyState) { self.0 = update.0; } fn state(&self) -> Self::PolicyState { MockPolicyState(self.0) } fn load_record( self, _record: >::Record, ) -> Self { self } } /// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockObservation(pub Vec); /// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockPolicyAction(pub Vec); /// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it. #[derive(Clone)] pub(crate) struct MockActionDistribution(Vec); #[derive(Clone)] pub(crate) struct MockActionContext; /// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy. #[derive(Clone)] pub(crate) struct MockPolicyState(pub usize); impl PolicyState for MockPolicyState { type Record = (); fn into_record(self) -> Self::Record {} fn load_record(&self, _record: Self::Record) -> Self { self.clone() } } impl Batchable for MockObservation { fn batch(items: Vec) -> Self { MockObservation(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { vec![MockObservation(self.0)] } } impl Batchable for MockPolicyAction { fn batch(items: Vec) -> Self { MockPolicyAction(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { let mut actions = vec![]; for a in self.0 { actions.push(MockPolicyAction(vec![a])); } actions } } impl Batchable for MockActionDistribution { fn batch(items: Vec) -> Self { MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect()) } fn unbatch(self) -> Vec { let mut dists = vec![]; for _ in self.0 { dists.push(MockActionDistribution(vec![0.])); } dists } } /// Mock environment for testing #[derive(Clone)] pub(crate) struct MockEnv { counter: usize, } #[derive(Clone, Debug)] pub(crate) struct MockState; #[derive(Clone, Debug)] pub(crate) struct MockAction(pub i32); impl From for MockObservation { fn from(_value: MockState) -> Self { MockObservation(vec![0.]) } } impl From for MockAction { fn from(value: MockPolicyAction) -> Self { MockAction(value.0[0]) } } impl From for MockPolicyAction { fn from(value: MockAction) -> Self { MockPolicyAction(vec![value.0]) } } impl ItemLazy for MockActionContext { type ItemSync = MockActionContext; fn sync(self) -> Self::ItemSync { self } } impl MockEnv { fn new() -> Self { Self { counter: 0 } } } impl Environment for MockEnv { type State = MockState; type Action = MockAction; const MAX_STEPS: usize = 5; fn reset(&mut self) { self.counter = 0; } fn step(&mut self, _action: Self::Action) -> StepResult { self.counter += 1; let done = self.counter >= Self::MAX_STEPS; burn_rl::StepResult { next_state: MockState, reward: 1.0, done, truncated: false, } } fn state(&self) -> Self::State { MockState } } /// Mock environment init for testing #[derive(Clone)] pub(crate) struct MockEnvInit; impl EnvironmentInit for MockEnvInit { fn init(&self) -> MockEnv { MockEnv::new() } } // Mock RLComponentsTypes for testing pub(crate) struct MockRLComponents; impl RLComponentsTypes for MockRLComponents { type Backend = TestAutodiffBackend; type Env = MockEnv; type EnvInit = MockEnvInit; type State = MockState; type Action = MockAction; type Policy = MockPolicy; type PolicyObs = MockObservation; type PolicyAD = MockActionDistribution; type PolicyAction = MockPolicyAction; type ActionContext = MockActionContext; type PolicyState = MockPolicyState; type LearningAgent = MockLearningAgent; type TrainingOutput = (); } // Mock learning agent for testing #[derive(Clone)] pub(crate) struct MockLearningAgent; impl PolicyLearner for MockLearningAgent { type InnerPolicy = MockPolicy; type TrainContext = (); type Record = (); fn train( &mut self, _input: LearnerTransitionBatch, ) -> RLTrainOutput< Self::TrainContext, >::PolicyState, > { unimplemented!() } fn policy(&self) -> Self::InnerPolicy { unimplemented!() } fn update_policy(&mut self, _update: Self::InnerPolicy) { unimplemented!() } fn record(&self) -> Self::Record { unimplemented!() } fn load_record(self, _record: Self::Record) -> Self { unimplemented!() } } // Mock event processor for testing pub(crate) struct MockProcessor; impl EventProcessorTraining< RLEvent<(), MockActionContext>, AgentEvaluationEvent, > for MockProcessor { fn process_train(&mut self, _event: RLEvent<(), MockActionContext>) { // Mock process train } fn process_valid(&mut self, _event: AgentEvaluationEvent) { // Mock process valid } fn renderer(self) -> Box { unimplemented!() } } } ================================================ FILE: crates/burn-train/src/learner/rl/mod.rs ================================================ mod checkpointer; mod components; mod env_runner; mod off_policy; mod output; mod paradigm; mod strategy; pub use checkpointer::*; pub use components::*; pub use env_runner::*; pub use off_policy::*; pub use output::*; pub use paradigm::*; pub use strategy::*; ================================================ FILE: crates/burn-train/src/learner/rl/off_policy.rs ================================================ use std::marker::PhantomData; use crate::{ AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem, EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent, RLEventProcessorType, RLStrategy, }; use burn_core::{self as burn}; use burn_core::{config::Config, data::dataloader::Progress}; use burn_ndarray::NdArray; use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer}; /// Parameters of an on policy training with multi environments and double-batching. #[derive(Config, Debug)] pub struct OffPolicyConfig { /// The number of environments to run simultaneously for experience collection. #[config(default = 1)] pub num_envs: usize, /// Number of environment state to accumulate before running one step of inference with the policy. /// Must be equal or less than the number of simultaneous environments. #[config(default = 1)] pub autobatch_size: usize, /// Max number of transitions stored in the replay buffer. #[config(default = 1024)] pub replay_buffer_size: usize, /// The number of steps to collect between each step of training. #[config(default = 1)] pub train_interval: usize, /// Number of optimization steps done each `train_interval`. #[config(default = 1)] pub train_steps: usize, /// The number of steps to collect between each evaluation. #[config(default = 10_000)] pub eval_interval: usize, /// The number of episodes to run for each evaluation. #[config(default = 1)] pub eval_episodes: usize, /// The number of transition to train on. #[config(default = 32)] pub train_batch_size: usize, /// Number of steps to collect before starting to train. #[config(default = 0)] pub warmup_steps: usize, } /// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching. pub struct OffPolicyStrategy { config: OffPolicyConfig, _components: PhantomData, } impl OffPolicyStrategy { /// Create a new off-policy base strategy. pub fn new(config: OffPolicyConfig) -> Self { Self { config, _components: PhantomData, } } } impl RLStrategy for OffPolicyStrategy where RLC: RLComponentsTypes, RLC::PolicyObs: SliceAccess, RLC::PolicyAction: SliceAccess, { fn train_loop( &self, training_components: RLComponents, learner_agent: &mut RLC::LearningAgent, starting_epoch: usize, env_init: RLC::EnvInit, ) -> (RLC::Policy, RLEventProcessorType) { let mut event_processor = training_components.event_processor; let mut checkpointer = training_components.checkpointer; let num_steps_total = training_components.num_steps; let mut env_runner = MultiAgentEnvLoop::::new( self.config.num_envs, env_init.clone(), AsyncPolicy::new( self.config.num_envs.min(self.config.autobatch_size), learner_agent.policy(), ), false, false, &Default::default(), ); let runner_config = AsyncAgentEnvLoopConfig { eval: true, deterministic: true, id: 0, }; let mut env_runner_valid = AgentEnvAsyncLoop::::new( env_init, AsyncPolicy::new(1, learner_agent.policy()), runner_config, &Default::default(), None, None, ); let device: ::Device = Default::default(); let mut transition_buffer = TransitionBuffer::< RLC::Backend, RLC::PolicyObs, RLC::PolicyAction, >::new(self.config.replay_buffer_size, &device); let mut valid_next = self.config.eval_interval + starting_epoch - 1; let mut progress = Progress { items_processed: starting_epoch, items_total: num_steps_total, }; let mut intermediary_update: Option<>::PolicyState> = None; while progress.items_processed < num_steps_total { if training_components.interrupter.should_stop() { let reason = training_components .interrupter .get_message() .unwrap_or(String::from("Reason unknown")); log::info!("Training interrupted: {reason}"); break; } let previous_steps = progress.items_processed; let items = env_runner.run_steps( self.config.train_interval, &mut event_processor, &training_components.interrupter, &mut progress, ); for item in &items { let t = &item.transition; let state: RLC::PolicyObs = t.state.clone().into(); let next_state: RLC::PolicyObs = t.next_state.clone().into(); let action: RLC::PolicyAction = t.action.clone().into(); let reward = t.reward.to_data().to_vec::().unwrap()[0]; let done = t.done.to_data().to_vec::().unwrap()[0] > 0.5; transition_buffer.push(state, next_state, action, reward, done); } if transition_buffer.len() >= self.config.train_batch_size && progress.items_processed >= self.config.warmup_steps { if let Some(ref u) = intermediary_update { env_runner.update_policy(u.clone()); } for _ in 0..self.config.train_steps { let batch = transition_buffer.sample(self.config.train_batch_size); let train_item = learner_agent.train(batch); intermediary_update = Some(train_item.policy); event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new( train_item.item, progress.clone(), None, ))); } } if valid_next > previous_steps && valid_next <= progress.items_processed { env_runner_valid.update_policy(learner_agent.policy().state()); env_runner_valid.run_episodes( self.config.eval_episodes, &mut event_processor, &training_components.interrupter, &mut progress, ); if let Some(checkpointer) = &mut checkpointer { checkpointer.checkpoint( &env_runner.policy(), learner_agent, valid_next, &training_components.event_store, ); } valid_next += self.config.eval_interval; } } (learner_agent.policy(), event_processor) } } ================================================ FILE: crates/burn-train/src/learner/rl/output.rs ================================================ use crate::{ ItemLazy, metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput}, }; /// Summary of an episode. pub struct EpisodeSummary { /// The total length of the episode. pub episode_length: usize, /// The final cumulative reward. pub cum_reward: f64, } impl ItemLazy for EpisodeSummary { type ItemSync = EpisodeSummary; fn sync(self) -> Self::ItemSync { self } } impl Adaptor for EpisodeSummary { fn adapt(&self) -> EpisodeLengthInput { EpisodeLengthInput::new(self.episode_length as f64) } } impl Adaptor for EpisodeSummary { fn adapt(&self) -> CumulativeRewardInput { CumulativeRewardInput::new(self.cum_reward) } } ================================================ FILE: crates/burn-train/src/learner/rl/paradigm.rs ================================================ use crate::checkpoint::{ AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, KeepLastNCheckpoints, MetricCheckpointingStrategy, }; use crate::learner::base::Interrupter; use crate::logger::{FileMetricLogger, MetricLogger}; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric}; use crate::renderer::{MetricsRenderer, default_renderer}; use crate::{ ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy, LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer, RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics, RLPolicyRecord, RLStrategy, }; use crate::{EpisodeSummary, RLStrategies}; use burn_core::record::FileRecorder; use burn_core::tensor::backend::AutodiffBackend; use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess}; use std::collections::BTreeSet; use std::path::{Path, PathBuf}; use std::sync::Arc; /// Structure to configure and launch reinforcement learning trainings. pub struct RLTraining { // Not that complex. Extracting into yet another type would only make it more confusing. #[allow(clippy::type_complexity)] checkpointers: Option<( AsyncCheckpointer, RLC::Backend>, AsyncCheckpointer, RLC::Backend>, )>, num_steps: usize, checkpoint: Option, directory: PathBuf, grad_accumulation: Option, renderer: Option>, metrics: RLMetrics, event_store: LogEventStore, interrupter: Interrupter, tracing_logger: Option>, checkpointer_strategy: Box, learning_strategy: RLStrategies, // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order summary_metrics: BTreeSet, summary: bool, env_initializer: RLC::EnvInit, } impl RLTraining> where B: AutodiffBackend, E: Environment + 'static, EI: EnvironmentInit + Send + 'static, A: PolicyLearner + Send + 'static, A::TrainContext: ItemLazy + Clone + Send, A::InnerPolicy: Policy + Send, >::Observation: Batchable + Clone + Send, >::ActionDistribution: Batchable + Clone + Send, >::Action: Batchable + Clone + Send, >::ActionContext: ItemLazy + Clone + Send + 'static, >::PolicyState: Clone + Send, E::State: Into<>::Observation> + Clone + Send + 'static, E::Action: From<>::Action> + Into<>::Action> + Clone + Send + 'static, { /// Creates a new runner for reinforcement learning. /// /// # Arguments /// /// * `directory` - The directory to save the checkpoints. /// * `env_init` - Specifies how to initialize the environment. pub fn new(directory: impl AsRef, env_initializer: EI) -> Self { let directory = directory.as_ref().to_path_buf(); let experiment_log_file = directory.join("experiment.log"); Self { num_steps: 1, checkpoint: None, checkpointers: None, directory, grad_accumulation: None, metrics: RLMetrics::default(), event_store: LogEventStore::default(), renderer: None, interrupter: Interrupter::new(), tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new( experiment_log_file, ))), checkpointer_strategy: Box::new( ComposedCheckpointingStrategy::builder() .add(KeepLastNCheckpoints::new(2)) .add(MetricCheckpointingStrategy::new( &EpisodeLengthMetric::new(), // default to evaluations' cumulative reward. Aggregate::Mean, Direction::Lowest, Split::Valid, )) .build(), ), learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()), summary_metrics: BTreeSet::new(), summary: false, env_initializer, } } } impl RLTraining { /// Replace the default learning strategy (Off Policy learning) with the provided one. /// /// # Arguments /// /// * `training_strategy` - The training strategy. pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies) -> Self { self.learning_strategy = learning_strategy; self } /// Replace the default metric loggers with the provided ones. /// /// # Arguments /// /// * `logger` - The training logger. pub fn with_metric_logger(mut self, logger: ML) -> Self where ML: MetricLogger + 'static, { self.event_store.register_logger(logger); self } /// Update the checkpointing_strategy. pub fn with_checkpointing_strategy( mut self, strategy: CS, ) -> Self { self.checkpointer_strategy = Box::new(strategy); self } /// Replace the default CLI renderer with a custom one. /// /// # Arguments /// /// * `renderer` - The custom renderer. pub fn renderer(mut self, renderer: MR) -> Self where MR: MetricsRenderer + 'static, { self.renderer = Some(Box::new(renderer)); self } /// Register numerical metrics for a training step of the agent. pub fn metrics_train>(self, metrics: Me) -> Self { metrics.register(self) } /// Register textual metrics for a training step of the agent. pub fn text_metrics_train>(self, metrics: Me) -> Self { metrics.register(self) } /// Register numerical metrics for each action of the agent. pub fn metrics_agent>(self, metrics: Me) -> Self { metrics.register(self) } /// Register textual metrics for each action of the agent. pub fn text_metrics_agent>(self, metrics: Me) -> Self { metrics.register(self) } /// Register numerical metrics for a completed episode. pub fn metrics_episode>(self, metrics: Me) -> Self { metrics.register(self) } /// Register textual metrics for a completed episode. pub fn text_metrics_episode>(self, metrics: Me) -> Self { metrics.register(self) } /// Register a textual metric for a training step. pub fn text_metric_train(mut self, metric: Me) -> Self where ::ItemSync: Adaptor, { self.metrics.register_text_metric_train(metric); self } /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step. pub fn metric_train(mut self, metric: Me) -> Self where Me: Metric + Numeric + 'static, ::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_metric_train(metric); self } /// Register a textual metric for each action taken by the agent. pub fn text_metric_agent(mut self, metric: Me) -> Self where ::ItemSync: Adaptor, { self.metrics.register_text_metric_agent(metric.clone()); self.metrics.register_text_metric_agent_valid(metric); self } /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent. pub fn metric_agent(mut self, metric: Me) -> Self where Me: Metric + Numeric + 'static, ::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_agent_metric(metric.clone()); self.metrics.register_agent_metric_valid(metric); self } /// Register a textual metric for a completed episode. pub fn text_metric_episode(mut self, metric: Me) -> Self where EpisodeSummary: Adaptor + 'static, { self.metrics.register_text_metric_episode(metric.clone()); self.metrics.register_text_metric_episode_valid(metric); self } /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode. pub fn metric_episode(mut self, metric: Me) -> Self where Me: Metric + Numeric + 'static, EpisodeSummary: Adaptor + 'static, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_episode_metric(metric.clone()); self.metrics.register_episode_metric_valid(metric); self } /// The number of environment steps to train for. pub fn num_steps(mut self, num_steps: usize) -> Self { self.num_steps = num_steps; self } /// The step from which the training must resume. pub fn checkpoint(mut self, checkpoint: usize) -> Self { self.checkpoint = Some(checkpoint); self } /// Provides a handle that can be used to interrupt training. pub fn interrupter(&self) -> Interrupter { self.interrupter.clone() } /// Override the handle for stopping training with an externally provided handle pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self { self.interrupter = interrupter; self } /// By default, Rust logs are captured and written into /// `experiment.log`. If disabled, standard Rust log handling /// will apply. pub fn with_application_logger( mut self, logger: Option>, ) -> Self { self.tracing_logger = logger; self } /// Register a checkpointer that will save the environment runner's [policy](Policy) /// and the [PolicyLearner](PolicyLearner) state to different files. pub fn with_file_checkpointer(mut self, recorder: FR) -> Self where FR: FileRecorder + 'static, FR: FileRecorder<::InnerBackend> + 'static, { let checkpoint_dir = self.directory.join("checkpoint"); let checkpointer_policy = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy"); let checkpointer_learning = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent"); self.checkpointers = Some(( AsyncCheckpointer::new(checkpointer_policy), AsyncCheckpointer::new(checkpointer_learning), )); self } /// Enable the training summary report. /// /// The summary will be displayed after `.launch()`, when the renderer is dropped. pub fn summary(mut self) -> Self { self.summary = true; self } /// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment. pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult where RLC::PolicyObs: SliceAccess, RLC::PolicyAction: SliceAccess, { if self.tracing_logger.is_some() && let Err(e) = self.tracing_logger.as_ref().unwrap().install() { log::warn!("Failed to install the experiment logger: {e}"); } let renderer = self .renderer .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint)); if !self.event_store.has_loggers() { self.event_store .register_logger(FileMetricLogger::new(self.directory.clone())); } let event_store = Arc::new(EventStoreClient::new(self.event_store)); let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new( self.metrics, renderer, event_store.clone(), )); let checkpointer = self.checkpointers.map(|(policy, learning_agent)| { RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy) }); let summary = if self.summary { Some(LearnerSummaryConfig { directory: self.directory, metrics: self.summary_metrics.into_iter().collect::>(), }) } else { None }; let components = RLComponents:: { checkpoint: self.checkpoint, checkpointer, interrupter: self.interrupter, event_processor, event_store, num_steps: self.num_steps, grad_accumulation: self.grad_accumulation, summary, }; match self.learning_strategy { RLStrategies::OffPolicyStrategy(config) => { let strategy = OffPolicyStrategy::new(config); strategy.train(learner_agent, components, self.env_initializer) } RLStrategies::Custom(strategy) => { strategy.train(learner_agent, components, self.env_initializer) } } } } /// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer). pub struct RLResult

{ /// The learned policy. pub policy: P, /// The renderer that can be used for follow up training and evaluation. pub renderer: Box, } /// Trait to fake variadic generics for train step metrics. pub trait AgentMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } /// Trait to fake variadic generics for train step text metrics. pub trait AgentTextMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } /// Trait to fake variadic generics for env step metrics. pub trait TrainMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } /// Trait to fake variadic generics for env step text metrics. pub trait TrainTextMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } /// Trait to fake variadic generics for episode metrics. pub trait EpisodeMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } /// Trait to fake variadic generics for episode text metrics. pub trait EpisodeTextMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: RLTraining) -> RLTraining; } macro_rules! gen_tuple { ($($M:ident),*) => { impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration for ($($M,)*) where $(::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.text_metric_train($M.clone());)* builder } } impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration for ($($M,)*) where $(::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + Numeric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.metric_train($M.clone());)* builder } } impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration for ($($M,)*) where $(::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.text_metric_agent($M.clone());)* builder } } impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration for ($($M,)*) where $(::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + Numeric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.metric_agent($M.clone());)* builder } } impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration for ($($M,)*) where $(EpisodeSummary: Adaptor<$M::Input> + 'static,)* $($M: Metric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.text_metric_episode($M.clone());)* builder } } impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration for ($($M,)*) where $(EpisodeSummary: Adaptor<$M::Input> + 'static,)* $($M: Metric + Numeric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: RLTraining, ) -> RLTraining { let ($($M,)*) = self; $(let builder = builder.metric_episode($M.clone());)* builder } } }; } gen_tuple!(M1); gen_tuple!(M1, M2); gen_tuple!(M1, M2, M3); gen_tuple!(M1, M2, M3, M4); gen_tuple!(M1, M2, M3, M4, M5); gen_tuple!(M1, M2, M3, M4, M5, M6); ================================================ FILE: crates/burn-train/src/learner/rl/strategy.rs ================================================ use std::sync::Arc; use crate::{ Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent, RLEventProcessorType, RLResult, metric::{processor::EventProcessorTraining, store::EventStoreClient}, }; /// Struct to minimise parameters passed to [RLStrategy::train]. pub struct RLComponents { /// The total number of environment steps. pub num_steps: usize, /// The step number from which to continue the training. pub checkpoint: Option, /// A checkpointer used to load and save learning checkpoints. pub checkpointer: Option>, /// Enables gradients accumulation. pub grad_accumulation: Option, /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early. pub interrupter: Interrupter, /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation. pub event_processor: RLEventProcessorType, /// A reference to an [EventStoreClient](EventStoreClient). pub event_store: Arc, /// Config for creating a summary of the learning pub summary: Option, } /// The strategy for reinforcement learning. #[derive(Clone)] pub enum RLStrategies { /// Training on one device OffPolicyStrategy(OffPolicyConfig), /// Training using a custom learning strategy Custom(CustomRLStrategy), } /// A reference to an implementation of [RLStrategy]. pub type CustomRLStrategy = Arc>; /// Provides the `fit` function for any learning strategy pub trait RLStrategy { /// Train the learner agent with this strategy. fn train( &self, mut learner_agent: RLC::LearningAgent, mut training_components: RLComponents, env_init: RLC::EnvInit, ) -> RLResult { let starting_epoch = match training_components.checkpoint { Some(checkpoint) => { if let Some(checkpointer) = &mut training_components.checkpointer { learner_agent = checkpointer.load_checkpoint( learner_agent, &Default::default(), checkpoint, ); } checkpoint + 1 } None => 1, }; let summary_config = training_components.summary.clone(); // Event processor start training training_components .event_processor .process_train(RLEvent::Start); // Training loop let (policy, mut event_processor) = self.train_loop( training_components, &mut learner_agent, starting_epoch, env_init, ); let summary = summary_config.and_then(|summary| summary.init().ok()); // Signal training end. For the TUI renderer, this handles the exit & return to main screen. // TODO: summary makes sense for RL? event_processor.process_train(RLEvent::End(summary)); // let model = model.valid(); let renderer = event_processor.renderer(); RLResult { policy, renderer } } /// Training loop for this strategy fn train_loop( &self, training_components: RLComponents, learner_agent: &mut RLC::LearningAgent, starting_epoch: usize, env_init: RLC::EnvInit, ) -> (RLC::Policy, RLEventProcessorType); } ================================================ FILE: crates/burn-train/src/learner/sequence.rs ================================================ use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput}; use crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor, Transaction}; use burn_ndarray::NdArray; /// Sequence prediction output adapted for multiple metrics. /// /// Supported metrics: /// - Accuracy /// - TopKAccuracy /// - Perplexity /// - Loss /// - CER /// - WER #[derive(new)] pub struct SequenceOutput { /// The loss. pub loss: Tensor, /// Raw logits. Shape: `[batch_size, seq_len, vocab_size]` pub logits: Tensor, /// Optional predicted token indices. Shape: `[batch_size, seq_length]`. /// If not provided, predictions default to argmax of `logits` along the last dimension. pub predictions: Option>, /// The target token indices. Shape: `[batch_size, seq_length]` pub targets: Tensor, } impl SequenceOutput { fn predicted_tokens(&self) -> Tensor { match &self.predictions { Some(preds) => preds.clone(), None => self.logits.clone().argmax(2).squeeze_dim::<2>(2), } } fn flat_logits(&self) -> Tensor { let [batch_size, seq_len, vocab_size] = self.logits.dims(); self.logits .clone() .reshape([batch_size * seq_len, vocab_size]) } fn flat_targets(&self) -> Tensor { let [batch_size, seq_len] = self.targets.dims(); self.targets.clone().reshape([batch_size * seq_len]) } } impl ItemLazy for SequenceOutput { type ItemSync = SequenceOutput; fn sync(self) -> Self::ItemSync { let device = &Default::default(); match self.predictions { Some(preds) => { let [logits, loss, targets, predictions] = Transaction::default() .register(self.logits) .register(self.loss) .register(self.targets) .register(preds) .execute() .try_into() .expect("Correct amount of tensor data"); SequenceOutput { logits: Tensor::from_data(logits, device), loss: Tensor::from_data(loss, device), targets: Tensor::from_data(targets, device), predictions: Some(Tensor::from_data(predictions, device)), } } None => { let [logits, loss, targets] = Transaction::default() .register(self.logits) .register(self.loss) .register(self.targets) .execute() .try_into() .expect("Correct amount of tensor data"); SequenceOutput { logits: Tensor::from_data(logits, device), loss: Tensor::from_data(loss, device), targets: Tensor::from_data(targets, device), predictions: None, } } } } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> LossInput { LossInput::new(self.loss.clone()) } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> CerInput { CerInput::new(self.predicted_tokens(), self.targets.clone()) } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> WerInput { WerInput::new(self.predicted_tokens(), self.targets.clone()) } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> AccuracyInput { AccuracyInput::new(self.flat_logits(), self.flat_targets()) } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> TopKAccuracyInput { TopKAccuracyInput::new(self.flat_logits(), self.flat_targets()) } } impl Adaptor> for SequenceOutput { fn adapt(&self) -> PerplexityInput { PerplexityInput::new(self.flat_logits(), self.flat_targets()) } } ================================================ FILE: crates/burn-train/src/learner/summary.rs ================================================ use core::cmp::Ordering; use std::{ collections::{HashMap, hash_map::Entry}, fmt::Display, path::{Path, PathBuf}, }; use crate::{ logger::FileMetricLogger, metric::store::{Aggregate, EventStore, LogEventStore, Split}, }; /// Contains the metric value at a given time. #[derive(Debug)] pub struct MetricEntry { /// The step at which the metric was recorded (i.e., epoch). pub step: usize, /// The metric value. pub value: f64, } /// Contains the summary of recorded values for a given metric. #[derive(Debug)] pub struct MetricSummary { /// The metric name. pub name: String, /// The metric entries. pub entries: Vec, } impl MetricSummary { fn collect( event_store: &mut E, metric: &str, split: &Split, num_epochs: usize, ) -> Option { let entries = (1..=num_epochs) .filter_map(|epoch| { event_store .find_metric(metric, epoch, Aggregate::Mean, split) .map(|value| MetricEntry { step: epoch, value }) }) .collect::>(); if entries.is_empty() { None } else { Some(Self { name: metric.to_string(), entries, }) } } } /// Contains the summary of recorded metrics for the training and validation steps. pub struct SummaryMetrics { /// Training metrics summary. pub train: Vec, /// Validation metrics summary. pub valid: Vec, /// Test metrics summary per test split tag. /// /// Each key corresponds to a `Split::Test(Some(tag))`. /// The empty string represents `Split::Test(None)`. pub test: HashMap>, } /// Detailed training summary. pub struct LearnerSummary { /// The number of epochs completed. pub epochs: usize, /// The summary of recorded metrics during training. pub metrics: SummaryMetrics, /// The model name (only recorded within the learner). pub(crate) model: Option, } impl LearnerSummary { /// Creates a new learner summary for the specified metrics. /// /// # Arguments /// /// * `directory` - The directory containing the training artifacts (checkpoints and logs). /// * `metrics` - The list of metrics to collect for the summary. pub fn new>(directory: impl AsRef, metrics: &[S]) -> Result { let directory = directory.as_ref(); if !directory.exists() { return Err(format!( "Artifact directory does not exist at: {}", directory.display() )); } let mut event_store = LogEventStore::default(); let train_split = Split::Train; let valid_split = Split::Valid; let logger = FileMetricLogger::new(directory); let test_split_root = logger.split_dir(&Split::Test(None)); if !logger.split_exists(&train_split) && !logger.split_exists(&valid_split) && test_split_root.is_none() { return Err(format!( "No training, validation or test artifacts found at: {}", directory.display() )); } // Number of recorded epochs let epochs = logger.epochs(); event_store.register_logger(logger); let train_summary = metrics .iter() .filter_map(|metric| { MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs) }) .collect::>(); let valid_summary = metrics .iter() .filter_map(|metric| { MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs) }) .collect::>(); let test_summary = match test_split_root { Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs), None => Default::default(), }; Ok(Self { epochs, metrics: SummaryMetrics { train: train_summary, valid: valid_summary, test: test_summary, }, model: None, }) } pub(crate) fn with_model(mut self, name: String) -> Self { self.model = Some(name); self } /// Merges another summary into this one, combining all metric entries. pub(crate) fn merge(mut self, other: LearnerSummary) -> Self { fn merge_metrics( base: Vec, incoming: Vec, ) -> Vec { let mut map: HashMap = base.into_iter().map(|m| (m.name.clone(), m)).collect(); for metric in incoming { match map.entry(metric.name.clone()) { Entry::Occupied(mut entry) => { entry.get_mut().entries.extend(metric.entries); } Entry::Vacant(entry) => { entry.insert(metric); } } } map.into_values().collect() } self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train); self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid); for (tag, metrics) in other.metrics.test { match self.metrics.test.entry(tag) { Entry::Occupied(mut entry) => { let current = std::mem::take(entry.get_mut()); let merged = merge_metrics(current, metrics); *entry.get_mut() = merged; } Entry::Vacant(entry) => { entry.insert(metrics); } } } if self.model != other.model { self.model = None; } self } } fn collect_test_split_metrics, S: AsRef>( root: P, metrics: &[S], event_store: &mut LogEventStore, epochs: usize, ) -> HashMap> { // Collect immediate child directories let dirs = match std::fs::read_dir(root) { Ok(entries) => entries .filter_map(|entry| { let entry = entry.ok()?; let file_type = entry.file_type().ok()?; if file_type.is_dir() { Some(entry.file_name().to_string_lossy().to_string()) } else { None } }) .collect::>(), Err(_) => Vec::new(), }; let mut map = HashMap::new(); if dirs.is_empty() { return map; } // Detect if all directories are epoch directories let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir); if all_epochs { // Single untagged test split let split = Split::Test(None); let summaries = metrics .iter() .filter_map(|metric| { MetricSummary::collect(event_store, metric.as_ref(), &split, epochs) }) .collect::>(); // Untagged marked with empty string map.insert("".to_string(), summaries); } else { // Tagged splits for tag in dirs { let split = Split::Test(Some(tag.clone().into())); let summaries = metrics .iter() .filter_map(|metric| { MetricSummary::collect(event_store, metric.as_ref(), &split, epochs) }) .collect::>(); map.insert(tag, summaries); } } map } impl Display for LearnerSummary { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Compute the max length for each column let mut max_split_len = 5; // "Train" let mut max_metric_len = "Metric".len(); for metric in self.metrics.train.iter() { max_metric_len = max_metric_len.max(metric.name.len()); } for metric in self.metrics.valid.iter() { max_metric_len = max_metric_len.max(metric.name.len()); } for (tag, metrics) in self.metrics.test.iter() { let split_name = if tag.is_empty() { "Test".to_string() } else { format!("Test ({tag})") }; max_split_len = max_split_len.max(split_name.len()); for metric in metrics { max_metric_len = max_metric_len.max(metric.name.len()); } } // Summary header writeln!( f, "{:=>width_symbol$} Learner Summary {:=>width_symbol$}", "", "", width_symbol = 24, )?; if let Some(model) = &self.model { writeln!(f, "Model:\n{model}")?; } writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?; // Metrics table header writeln!( f, "| {:width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|", "Split", "Metric", "", "", width_split = max_split_len, width_metric = max_metric_len, )?; // Table entries fn cmp_f64(a: &f64, b: &f64) -> Ordering { match (a.is_nan(), b.is_nan()) { (true, true) => Ordering::Equal, (true, false) => Ordering::Greater, (false, true) => Ordering::Less, _ => a.partial_cmp(b).unwrap(), } } fn fmt_val(val: f64) -> String { if val < 1e-2 { // Use scientific notation for small values which would otherwise be truncated format!("{val:<9.3e}") } else { format!("{val:<9.3}") } } let mut write_metrics_summary = |metrics: &[MetricSummary], split: String| -> std::fmt::Result { for metric in metrics.iter() { if metric.entries.is_empty() { continue; // skip metrics with no recorded values } // Compute the min & max for each metric let metric_min = metric .entries .iter() .min_by(|a, b| cmp_f64(&a.value, &b.value)) .unwrap(); let metric_max = metric .entries .iter() .max_by(|a, b| cmp_f64(&a.value, &b.value)) .unwrap(); writeln!( f, "| {:, } impl LearnerSummaryConfig { /// Create the learning summary. pub fn init(&self) -> Result { LearnerSummary::new(&self.directory, &self.metrics[..]) } } #[cfg(test)] mod tests { use super::*; #[test] #[should_panic = "Summary artifacts should exist"] fn test_artifact_dir_should_exist() { let dir = "/tmp/learner-summary-not-found"; let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist"); } #[test] #[should_panic = "Summary artifacts should exist"] fn test_train_valid_artifacts_should_exist() { let dir = "/tmp/test-learner-summary-empty"; std::fs::create_dir_all(dir).ok(); let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist"); } #[test] fn test_summary_should_be_empty() { let dir = Path::new("/tmp/test-learner-summary-empty-metrics"); std::fs::create_dir_all(dir).unwrap(); std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap(); std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap(); let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"]) .expect("Summary artifacts should exist"); assert_eq!(summary.epochs, 1); assert_eq!(summary.metrics.train.len(), 0); assert_eq!(summary.metrics.valid.len(), 0); std::fs::remove_dir_all(dir).unwrap(); } #[test] fn test_summary_should_be_collected() { let dir = Path::new("/tmp/test-learner-summary"); let train_dir = dir.join("train/epoch-1"); let valid_dir = dir.join("valid/epoch-1"); std::fs::create_dir_all(dir).unwrap(); std::fs::create_dir_all(&train_dir).unwrap(); std::fs::create_dir_all(&valid_dir).unwrap(); std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file"); std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file"); let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"]) .expect("Summary artifacts should exist"); assert_eq!(summary.epochs, 1); // Only Loss metric assert_eq!(summary.metrics.train.len(), 1); assert_eq!(summary.metrics.valid.len(), 1); // Aggregated train metric entries for 1 epoch let train_metric = &summary.metrics.train[0]; assert_eq!(train_metric.name, "Loss"); assert_eq!(train_metric.entries.len(), 1); let entry = &train_metric.entries[0]; assert_eq!(entry.step, 1); // epoch = 1 assert_eq!(entry.value, 1.5); // (1 + 2) / 2 // Aggregated valid metric entries for 1 epoch let valid_metric = &summary.metrics.valid[0]; assert_eq!(valid_metric.name, "Loss"); assert_eq!(valid_metric.entries.len(), 1); let entry = &valid_metric.entries[0]; assert_eq!(entry.step, 1); // epoch = 1 assert_eq!(entry.value, 1.0); std::fs::remove_dir_all(dir).unwrap(); } } ================================================ FILE: crates/burn-train/src/learner/supervised/mod.rs ================================================ mod paradigm; mod step; mod strategies; pub use paradigm::*; pub use step::*; pub use strategies::*; ================================================ FILE: crates/burn-train/src/learner/supervised/paradigm.rs ================================================ use crate::checkpoint::{ AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer, KeepLastNCheckpoints, MetricCheckpointingStrategy, }; use crate::components::{InferenceModelOutput, TrainingModelOutput}; use crate::learner::EarlyStoppingStrategy; use crate::learner::base::Interrupter; use crate::logger::{FileMetricLogger, MetricLogger}; use crate::metric::processor::{ AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining, }; use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, LossMetric, Metric, Numeric}; use crate::multi::MultiDeviceLearningStrategy; use crate::renderer::{MetricsRenderer, default_renderer}; use crate::single::SingleDeviceTrainingStrategy; use crate::{ ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller, InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent, LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy, }; use crate::{Learner, SupervisedLearningStrategy}; use burn_core::data::dataloader::DataLoader; use burn_core::module::{AutodiffModule, Module}; use burn_core::record::FileRecorder; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::Optimizer; use burn_optim::lr_scheduler::LrScheduler; use std::collections::BTreeSet; use std::path::{Path, PathBuf}; use std::sync::Arc; /// A reference to the training split [DataLoader](DataLoader). pub type TrainLoader = Arc, TrainingModelInput>>; /// A reference to the validation split [DataLoader](DataLoader). pub type ValidLoader = Arc, InferenceModelInput>>; /// The event processor type for supervised learning. pub type SupervisedTrainingEventProcessor = AsyncProcessorTraining< LearnerEvent>, LearnerEvent>, >; /// Structure to configure and launch supervised learning trainings. pub struct SupervisedTraining where LC: LearningComponentsTypes, { // Not that complex. Extracting into another type would only make it more confusing. #[allow(clippy::type_complexity)] checkpointers: Option<( AsyncCheckpointer, TrainingBackend>, AsyncCheckpointer, TrainingBackend>, AsyncCheckpointer, TrainingBackend>, )>, num_epochs: usize, checkpoint: Option, directory: PathBuf, grad_accumulation: Option, renderer: Option>, metrics: MetricsTraining, InferenceModelOutput>, event_store: LogEventStore, interrupter: Interrupter, tracing_logger: Option>, checkpointer_strategy: Box, early_stopping: Option, training_strategy: Option>, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order summary_metrics: BTreeSet, summary: bool, } impl SupervisedTraining> where B: AutodiffBackend, LR: LrScheduler + 'static, M: TrainStep + AutodiffModule + core::fmt::Display + 'static, M::InnerModule: InferenceStep, O: Optimizer + 'static, { /// Creates a new runner for a supervised training. /// /// # Arguments /// /// * `directory` - The directory to save the checkpoints. /// * `dataloader_train` - The dataloader for the training split. /// * `dataloader_valid` - The dataloader for the validation split. pub fn new( directory: impl AsRef, dataloader_train: Arc>, dataloader_valid: Arc< dyn DataLoader::Input>, >, ) -> Self { let directory = directory.as_ref().to_path_buf(); let experiment_log_file = directory.join("experiment.log"); Self { num_epochs: 1, checkpoint: None, checkpointers: None, directory, grad_accumulation: None, metrics: MetricsTraining::default(), event_store: LogEventStore::default(), renderer: None, interrupter: Interrupter::new(), tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new( experiment_log_file, ))), checkpointer_strategy: Box::new( ComposedCheckpointingStrategy::builder() .add(KeepLastNCheckpoints::new(2)) .add(MetricCheckpointingStrategy::new( &LossMetric::::new(), // default to valid loss Aggregate::Mean, Direction::Lowest, Split::Valid, )) .build(), ), early_stopping: None, training_strategy: None, summary_metrics: BTreeSet::new(), summary: false, dataloader_train, dataloader_valid, } } } impl SupervisedTraining { /// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided one. /// /// # Arguments /// /// * `training_strategy` - The training strategy. pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy) -> Self { self.training_strategy = Some(training_strategy); self } /// Replace the default metric loggers with the provided ones. /// /// # Arguments /// /// * `logger` - The training logger. pub fn with_metric_logger(mut self, logger: ML) -> Self where ML: MetricLogger + 'static, { self.event_store.register_logger(logger); self } /// Update the checkpointing_strategy. pub fn with_checkpointing_strategy( mut self, strategy: CS, ) -> Self { self.checkpointer_strategy = Box::new(strategy); self } /// Replace the default CLI renderer with a custom one. /// /// # Arguments /// /// * `renderer` - The custom renderer. pub fn renderer(mut self, renderer: MR) -> Self where MR: MetricsRenderer + 'static, { self.renderer = Some(Box::new(renderer)); self } /// Register all metrics as numeric for the training and validation set. pub fn metrics>(self, metrics: Me) -> Self { metrics.register(self) } /// Register all metrics as text for the training and validation set. pub fn metrics_text>(self, metrics: Me) -> Self { metrics.register(self) } /// Register a training metric. pub fn metric_train(mut self, metric: Me) -> Self where as ItemLazy>::ItemSync: Adaptor, { self.metrics.register_train_metric(metric); self } /// Register a validation metric. pub fn metric_valid(mut self, metric: Me) -> Self where as ItemLazy>::ItemSync: Adaptor, { self.metrics.register_valid_metric(metric); self } /// Enable gradients accumulation. /// /// # Notes /// /// When you enable gradients accumulation, the gradients object used by the optimizer will be /// the sum of all gradients generated by each backward pass. It might be a good idea to /// reduce the learning to compensate. /// /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation` /// amount. pub fn grads_accumulation(mut self, accumulation: usize) -> Self { self.grad_accumulation = Some(accumulation); self } /// Register a [numeric](crate::metric::Numeric) training [metric](Metric). pub fn metric_train_numeric(mut self, metric: Me) -> Self where Me: Metric + Numeric + 'static, as ItemLazy>::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_train_metric_numeric(metric); self } /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric). pub fn metric_valid_numeric(mut self, metric: Me) -> Self where as ItemLazy>::ItemSync: Adaptor, { self.summary_metrics.insert(metric.name().to_string()); self.metrics.register_valid_metric_numeric(metric); self } /// The number of epochs the training should last. pub fn num_epochs(mut self, num_epochs: usize) -> Self { self.num_epochs = num_epochs; self } /// The epoch from which the training must resume. pub fn checkpoint(mut self, checkpoint: usize) -> Self { self.checkpoint = Some(checkpoint); self } /// Provides a handle that can be used to interrupt training. pub fn interrupter(&self) -> Interrupter { self.interrupter.clone() } /// Override the handle for stopping training with an externally provided handle pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self { self.interrupter = interrupter; self } /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the /// conditions are meet. pub fn early_stopping(mut self, strategy: Strategy) -> Self where Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static, { self.early_stopping = Some(Box::new(strategy)); self } /// By default, Rust logs are captured and written into /// `experiment.log`. If disabled, standard Rust log handling /// will apply. pub fn with_application_logger( mut self, logger: Option>, ) -> Self { self.tracing_logger = logger; self } /// Register a checkpointer that will save the [optimizer](Optimizer), the /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files. pub fn with_file_checkpointer(mut self, recorder: FR) -> Self where FR: FileRecorder<::Backend> + 'static, FR: FileRecorder< <::Backend as AutodiffBackend>::InnerBackend, > + 'static, { let checkpoint_dir = self.directory.join("checkpoint"); let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model"); let checkpointer_optimizer = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim"); let checkpointer_scheduler: FileCheckpointer = FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler"); self.checkpointers = Some(( AsyncCheckpointer::new(checkpointer_model), AsyncCheckpointer::new(checkpointer_optimizer), AsyncCheckpointer::new(checkpointer_scheduler), )); self } /// Enable the training summary report. /// /// The summary will be displayed after `.fit()`, when the renderer is dropped. pub fn summary(mut self) -> Self { self.summary = true; self } } impl SupervisedTraining { /// Launch this training with the given [Learner](Learner). pub fn launch(mut self, learner: Learner) -> LearningResult> { if self.tracing_logger.is_some() && let Err(e) = self.tracing_logger.as_ref().unwrap().install() { log::warn!("Failed to install the experiment logger: {e}"); } let renderer = self .renderer .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint)); if !self.event_store.has_loggers() { self.event_store .register_logger(FileMetricLogger::new(self.directory.clone())); } let event_store = Arc::new(EventStoreClient::new(self.event_store)); let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new( self.metrics, renderer, event_store.clone(), )); let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { LearningCheckpointer::new( model.with_interrupter(self.interrupter.clone()), optim.with_interrupter(self.interrupter.clone()), scheduler.with_interrupter(self.interrupter.clone()), self.checkpointer_strategy, ) }); let summary = if self.summary { Some(LearnerSummaryConfig { directory: self.directory, metrics: self.summary_metrics.into_iter().collect::>(), }) } else { None }; let components = TrainingComponents { checkpoint: self.checkpoint, checkpointer, interrupter: self.interrupter, early_stopping: self.early_stopping, event_processor, event_store, num_epochs: self.num_epochs, grad_accumulation: self.grad_accumulation, summary, }; // Default to single device based on model let training_strategy = self .training_strategy .unwrap_or(TrainingStrategy::SingleDevice( learner.model.devices()[0].clone(), )); match training_strategy { TrainingStrategy::SingleDevice(device) => { let single_device: SingleDeviceTrainingStrategy = SingleDeviceTrainingStrategy::new(device); single_device.train( learner, self.dataloader_train, self.dataloader_valid, components, ) } TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train( learner, self.dataloader_train, self.dataloader_valid, components, ), TrainingStrategy::MultiDevice(devices, multi_device_optim) => { let strategy: Box> = match devices.len() == 1 { true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())), false => Box::new(MultiDeviceLearningStrategy::new( devices, multi_device_optim, )), }; strategy.train( learner, self.dataloader_train, self.dataloader_valid, components, ) } #[cfg(feature = "ddp")] TrainingStrategy::DistributedDataParallel { devices, config } => { use crate::ddp::DdpTrainingStrategy; let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone()); ddp.train( learner, self.dataloader_train, self.dataloader_valid, components, ) } } } } /// Trait to fake variadic generics. pub trait MetricRegistration: Sized { /// Register the metrics. fn register(self, builder: SupervisedTraining) -> SupervisedTraining; } /// Trait to fake variadic generics. pub trait TextMetricRegistration: Sized { /// Register the metrics. fn register(self, builder: SupervisedTraining) -> SupervisedTraining; } macro_rules! gen_tuple { ($($M:ident),*) => { impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration for ($($M,)*) where $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: SupervisedTraining, ) -> SupervisedTraining { let ($($M,)*) = self; $(let builder = builder.metric_train($M.clone());)* $(let builder = builder.metric_valid($M);)* builder } } impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration for ($($M,)*) where $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $( as ItemLazy>::ItemSync: Adaptor<$M::Input>,)* $($M: Metric + Numeric + 'static,)* { #[allow(non_snake_case)] fn register( self, builder: SupervisedTraining, ) -> SupervisedTraining { let ($($M,)*) = self; $(let builder = builder.metric_train_numeric($M.clone());)* $(let builder = builder.metric_valid_numeric($M);)* builder } } }; } gen_tuple!(M1); gen_tuple!(M1, M2); gen_tuple!(M1, M2, M3); gen_tuple!(M1, M2, M3, M4); gen_tuple!(M1, M2, M3, M4, M5); gen_tuple!(M1, M2, M3, M4, M5, M6); ================================================ FILE: crates/burn-train/src/learner/supervised/step/mod.rs ================================================ /// The trainer module. pub mod train; ================================================ FILE: crates/burn-train/src/learner/supervised/step/train.rs ================================================ use crate::{LearningComponentsTypes, TrainingModel}; use crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput}; use burn_core::data::dataloader::DataLoaderIterator; use burn_core::data::dataloader::Progress; use burn_core::module::Module; use burn_core::prelude::DeviceOps; use burn_core::tensor::Device; use burn_core::tensor::backend::DeviceId; use std::sync::mpsc::{Receiver, Sender}; use std::thread::spawn; /// Multi devices train step. pub struct MultiDevicesTrainStep { workers: Vec>, receiver: Receiver>>, } struct Message { item: TI, model: M, } struct Worker { // Not that complex. Extracting into another type would only make it more confusing. #[allow(clippy::type_complexity)] sender_input: Sender, TrainingModelInput>>, device: Device>, } impl Worker { fn register(&self, item: TrainingModelInput, model: &TrainingModel) { let message = Message { item, model: model.clone(), }; self.sender_input.send(message).unwrap(); } // Not that complex. Extracting into another type would only make it more confusing. #[allow(clippy::type_complexity)] fn start( &self, sender_output: Sender>>, receiver_input: Receiver, TrainingModelInput>>, ) { let device = self.device.clone(); spawn(move || { loop { match receiver_input.recv() { Ok(item) => { let model = item.model.fork(&device); let output = model.step(item.item); let item = MultiTrainOutput { output, device: device.to_id(), }; sender_output.send(item).unwrap(); } Err(_err) => { log::info!("Closing thread on device {device:?}"); break; } } } }); } } /// Multiple output items. pub struct MultiTrainOutput { /// The training output. pub output: TrainOutput, /// The device on which the computing happened. pub device: DeviceId, } impl MultiDevicesTrainStep { /// Create a new multi devices train step. /// /// # Arguments /// /// * `devices` - Devices. /// /// # Returns /// /// MultiDevicesTrainStep instance. pub fn new(devices: &[Device>]) -> Self { let (sender_output, receiver_output) = std::sync::mpsc::channel(); let workers = devices .iter() .map(|device| { let (sender_input, receiver_input) = std::sync::mpsc::channel(); let worker = Worker { sender_input, device: device.clone(), }; worker.start(sender_output.clone(), receiver_input); worker }) .collect(); Self { workers, receiver: receiver_output, } } /// Collect outputs from workers for one step. /// /// # Arguments /// /// * `model` - Model. /// * `dataloaders` - The data loader for each worker. /// /// # Returns /// /// Outputs. pub fn step<'a>( &self, dataloaders: &mut [Box> + 'a>], model: &TrainingModel, ) -> (Vec>>, Progress) { let mut num_send = 0; let mut items_total = 0; let mut items_processed = 0; for (i, worker) in self.workers.iter().enumerate() { let dataloader = &mut dataloaders[i]; if let Some(item) = dataloader.next() { worker.register(item, model); num_send += 1; let progress = dataloader.progress(); items_total += progress.items_total; items_processed += progress.items_processed; } } let mut outputs = Vec::with_capacity(num_send); for _ in 0..num_send { let output = self.receiver.recv().unwrap(); outputs.push(output); } (outputs, Progress::new(items_processed, items_total)) } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/base.rs ================================================ use std::sync::Arc; #[cfg(feature = "ddp")] use burn_collective::CollectiveConfig; use burn_core::{module::AutodiffModule, prelude::Backend}; use crate::{ EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig, LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader, TrainingModel, ValidLoader, components::LearningComponentsTypes, metric::{ processor::{EventProcessorTraining, LearnerEvent}, store::EventStoreClient, }, }; type LearnerDevice = <::Backend as Backend>::Device; /// A reference to an implementation of SupervisedLearningStrategy. pub type CustomLearningStrategy = Arc>; #[derive(Clone, Copy, Debug)] /// Determine how the optimization is performed when training with multiple devices. pub enum MultiDeviceOptim { /// The optimization is done on an elected device. OptimMainDevice, /// The optimization is sharded across all devices. OptimSharded, } /// How should the learner run the learning for the model #[derive(Clone)] pub enum TrainingStrategy { /// Training on one device SingleDevice(LearnerDevice), /// Performs data-parallel distributed training where the optimization is /// done on an elected master device. MultiDevice(Vec>, MultiDeviceOptim), /// Training using a custom learning strategy Custom(CustomLearningStrategy), /// Training with input distributed across devices, each device has its own copy of the model. /// Collective ops are used to sync the gradients after each pass. #[cfg(feature = "ddp")] DistributedDataParallel { /// Devices on this node for the DDP devices: Vec>, /// The configuration for collective operations /// num_devices is ignored config: CollectiveConfig, }, } /// Constructor for a distributed data parallel (DDP) learning strategy #[cfg(feature = "ddp")] pub fn ddp( devices: Vec>, config: CollectiveConfig, ) -> TrainingStrategy { TrainingStrategy::DistributedDataParallel { devices, config } } impl Default for TrainingStrategy { fn default() -> Self { Self::SingleDevice(Default::default()) } } /// Struct to minimise parameters passed to [SupervisedLearningStrategy::train]. /// These components are used during training. pub struct TrainingComponents { /// The total number of epochs pub num_epochs: usize, /// The epoch number from which to continue the training. pub checkpoint: Option, /// A checkpointer used to load and save learner checkpoints. pub checkpointer: Option>, /// Enables gradients accumulation. pub grad_accumulation: Option, /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early. pub interrupter: Interrupter, /// Cloneable reference to an early stopping strategy. pub early_stopping: Option, /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation. pub event_processor: SupervisedTrainingEventProcessor, /// A reference to an [EventStoreClient](EventStoreClient). pub event_store: Arc, /// Config for creating a summary of the learning pub summary: Option, } /// Provides the `fit` function for any learning strategy pub trait SupervisedLearningStrategy { /// Train the learner's model with this strategy. fn train( &self, mut learner: Learner, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, mut training_components: TrainingComponents, ) -> LearningResult> { let starting_epoch = match training_components.checkpoint { Some(checkpoint) => { if let Some(checkpointer) = &mut training_components.checkpointer { learner = checkpointer.load_checkpoint(learner, &Default::default(), checkpoint); } checkpoint + 1 } None => 1, }; let summary_config = training_components.summary.clone(); // Event processor start training training_components .event_processor .process_train(LearnerEvent::Start); // Training loop let (model, mut event_processor) = self.fit( training_components, learner, dataloader_train, dataloader_valid, starting_epoch, ); let summary = summary_config.and_then(|summary| { summary .init() .map(|summary| summary.with_model(model.to_string())) .ok() }); // Signal training end. For the TUI renderer, this handles the exit & return to main screen. event_processor.process_train(LearnerEvent::End(summary)); let model = model.valid(); let renderer = event_processor.renderer(); LearningResult::> { model, renderer } } /// Training loop for this strategy fn fit( &self, training_components: TrainingComponents, learner: Learner, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, starting_epoch: usize, ) -> (TrainingModel, SupervisedTrainingEventProcessor); } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/ddp/README.md ================================================ ## DDP Distributed Data Parallel The DDP is a learning strategy that trains a replica of the model on each device. The DDP launches threads for each local device. Each thread on each node will run the model. After the forward and backward passes, the gradients are synced between all peers on all nodes with an `all-reduce` operation. While the DDP launches threads for each local device, it is the user's responsibility to launch the DDP on each node, and assure the collective configuration matches. ## Main device vs secondary devices The main device is responsible for validation, as well as event processing, which is used in the UI. The first device is chosen as the main device. ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs ================================================ use burn_collective::{PeerId, ReduceOperation}; use burn_core::data::dataloader::Progress; use burn_core::module::AutodiffModule; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::GradientsAccumulator; use burn_optim::GradientsParams; use std::marker::PhantomData; use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::{Arc, Mutex}; use crate::SupervisedTrainingEventProcessor; use crate::learner::base::Interrupter; use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::{ InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader, }; /// A validation epoch. #[derive(new)] pub struct DdpValidEpoch { dataloader: ValidLoader, } /// A training epoch. #[derive(new)] pub struct DdpTrainEpoch { dataloader: TrainLoader, grad_accumulation: Option, } impl DdpValidEpoch { /// Runs the validation epoch. /// /// # Arguments /// /// * `model` - The model to validate. /// * `processor` - The event processor to use. pub fn run( &self, model: &::TrainingModel, global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { let epoch = global_progress.items_processed; log::info!("Executing validation step for epoch {}", epoch); let model = model.valid(); let mut iterator = self.dataloader.iter(); let mut iteration = 0; while let Some(item) = iterator.next() { let progress = iterator.progress(); iteration += 1; let item = model.step(item); let item = TrainingItem::new( item, progress, global_progress.clone(), Some(iteration), None, ); processor.process_valid(LearnerEvent::ProcessedItem(item)); if interrupter.should_stop() { log::info!("Training interrupted."); break; } } processor.process_valid(LearnerEvent::EndEpoch(epoch)); } } impl DdpTrainEpoch { /// Runs the training epoch. /// /// # Arguments /// /// * `model` - The model to train. /// * `optim` - The optimizer to use. /// * `scheduler` - The learning rate scheduler to use. /// * `processor` - The event processor to use. /// /// # Returns /// /// The trained model and the optimizer. #[allow(clippy::too_many_arguments)] pub fn run( &self, learner: &mut Learner, global_progress: &Progress, processor: Arc>>, interrupter: &Interrupter, peer_id: PeerId, peer_count: usize, is_main: bool, ) { let epoch = global_progress.items_processed; log::info!("Executing training step for epoch {}", epoch,); let mut iterator = self.dataloader.iter(); let mut iteration = 0; let mut accumulator = GradientsAccumulator::new(); let mut accumulation_current = 0; let grads_syncer = GradsSyncer::< TrainingBackend, ::TrainingModel, >::new(false, peer_id); while let Some(item) = iterator.next() { for _ in 0..peer_count { iteration += 1; learner.lr_step(); } log::info!("Iteration {iteration}"); let mut progress = iterator.progress(); progress.items_processed *= peer_count; progress.items_total *= peer_count; let item = learner.train_step(item); match self.grad_accumulation { Some(accumulation) => { accumulator.accumulate(&learner.model(), item.grads); accumulation_current += 1; if accumulation <= accumulation_current { let grads = accumulator.grads(); // With double buffering, these are the previous iteration's gradients let grads = grads_syncer.sync(grads); if let Some(grads) = grads { learner.optimizer_step(grads); } accumulation_current = 0; } } None => { // With double buffering, these are the previous iteration's gradients let grads = grads_syncer.sync(item.grads); if let Some(grads) = grads { learner.optimizer_step(grads); } } } let item = TrainingItem::new( item.item, progress, global_progress.clone(), Some(iteration), Some(learner.lr_current()), ); { let mut processor = processor.lock().unwrap(); processor.process_train(LearnerEvent::ProcessedItem(item)); } if interrupter.should_stop() { log::info!("Training interrupted."); break; } } if is_main { let mut processor = processor.lock().unwrap(); processor.process_train(LearnerEvent::EndEpoch(epoch)); } } } /// Worker that is responsible for syncing gradients for the DDP worker. With double buffering, /// this allows for more optimization. struct GradsSyncer + 'static> { msg_send: SyncSender, // Optional because with double buffering, the first iteration yields no gradients. result_recv: Receiver>, _p: PhantomData<(B, M)>, } impl + 'static> GradsSyncer { fn new(double_buffering: bool, peer_id: PeerId) -> Self { let (msg_send, msg_recv) = std::sync::mpsc::sync_channel::(1); let (result_send, result_recv) = std::sync::mpsc::sync_channel::>(1); std::thread::spawn(move || { Self::run_worker(double_buffering, peer_id, result_send, msg_recv) }); Self { msg_send, result_recv, _p: PhantomData, } } fn sync(&self, grads: GradientsParams) -> Option { self.msg_send.send(grads).unwrap(); self.result_recv.recv().unwrap() } fn run_worker( double_buffering: bool, peer_id: PeerId, send: SyncSender>, recv: Receiver, ) { let mut grads_buffer = None; while let Ok(new_grads) = recv.recv() { // Sync grads with collective let new_grads = new_grads .all_reduce::(peer_id, ReduceOperation::Mean) .expect("DDP worker could not sync gradients!"); if double_buffering { let old_grads = grads_buffer.take(); grads_buffer = Some(new_grads); send.send(old_grads).unwrap(); } else { send.send(Some(new_grads)).unwrap(); } } // GradsSyncer dropped, channel closed, this thread can end } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/ddp/mod.rs ================================================ mod epoch; mod strategy; mod worker; pub use strategy::*; ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/ddp/strategy.rs ================================================ use core::panic; use std::sync::{Arc, Mutex}; use burn_collective::CollectiveConfig; use burn_core::tensor::Device; use burn_core::tensor::backend::DeviceOps; use crate::ddp::worker::DdpWorker; use crate::metric::store::EventStoreClient; use crate::{ EarlyStoppingStrategyRef, Interrupter, Learner, LearningComponentsTypes, SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader, }; use burn_core::data::dataloader::split::split_dataloader; #[derive(Clone)] pub(crate) struct WorkerComponents { /// The total number of epochs pub num_epochs: usize, /// Enables gradients accumulation. pub grad_accumulation: Option, /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early. pub interrupter: Interrupter, /// Cloneable reference to an early stopping strategy. pub early_stopping: Option, /// A reference to an [EventStoreClient](EventStoreClient). pub event_store: Arc, } pub struct DdpTrainingStrategy { devices: Vec>>, config: CollectiveConfig, } impl DdpTrainingStrategy { pub fn new(devices: Vec>>, config: CollectiveConfig) -> Self { let config = config.with_num_devices(devices.len()); Self { devices, config } } } impl SupervisedLearningStrategy for DdpTrainingStrategy { fn fit( &self, training_components: TrainingComponents, learner: Learner, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, starting_epoch: usize, ) -> (TrainingModel, SupervisedTrainingEventProcessor) { // The reference model is always on the first device provided. let main_device = self.devices.first().unwrap(); // One worker per device, so we use a fixed device strategy // for each (worker) data loader. This matches the expected device on the worker, so we // don't have to move the data between devices. let mut dataloaders_train = split_dataloader(dataloader_train, &self.devices); let dataloader_valid = dataloader_valid.to_device(main_device.inner()); let main_device = self.devices[0].clone(); let peer_count = self.devices.len(); let event_processor = Arc::new(Mutex::new(training_components.event_processor)); let interrupter = training_components.interrupter; let worker_components = WorkerComponents { num_epochs: training_components.num_epochs, grad_accumulation: training_components.grad_accumulation, interrupter: interrupter.clone(), early_stopping: training_components.early_stopping, event_store: training_components.event_store, }; // Start worker for main device // First training dataloader corresponds to main device let main_handle = DdpWorker::::start( 0.into(), main_device, learner.clone(), event_processor.clone(), worker_components.clone(), training_components.checkpointer, dataloaders_train.remove(0), Some(dataloader_valid), self.config.clone(), starting_epoch, peer_count, true, ); // Spawn other workers for the other devices, starting with peer id 1 let mut peer_id = 1; let mut secondary_workers = vec![]; for device in &self.devices[1..] { let handle = DdpWorker::::start( peer_id.into(), device.clone(), learner.clone(), event_processor.clone(), worker_components.clone(), None, dataloaders_train.remove(0), None, self.config.clone(), starting_epoch, peer_count, false, ); peer_id += 1; secondary_workers.push(handle); } // Wait for all devices to finish for worker in secondary_workers { worker .join() .expect("Distributed data parallel worker failed"); } // Main worker had the event processor let model = main_handle .join() .expect("Distributed data parallel main worker failed"); if interrupter.should_stop() { let reason = interrupter .get_message() .unwrap_or(String::from("Reason unknown")); log::info!("Training interrupted: {reason}"); } let Ok(event_processor) = Arc::try_unwrap(event_processor) else { panic!("Event processor still held!"); }; let Ok(event_processor) = event_processor.into_inner() else { panic!("Event processor lock poisoned"); }; (model, event_processor) } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs ================================================ use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch}; use crate::ddp::strategy::WorkerComponents; use crate::single::TrainingLoop; use crate::{ Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, ValidLoader, }; use burn_collective::{self, CollectiveConfig, PeerId}; use burn_core::tensor::Device; use burn_core::tensor::backend::AutodiffBackend; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; /// A worker runs the model, syncing gradients using collective operations. /// Event processing and validation is optional too. pub(crate) struct DdpWorker where LC: LearningComponentsTypes + Send + 'static, { peer_id: PeerId, device: Device>, learner: Learner, event_processor: Arc>>, components: WorkerComponents, checkpointer: Option>, dataloader_train: TrainLoader, dataloader_valid: Option>, collective_config: CollectiveConfig, starting_epoch: usize, peer_count: usize, is_main: bool, } impl DdpWorker where LC: LearningComponentsTypes + Send + 'static, { /// Starts a worker that runs the model in a data distributed parallel #[allow(clippy::too_many_arguments)] pub fn start( peer_id: PeerId, device: Device>, learner: Learner, event_processor: Arc>>, components: WorkerComponents, checkpointer: Option>, dataloader_train: TrainLoader, dataloader_valid: Option>, collective_config: CollectiveConfig, starting_epoch: usize, peer_count: usize, is_main: bool, ) -> JoinHandle<::TrainingModel> { let worker = Self { peer_id, device, learner, event_processor, components, checkpointer, dataloader_train, dataloader_valid, collective_config, starting_epoch, peer_count, is_main, }; std::thread::spawn(|| worker.fit()) } /// Fits the model, pub fn fit(mut self) -> ::TrainingModel { burn_collective::register::< as AutodiffBackend>::InnerBackend>( self.peer_id, self.device.clone(), self.collective_config.clone(), ) .expect("Couldn't register for collective operations!"); let num_epochs = self.components.num_epochs; let interrupter = self.components.interrupter; // Changed the train epoch to keep the dataloaders let epoch_train = DdpTrainEpoch::::new( self.dataloader_train.clone(), self.components.grad_accumulation, ); let epoch_valid = self .dataloader_valid .map(|dataloader| DdpValidEpoch::::new(dataloader)); self.learner.fork(&self.device); for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) { let epoch = training_progress.items_processed; epoch_train.run( &mut self.learner, &training_progress, self.event_processor.clone(), &interrupter, self.peer_id, self.peer_count, self.is_main, ); if interrupter.should_stop() { break; } // Validation if let Some(runner) = &epoch_valid { let mut event_processor = self.event_processor.lock().unwrap(); runner.run( &self.learner.model(), &training_progress, &mut event_processor, &interrupter, ); } if let Some(checkpointer) = &mut self.checkpointer { checkpointer.checkpoint(&self.learner, epoch, &self.components.event_store); } if let Some(early_stopping) = &mut self.components.early_stopping && early_stopping.should_stop(epoch, &self.components.event_store) { break; } } self.learner.model() } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/mod.rs ================================================ mod base; #[cfg(feature = "ddp")] pub(crate) mod ddp; pub(crate) mod multi; pub(crate) mod single; pub use base::*; ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs ================================================ use crate::learner::base::Interrupter; use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::train::MultiDevicesTrainStep; use crate::{ Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, }; use burn_core::data::dataloader::Progress; use burn_core::prelude::DeviceOps; use burn_core::tensor::Device; use burn_core::tensor::backend::DeviceId; use burn_optim::GradientsAccumulator; use burn_optim::MultiGradientsParams; use std::collections::HashMap; /// A training epoch. #[derive(new)] pub struct MultiDeviceTrainEpoch { dataloaders: Vec>, grad_accumulation: Option, } impl MultiDeviceTrainEpoch { /// Runs the training epoch on multiple devices. /// /// # Arguments /// /// * `model` - The model to train. /// * `optim` - The optimizer to use. /// * `lr_scheduler` - The learning rate scheduler to use. /// * `processor` - The event processor to use. /// * `devices` - The devices to use. /// /// # Returns /// /// The trained model and the optimizer. #[allow(clippy::too_many_arguments)] pub fn run( &self, learner: &mut Learner, global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, strategy: MultiDeviceOptim, ) { match strategy { MultiDeviceOptim::OptimMainDevice => self.run_optim_main( learner, global_progress, event_processor, interrupter, devices, ), MultiDeviceOptim::OptimSharded => self.run_optim_distr( learner, global_progress, event_processor, interrupter, devices, ), } } fn run_optim_main( &self, learner: &mut Learner, global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, ) { let epoch = global_progress.items_processed; log::info!( "Executing training step for epoch {} on devices {:?}", epoch, devices ); let mut iterators = self .dataloaders .iter() .map(|d| d.iter()) .collect::>(); let mut iteration = 0; let mut accumulator = GradientsAccumulator::new(); let mut accumulation_current = 0; let accumulation = self.grad_accumulation.unwrap_or(1); let step = MultiDevicesTrainStep::::new(&devices); // The main device is always the first in the list. let device_main = devices.first().expect("A minimum of one device.").clone(); loop { let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model()); if items.is_empty() { break; } learner.lr_step(); let mut progress_items = Vec::with_capacity(items.len()); for item in items.into_iter() { let grads = item.output.grads.to_device(&device_main, &learner.model()); accumulator.accumulate(&learner.model(), grads); progress_items.push(item.output.item); } accumulation_current += 1; if accumulation <= accumulation_current { let grads = accumulator.grads(); learner.optimizer_step(grads); accumulation_current = 0; } for item in progress_items { iteration += 1; let item = TrainingItem::new( item, progress.clone(), global_progress.clone(), Some(iteration), Some(learner.lr_current()), ); event_processor.process_train(LearnerEvent::ProcessedItem(item)); } if interrupter.should_stop() { break; } } event_processor.process_train(LearnerEvent::EndEpoch(epoch)); } fn run_optim_distr( &self, learner: &mut Learner, global_progress: &Progress, event_processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, devices: Vec>>, ) { let epoch = global_progress.items_processed; log::info!( "Executing training step for epoch {} on devices {:?}", epoch, devices ); let mut iterators = self .dataloaders .iter() .map(|d| d.iter()) .collect::>(); let mut iteration = 0; let mut accumulators = HashMap::< DeviceId, GradientsAccumulator<::TrainingModel>, >::new(); for device in devices.iter() { accumulators.insert(device.to_id(), GradientsAccumulator::new()); } let mut accumulation_current = 0; let accumulation = self.grad_accumulation.unwrap_or(1); let step = MultiDevicesTrainStep::::new(&devices); loop { let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model()); if items.is_empty() { break; } learner.lr_step(); let mut progress_items = Vec::with_capacity(items.len()); for item in items.into_iter() { let accumulator = accumulators.get_mut(&item.device).unwrap(); accumulator.accumulate(&learner.model(), item.output.grads); progress_items.push(item.output.item); } accumulation_current += 1; if accumulation <= accumulation_current { let mut grads = MultiGradientsParams::default(); for (device_id, accumulator) in accumulators.iter_mut() { let grad = accumulator.grads(); grads.grads.push((grad, *device_id)); } learner.optimizer_step_multi(grads); accumulation_current = 0; } for item in progress_items { iteration += 1; let item = TrainingItem::new( item, progress.clone(), global_progress.clone(), Some(iteration), Some(learner.lr_current()), ); event_processor.process_train(LearnerEvent::ProcessedItem(item)); } if interrupter.should_stop() { break; } } event_processor.process_train(LearnerEvent::EndEpoch(epoch)); } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/multi/mod.rs ================================================ pub(crate) mod epoch; mod strategy; pub use strategy::*; ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs ================================================ use crate::{ Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader, multi::epoch::MultiDeviceTrainEpoch, single::{TrainingLoop, epoch::SingleDeviceValidEpoch}, }; use burn_core::{ data::dataloader::split::split_dataloader, tensor::{Device, backend::DeviceOps}, }; pub struct MultiDeviceLearningStrategy { devices: Vec>>, optim: MultiDeviceOptim, } impl MultiDeviceLearningStrategy { pub fn new(devices: Vec>>, optim: MultiDeviceOptim) -> Self { Self { devices, optim } } } impl SupervisedLearningStrategy for MultiDeviceLearningStrategy { fn fit( &self, training_components: TrainingComponents, mut learner: Learner, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, starting_epoch: usize, ) -> (TrainingModel, SupervisedTrainingEventProcessor) { let main_device = self.devices.first().unwrap(); // `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy // for each (worker) data loader. This matches the expected device on the worker, so we // don't have to move the data between devices. let dataloader_train = split_dataloader(dataloader_train, &self.devices); let dataloader_valid = dataloader_valid.to_device(main_device.inner()); learner.fork(main_device); let mut event_processor = training_components.event_processor; let mut checkpointer = training_components.checkpointer; let mut early_stopping = training_components.early_stopping; let epoch_train = MultiDeviceTrainEpoch::::new( dataloader_train.clone(), training_components.grad_accumulation, ); let epoch_valid: SingleDeviceValidEpoch = SingleDeviceValidEpoch::new(dataloader_valid.clone()); for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) { let epoch = training_progress.items_processed; epoch_train.run( &mut learner, &training_progress, &mut event_processor, &training_components.interrupter, self.devices.to_vec(), self.optim, ); if training_components.interrupter.should_stop() { let reason = training_components .interrupter .get_message() .unwrap_or(String::from("Reason unknown")); log::info!("Training interrupted: {reason}"); break; } // After OptimSharded training, model parameters are scattered across // devices. Fork back to main_device before single-device validation. if matches!(self.optim, MultiDeviceOptim::OptimSharded) { learner.fork(main_device); } epoch_valid.run( &learner, &training_progress, &mut event_processor, &training_components.interrupter, ); if let Some(checkpointer) = &mut checkpointer { checkpointer.checkpoint(&learner, epoch, &training_components.event_store); } if let Some(early_stopping) = &mut early_stopping && early_stopping.should_stop(epoch, &training_components.event_store) { break; } } (learner.model(), event_processor) } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/single/epoch.rs ================================================ use crate::learner::base::Interrupter; use crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem}; use crate::{ InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader, ValidLoader, }; use burn_core::data::dataloader::Progress; use burn_core::module::AutodiffModule; use burn_optim::GradientsAccumulator; /// A validation epoch. #[derive(new)] pub struct SingleDeviceValidEpoch { dataloader: ValidLoader, } /// A training epoch. #[derive(new)] pub struct SingleDeviceTrainEpoch { dataloader: TrainLoader, grad_accumulation: Option, } impl SingleDeviceValidEpoch { /// Runs the validation epoch. /// /// # Arguments /// /// * `model` - The model to validate. /// * `processor` - The event processor to use. pub fn run( &self, learner: &Learner, global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { let epoch = global_progress.items_processed; log::info!("Executing validation step for epoch {}", epoch); let model = learner.model().valid(); let mut iterator = self.dataloader.iter(); let mut iteration = 0; while let Some(item) = iterator.next() { let progress = iterator.progress(); iteration += 1; let item = model.step(item); let item = TrainingItem::new( item, progress, global_progress.clone(), Some(iteration), None, ); processor.process_valid(LearnerEvent::ProcessedItem(item)); if interrupter.should_stop() { break; } } processor.process_valid(LearnerEvent::EndEpoch(epoch)); } } impl SingleDeviceTrainEpoch { /// Runs the training epoch. /// /// # Arguments /// /// * `model` - The model to train. /// * `optim` - The optimizer to use. /// * `scheduler` - The learning rate scheduler to use. /// * `processor` - The event processor to use. /// /// # Returns /// /// The trained model and the optimizer. pub fn run( &self, learner: &mut Learner, global_progress: &Progress, processor: &mut SupervisedTrainingEventProcessor, interrupter: &Interrupter, ) { let epoch = global_progress.items_processed; log::info!("Executing training step for epoch {}", epoch,); // Single device / dataloader let mut iterator = self.dataloader.iter(); let mut iteration = 0; let mut accumulator = GradientsAccumulator::new(); let mut accumulation_current = 0; while let Some(item) = iterator.next() { iteration += 1; learner.lr_step(); log::info!("Iteration {iteration}"); let progress = iterator.progress(); let item = learner.train_step(item); match self.grad_accumulation { Some(accumulation) => { accumulator.accumulate(&learner.model(), item.grads); accumulation_current += 1; if accumulation <= accumulation_current { let grads = accumulator.grads(); learner.optimizer_step(grads); accumulation_current = 0; } } None => learner.optimizer_step(item.grads), } let item = TrainingItem::new( item.item, progress, global_progress.clone(), Some(iteration), Some(learner.lr_current()), ); processor.process_train(LearnerEvent::ProcessedItem(item)); if interrupter.should_stop() { break; } } processor.process_train(LearnerEvent::EndEpoch(epoch)); } } ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/single/mod.rs ================================================ pub(crate) mod epoch; mod strategy; pub use strategy::*; ================================================ FILE: crates/burn-train/src/learner/supervised/strategies/single/strategy.rs ================================================ use crate::{ Learner, LearningComponentsTypes, SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader, single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch}, }; use burn_core::{ data::dataloader::Progress, tensor::{Device, backend::DeviceOps}, }; /// Simplest learning strategy possible, with only a single devices doing both the training and /// validation. pub struct SingleDeviceTrainingStrategy { device: Device>, } impl SingleDeviceTrainingStrategy { pub fn new(device: Device>) -> Self { Self { device } } } #[derive(new)] pub(crate) struct TrainingLoop { next_iteration: usize, total_iteration: usize, } /// An iterator that returns the progress of the training. impl Iterator for TrainingLoop { type Item = Progress; fn next(&mut self) -> Option { if self.next_iteration > self.total_iteration { return None; } let progress = Progress { items_processed: self.next_iteration, items_total: self.total_iteration, }; self.next_iteration += 1; Some(progress) } } impl SupervisedLearningStrategy for SingleDeviceTrainingStrategy { fn fit( &self, training_components: TrainingComponents, mut learner: Learner, dataloader_train: TrainLoader, dataloader_valid: ValidLoader, starting_epoch: usize, ) -> (TrainingModel, SupervisedTrainingEventProcessor) { let dataloader_train = dataloader_train.to_device(&self.device); let dataloader_valid = dataloader_valid.to_device(self.device.inner()); learner.fork(&self.device); let mut event_processor = training_components.event_processor; let mut checkpointer = training_components.checkpointer; let mut early_stopping = training_components.early_stopping; let epoch_train: SingleDeviceTrainEpoch = SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation); let epoch_valid: SingleDeviceValidEpoch = SingleDeviceValidEpoch::new(dataloader_valid.clone()); for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) { let epoch = training_progress.items_processed; epoch_train.run( &mut learner, &training_progress, &mut event_processor, &training_components.interrupter, ); if training_components.interrupter.should_stop() { let reason = training_components .interrupter .get_message() .unwrap_or(String::from("Reason unknown")); log::info!("Training interrupted: {reason}"); break; } epoch_valid.run( &learner, &training_progress, &mut event_processor, &training_components.interrupter, ); if let Some(checkpointer) = &mut checkpointer { checkpointer.checkpoint(&learner, epoch, &training_components.event_store); } if let Some(early_stopping) = &mut early_stopping && early_stopping.should_stop(epoch, &training_components.event_store) { break; } } (learner.model(), event_processor) } } ================================================ FILE: crates/burn-train/src/learner/train_val.rs ================================================ use crate::{ItemLazy, renderer::MetricsRenderer}; use burn_core::module::AutodiffModule; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer}; /// A training output. pub struct TrainOutput { /// The gradients. pub grads: GradientsParams, /// The item. pub item: TO, } impl TrainOutput { /// Creates a new training output. /// /// # Arguments /// /// * `module` - The module. /// * `grads` - The gradients. /// * `item` - The item. /// /// # Returns /// /// A new training output. pub fn new>( module: &M, grads: B::Gradients, item: TO, ) -> Self { let grads = GradientsParams::from_grads(grads, module); Self { grads, item } } } /// Trait to be implemented for models to be able to be trained. /// /// The [step](TrainStep::step) method needs to be manually implemented for all structs. /// /// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the /// optimizer is used to update the model. This can be useful if you want to call custom mutable /// functions on your model (e.g., clipping the weights) before or after the optimizer is used. /// /// # Notes /// /// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must /// also implement the [AutodiffModule] trait, which is done automatically with the /// [Module](burn_core::module::Module) derive. pub trait TrainStep { /// Type of input for a step of the training stage. type Input: Send + 'static; /// Type of output for a step of the training stage. type Output: ItemLazy + 'static; /// Runs a step for training, which executes the forward and backward passes. /// /// # Arguments /// /// * `item` - The input for the model. /// /// # Returns /// /// The output containing the model output and the gradients. fn step(&self, item: Self::Input) -> TrainOutput; /// Optimize the current module with the provided gradients and learning rate. /// /// # Arguments /// /// * `optim`: Optimizer used for learning. /// * `lr`: The learning rate used for this step. /// * `grads`: The gradients of each parameter in the current model. /// /// # Returns /// /// The updated model. fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self where B: AutodiffBackend, O: Optimizer, Self: AutodiffModule, { optim.step(lr, self, grads) } /// Optimize the current module with the provided gradients and learning rate. /// /// # Arguments /// /// * `optim`: Optimizer used for learning. /// * `lr`: The learning rate used for this step. /// * `grads`: Multiple gradients associated to each parameter in the current model. /// /// # Returns /// /// The updated model. fn optimize_multi(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self where B: AutodiffBackend, O: Optimizer, Self: AutodiffModule, { optim.step_multi(lr, self, grads) } } /// Trait to be implemented for validating models. pub trait InferenceStep { /// Type of input for an inference step. type Input: Send + 'static; /// Type of output for an inference step. type Output: ItemLazy + 'static; /// Runs a validation step. /// /// # Arguments /// /// * `item` - The item to validate on. /// /// # Returns /// /// The validation output. fn step(&self, item: Self::Input) -> Self::Output; } /// The result of a training, containing the model along with the [renderer](MetricsRenderer). pub struct LearningResult { /// The model with the learned weights. pub model: M, /// The renderer that can be used for follow up training and evaluation. pub renderer: Box, } ================================================ FILE: crates/burn-train/src/lib.rs ================================================ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for training neural networks using the burn crate. #[macro_use] extern crate derive_new; /// The checkpoint module. pub mod checkpoint; pub(crate) mod components; /// Renderer modules to display metrics and training information. pub mod renderer; /// The logger module. pub mod logger; /// The metric module. pub mod metric; pub use metric::processor::*; mod learner; pub use learner::*; mod evaluator; pub use evaluator::*; pub use components::*; #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; #[cfg(test)] pub(crate) mod tests { use crate::TestBackend; use burn_core::{prelude::Tensor, tensor::Bool}; use std::default::Default; pub type TestAutodiffBackend = burn_autodiff::Autodiff; /// Probability of tp before adding errors pub const THRESHOLD: f64 = 0.5; #[derive(Debug, Default)] pub enum ClassificationType { #[default] Binary, Multiclass, Multilabel, } /// Sample x Class shaped matrix for use in /// classification metrics testing pub fn dummy_classification_input( classification_type: &ClassificationType, ) -> (Tensor, Tensor) { match classification_type { ClassificationType::Binary => { ( Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()), // targets Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()), // predictions @ threshold=0.5 // [[0], [0], [1], [0], [1]] ) } ClassificationType::Multiclass => { ( Tensor::from_data( [ [0.2, 0.8, 0.0], [0.3, 0.6, 0.1], [0.7, 0.25, 0.05], [0.1, 0.15, 0.8], [0.9, 0.03, 0.07], ], &Default::default(), ), Tensor::from_data( // targets [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]], // predictions @ top_k=1 // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] // predictions @ top_k=2 // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] &Default::default(), ), ) } ClassificationType::Multilabel => { ( Tensor::from_data( [ [0.1, 0.7, 0.6], [0.3, 0.9, 0.05], [0.8, 0.9, 0.4], [0.7, 0.5, 0.9], [1.0, 0.3, 0.2], ], &Default::default(), ), // targets Tensor::from_data( [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]], // predictions @ threshold=0.5 // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] &Default::default(), ), ) } } } } ================================================ FILE: crates/burn-train/src/logger/async_logger.rs ================================================ use super::Logger; use std::sync::mpsc; enum Message { Log(T), End, Sync(mpsc::Sender<()>), } /// Async logger. pub struct AsyncLogger { sender: mpsc::Sender>, handler: Option>, } #[derive(new)] struct LoggerThread> { logger: L, receiver: mpsc::Receiver>, } impl LoggerThread where L: Logger, { fn run(mut self) { for item in self.receiver.iter() { match item { Message::Log(item) => { self.logger.log(item); } Message::End => { return; } Message::Sync(callback) => { callback .send(()) .expect("Can return result with the callback channel."); } } } } } impl AsyncLogger { /// Create a new async logger. pub fn new(logger: L) -> Self where L: Logger + 'static, { let (sender, receiver) = mpsc::channel(); let thread = LoggerThread::new(logger, receiver); let handler = Some(std::thread::spawn(move || thread.run())); Self { sender, handler } } /// Sync the async logger. pub(crate) fn sync(&self) { let (sender, receiver) = mpsc::channel(); self.sender .send(Message::Sync(sender)) .expect("Can send message to logger thread."); receiver .recv() .expect("Should sync, otherwise the thread is dead."); } } impl Logger for AsyncLogger { fn log(&mut self, item: T) { self.sender .send(Message::Log(item)) .expect("Can log using the logger thread."); } } impl Drop for AsyncLogger { fn drop(&mut self) { self.sender .send(Message::End) .expect("Can send the end message to the logger thread."); let handler = self.handler.take(); if let Some(handler) = handler { handler.join().expect("The logger thread should stop."); } } } ================================================ FILE: crates/burn-train/src/logger/base.rs ================================================ /// The logger trait. pub trait Logger: Send { /// Logs an item. /// /// # Arguments /// /// * `item` - The item. fn log(&mut self, item: T); } ================================================ FILE: crates/burn-train/src/logger/file.rs ================================================ use super::Logger; use std::{fs::File, io::Write, path::Path}; /// File logger. pub struct FileLogger { file: File, } impl FileLogger { /// Create a new file logger. /// /// # Arguments /// /// * `path` - The path. /// /// # Returns /// /// The file logger. pub fn new(path: impl AsRef) -> Self { let path = path.as_ref(); let mut options = std::fs::File::options(); let file = options .write(true) .truncate(true) .create(true) .open(path) .unwrap_or_else(|err| { panic!( "Should be able to create the new file '{}': {}", path.display(), err ) }); Self { file } } } impl Logger for FileLogger where T: std::fmt::Display, { fn log(&mut self, item: T) { writeln!(&mut self.file, "{item}").expect("Can log an item."); } } ================================================ FILE: crates/burn-train/src/logger/in_memory.rs ================================================ use super::Logger; /// In memory logger. #[derive(Default)] pub struct InMemoryLogger { pub(crate) values: Vec, } impl Logger for InMemoryLogger where T: std::fmt::Display, { fn log(&mut self, item: T) { self.values.push(item.to_string()); } } ================================================ FILE: crates/burn-train/src/logger/metric.rs ================================================ use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger}; use crate::metric::{ MetricDefinition, MetricEntry, MetricId, NumericEntry, store::{EpochSummary, MetricsUpdate, Split}, }; use std::{ collections::HashMap, fs, path::{Path, PathBuf}, }; const EPOCH_PREFIX: &str = "epoch-"; /// Metric logger. pub trait MetricLogger: Send { /// Logs an item. /// /// # Arguments /// /// * `update` - Update information for all registered metrics. /// * `epoch` - Current epoch. /// * `split` - Current dataset split. fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split); /// Read the logs for an epoch. fn read_numeric( &mut self, name: &str, epoch: usize, split: &Split, ) -> Result, String>; /// Logs the metric definition information (name, description, unit, etc.) fn log_metric_definition(&mut self, definition: MetricDefinition); /// Logs summary at the end of the epoch. fn log_epoch_summary(&mut self, summary: EpochSummary); } /// The file metric logger. pub struct FileMetricLogger { loggers: HashMap>, directory: PathBuf, metric_definitions: HashMap, is_eval: bool, last_epoch: Option, } impl FileMetricLogger { /// Create a new file metric logger. /// /// # Arguments /// /// * `directory` - The directory. /// /// # Returns /// /// The file metric logger. pub fn new(directory: impl AsRef) -> Self { Self { loggers: HashMap::new(), directory: directory.as_ref().to_path_buf(), metric_definitions: HashMap::default(), is_eval: false, last_epoch: None, } } /// Create a new file metric logger. /// /// # Arguments /// /// * `directory` - The directory. /// /// # Returns /// /// The file metric logger. pub fn new_eval(directory: impl AsRef) -> Self { Self { loggers: HashMap::new(), directory: directory.as_ref().to_path_buf(), metric_definitions: HashMap::default(), is_eval: true, last_epoch: None, } } pub(crate) fn split_exists(&self, split: &Split) -> bool { self.split_dir(split).is_some() } pub(crate) fn split_dir(&self, split: &Split) -> Option { let split_path = match split { Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()), other => self.directory.join(other.to_string()), }; (split_path.exists() && split_path.is_dir()).then_some(split_path) } pub(crate) fn is_epoch_dir>(dirname: P) -> bool { dirname.as_ref().starts_with(EPOCH_PREFIX) } /// Number of epochs recorded. pub(crate) fn epochs(&self) -> usize { if self.is_eval { log::warn!("Number of epochs not available when testing."); return 0; } let mut max_epoch = 0; // with split for path in fs::read_dir(&self.directory).unwrap() { let path = path.unwrap(); if fs::metadata(path.path()).unwrap().is_dir() { for split_path in fs::read_dir(path.path()).unwrap() { let split_path = split_path.unwrap(); if fs::metadata(split_path.path()).unwrap().is_dir() { let dir_name = split_path.file_name().into_string().unwrap(); if !dir_name.starts_with(EPOCH_PREFIX) { continue; } let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::().ok(); if let Some(epoch) = epoch && epoch > max_epoch { max_epoch = epoch; } } } } } max_epoch } fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf { let name = format!("{EPOCH_PREFIX}{epoch}"); match split { Split::Train | Split::Valid | Split::Test(None) => { self.directory.join(split.to_string()).join(name) } Split::Test(Some(tag)) => { let tag = format_tag(tag); self.directory.join(split.to_string()).join(tag).join(name) } } } fn eval_directory(&self, split: &Split) -> PathBuf { match split { Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(), Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)), } } fn file_path(&self, name: &str, epoch: Option, split: &Split) -> PathBuf { let directory = match epoch { Some(epoch) => self.train_directory(epoch, split), None => self.eval_directory(split), }; let name = name.replace(' ', "_"); let name = format!("{name}.log"); directory.join(name) } fn create_directory(&self, epoch: Option, split: &Split) { let directory = match epoch { Some(epoch) => self.train_directory(epoch, split), None => self.eval_directory(split), }; std::fs::create_dir_all(directory).ok(); } } impl FileMetricLogger { fn log_item(&mut self, item: &MetricEntry, epoch: Option, split: &Split) { let name = &self.metric_definitions.get(&item.metric_id).unwrap().name; let key = logger_key(name, split); let value = &item.serialized_entry.serialized; let logger = match self.loggers.get_mut(&key) { Some(val) => val, None => { self.create_directory(epoch, split); let file_path = self.file_path(name, epoch, split); let logger = FileLogger::new(file_path); let logger = AsyncLogger::new(logger); self.loggers.insert(key.clone(), logger); self.loggers .get_mut(&key) .expect("Can get the previously saved logger.") } }; logger.log(value.clone()); } } fn format_tag(tag: &str) -> String { tag.trim().replace(' ', "-").to_lowercase() } impl MetricLogger for FileMetricLogger { fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) { if !self.is_eval && self.last_epoch != Some(epoch) { self.loggers.clear(); self.last_epoch = Some(epoch); } let entries: Vec<_> = update .entries .iter() .chain( update .entries_numeric .iter() .map(|numeric_update| &numeric_update.entry), ) .cloned() .collect(); for item in entries.iter() { self.log_item(item, Some(epoch), split); } } fn read_numeric( &mut self, name: &str, epoch: usize, split: &Split, ) -> Result, String> { if let Some(value) = self.loggers.get(name) { value.sync() } let file_path = self.file_path(name, Some(epoch), split); let mut errors = false; let data = std::fs::read_to_string(file_path) .unwrap_or_default() .split('\n') .filter_map(|value| { if value.is_empty() { None } else { match NumericEntry::deserialize(value) { Ok(value) => Some(value), Err(err) => { log::error!("{err}"); errors = true; None } } } }) .collect(); if errors { Err("Parsing numeric entry errors".to_string()) } else { Ok(data) } } fn log_metric_definition(&mut self, definition: MetricDefinition) { self.metric_definitions .insert(definition.metric_id.clone(), definition); } fn log_epoch_summary(&mut self, _summary: EpochSummary) { if !self.is_eval { self.loggers.clear(); } } } fn logger_key(name: &str, split: &Split) -> String { format!("{name}_{split}") } /// In memory metric logger, useful when testing and debugging. #[derive(Default)] pub struct InMemoryMetricLogger { values: HashMap>, last_epoch: Option, metric_definitions: HashMap, } impl InMemoryMetricLogger { /// Create a new in-memory metric logger. pub fn new() -> Self { Self::default() } } impl MetricLogger for InMemoryMetricLogger { fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) { if self.last_epoch != Some(epoch) { self.values .values_mut() .for_each(|loggers| loggers.push(InMemoryLogger::default())); self.last_epoch = Some(epoch); } let entries: Vec<_> = update .entries .iter() .chain( update .entries_numeric .iter() .map(|numeric_update| &numeric_update.entry), ) .cloned() .collect(); for item in entries.iter() { let name = &self.metric_definitions.get(&item.metric_id).unwrap().name; let key = logger_key(name, split); if !self.values.contains_key(&key) { self.values .insert(key.to_string(), vec![InMemoryLogger::default()]); } let values = self.values.get_mut(&key).unwrap(); values .last_mut() .unwrap() .log(item.serialized_entry.serialized.clone()); } } fn read_numeric( &mut self, name: &str, epoch: usize, split: &Split, ) -> Result, String> { let key = logger_key(name, split); let values = match self.values.get(&key) { Some(values) => values, None => return Ok(Vec::new()), }; match values.get(epoch - 1) { Some(logger) => Ok(logger .values .iter() .filter_map(|value| NumericEntry::deserialize(value).ok()) .collect()), None => Ok(Vec::new()), } } fn log_metric_definition(&mut self, definition: MetricDefinition) { self.metric_definitions .insert(definition.metric_id.clone(), definition); } fn log_epoch_summary(&mut self, _summary: EpochSummary) {} } ================================================ FILE: crates/burn-train/src/logger/mod.rs ================================================ mod async_logger; mod base; mod file; mod in_memory; mod metric; pub use async_logger::*; pub use base::*; pub use file::*; pub use in_memory::*; pub use metric::*; ================================================ FILE: crates/burn-train/src/metric/acc.rs ================================================ use core::marker::PhantomData; use super::MetricMetadata; use super::state::{FormatOptions, NumericMetricState}; use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{ElementConversion, Int, Tensor}; /// The accuracy metric. #[derive(Clone)] pub struct AccuracyMetric { name: MetricName, state: NumericMetricState, pad_token: Option, _b: PhantomData, } /// The [accuracy metric](AccuracyMetric) input type. #[derive(new)] pub struct AccuracyInput { outputs: Tensor, targets: Tensor, } impl Default for AccuracyMetric { fn default() -> Self { Self::new() } } impl AccuracyMetric { /// Creates the metric. pub fn new() -> Self { Self { name: MetricName::new("Accuracy".to_string()), state: Default::default(), pad_token: Default::default(), _b: PhantomData, } } /// Sets the pad token. pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self } } impl Metric for AccuracyMetric { type Input = AccuracyInput; fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> SerializedEntry { let targets = input.targets.clone(); let outputs = input.outputs.clone(); let [batch_size, _n_classes] = outputs.dims(); let outputs = outputs.argmax(1).reshape([batch_size]); let accuracy = match self.pad_token { Some(pad_token) => { let mask = targets.clone().equal_elem(pad_token as i64); let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0); let num_pad = mask.float().sum(); let acc = matches.sum() / (num_pad.neg() + batch_size as f32); acc.into_scalar().elem::() } None => { outputs .equal(targets) .int() .sum() .into_scalar() .elem::() / batch_size as f64 } }; self.state.update( 100.0 * accuracy, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { super::NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for AccuracyMetric { fn value(&self) -> super::NumericEntry { self.state.current_value() } fn running_value(&self) -> super::NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn test_accuracy_without_padding() { let device = Default::default(); let mut metric = AccuracyMetric::::new(); let input = AccuracyInput::new( Tensor::from_data( [ [0.0, 0.2, 0.8], // 2 [1.0, 2.0, 0.5], // 1 [0.4, 0.1, 0.2], // 0 [0.6, 0.7, 0.2], // 1 ], &device, ), Tensor::from_data([2, 2, 1, 1], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } #[test] fn test_accuracy_with_padding() { let device = Default::default(); let mut metric = AccuracyMetric::::new().with_pad_token(3); let input = AccuracyInput::new( Tensor::from_data( [ [0.0, 0.2, 0.8, 0.0], // 2 [1.0, 2.0, 0.5, 0.0], // 1 [0.4, 0.1, 0.2, 0.0], // 0 [0.6, 0.7, 0.2, 0.0], // 1 [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count [0.0, 0.1, 0.2, 0.0], // Error on padding should not count [0.6, 0.0, 0.2, 0.0], // Error on padding should not count ], &device, ), Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } } ================================================ FILE: crates/burn-train/src/metric/auroc.rs ================================================ use core::f64; use core::marker::PhantomData; use super::MetricMetadata; use super::state::{FormatOptions, NumericMetricState}; use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{ElementConversion, Int, Tensor}; /// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification. #[derive(Clone)] pub struct AurocMetric { name: MetricName, state: NumericMetricState, _b: PhantomData, } /// The [AUROC metric](AurocMetric) input type. #[derive(new)] pub struct AurocInput { outputs: Tensor, targets: Tensor, } impl Default for AurocMetric { fn default() -> Self { Self::new() } } impl AurocMetric { /// Creates the metric. pub fn new() -> Self { Self { name: MetricName::new("AUROC".to_string()), state: Default::default(), _b: PhantomData, } } fn binary_auroc(&self, probabilities: &Tensor, targets: &Tensor) -> f64 { let n = targets.dims()[0]; let n_pos = targets.clone().sum().into_scalar().elem::() as usize; // Early return if we don't have both positive and negative samples if n_pos == 0 || n_pos == n { if n_pos == 0 { log::warn!("Metric cannot be computed because all target values are negative.") } else { log::warn!("Metric cannot be computed because all target values are positive.") } return 0.0; } let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]); let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]); let valid_pairs = pos_mask * neg_mask; let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n); let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n); let correct_order = prob_i.clone().greater(prob_j.clone()).int(); let ties = prob_i.equal(prob_j).int(); // Calculate AUC components let num_pairs = valid_pairs.clone().sum().into_scalar().elem::(); let correct_pairs = (correct_order * valid_pairs.clone()) .sum() .into_scalar() .elem::(); let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::(); (correct_pairs + 0.5 * tied_pairs) / num_pairs } } impl Metric for AurocMetric { type Input = AurocInput; fn update(&mut self, input: &AurocInput, _metadata: &MetricMetadata) -> SerializedEntry { let [batch_size, num_classes] = input.outputs.dims(); assert_eq!( num_classes, 2, "Currently only binary classification is supported" ); let probabilities = { let exponents = input.outputs.clone().exp(); let sum = exponents.clone().sum_dim(1); (exponents / sum) .select(1, Tensor::arange(1..2, &input.outputs.device())) .squeeze_dim(1) }; let area_under_curve = self.binary_auroc(&probabilities, &input.targets); self.state.update( 100.0 * area_under_curve, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } } impl Numeric for AurocMetric { fn value(&self) -> super::NumericEntry { self.state.current_value() } fn running_value(&self) -> super::NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn test_auroc() { let device = Default::default(); let mut metric = AurocMetric::::new(); let input = AurocInput::new( Tensor::from_data( [ [0.1, 0.9], // High confidence positive [0.7, 0.3], // Low confidence negative [0.6, 0.4], // Low confidence negative [0.2, 0.8], // High confidence positive ], &device, ), Tensor::from_data([1, 0, 0, 1], &device), // True labels ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(metric.value().current(), 100.0); } #[test] fn test_auroc_perfect_separation() { let device = Default::default(); let mut metric = AurocMetric::::new(); let input = AurocInput::new( Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device), Tensor::from_data([1, 0, 0, 1], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(metric.value().current(), 100.0); // Perfect AUC } #[test] fn test_auroc_random() { let device = Default::default(); let mut metric = AurocMetric::::new(); let input = AurocInput::new( Tensor::from_data( [ [0.5, 0.5], // Random predictions [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], ], &device, ), Tensor::from_data([1, 0, 0, 1], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(metric.value().current(), 50.0); } #[test] fn test_auroc_all_one_class() { let device = Default::default(); let mut metric = AurocMetric::::new(); let input = AurocInput::new( Tensor::from_data( [ [0.1, 0.9], // All positives predictions [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], ], &device, ), Tensor::from_data([1, 1, 1, 1], &device), // All positive class ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(metric.value().current(), 0.0); } #[test] #[should_panic(expected = "Currently only binary classification is supported")] fn test_auroc_multiclass_error() { let device = Default::default(); let mut metric = AurocMetric::::new(); let input = AurocInput::new( Tensor::from_data( [ [0.1, 0.2, 0.7], // More than 2 classes not supported [0.3, 0.5, 0.2], ], &device, ), Tensor::from_data([2, 1], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); } } ================================================ FILE: crates/burn-train/src/metric/base.rs ================================================ use std::sync::Arc; use burn_core::data::dataloader::Progress; use burn_optim::LearningRate; /// Metric metadata that can be used when computing metrics. pub struct MetricMetadata { /// The current progress. pub progress: Progress, /// The global progress of the training (e.g. epochs). pub global_progress: Progress, /// The current iteration. pub iteration: Option, /// The current learning rate. pub lr: Option, } impl MetricMetadata { /// Fake metric metadata #[cfg(test)] pub fn fake() -> Self { Self { progress: Progress { items_processed: 1, items_total: 1, }, global_progress: Progress { items_processed: 0, items_total: 1, }, iteration: Some(0), lr: None, } } } /// Metric id that can be used to compare metrics and retrieve entries of the same metric. /// For now we take the name as id to make sure that the same metric has the same id across different runs. #[derive(Debug, Clone, new, PartialEq, Eq, Hash)] pub struct MetricId { /// The metric id. id: Arc, } /// Metric attributes define the properties intrinsic to different types of metric. #[derive(Clone, Debug)] pub enum MetricAttributes { /// Numeric attributes. Numeric(NumericAttributes), /// No attributes. None, } /// Definition of a metric. #[derive(Clone, Debug)] pub struct MetricDefinition { /// The metric's id. pub metric_id: MetricId, /// The name of the metric. pub name: String, /// The description of the metric. pub description: Option, /// The attributes of the metric. pub attributes: MetricAttributes, } impl MetricDefinition { /// Create a new metric definition given the metric and a unique id. pub fn new(metric_id: MetricId, metric: &Me) -> Self { Self { metric_id, name: metric.name().to_string(), description: metric.description(), attributes: metric.attributes(), } } } /// Metric trait. /// /// # Notes /// /// Implementations should define their own input type only used by the metric. /// This is important since some conflict may happen when the model output is adapted for each /// metric's input type. pub trait Metric: Send + Sync + Clone { /// The input type of the metric. type Input; /// The parameterized name of the metric. /// /// This should be unique, so avoid using short generic names, prefer using the long name. /// /// For a metric that can exist at different parameters (e.g., top-k accuracy for different /// values of k), the name should be unique for each instance. fn name(&self) -> MetricName; /// A short description of the metric. fn description(&self) -> Option { None } /// Attributes of the metric. /// /// By default, metrics have no attributes. fn attributes(&self) -> MetricAttributes { MetricAttributes::None } /// Update the metric state and returns the current metric entry. fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry; /// Clear the metric state. fn clear(&mut self); } /// Type used to store metric names efficiently. pub type MetricName = Arc; /// Adaptor are used to transform types so that they can be used by metrics. /// /// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are /// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)). pub trait Adaptor { /// Adapt the type to be passed to a [metric](Metric). fn adapt(&self) -> T; } impl Adaptor<()> for T { fn adapt(&self) {} } /// Attributes that describe intrinsic properties of a numeric metric. #[derive(Clone, Debug)] pub struct NumericAttributes { /// Optional unit (e.g. "%", "ms", "pixels") pub unit: Option, /// Whether larger values are better (true) or smaller are better (false). pub higher_is_better: bool, } impl From for MetricAttributes { fn from(attr: NumericAttributes) -> Self { MetricAttributes::Numeric(attr) } } impl Default for NumericAttributes { fn default() -> Self { Self { unit: None, higher_is_better: true, } } } /// Declare a metric to be numeric. /// /// This is useful to plot the values of a metric during training. pub trait Numeric { /// Returns the numeric value of the metric. fn value(&self) -> NumericEntry; /// Returns the current aggregated value of the metric over the global step (epoch). fn running_value(&self) -> NumericEntry; } /// Serialized form of a metric entry. #[derive(Debug, Clone, new)] pub struct SerializedEntry { /// The string to be displayed. pub formatted: String, /// The string to be saved. pub serialized: String, } /// Data type that contains the current state of a metric at a given time. #[derive(Debug, Clone)] pub struct MetricEntry { /// Id of the entry's metric. pub metric_id: MetricId, /// The serialized form of the entry. pub serialized_entry: SerializedEntry, } impl MetricEntry { /// Create a new metric. pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self { Self { metric_id, serialized_entry, } } } /// Numeric metric entry. #[derive(Debug, Clone)] pub enum NumericEntry { /// Single numeric value. Value(f64), /// Aggregated numeric (value, number of elements). Aggregated { /// The aggregated value of all entries. aggregated_value: f64, /// The number of entries present in the aggregated value. count: usize, }, } impl NumericEntry { /// Gets the current aggregated value of the metric. pub fn current(&self) -> f64 { match self { NumericEntry::Value(val) => *val, NumericEntry::Aggregated { aggregated_value, .. } => *aggregated_value, } } /// Returns a String representing the NumericEntry pub fn serialize(&self) -> String { match self { Self::Value(v) => v.to_string(), Self::Aggregated { aggregated_value, count, } => format!("{aggregated_value},{count}"), } } /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry. pub fn deserialize(entry: &str) -> Result { // Check for comma separated values let values = entry.split(',').collect::>(); let num_values = values.len(); if num_values == 1 { // Numeric value match values[0].parse::() { Ok(value) => Ok(NumericEntry::Value(value)), Err(err) => Err(err.to_string()), } } else if num_values == 2 { // Aggregated numeric (value, number of elements) let (value, numel) = (values[0], values[1]); match value.parse::() { Ok(value) => match numel.parse::() { Ok(numel) => Ok(NumericEntry::Aggregated { aggregated_value: value, count: numel, }), Err(err) => Err(err.to_string()), }, Err(err) => Err(err.to_string()), } } else { Err("Invalid number of values for numeric entry".to_string()) } } /// Compare this numeric metric's value with another one using the specified direction. pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool { (self.current() > other.current()) == higher_is_better } } /// Format a float with the given precision. Will use scientific notation if necessary. pub fn format_float(float: f64, precision: usize) -> String { let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0); match scientific_notation_threshold >= float { true => format!("{float:.precision$e}"), false => format!("{float:.precision$}"), } } ================================================ FILE: crates/burn-train/src/metric/cer.rs ================================================ use super::state::{FormatOptions, NumericMetricState}; use super::{MetricMetadata, SerializedEntry}; use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; use core::marker::PhantomData; use std::sync::Arc; /// Computes the edit distance (Levenshtein distance) between two sequences of integers. /// /// The edit distance is defined as the minimum number of single-element edits (insertions, /// deletions, or substitutions) required to change one sequence into the other. This /// implementation is optimized for space, using only two rows of the dynamic programming table. /// pub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize { let mut prev = (0..=prediction.len()).collect::>(); let mut curr = vec![0; prediction.len() + 1]; for (i, &r) in reference.iter().enumerate() { curr[0] = i + 1; for (j, &p) in prediction.iter().enumerate() { curr[j + 1] = if r == p { prev[j] // no operation needed } else { 1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion }; } core::mem::swap(&mut prev, &mut curr); } prev[prediction.len()] } /// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted /// and reference character sequences, divided by the total number of characters in the reference. /// This metric is commonly used in tasks such as speech recognition, OCR, or text generation /// to quantify how closely the predicted output matches the ground truth at a character level. /// #[derive(Clone)] pub struct CharErrorRate { name: MetricName, state: NumericMetricState, pad_token: Option, _b: PhantomData, } /// The [character error rate metric](CharErrorRate) input type. #[derive(new)] pub struct CerInput { /// The predicted token sequences (as a 2-D tensor of token indices). pub outputs: Tensor, /// The target token sequences (as a 2-D tensor of token indices). pub targets: Tensor, } impl Default for CharErrorRate { fn default() -> Self { Self::new() } } impl CharErrorRate { /// Creates the metric. pub fn new() -> Self { Self { name: Arc::new("CER".to_string()), state: NumericMetricState::default(), pad_token: None, _b: PhantomData, } } /// Sets the pad token. pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self } } /// The [character error rate metric](CharErrorRate) implementation. impl Metric for CharErrorRate { type Input = CerInput; fn update(&mut self, input: &CerInput, _metadata: &MetricMetadata) -> SerializedEntry { let outputs = &input.outputs; let targets = &input.targets; let [batch_size, seq_len] = targets.dims(); let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token { // Create boolean masks for non-padding tokens. let output_mask = outputs.clone().not_equal_elem(pad as i64); let target_mask = targets.clone().not_equal_elem(pad as i64); let output_lengths_tensor = output_mask.int().sum_dim(1); let target_lengths_tensor = target_mask.int().sum_dim(1); ( output_lengths_tensor.to_data().to_vec::().unwrap(), target_lengths_tensor.to_data().to_vec::().unwrap(), ) } else { // If there's no padding, all sequences have the full length. ( vec![seq_len as i64; batch_size], vec![seq_len as i64; batch_size], ) }; let outputs_data = outputs.to_data().to_vec::().unwrap(); let targets_data = targets.to_data().to_vec::().unwrap(); let total_edit_distance: usize = (0..batch_size) .map(|i| { let start = i * seq_len; // Get pre-calculated lengths for the current sequence. let output_len = output_lengths[i] as usize; let target_len = target_lengths[i] as usize; let output_seq_slice = &outputs_data[start..(start + output_len)]; let target_seq_slice = &targets_data[start..(start + target_len)]; let output_seq: Vec = output_seq_slice.iter().map(|&x| x as i32).collect(); let target_seq: Vec = target_seq_slice.iter().map(|&x| x as i32).collect(); edit_distance(&target_seq, &output_seq) }) .sum(); let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::(); let value = if total_target_length > 0.0 { 100.0 * total_edit_distance as f64 / total_target_length } else { 0.0 }; self.state.update( value, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset(); } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { super::NumericAttributes { unit: Some("%".to_string()), higher_is_better: false, } .into() } } impl Numeric for CharErrorRate { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; /// Perfect match ⇒ CER = 0 %. #[test] fn test_cer_without_padding() { let device = Default::default(); let mut metric = CharErrorRate::::new(); // Batch size = 2, sequence length = 2 let preds = Tensor::from_data([[1, 2], [3, 4]], &device); let tgts = Tensor::from_data([[1, 2], [3, 4]], &device); metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake()); assert_eq!(0.0, metric.value().current()); } /// Two edits in four target tokens ⇒ 50 %. #[test] fn test_cer_without_padding_two_errors() { let device = Default::default(); let mut metric = CharErrorRate::::new(); // One substitution in each sequence. let preds = Tensor::from_data([[1, 2], [3, 5]], &device); let tgts = Tensor::from_data([[1, 3], [3, 4]], &device); metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake()); // 2 edits / 4 tokens = 50 % assert_eq!(50.0, metric.value().current()); } /// Same scenario as above, but with right-padding (token 9) ignored. #[test] fn test_cer_with_padding() { let device = Default::default(); let pad = 9_i64; let mut metric = CharErrorRate::::new().with_pad_token(pad as usize); // Each row has three columns, last one is the pad token. let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device); let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device); metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } /// `clear()` must reset the running statistics to zero. #[test] fn test_clear_resets_state() { let device = Default::default(); let mut metric = CharErrorRate::::new(); let preds = Tensor::from_data([[1, 2]], &device); let tgts = Tensor::from_data([[1, 3]], &device); // one error metric.update( &CerInput::new(preds.clone(), tgts.clone()), &MetricMetadata::fake(), ); assert!(metric.value().current() > 0.0); metric.clear(); assert!(metric.value().current().is_nan()); } } ================================================ FILE: crates/burn-train/src/metric/classification.rs ================================================ use std::num::NonZeroUsize; /// Necessary data for classification metrics. #[derive(Default, Debug, Clone)] pub struct ClassificationMetricConfig { pub decision_rule: DecisionRule, pub class_reduction: ClassReduction, } /// The prediction decision rule for classification metrics. #[derive(Debug, Clone)] pub enum DecisionRule { /// Consider a class predicted if its probability exceeds the threshold. Threshold(f64), /// Consider a class predicted correctly if it is within the top k predicted classes based on scores. TopK(NonZeroUsize), } impl Default for DecisionRule { fn default() -> Self { Self::Threshold(0.5) } } /// The reduction strategy for classification metrics. #[derive(Copy, Clone, Default, Debug)] pub enum ClassReduction { /// Computes the statistics over all classes before averaging Micro, /// Computes the statistics independently for each class before averaging #[default] Macro, } ================================================ FILE: crates/burn-train/src/metric/confusion_stats.rs ================================================ use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}; use burn_core::{ prelude::{Backend, Bool, Int, Tensor}, tensor::IndexingUpdateOp, }; use std::fmt::{self, Debug}; /// Input for confusion statistics error types. #[derive(new, Debug, Clone)] pub struct ConfusionStatsInput { /// Sample x Class Non thresholded normalized predictions. pub predictions: Tensor, /// Sample x Class one-hot encoded target. pub targets: Tensor, } impl From> for (Tensor, Tensor) { fn from(input: ConfusionStatsInput) -> Self { (input.predictions, input.targets) } } impl From<(Tensor, Tensor)> for ConfusionStatsInput { fn from(value: (Tensor, Tensor)) -> Self { Self::new(value.0, value.1) } } #[derive(Clone)] pub struct ConfusionStats { confusion_classes: Tensor, class_reduction: ClassReduction, } impl Debug for ConfusionStats { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let to_vec = |tensor_data: Tensor| { tensor_data .to_data() .to_vec::() .expect("A vector representation of the input Tensor is expected") }; let ratio_of_support_vec = |metric: Tensor| to_vec(self.clone().ratio_of_support(metric)); f.debug_struct("ConfusionStats") .field("tp", &ratio_of_support_vec(self.clone().true_positive())) .field("fp", &ratio_of_support_vec(self.clone().false_positive())) .field("tn", &ratio_of_support_vec(self.clone().true_negative())) .field("fn", &ratio_of_support_vec(self.clone().false_negative())) .field("support", &to_vec(self.clone().support())) .finish() } } impl ConfusionStats { /// Expects `predictions` to be normalized. pub fn new(input: &ConfusionStatsInput, config: &ClassificationMetricConfig) -> Self { let prediction_mask = match config.decision_rule { DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold), DecisionRule::TopK(top_k) => { let mask = input.predictions.zeros_like(); let indexes = input .predictions .clone() .argsort_descending(1) .narrow(1, 0, top_k.get()); let values = indexes.ones_like().float(); mask.scatter(1, indexes, values, IndexingUpdateOp::Add) .bool() } }; Self { confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2, class_reduction: config.class_reduction, } } /// sum over samples fn aggregate( sample_class_mask: Tensor, class_reduction: ClassReduction, ) -> Tensor { use ClassReduction::{Macro, Micro}; match class_reduction { Micro => sample_class_mask.float().sum(), Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0), } } pub fn true_positive(self) -> Tensor { Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction) } pub fn true_negative(self) -> Tensor { Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction) } pub fn false_positive(self) -> Tensor { Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction) } pub fn false_negative(self) -> Tensor { Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction) } pub fn positive(self) -> Tensor { self.clone().true_positive() + self.false_negative() } pub fn negative(self) -> Tensor { self.clone().true_negative() + self.false_positive() } pub fn predicted_positive(self) -> Tensor { self.clone().true_positive() + self.false_positive() } pub fn support(self) -> Tensor { self.clone().positive() + self.negative() } pub fn ratio_of_support(self, metric: Tensor) -> Tensor { metric / self.clone().support() } } #[cfg(test)] mod tests { use super::{ConfusionStats, ConfusionStatsInput}; use crate::{ TestBackend, metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, tests::{ClassificationType, THRESHOLD, dummy_classification_input}, }; use burn_core::prelude::TensorData; use rstest::{fixture, rstest}; use std::num::NonZeroUsize; fn top_k_config( top_k: NonZeroUsize, class_reduction: ClassReduction, ) -> ClassificationMetricConfig { ClassificationMetricConfig { decision_rule: DecisionRule::TopK(top_k), class_reduction, } } #[fixture] #[once] fn top_k_config_k1_micro() -> ClassificationMetricConfig { top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro) } #[fixture] #[once] fn top_k_config_k1_macro() -> ClassificationMetricConfig { top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro) } #[fixture] #[once] fn top_k_config_k2_micro() -> ClassificationMetricConfig { top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro) } #[fixture] #[once] fn top_k_config_k2_macro() -> ClassificationMetricConfig { top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro) } fn threshold_config( threshold: f64, class_reduction: ClassReduction, ) -> ClassificationMetricConfig { ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), class_reduction, } } #[fixture] #[once] fn threshold_config_micro() -> ClassificationMetricConfig { threshold_config(THRESHOLD, ClassReduction::Micro) } #[fixture] #[once] fn threshold_config_macro() -> ClassificationMetricConfig { threshold_config(THRESHOLD, ClassReduction::Macro) } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())] fn test_true_positive( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .true_positive() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())] fn test_true_negative( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .true_negative() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())] fn test_false_positive( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .false_positive() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())] fn test_false_negatives( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .false_negative() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())] fn test_positive( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .positive() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())] fn test_negative( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .negative() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())] #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())] #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())] #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())] fn test_predicted_positive( #[case] classification_type: ClassificationType, #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { let input: ConfusionStatsInput = dummy_classification_input(&classification_type).into(); ConfusionStats::new(&input, &config) .predicted_positive() .int() .into_data() .assert_eq(&TensorData::from(expected.as_slice()), true); } } ================================================ FILE: crates/burn-train/src/metric/cpu_temp.rs ================================================ use std::sync::Arc; /// CPU Temperature metric use super::MetricMetadata; use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry}; use systemstat::{Platform, System}; /// CPU Temperature in celsius degrees #[derive(Clone)] pub struct CpuTemperature { name: MetricName, temp_celsius: f32, sys: Arc, } impl CpuTemperature { /// Creates a new CPU temp metric pub fn new() -> Self { let name = Arc::new("CPU Temperature".to_string()); Self { name, temp_celsius: 0., sys: Arc::new(System::new()), } } } impl Default for CpuTemperature { fn default() -> Self { CpuTemperature::new() } } impl Metric for CpuTemperature { type Input = (); fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { match self.sys.cpu_temp() { Ok(temp) => self.temp_celsius = temp, Err(_) => self.temp_celsius = f32::NAN, } let formatted = match self.temp_celsius.is_nan() { true => format!("{}: NaN °C", self.name()), false => format!("{}: {:.2} °C", self.name(), self.temp_celsius), }; let raw = format!("{:.2}", self.temp_celsius); SerializedEntry::new(formatted, raw) } fn clear(&mut self) {} fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { super::NumericAttributes { unit: Some("°C".to_string()), higher_is_better: false, } .into() } } impl Numeric for CpuTemperature { fn value(&self) -> NumericEntry { NumericEntry::Value(self.temp_celsius as f64) } fn running_value(&self) -> NumericEntry { NumericEntry::Value(self.temp_celsius as f64) } } ================================================ FILE: crates/burn-train/src/metric/cpu_use.rs ================================================ use super::MetricMetadata; use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry}; use std::{ sync::Arc, time::{Duration, Instant}, }; use sysinfo::{CpuRefreshKind, RefreshKind, System}; /// General CPU Usage metric pub struct CpuUse { name: MetricName, last_refresh: Instant, refresh_frequency: Duration, sys: System, current: f64, } impl Clone for CpuUse { fn clone(&self) -> Self { Self { name: self.name.clone(), last_refresh: self.last_refresh, refresh_frequency: self.refresh_frequency, sys: System::new(), current: self.current, } } } impl CpuUse { /// Creates a new CPU metric pub fn new() -> Self { let mut sys = System::new(); let current = Self::refresh(&mut sys); let name = "CPU Usage".to_string(); Self { name: Arc::new(name), last_refresh: Instant::now(), refresh_frequency: Duration::from_millis(200), sys, current, } } fn refresh(sys: &mut System) -> f64 { sys.refresh_specifics( RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()), ); let cpus = sys.cpus(); let num_cpus = cpus.len(); let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64; use_percentage / num_cpus as f64 } } impl Default for CpuUse { fn default() -> Self { CpuUse::new() } } impl Metric for CpuUse { type Input = (); fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { if self.last_refresh.elapsed() >= self.refresh_frequency { self.current = Self::refresh(&mut self.sys); self.last_refresh = Instant::now(); } let formatted = format!("{}: {:.2} %", self.name(), self.current); let raw = format!("{:.2}", self.current); SerializedEntry::new(formatted, raw) } fn clear(&mut self) {} fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { super::NumericAttributes { unit: Some("%".to_string()), higher_is_better: false, } .into() } } impl Numeric for CpuUse { fn value(&self) -> NumericEntry { NumericEntry::Value(self.current) } fn running_value(&self) -> NumericEntry { NumericEntry::Value(self.current) } } ================================================ FILE: crates/burn-train/src/metric/cuda.rs ================================================ use std::sync::Arc; use super::MetricMetadata; use crate::metric::{Metric, MetricName, SerializedEntry}; use nvml_wrapper::Nvml; /// Track basic cuda infos. #[derive(Clone)] pub struct CudaMetric { name: MetricName, nvml: Arc>, } impl CudaMetric { /// Creates a new metric for CUDA. pub fn new() -> Self { Self { name: Arc::new("Cuda".to_string()), nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| { log::warn!("Unable to initialize CUDA Metric: {err}"); None })), } } } impl Default for CudaMetric { fn default() -> Self { Self::new() } } impl Metric for CudaMetric { type Input = (); fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry { let not_available = || SerializedEntry::new("Unavailable".to_string(), "Unavailable".to_string()); let available = |nvml: &Nvml| { let mut formatted = String::new(); let mut raw_running = String::new(); let device_count = match nvml.device_count() { Ok(val) => val, Err(err) => { log::warn!("Unable to get the number of cuda devices: {err}"); return not_available(); } }; for index in 0..device_count { let device = match nvml.device_by_index(index) { Ok(val) => val, Err(err) => { log::warn!("Unable to get device {index}: {err}"); return not_available(); } }; let memory_info = match device.memory_info() { Ok(info) => info, Err(err) => { log::warn!("Unable to get memory info from device {index}: {err}"); return not_available(); } }; let used_gb = memory_info.used as f64 * 1e-9; let total_gb = memory_info.total as f64 * 1e-9; let memory_info_formatted = format!("{used_gb:.2}/{total_gb:.2} Gb"); let memory_info_raw = format!("{used_gb}/{total_gb}"); formatted = format!("{formatted} GPU #{index} - Memory {memory_info_formatted}"); raw_running = format!("{memory_info_raw} "); let utilization_rates = match device.utilization_rates() { Ok(rate) => rate, Err(err) => { log::warn!("Unable to get utilization rates from device {index}: {err}"); return not_available(); } }; let utilization_rate_formatted = format!("{}%", utilization_rates.gpu); formatted = format!("{formatted} - Usage {utilization_rate_formatted}"); // Power is the currency for perf/W. NVML reports milliwatts. if let Ok(power_mw) = device.power_usage() { let power_w = power_mw as f64 / 1000.0; formatted = format!("{formatted} - Power {power_w:.1} W"); } } SerializedEntry::new(formatted, raw_running) }; match self.nvml.as_ref() { Some(nvml) => available(nvml), None => not_available(), } } fn clear(&mut self) {} fn name(&self) -> MetricName { self.name.clone() } } ================================================ FILE: crates/burn-train/src/metric/fbetascore.rs ================================================ use crate::metric::{MetricName, Numeric}; use super::{ Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry, classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::cast::ToElement, }; use core::marker::PhantomData; use std::{num::NonZeroUsize, sync::Arc}; /// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric. /// /// The `beta` parameter represents the ratio of recall importance to precision importance. /// `beta > 1` gives more weight to recall, while `beta < 1` favors precision. #[derive(Clone)] pub struct FBetaScoreMetric { name: MetricName, state: NumericMetricState, _b: PhantomData, config: ClassificationMetricConfig, beta: f64, } impl Default for FBetaScoreMetric { fn default() -> Self { Self::new(Default::default(), Default::default()) } } impl FBetaScoreMetric { #[allow(dead_code)] fn new(config: ClassificationMetricConfig, beta: f64) -> Self { let name = Arc::new(format!( "FBetaScore ({}) @ {:?} [{:?}]", beta, config.decision_rule, config.class_reduction )); Self { name, config, beta, state: Default::default(), _b: PhantomData, } } /// F-beta score metric for binary classification. /// /// # Arguments /// /// * `beta` - Positive real factor to weight recall's importance. /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] pub fn binary(beta: f64, threshold: f64) -> Self { Self::new( ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), // binary classification results are the same independently of class_reduction ..Default::default() }, beta, ) } /// F-beta score metric for multiclass classification. /// /// # Arguments /// /// * `beta` - Positive real factor to weight recall's importance. /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self { Self::new( ClassificationMetricConfig { decision_rule: DecisionRule::TopK( NonZeroUsize::new(top_k).expect("top_k must be non-zero"), ), class_reduction, }, beta, ) } /// F-beta score metric for multi-label classification. /// /// # Arguments /// /// * `beta` - Positive real factor to weight recall's importance. /// * `threshold` - The threshold to transform a probability into a binary prediction. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self { Self::new( ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), class_reduction, }, beta, ) } fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { use ClassReduction::{Macro, Micro}; let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric .clone() .contains_nan() .any() .into_scalar() .to_bool() { let nan_mask = aggregated_metric.clone().is_nan(); aggregated_metric = aggregated_metric .clone() .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1)) } aggregated_metric.mean() } }; avg_tensor.into_scalar().to_f64() } } impl Metric for FBetaScoreMetric { type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let [sample_size, _] = input.predictions.dims(); let cf_stats = ConfusionStats::new(input, &self.config); let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2)); let metric = self.class_average( scaled_true_positive.clone() / (scaled_true_positive + cf_stats.clone().false_negative() * self.beta.powi(2) + cf_stats.false_positive()), ); self.state.update( 100.0 * metric, sample_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for FBetaScoreMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::{ ClassReduction::{self, *}, FBetaScoreMetric, Metric, MetricMetadata, }; use crate::metric::Numeric; use crate::{ TestBackend, tests::{ClassificationType, THRESHOLD, dummy_classification_input}, }; use burn_core::tensor::TensorData; use burn_core::tensor::Tolerance; use rstest::rstest; #[rstest] #[case::binary_b1(1.0, THRESHOLD, 0.5)] #[case::binary_b2(2.0, THRESHOLD, 0.5)] fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = FBetaScoreMetric::binary(beta, threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)] #[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))] #[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)] #[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)] #[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)] #[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))] #[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)] #[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)] fn test_multiclass_fscore( #[case] beta: f64, #[case] class_reduction: ClassReduction, #[case] top_k: usize, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multiclass).into(); let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))] #[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)] #[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))] #[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)] fn test_multilabel_fscore( #[case] beta: f64, #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[test] fn test_parameterized_unique_name() { let metric_a = FBetaScoreMetric::::multiclass(0.5, 1, ClassReduction::Macro); let metric_b = FBetaScoreMetric::::multiclass(0.5, 2, ClassReduction::Macro); let metric_c = FBetaScoreMetric::::multiclass(0.5, 1, ClassReduction::Macro); assert_ne!(metric_a.name(), metric_b.name()); assert_eq!(metric_a.name(), metric_c.name()); let metric_a = FBetaScoreMetric::::binary(0.5, 0.5); let metric_b = FBetaScoreMetric::::binary(0.75, 0.5); assert_ne!(metric_a.name(), metric_b.name()); } } ================================================ FILE: crates/burn-train/src/metric/hamming.rs ================================================ use core::marker::PhantomData; use std::sync::Arc; use super::state::{FormatOptions, NumericMetricState}; use super::{MetricMetadata, SerializedEntry}; use crate::metric::{ Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry, }; use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend}; /// The hamming score, sometimes referred to as multi-label or label-based accuracy. #[derive(Clone)] pub struct HammingScore { name: MetricName, state: NumericMetricState, threshold: f32, sigmoid: bool, _b: PhantomData, } /// The [hamming score](HammingScore) input type. #[derive(new)] pub struct HammingScoreInput { outputs: Tensor, targets: Tensor, } impl HammingScore { /// Creates the metric. pub fn new() -> Self { Self::default() } fn update_name(&mut self) { self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold)); } /// Sets the threshold. pub fn with_threshold(mut self, threshold: f32) -> Self { self.threshold = threshold; self.update_name(); self } /// Sets the sigmoid activation function usage. pub fn with_sigmoid(mut self, sigmoid: bool) -> Self { self.sigmoid = sigmoid; self.update_name(); self } } impl Default for HammingScore { /// Creates a new metric instance with default values. fn default() -> Self { let threshold = 0.5; let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold)); Self { name, state: NumericMetricState::default(), threshold, sigmoid: false, _b: PhantomData, } } } impl Metric for HammingScore { type Input = HammingScoreInput; fn update( &mut self, input: &HammingScoreInput, _metadata: &MetricMetadata, ) -> SerializedEntry { let [batch_size, _n_classes] = input.outputs.dims(); let targets = input.targets.clone(); let mut outputs = input.outputs.clone(); if self.sigmoid { outputs = sigmoid(outputs); } let score = outputs .greater_elem(self.threshold) .equal(targets.bool()) .float() .mean() .into_scalar() .elem::(); self.state.update( 100.0 * score, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for HammingScore { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn test_hamming_score() { let device = Default::default(); let mut metric = HammingScore::::new(); let x = Tensor::from_data( [ [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1] [0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1] [0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1] [0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0] ], &device, ); let y = Tensor::from_data( [ [0, 1, 0, 1, 1], [0, 0, 0, 1, 1], [0, 0, 1, 0, 1], [0, 0, 1, 0, 0], ], &device, ); let _entry = metric.update( &HammingScoreInput::new(x.clone(), y.clone()), &MetricMetadata::fake(), ); assert_eq!(100.0, metric.value().current()); // Invert all targets: y = (1 - y) let y = y.neg().add_scalar(1); let _entry = metric.update( &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y) &MetricMetadata::fake(), ); assert_eq!(0.0, metric.value().current()); // Invert 5 target values -> 1 - (5/20) = 0.75 let y = Tensor::from_data( [ [0, 1, 1, 0, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 1], [0, 1, 1, 0, 0], ], &device, ); let _entry = metric.update( &HammingScoreInput::new(x, y), // invert targets (1 - y) &MetricMetadata::fake(), ); assert_eq!(75.0, metric.value().current()); } #[test] fn test_parameterized_unique_name() { let metric_a = HammingScore::::new().with_threshold(0.5); let metric_b = HammingScore::::new().with_threshold(0.75); let metric_c = HammingScore::::new().with_threshold(0.5); assert_ne!(metric_a.name(), metric_b.name()); assert_eq!(metric_a.name(), metric_c.name()); } } ================================================ FILE: crates/burn-train/src/metric/iteration.rs ================================================ use std::sync::Arc; use super::MetricMetadata; use super::SerializedEntry; use super::state::FormatOptions; use super::state::NumericMetricState; use crate::metric::MetricName; use crate::metric::Numeric; use crate::metric::{Metric, MetricAttributes, NumericAttributes, NumericEntry}; /// The loss metric. #[derive(Clone)] pub struct IterationSpeedMetric { name: MetricName, state: NumericMetricState, instant: Option, } impl Default for IterationSpeedMetric { fn default() -> Self { Self::new() } } impl IterationSpeedMetric { /// Create the metric. pub fn new() -> Self { Self { name: Arc::new("Iteration Speed".to_string()), state: Default::default(), instant: Default::default(), } } } impl Metric for IterationSpeedMetric { type Input = (); fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry { let raw = match self.instant { Some(val) => { // If iteration is not logged, compute the speed over the number of items processed. // 1 iteration should equal 1 item when iteration is not logged. metadata .iteration .unwrap_or(metadata.progress.items_processed) as f64 / val.elapsed().as_secs_f64() } None => { self.instant = Some(std::time::Instant::now()); 0.0 } }; self.state.update( raw, 1, FormatOptions::new(self.name()) .unit("iter/sec") .precision(2), ) } fn clear(&mut self) { self.instant = None; } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("iter/sec".to_string()), higher_is_better: true, } .into() } } impl Numeric for IterationSpeedMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/learning_rate.rs ================================================ use std::sync::Arc; use super::{ MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, state::{FormatOptions, NumericMetricState}, }; use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; /// Track the learning rate across iterations. #[derive(Clone)] pub struct LearningRateMetric { name: MetricName, state: NumericMetricState, } impl LearningRateMetric { /// Creates a new learning rate metric. pub fn new() -> Self { Self { name: Arc::new("Learning Rate".to_string()), state: NumericMetricState::new(), } } } impl Default for LearningRateMetric { fn default() -> Self { Self::new() } } impl Metric for LearningRateMetric { type Input = (); fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> SerializedEntry { let lr = metadata.lr.unwrap_or(0.0); self.state .update(lr, 1, FormatOptions::new(self.name()).precision(2)) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: false, } .into() } } impl Numeric for LearningRateMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/loss.rs ================================================ use std::sync::Arc; use super::MetricMetadata; use super::SerializedEntry; use super::state::FormatOptions; use super::state::NumericMetricState; use crate::metric::MetricName; use crate::metric::{Metric, MetricAttributes, Numeric, NumericAttributes, NumericEntry}; use burn_core::tensor::Tensor; use burn_core::tensor::backend::Backend; /// The loss metric. #[derive(Clone)] pub struct LossMetric { name: Arc, state: NumericMetricState, _b: B, } /// The [loss metric](LossMetric) input type. #[derive(new)] pub struct LossInput { tensor: Tensor, } impl Default for LossMetric { fn default() -> Self { Self::new() } } impl LossMetric { /// Create the metric. pub fn new() -> Self { Self { name: Arc::new("Loss".to_string()), state: NumericMetricState::default(), _b: Default::default(), } } } impl Metric for LossMetric { type Input = LossInput; fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let [batch_size] = loss.tensor.dims(); let loss = loss .tensor .clone() .mean() .into_data() .iter::() .next() .unwrap(); self.state.update( loss, batch_size, FormatOptions::new(self.name()).precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: false, } .into() } } impl Numeric for LossMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/memory_use.rs ================================================ /// RAM use metric use super::{MetricAttributes, MetricMetadata, NumericAttributes}; use crate::metric::{Metric, Numeric, NumericEntry, SerializedEntry}; use std::{ sync::Arc, time::{Duration, Instant}, }; use sysinfo::System; /// Memory information pub struct CpuMemory { name: Arc, last_refresh: Instant, refresh_frequency: Duration, sys: System, ram_bytes_total: u64, ram_bytes_used: u64, } impl Clone for CpuMemory { fn clone(&self) -> Self { Self { name: self.name.clone(), last_refresh: self.last_refresh, refresh_frequency: self.refresh_frequency, sys: System::new(), ram_bytes_total: self.ram_bytes_total, ram_bytes_used: self.ram_bytes_used, } } } impl CpuMemory { /// Creates a new memory metric pub fn new() -> Self { let mut metric = Self { name: Arc::new("CPU Memory".into()), last_refresh: Instant::now(), refresh_frequency: Duration::from_millis(200), sys: System::new(), ram_bytes_total: 0, ram_bytes_used: 0, }; metric.refresh(); metric } fn refresh(&mut self) { self.sys.refresh_memory(); self.last_refresh = Instant::now(); // bytes of RAM available self.ram_bytes_total = self.sys.total_memory(); // bytes of RAM in use self.ram_bytes_used = self.sys.used_memory(); } } impl Default for CpuMemory { fn default() -> Self { CpuMemory::new() } } impl Metric for CpuMemory { type Input = (); fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { if self.last_refresh.elapsed() >= self.refresh_frequency { self.refresh(); } let raw = bytes2gb(self.ram_bytes_used); let formatted = format!( "RAM Used: {:.2} / {:.2} Gb", raw, bytes2gb(self.ram_bytes_total), ); SerializedEntry::new(formatted, raw.to_string()) } fn clear(&mut self) {} fn name(&self) -> Arc { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("Gb".to_string()), higher_is_better: false, } .into() } } impl Numeric for CpuMemory { fn value(&self) -> NumericEntry { NumericEntry::Value(bytes2gb(self.ram_bytes_used)) } fn running_value(&self) -> NumericEntry { NumericEntry::Value(bytes2gb(self.ram_bytes_used)) } } fn bytes2gb(bytes: u64) -> f64 { bytes as f64 / 1e9 } ================================================ FILE: crates/burn-train/src/metric/mod.rs ================================================ /// State module. pub mod state; /// Module responsible to save and exposes data collected during training. pub mod store; /// Metrics module for vision tasks. #[cfg(feature = "vision")] pub mod vision; //Metrics for reinforcement learning. #[cfg(feature = "rl")] mod rl; #[cfg(feature = "rl")] pub use rl::*; // System metrics #[cfg(feature = "sys-metrics")] mod cpu_temp; #[cfg(feature = "sys-metrics")] mod cpu_use; #[cfg(feature = "sys-metrics")] mod cuda; #[cfg(feature = "sys-metrics")] mod memory_use; #[cfg(feature = "sys-metrics")] pub use cpu_temp::*; #[cfg(feature = "sys-metrics")] pub use cpu_use::*; #[cfg(feature = "sys-metrics")] pub use cuda::*; #[cfg(feature = "sys-metrics")] pub use memory_use::*; // Training metrics mod acc; mod auroc; mod base; mod cer; mod confusion_stats; mod fbetascore; mod hamming; mod iteration; mod learning_rate; mod loss; mod perplexity; mod precision; mod recall; mod top_k_acc; mod wer; pub use acc::*; pub use auroc::*; pub use base::*; pub use cer::*; pub use confusion_stats::ConfusionStatsInput; pub use fbetascore::*; pub use hamming::*; pub use iteration::*; pub use learning_rate::*; pub use loss::*; pub use perplexity::*; pub use precision::*; pub use recall::*; pub use top_k_acc::*; pub use wer::*; pub(crate) mod classification; pub(crate) mod processor; pub use crate::metric::classification::ClassReduction; // Expose `ItemLazy` so it can be implemented for custom types pub use processor::ItemLazy; ================================================ FILE: crates/burn-train/src/metric/perplexity.rs ================================================ use core::marker::PhantomData; use super::state::FormatOptions; use super::{MetricMetadata, NumericEntry, SerializedEntry, format_float}; use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{ElementConversion, Int, Tensor}; /// Custom state for perplexity metric that correctly accumulates negative log-likelihood. /// /// Unlike other metrics that can be averaged, perplexity requires special handling: /// - Accumulate total negative log-likelihood across all tokens /// - Accumulate total number of effective tokens /// - Compute perplexity as exp(total_nll / total_tokens) only at the end #[derive(Clone)] struct PerplexityState { /// Sum of negative log-likelihood across all tokens sum_nll: f64, /// Total number of effective tokens (excluding padding) total_tokens: usize, /// Current batch perplexity (for display purposes) current: f64, } impl PerplexityState { fn new() -> Self { Self { sum_nll: 0.0, total_tokens: 0, current: f64::NAN, } } fn reset(&mut self) { self.sum_nll = 0.0; self.total_tokens = 0; self.current = f64::NAN; } /// Update state with negative log-likelihood and token count from current batch fn update( &mut self, sum_log_prob: f64, effective_tokens: usize, format: FormatOptions, ) -> SerializedEntry { // sum_log_prob is already the sum of log probabilities (negative values) // We need to negate it to get negative log-likelihood let batch_nll = -sum_log_prob; // Accumulate across batches self.sum_nll += batch_nll; self.total_tokens += effective_tokens; // Compute current batch perplexity for display let batch_perplexity = if effective_tokens > 0 { (batch_nll / effective_tokens as f64).exp() } else { f64::INFINITY }; self.current = batch_perplexity; // Compute running epoch perplexity let epoch_perplexity = if self.total_tokens > 0 { (self.sum_nll / self.total_tokens as f64).exp() } else { f64::INFINITY }; // Format for display let (formatted_current, formatted_running) = match format.precision_value() { Some(precision) => ( format_float(batch_perplexity, precision), format_float(epoch_perplexity, precision), ), None => (format!("{batch_perplexity}"), format!("{epoch_perplexity}")), }; let formatted = match format.unit_value() { Some(unit) => { format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") } None => format!("epoch {formatted_running} - batch {formatted_current}"), }; // Serialize the state for aggregation let serialized = NumericEntry::Aggregated { aggregated_value: epoch_perplexity, count: self.total_tokens, } .serialize(); SerializedEntry::new(formatted, serialized) } fn value(&self) -> NumericEntry { let perplexity = if self.total_tokens > 0 { (self.sum_nll / self.total_tokens as f64).exp() } else { f64::INFINITY }; NumericEntry::Aggregated { aggregated_value: perplexity, count: self.total_tokens, } } fn running_value(&self) -> NumericEntry { self.value() } } /// The perplexity metric. /// /// Perplexity is a measure of how well a probability distribution or probability model /// predicts a sample. It's commonly used to evaluate language models. A lower perplexity /// indicates that the model is more confident in its predictions. /// /// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss: /// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i))) /// /// where: /// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q /// - N is the number of tokens /// - p(x_i) is the predicted probability of the i-th token /// /// # Aggregation /// Unlike other metrics, perplexity cannot be simply averaged across batches. /// This implementation correctly accumulates the total negative log-likelihood and /// total token count across batches, then computes perplexity as exp(total_nll / total_tokens). #[derive(Clone)] pub struct PerplexityMetric { name: MetricName, state: PerplexityState, pad_token: Option, _b: PhantomData, } /// The [perplexity metric](PerplexityMetric) input type. #[derive(new)] pub struct PerplexityInput { /// Logits tensor of shape [batch_size * sequence_length, vocab_size] outputs: Tensor, /// Target tokens tensor of shape [batch_size * sequence_length] targets: Tensor, } impl Default for PerplexityMetric { fn default() -> Self { Self::new() } } impl PerplexityMetric { /// Creates the metric. pub fn new() -> Self { Self { name: MetricName::new("Perplexity".to_string()), state: PerplexityState::new(), pad_token: Default::default(), _b: PhantomData, } } /// Sets the pad token to exclude from perplexity calculation. /// /// When a pad token is set, predictions for padding tokens are masked out /// and do not contribute to the perplexity calculation. This is important /// for variable-length sequences where padding is used. pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self } } impl Metric for PerplexityMetric { type Input = PerplexityInput; fn update( &mut self, input: &PerplexityInput, _metadata: &MetricMetadata, ) -> SerializedEntry { let targets = input.targets.clone(); let outputs = input.outputs.clone(); let [total_tokens, _vocab_size] = outputs.dims(); // Convert logits to log probabilities using log_softmax for numerical stability let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1); // Gather the log probabilities for the target tokens let target_log_probs = log_probs .gather(1, targets.clone().unsqueeze_dim(1)) .squeeze_dim(1); let (sum_log_prob, effective_tokens) = match self.pad_token { Some(pad_token) => { // Create a mask for non-padding tokens let mask = targets.clone().not_equal_elem(pad_token as i64); // Apply mask to log probabilities (set padding log probs to 0) let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0); // Sum the log probabilities and count effective tokens let sum_log_prob = masked_log_probs.sum().into_scalar().elem::(); let effective_tokens = mask.int().sum().into_scalar().elem::() as usize; (sum_log_prob, effective_tokens) } None => { // No padding, use all tokens let sum_log_prob = target_log_probs.sum().into_scalar().elem::(); (sum_log_prob, total_tokens) } }; // Pass the sum_log_prob and effective_tokens to the state // The state will handle the correct accumulation and perplexity calculation self.state.update( sum_log_prob, effective_tokens, FormatOptions::new(self.name()).precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: false, } .into() } } impl Numeric for PerplexityMetric { fn value(&self) -> NumericEntry { self.state.value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn test_perplexity_perfect_prediction() { let device = Default::default(); let mut metric = PerplexityMetric::::new(); // Perfect prediction: target is always the highest probability class let input = PerplexityInput::new( Tensor::from_data( [ [10.0, 0.0, 0.0], // Very confident prediction for class 0 [0.0, 10.0, 0.0], // Very confident prediction for class 1 [0.0, 0.0, 10.0], // Very confident prediction for class 2 ], &device, ), Tensor::from_data([0, 1, 2], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); let perplexity = metric.value().current(); // Perfect predictions should result in very low perplexity (close to 1.0) assert!( perplexity < 1.1, "Perfect predictions should have low perplexity, got {}", perplexity ); } #[test] fn test_perplexity_uniform_prediction() { let device = Default::default(); let mut metric = PerplexityMetric::::new(); // Uniform prediction: all classes have equal probability let input = PerplexityInput::new( Tensor::from_data( [ [0.0, 0.0, 0.0], // Uniform distribution (after softmax) [0.0, 0.0, 0.0], // Uniform distribution (after softmax) [0.0, 0.0, 0.0], // Uniform distribution (after softmax) ], &device, ), Tensor::from_data([0, 1, 2], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); let perplexity = metric.value().current(); // Uniform distribution over 3 classes should have perplexity ≈ 3.0 assert!( (perplexity - 3.0).abs() < 0.1, "Uniform distribution perplexity should be ~3.0, got {}", perplexity ); } #[test] fn test_perplexity_with_padding() { let device = Default::default(); let mut metric = PerplexityMetric::::new().with_pad_token(3); let input = PerplexityInput::new( Tensor::from_data( [ [10.0, 0.0, 0.0, 0.0], // Good prediction for class 0 [0.0, 10.0, 0.0, 0.0], // Good prediction for class 1 [0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored [0.0, 0.0, 0.0, 1.0], // This is padding - should be ignored ], &device, ), Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token ); let _entry = metric.update(&input, &MetricMetadata::fake()); let perplexity = metric.value().current(); // Should only consider the first two predictions, both of which are confident assert!( perplexity < 1.1, "Good predictions with padding should have low perplexity, got {}", perplexity ); } #[test] fn test_perplexity_wrong_prediction() { let device = Default::default(); let mut metric = PerplexityMetric::::new(); // Wrong predictions: target class has very low probability let input = PerplexityInput::new( Tensor::from_data( [ [0.0, 10.0, 0.0], // Predicts class 1, but target is 0 [10.0, 0.0, 0.0], // Predicts class 0, but target is 1 [0.0, 0.0, 10.0], // Predicts class 2, but target is 0 ], &device, ), Tensor::from_data([0, 1, 0], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); let perplexity = metric.value().current(); // Wrong predictions should result in high perplexity assert!( perplexity > 10.0, "Wrong predictions should have high perplexity, got {}", perplexity ); } #[test] fn test_perplexity_multi_batch_aggregation() { let device = Default::default(); let mut metric = PerplexityMetric::::new(); // First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each) let input1 = PerplexityInput::new( Tensor::from_data( [ [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986) [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986) ], &device, ), Tensor::from_data([0, 1], &device), ); // Second batch: 1 token with uniform distribution let input2 = PerplexityInput::new( Tensor::from_data( [ [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986) ], &device, ), Tensor::from_data([2], &device), ); // Update with both batches let _entry1 = metric.update(&input1, &MetricMetadata::fake()); let _entry2 = metric.update(&input2, &MetricMetadata::fake()); let aggregated_perplexity = metric.value().current(); // For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986 // Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958 // Total tokens: 3 // Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0 assert!( (aggregated_perplexity - 3.0).abs() < 0.1, "Multi-batch aggregated perplexity should be ~3.0, got {}", aggregated_perplexity ); // Compare with single batch containing all data let mut single_batch_metric = PerplexityMetric::::new(); let single_input = PerplexityInput::new( Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device), Tensor::from_data([0, 1, 2], &device), ); let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake()); let single_batch_perplexity = single_batch_metric.value().current(); // Multi-batch and single-batch should give the same result assert!( (aggregated_perplexity - single_batch_perplexity).abs() < 0.01, "Multi-batch ({}) and single-batch ({}) perplexity should match", aggregated_perplexity, single_batch_perplexity ); } } ================================================ FILE: crates/burn-train/src/metric/precision.rs ================================================ use crate::metric::{MetricName, Numeric}; use super::{ Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry, classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::cast::ToElement, }; use core::marker::PhantomData; use std::{num::NonZeroUsize, sync::Arc}; /// The Precision Metric #[derive(Clone)] pub struct PrecisionMetric { name: MetricName, state: NumericMetricState, _b: PhantomData, config: ClassificationMetricConfig, } impl Default for PrecisionMetric { fn default() -> Self { Self::new(Default::default()) } } impl PrecisionMetric { fn new(config: ClassificationMetricConfig) -> Self { let state = Default::default(); let name = Arc::new(format!( "Precision @ {:?} [{:?}]", config.decision_rule, config.class_reduction )); Self { state, config, name, _b: Default::default(), } } /// Precision metric for binary classification. /// /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self::new(ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), // binary classification results are the same independently of class_reduction ..Default::default() }) } /// Precision metric for multiclass classification. /// /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self::new(ClassificationMetricConfig { decision_rule: DecisionRule::TopK( NonZeroUsize::new(top_k).expect("top_k must be non-zero"), ), class_reduction, }) } /// Precision metric for multi-label classification. /// /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary value. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), class_reduction, }, ..Default::default() } } fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { use ClassReduction::{Macro, Micro}; let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric .clone() .contains_nan() .any() .into_scalar() .to_bool() { let nan_mask = aggregated_metric.clone().is_nan(); aggregated_metric = aggregated_metric .clone() .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1)) } aggregated_metric.mean() } }; avg_tensor.into_scalar().to_f64() } } impl Metric for PrecisionMetric { type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let [sample_size, _] = input.predictions.dims(); let cf_stats = ConfusionStats::new(input, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive()); self.state.update( 100.0 * metric, sample_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for PrecisionMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::{ ClassReduction::{self, *}, Metric, MetricMetadata, PrecisionMetric, }; use crate::metric::Numeric; use crate::{ TestBackend, tests::{ClassificationType, THRESHOLD, dummy_classification_input}, }; use burn_core::tensor::TensorData; use burn_core::tensor::Tolerance; use rstest::rstest; #[rstest] #[case::binary(THRESHOLD, 0.5)] fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = PrecisionMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)] #[case::multiclass_micro_k2(Micro, 2, 4.0/10.0)] #[case::multiclass_macro_k1(Macro, 1, (0.5 + 0.5 + 1.0)/3.0)] #[case::multiclass_macro_k2(Macro, 2, (0.5 + 1.0/4.0 + 0.5)/3.0)] fn test_multiclass_precision( #[case] class_reduction: ClassReduction, #[case] top_k: usize, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multiclass).into(); let mut metric = PrecisionMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)] #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)] fn test_multilabel_precision( #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); let mut metric = PrecisionMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[test] fn test_parameterized_unique_name() { let metric_a = PrecisionMetric::::multiclass(1, ClassReduction::Macro); let metric_b = PrecisionMetric::::multiclass(2, ClassReduction::Macro); let metric_c = PrecisionMetric::::multiclass(1, ClassReduction::Macro); assert_ne!(metric_a.name(), metric_b.name()); assert_eq!(metric_a.name(), metric_c.name()); let metric_a = PrecisionMetric::::binary(0.5); let metric_b = PrecisionMetric::::binary(0.75); assert_ne!(metric_a.name(), metric_b.name()); } } ================================================ FILE: crates/burn-train/src/metric/processor/async_wrapper.rs ================================================ use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation}; use super::EventProcessorTraining; use async_channel::{Receiver, Sender}; /// Event processor for the training process. pub struct AsyncProcessorTraining { sender: Sender>, } /// Event processor for the model evaluation. pub struct AsyncProcessorEvaluation { sender: Sender>, } struct WorkerTraining> { processor: P, rec: Receiver>, } struct WorkerEvaluation { processor: P, rec: Receiver>, } impl + 'static> WorkerTraining { pub fn start(processor: P, rec: Receiver>) { let mut worker = Self { processor, rec }; std::thread::Builder::new() .name("train-worker".into()) .spawn(move || { while let Ok(msg) = worker.rec.recv_blocking() { match msg { Message::Train(event) => worker.processor.process_train(event), Message::Valid(event) => worker.processor.process_valid(event), Message::Renderer(callback) => { callback.send_blocking(worker.processor.renderer()).unwrap(); return; } } } }) .unwrap(); } } impl WorkerEvaluation

{ pub fn start(processor: P, rec: Receiver>) { let mut worker = Self { processor, rec }; std::thread::Builder::new() .name("evel-worker".into()) .spawn(move || { while let Ok(event) = worker.rec.recv_blocking() { match event { EvalMessage::Test(event) => worker.processor.process_test(event), EvalMessage::Renderer(sender) => { sender.send_blocking(worker.processor.renderer()).unwrap(); return; } } } }) .unwrap(); } } impl AsyncProcessorTraining { /// Create an event processor for training. pub fn new + 'static>(processor: P) -> Self { let (sender, rec) = async_channel::bounded(1); WorkerTraining::start(processor, rec); Self { sender } } } impl AsyncProcessorEvaluation

{ /// Create an event processor for model evaluation. pub fn new(processor: P) -> Self { let (sender, rec) = async_channel::bounded(1); WorkerEvaluation::start(processor, rec); Self { sender } } } enum Message { Train(EventTrain), Valid(EventValid), Renderer(Sender>), } enum EvalMessage { Test(EvaluatorEvent), Renderer(Sender>), } impl EventProcessorTraining for AsyncProcessorTraining { fn process_train(&mut self, event: ET) { self.sender.send_blocking(Message::Train(event)).unwrap(); } fn process_valid(&mut self, event: EV) { self.sender.send_blocking(Message::Valid(event)).unwrap(); } fn renderer(self) -> Box { let (sender, rec) = async_channel::bounded(1); self.sender .send_blocking(Message::Renderer(sender)) .unwrap(); match rec.recv_blocking() { Ok(value) => value, Err(err) => panic!("{err:?}"), } } } impl EventProcessorEvaluation for AsyncProcessorEvaluation

{ type ItemTest = P::ItemTest; fn process_test(&mut self, event: EvaluatorEvent) { self.sender.send_blocking(EvalMessage::Test(event)).unwrap(); } fn renderer(self) -> Box { let (sender, rec) = async_channel::bounded(1); self.sender .send_blocking(EvalMessage::Renderer(sender)) .unwrap(); match rec.recv_blocking() { Ok(value) => value, Err(err) => panic!("{err:?}"), } } } ================================================ FILE: crates/burn-train/src/metric/processor/base.rs ================================================ use burn_core::data::dataloader::Progress; use burn_optim::LearningRate; use crate::{ LearnerSummary, renderer::{EvaluationName, MetricsRenderer}, }; /// Event happening during the training/validation process. pub enum LearnerEvent { /// Signal the start of the process (e.g., training start) Start, /// Signal that an item have been processed. ProcessedItem(TrainingItem), /// Signal the end of an epoch. EndEpoch(usize), /// Signal the end of the process (e.g., training end). End(Option), } /// Event happening during the evaluation process. pub enum EvaluatorEvent { /// Signal the start of the process (e.g., evaluation start) Start, /// Signal that an item have been processed. ProcessedItem(EvaluationName, EvaluationItem), /// Signal the end of the process (e.g., evaluation end). End(Option), } /// Items that are lazy are not ready to be processed by metrics. /// /// We want to sync them on a different thread to avoid blocking training. pub trait ItemLazy: Send { /// Item that is properly synced and ready to be processed by metrics. type ItemSync: Send; /// Sync the item. fn sync(self) -> Self::ItemSync; } /// Process events happening during training and validation. pub trait EventProcessorTraining: Send { /// Collect a training event. fn process_train(&mut self, event: TrainEvent); /// Collect a validation event. fn process_valid(&mut self, event: ValidEvent); /// Returns the renderer used for training. fn renderer(self) -> Box; } /// Process events happening during evaluation. pub trait EventProcessorEvaluation: Send { /// The test item. type ItemTest: ItemLazy; /// Collect a test event. fn process_test(&mut self, event: EvaluatorEvent); /// Returns the renderer used for evaluation. fn renderer(self) -> Box; } /// A learner item. #[derive(new)] pub struct TrainingItem { /// The item. pub item: T, /// The progress. pub progress: Progress, /// The global progress of the training (e.g. epochs). pub global_progress: Progress, /// The iteration, if it it different from the items processed. pub iteration: Option, /// The learning rate. pub lr: Option, } impl ItemLazy for TrainingItem { type ItemSync = TrainingItem; fn sync(self) -> Self::ItemSync { TrainingItem { item: self.item.sync(), progress: self.progress, global_progress: self.global_progress, iteration: self.iteration, lr: self.lr, } } } /// An evaluation item. #[derive(new)] pub struct EvaluationItem { /// The item. pub item: T, /// The progress. pub progress: Progress, /// The iteration, if it it different from the items processed. pub iteration: Option, } impl ItemLazy for EvaluationItem { type ItemSync = EvaluationItem; fn sync(self) -> Self::ItemSync { EvaluationItem { item: self.item.sync(), progress: self.progress, iteration: self.iteration, } } } impl ItemLazy for () { type ItemSync = (); fn sync(self) -> Self::ItemSync {} } ================================================ FILE: crates/burn-train/src/metric/processor/full.rs ================================================ use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining}; use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation}; use crate::metric::store::{EpochSummary, EventStoreClient, Split}; use crate::renderer::{ EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress, }; use std::sync::Arc; /// An [event processor](EventProcessorTraining) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). /// - Render metrics using a [metrics renderer](MetricsRenderer). pub struct FullEventProcessorTraining { metrics: MetricsTraining, renderer: Box, store: Arc, } /// An [event processor](EventProcessorEvaluation) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). /// - Render metrics using a [metrics renderer](MetricsRenderer). pub struct FullEventProcessorEvaluation { metrics: MetricsEvaluation, renderer: Box, store: Arc, } impl FullEventProcessorTraining { pub(crate) fn new( metrics: MetricsTraining, renderer: Box, store: Arc, ) -> Self { Self { metrics, renderer, store, } } fn progress_indicators(&self, progress: &TrainingProgress) -> Vec { let mut indicators = vec![]; indicators.push(ProgressType::Detailed { tag: String::from("Epoch"), progress: progress.global_progress.clone(), }); if let Some(iteration) = progress.iteration { indicators.push(ProgressType::Value { tag: String::from("Iteration"), value: iteration, }); }; if let Some(p) = &progress.progress { indicators.push(ProgressType::Detailed { tag: String::from("Items"), progress: p.clone(), }); }; indicators } } impl FullEventProcessorEvaluation { pub(crate) fn new( metrics: MetricsEvaluation, renderer: Box, store: Arc, ) -> Self { Self { metrics, renderer, store, } } fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec { let mut indicators = vec![]; if let Some(iteration) = progress.iteration { indicators.push(ProgressType::Value { tag: String::from("Iteration"), value: iteration, }); }; indicators.push(ProgressType::Detailed { tag: String::from("Items"), progress: progress.progress.clone(), }); indicators } } impl EventProcessorEvaluation for FullEventProcessorEvaluation { type ItemTest = T; fn process_test(&mut self, event: EvaluatorEvent) { match event { EvaluatorEvent::Start => { let definitions = self.metrics.metric_definitions(); self.store .add_event_train(crate::metric::store::Event::MetricsInit( definitions.clone(), )); definitions .iter() .for_each(|definition| self.renderer.register_metric(definition.clone())); } EvaluatorEvent::ProcessedItem(name, item) => { let item = item.sync(); let progress = (&item).into(); let metadata = (&item).into(); let update = self.metrics.update_test(&item, &metadata); self.store.add_event_test( crate::metric::store::Event::MetricsUpdate(update.clone()), name.name.clone(), ); update.entries.into_iter().for_each(|entry| { self.renderer .update_test(name.clone(), MetricState::Generic(entry)) }); update .entries_numeric .into_iter() .for_each(|numeric_update| { self.renderer.update_test( name.clone(), MetricState::Numeric( numeric_update.entry, numeric_update.numeric_entry, ), ) }); let indicators = self.progress_indicators(&progress); self.renderer.render_test(progress, indicators); } EvaluatorEvent::End(summary) => { self.renderer.on_test_end(summary).ok(); } } } fn renderer(self) -> Box { self.renderer } } impl EventProcessorTraining, LearnerEvent> for FullEventProcessorTraining { fn process_train(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => { let definitions = self.metrics.metric_definitions(); self.store .add_event_train(crate::metric::store::Event::MetricsInit( definitions.clone(), )); definitions .iter() .for_each(|definition| self.renderer.register_metric(definition.clone())); } LearnerEvent::ProcessedItem(item) => { let item = item.sync(); let progress = (&item).into(); let metadata = (&item).into(); let update = self.metrics.update_train(&item, &metadata); self.store .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); update .entries .into_iter() .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); update .entries_numeric .into_iter() .for_each(|numeric_update| { self.renderer.update_train(MetricState::Numeric( numeric_update.entry, numeric_update.numeric_entry, )) }); let indicators = self.progress_indicators(&progress); self.renderer.render_train(progress, indicators); } LearnerEvent::EndEpoch(epoch) => { self.store .add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new( epoch, Split::Train, ))); self.metrics.end_epoch_train(); } LearnerEvent::End(summary) => { self.renderer.on_train_end(summary).ok(); } } } fn process_valid(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => {} // no-op for now LearnerEvent::ProcessedItem(item) => { let item = item.sync(); let progress = (&item).into(); let metadata = (&item).into(); let update = self.metrics.update_valid(&item, &metadata); self.store .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); update .entries .into_iter() .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); update .entries_numeric .into_iter() .for_each(|numeric_update| { self.renderer.update_valid(MetricState::Numeric( numeric_update.entry, numeric_update.numeric_entry, )) }); let indicators = self.progress_indicators(&progress); self.renderer.render_valid(progress, indicators); } LearnerEvent::EndEpoch(epoch) => { self.store .add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new( epoch, Split::Valid, ))); self.metrics.end_epoch_valid(); } LearnerEvent::End(_) => {} // no-op for now } } fn renderer(self) -> Box { self.renderer } } ================================================ FILE: crates/burn-train/src/metric/processor/metrics.rs ================================================ use std::collections::HashMap; use super::{ItemLazy, TrainingItem}; use crate::{ EvaluationItem, metric::{ Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric, store::{MetricsUpdate, NumericMetricUpdate}, }, renderer::{EvaluationProgress, TrainingProgress}, }; pub(crate) struct MetricsTraining { train: Vec>>, valid: Vec>>, train_numeric: Vec>>, valid_numeric: Vec>>, metric_definitions: HashMap, } pub(crate) struct MetricsEvaluation { test: Vec>>, test_numeric: Vec>>, metric_definitions: HashMap, } impl Default for MetricsEvaluation { fn default() -> Self { Self { test: Default::default(), test_numeric: Default::default(), metric_definitions: HashMap::default(), } } } impl Default for MetricsTraining { fn default() -> Self { Self { train: Vec::default(), valid: Vec::default(), train_numeric: Vec::default(), valid_numeric: Vec::default(), metric_definitions: HashMap::default(), } } } impl MetricsEvaluation { /// Register a testing metric. pub(crate) fn register_test_metric(&mut self, metric: Me) where T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.test.push(Box::new(metric)) } /// Register a numeric testing metric. pub(crate) fn register_test_metric_numeric( &mut self, metric: Me, ) where T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.test_numeric.push(Box::new(metric)) } fn register_definition(&mut self, metric: &MetricWrapper) { self.metric_definitions.insert( metric.id.clone(), MetricDefinition::new(metric.id.clone(), &metric.metric), ); } /// Get metric definitions. pub(crate) fn metric_definitions(&mut self) -> Vec { self.metric_definitions.values().cloned().collect() } /// Update the testing information from the testing item. pub(crate) fn update_test( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.test.len()); let mut entries_numeric = Vec::with_capacity(self.test_numeric.len()); for metric in self.test.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.test_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } } impl MetricsTraining { /// Register a training metric. pub(crate) fn register_train_metric(&mut self, metric: Me) where T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.train.push(Box::new(metric)) } /// Register a validation metric. pub(crate) fn register_valid_metric(&mut self, metric: Me) where V::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.valid.push(Box::new(metric)) } /// Register a numeric training metric. pub(crate) fn register_train_metric_numeric( &mut self, metric: Me, ) where T::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.train_numeric.push(Box::new(metric)) } /// Register a numeric validation metric. pub(crate) fn register_valid_metric_numeric(&mut self, metric: Me) where V::ItemSync: Adaptor + 'static, Me: Metric + Numeric + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.valid_numeric.push(Box::new(metric)) } fn register_definition(&mut self, metric: &MetricWrapper) { self.metric_definitions.insert( metric.id.clone(), MetricDefinition::new(metric.id.clone(), &metric.metric), ); } /// Get metric definitions for all splits pub(crate) fn metric_definitions(&mut self) -> Vec { self.metric_definitions.values().cloned().collect() } /// Update the training information from the training item. pub(crate) fn update_train( &mut self, item: &TrainingItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.train.len()); let mut entries_numeric = Vec::with_capacity(self.train_numeric.len()); for metric in self.train.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.train_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Update the training information from the validation item. pub(crate) fn update_valid( &mut self, item: &TrainingItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.valid.len()); let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len()); for metric in self.valid.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.valid_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Signal the end of a training epoch. pub(crate) fn end_epoch_train(&mut self) { for metric in self.train.iter_mut() { metric.clear(); } for metric in self.train_numeric.iter_mut() { metric.clear(); } } /// Signal the end of a validation epoch. pub(crate) fn end_epoch_valid(&mut self) { for metric in self.valid.iter_mut() { metric.clear(); } for metric in self.valid_numeric.iter_mut() { metric.clear(); } } } impl From<&TrainingItem> for TrainingProgress { fn from(item: &TrainingItem) -> Self { Self { progress: Some(item.progress.clone()), global_progress: item.global_progress.clone(), iteration: item.iteration, } } } impl From<&EvaluationItem> for TrainingProgress { fn from(item: &EvaluationItem) -> Self { Self { progress: None, global_progress: item.progress.clone(), iteration: item.iteration, } } } impl From<&EvaluationItem> for EvaluationProgress { fn from(item: &EvaluationItem) -> Self { Self { progress: item.progress.clone(), iteration: item.iteration, } } } impl From<&TrainingItem> for MetricMetadata { fn from(item: &TrainingItem) -> Self { Self { progress: item.progress.clone(), global_progress: item.global_progress.clone(), iteration: item.iteration, lr: item.lr, } } } impl From<&EvaluationItem> for MetricMetadata { fn from(item: &EvaluationItem) -> Self { Self { progress: item.progress.clone(), global_progress: item.progress.clone(), iteration: item.iteration, lr: None, } } } pub(crate) trait NumericMetricUpdater: Send + Sync { fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate; fn clear(&mut self); } pub(crate) trait MetricUpdater: Send + Sync { fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry; fn clear(&mut self); } pub(crate) struct MetricWrapper { pub id: MetricId, pub metric: M, } impl MetricWrapper { pub fn new(metric: M) -> Self { Self { id: MetricId::new(metric.name()), metric, } } } impl NumericMetricUpdater for MetricWrapper where T: 'static, M: Metric + Numeric + 'static, T: Adaptor, { fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate { let serialized_entry = self.metric.update(&item.adapt(), metadata); let update = MetricEntry::new(self.id.clone(), serialized_entry); let numeric = self.metric.value(); let running = self.metric.running_value(); NumericMetricUpdate { entry: update, numeric_entry: numeric, running_entry: running, } } fn clear(&mut self) { self.metric.clear() } } impl MetricUpdater for MetricWrapper where T: 'static, M: Metric + 'static, T: Adaptor, { fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry { let serialized_entry = self.metric.update(&item.adapt(), metadata); MetricEntry::new(self.id.clone(), serialized_entry) } fn clear(&mut self) { self.metric.clear() } } ================================================ FILE: crates/burn-train/src/metric/processor/minimal.rs ================================================ use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining}; use crate::{ metric::store::{EpochSummary, EventStoreClient, Split}, renderer::cli::CliMetricsRenderer, }; use std::sync::Arc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[allow(dead_code)] #[derive(new)] pub(crate) struct MinimalEventProcessor { metrics: MetricsTraining, store: Arc, } impl EventProcessorTraining, LearnerEvent> for MinimalEventProcessor { fn process_train(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => { let definitions = self.metrics.metric_definitions(); self.store .add_event_train(crate::metric::store::Event::MetricsInit(definitions)); } LearnerEvent::ProcessedItem(item) => { let item = item.sync(); let metadata = (&item).into(); let update = self.metrics.update_train(&item, &metadata); self.store .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); } LearnerEvent::EndEpoch(epoch) => { self.metrics.end_epoch_train(); self.store .add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new( epoch, Split::Train, ))); } LearnerEvent::End(_summary) => {} // no-op for now } } fn process_valid(&mut self, event: LearnerEvent) { match event { LearnerEvent::Start => {} // no-op for now LearnerEvent::ProcessedItem(item) => { let item = item.sync(); let metadata = (&item).into(); let update = self.metrics.update_valid(&item, &metadata); self.store .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); } LearnerEvent::EndEpoch(epoch) => { self.metrics.end_epoch_valid(); self.store .add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new( epoch, Split::Valid, ))); } LearnerEvent::End(_) => {} // no-op for now } } fn renderer(self) -> Box { // TODO: Check for another default. Box::new(CliMetricsRenderer::new()) } } ================================================ FILE: crates/burn-train/src/metric/processor/mod.rs ================================================ mod async_wrapper; mod base; mod full; mod metrics; mod minimal; #[cfg(feature = "rl")] mod rl_metrics; #[cfg(feature = "rl")] mod rl_processor; pub use base::*; pub(crate) use full::*; pub(crate) use metrics::*; #[cfg(feature = "rl")] pub(crate) use rl_metrics::*; #[cfg(feature = "rl")] pub(crate) use rl_processor::*; #[cfg(test)] pub(crate) use minimal::*; pub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining}; #[cfg(test)] pub(crate) mod test_utils { use crate::metric::{ Adaptor, LossInput, processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem}, }; use burn_core::tensor::{ElementConversion, Tensor, backend::Backend}; use super::ItemLazy; impl ItemLazy for f64 { type ItemSync = f64; fn sync(self) -> Self::ItemSync { self } } impl Adaptor> for f64 { fn adapt(&self) -> LossInput { let device = B::Device::default(); LossInput::new(Tensor::from_data([self.elem::()], &device)) } } pub(crate) fn process_train( processor: &mut MinimalEventProcessor, value: f64, epoch: usize, ) { let dummy_progress = burn_core::data::dataloader::Progress { items_processed: 1, items_total: 10, }; let dummy_global_progress = burn_core::data::dataloader::Progress { items_processed: epoch, items_total: 3, }; let dummy_iteration = Some(1); processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new( value, dummy_progress, dummy_global_progress, dummy_iteration, None, ))); } pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { processor.process_train(LearnerEvent::EndEpoch(epoch)); processor.process_valid(LearnerEvent::EndEpoch(epoch)); } } ================================================ FILE: crates/burn-train/src/metric/processor/rl_metrics.rs ================================================ use std::collections::HashMap; use crate::{ EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater, metric::{ Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate, }, }; pub(crate) struct RLMetrics { train_step: Vec>>, env_step: Vec>>, env_step_valid: Vec>>, episode_end: Vec>>, episode_end_valid: Vec>>, train_step_numeric: Vec>>, env_step_numeric: Vec>>, env_step_valid_numeric: Vec>>, episode_end_numeric: Vec>>, episode_end_valid_numeric: Vec>>, metric_definitions: HashMap, } impl Default for RLMetrics { fn default() -> Self { Self { train_step: Vec::default(), env_step: Vec::default(), env_step_valid: Vec::default(), episode_end: Vec::default(), episode_end_valid: Vec::default(), train_step_numeric: Vec::default(), env_step_numeric: Vec::default(), env_step_valid_numeric: Vec::default(), episode_end_numeric: Vec::default(), episode_end_valid_numeric: Vec::default(), metric_definitions: HashMap::default(), } } } impl RLMetrics { /// Register a training metric. pub(crate) fn register_text_metric_agent(&mut self, metric: Me) where ES::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.env_step.push(Box::new(metric)) } /// Register a training metric. pub(crate) fn register_agent_metric(&mut self, metric: Me) where ES::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.env_step_numeric.push(Box::new(metric)) } /// Register a training metric. pub(crate) fn register_text_metric_train(&mut self, metric: Me) where TS::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.train_step.push(Box::new(metric)) } /// Register a training metric. pub(crate) fn register_metric_train(&mut self, metric: Me) where TS::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.train_step_numeric.push(Box::new(metric)) } /// Register a validation env-step metric. pub(crate) fn register_text_metric_agent_valid(&mut self, metric: Me) where ES::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.env_step_valid.push(Box::new(metric)) } /// Register a validation env-step numeric metric. pub(crate) fn register_agent_metric_valid(&mut self, metric: Me) where ES::ItemSync: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.env_step_valid_numeric.push(Box::new(metric)) } /// Register an episode-end metric. pub(crate) fn register_text_metric_episode(&mut self, metric: Me) where EpisodeSummary: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.episode_end.push(Box::new(metric)) } /// Register an episode-end numeric metric. pub(crate) fn register_episode_metric(&mut self, metric: Me) where EpisodeSummary: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.episode_end_numeric.push(Box::new(metric)) } /// Register an episode-end metric for validation. pub(crate) fn register_text_metric_episode_valid(&mut self, metric: Me) where EpisodeSummary: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.episode_end_valid.push(Box::new(metric)) } /// Register an episode-end numeric metric for validation. pub(crate) fn register_episode_metric_valid( &mut self, metric: Me, ) where EpisodeSummary: Adaptor + 'static, { let metric = MetricWrapper::new(metric); self.register_definition(&metric); self.episode_end_valid_numeric.push(Box::new(metric)) } fn register_definition(&mut self, metric: &MetricWrapper) { self.metric_definitions.insert( metric.id.clone(), MetricDefinition::new(metric.id.clone(), &metric.metric), ); } /// Get metric definitions for all splits pub(crate) fn metric_definitions(&mut self) -> Vec { self.metric_definitions.values().cloned().collect() } /// Update the training information from the training item. pub(crate) fn update_train_step( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.train_step.len()); let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len()); for metric in self.train_step.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.train_step_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Update the env-step metrics from an environment step item. pub(crate) fn update_env_step( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.env_step.len()); let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len()); for metric in self.env_step.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.env_step_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Update the env-step metrics for validation from an environment step item. pub(crate) fn update_env_step_valid( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.env_step_valid.len()); let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len()); for metric in self.env_step_valid.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.env_step_valid_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Update the episode-end metrics from an episode summary. pub(crate) fn update_episode_end( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.episode_end.len()); let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len()); for metric in self.episode_end.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.episode_end_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } /// Update the episode-end metrics for validation from an episode summary. pub(crate) fn update_episode_end_valid( &mut self, item: &EvaluationItem, metadata: &MetricMetadata, ) -> MetricsUpdate { let mut entries = Vec::with_capacity(self.episode_end_valid.len()); let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len()); for metric in self.episode_end_valid.iter_mut() { let state = metric.update(&item.item, metadata); entries.push(state); } for metric in self.episode_end_valid_numeric.iter_mut() { let numeric_update = metric.update(&item.item, metadata); entries_numeric.push(numeric_update); } MetricsUpdate::new(entries, entries_numeric) } } ================================================ FILE: crates/burn-train/src/metric/processor/rl_processor.rs ================================================ use std::sync::Arc; use crate::{ EpisodeSummary, EvaluationItem, EventProcessorTraining, ItemLazy, LearnerSummary, RLMetrics, metric::store::{Event, EventStoreClient, MetricsUpdate}, renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress}, }; /// Event happening during reinforcement learning. pub enum RLEvent { /// Signal the start of the process (e.g., learning starts). Start, /// Signal an agent's training step. TrainStep(EvaluationItem), /// Signal a timestep of the agent-environment interface. TimeStep(EvaluationItem), /// Signal an episode end. EpisodeEnd(EvaluationItem), /// Signal the end of the process (e.g., learning ends). End(Option), } /// Event happening during evaluation of a reinforcement learning's agent. pub enum AgentEvaluationEvent { /// Signal the start of the process (e.g., training start) Start, /// Signal a timestep of the agent-environment interface. TimeStep(EvaluationItem), /// Signal an episode end. EpisodeEnd(EvaluationItem), /// Signal the end of the process (e.g., training end). End, } /// An [event processor](EventProcessorTraining) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). /// - Render metrics using a [metrics renderer](MetricsRenderer). #[derive(new)] pub struct RLEventProcessor { metrics: RLMetrics, renderer: Box, store: Arc, } impl RLEventProcessor { fn progress_indicators(&self, progress: &TrainingProgress) -> Vec { let indicators = vec![ProgressType::Detailed { tag: String::from("Step"), progress: progress.global_progress.clone(), }]; indicators } fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec { let indicators = vec![ProgressType::Detailed { tag: String::from("Step"), progress: progress.global_progress.clone(), }]; indicators } } impl RLEventProcessor { fn process_update_train(&mut self, update: MetricsUpdate) { self.store .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); update .entries .into_iter() .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); update .entries_numeric .into_iter() .for_each(|numeric_update| { self.renderer.update_train(MetricState::Numeric( numeric_update.entry, numeric_update.numeric_entry, )) }); } fn process_update_valid(&mut self, update: MetricsUpdate) { self.store .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); update .entries .into_iter() .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); update .entries_numeric .into_iter() .for_each(|numeric_update| { self.renderer.update_valid(MetricState::Numeric( numeric_update.entry, numeric_update.numeric_entry, )) }); } } impl EventProcessorTraining, AgentEvaluationEvent> for RLEventProcessor { fn process_train(&mut self, event: RLEvent) { match event { RLEvent::Start => { let definitions = self.metrics.metric_definitions(); self.store .add_event_train(Event::MetricsInit(definitions.clone())); definitions .iter() .for_each(|definition| self.renderer.register_metric(definition.clone())); } RLEvent::TrainStep(item) => { let item = item.sync(); let metadata = (&item).into(); let update = self.metrics.update_train_step(&item, &metadata); self.process_update_train(update); } RLEvent::TimeStep(item) => { let item = item.sync(); let progress = (&item).into(); let metadata = (&item).into(); let update = self.metrics.update_env_step(&item, &metadata); self.process_update_train(update); let status = self.progress_indicators(&progress); self.renderer.render_train(progress, status); } RLEvent::EpisodeEnd(item) => { let item = item.sync(); let metadata = (&item).into(); let update = self.metrics.update_episode_end(&item, &metadata); self.process_update_train(update); } RLEvent::End(learner_summary) => { self.renderer.on_train_end(learner_summary).ok(); } } } fn process_valid(&mut self, event: AgentEvaluationEvent) { match event { AgentEvaluationEvent::Start => {} // no-op for now AgentEvaluationEvent::TimeStep(item) => { let item = item.sync(); let metadata = (&item).into(); let update = self.metrics.update_env_step_valid(&item, &metadata); self.process_update_valid(update); } AgentEvaluationEvent::EpisodeEnd(item) => { let item = item.sync(); let progress = (&item).into(); let metadata = (&item).into(); let update = self.metrics.update_episode_end_valid(&item, &metadata); self.process_update_valid(update); let status = self.progress_indicators_eval(&progress); self.renderer.render_valid(progress, status); } AgentEvaluationEvent::End => {} // no-op for now } } fn renderer(self) -> Box { self.renderer } } ================================================ FILE: crates/burn-train/src/metric/recall.rs ================================================ use crate::metric::{MetricName, Numeric}; use super::{ Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry, classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::cast::ToElement, }; use core::marker::PhantomData; use std::{num::NonZeroUsize, sync::Arc}; ///The Recall Metric #[derive(Clone)] pub struct RecallMetric { name: MetricName, state: NumericMetricState, _b: PhantomData, config: ClassificationMetricConfig, } impl Default for RecallMetric { fn default() -> Self { Self::new(Default::default()) } } impl RecallMetric { fn new(config: ClassificationMetricConfig) -> Self { let state = Default::default(); let name = Arc::new(format!( "Recall @ {:?} [{:?}]", config.decision_rule, config.class_reduction )); Self { state, config, name, _b: Default::default(), } } /// Recall metric for binary classification. /// /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self::new(ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), // binary classification results are the same independently of class_reduction ..Default::default() }) } /// Recall metric for multiclass classification. /// /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self::new(ClassificationMetricConfig { decision_rule: DecisionRule::TopK( NonZeroUsize::new(top_k).expect("top_k must be non-zero"), ), class_reduction, }) } /// Recall metric for multi-label classification. /// /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self::new(ClassificationMetricConfig { decision_rule: DecisionRule::Threshold(threshold), class_reduction, }) } fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { use ClassReduction::{Macro, Micro}; let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric .clone() .contains_nan() .any() .into_scalar() .to_bool() { let nan_mask = aggregated_metric.clone().is_nan(); aggregated_metric = aggregated_metric .clone() .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1)) } aggregated_metric.mean() } }; avg_tensor.into_scalar().to_f64() } } impl Metric for RecallMetric { type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let [sample_size, _] = input.predictions.dims(); let cf_stats = ConfusionStats::new(input, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive()); self.state.update( 100.0 * metric, sample_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for RecallMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::{ ClassReduction::{self, *}, Metric, MetricMetadata, RecallMetric, }; use crate::metric::Numeric; use crate::{ TestBackend, tests::{ClassificationType, THRESHOLD, dummy_classification_input}, }; use burn_core::tensor::{TensorData, Tolerance}; use rstest::rstest; #[rstest] #[case::binary(THRESHOLD, 0.5)] fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = RecallMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)] #[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)] #[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)] #[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)] fn test_multiclass_recall( #[case] class_reduction: ClassReduction, #[case] top_k: usize, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multiclass).into(); let mut metric = RecallMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[rstest] #[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)] #[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)] fn test_multilabel_recall( #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); let mut metric = RecallMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value().current()]) .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()) } #[test] fn test_parameterized_unique_name() { let metric_a = RecallMetric::::multiclass(1, ClassReduction::Macro); let metric_b = RecallMetric::::multiclass(2, ClassReduction::Macro); let metric_c = RecallMetric::::multiclass(1, ClassReduction::Macro); assert_ne!(metric_a.name(), metric_b.name()); assert_eq!(metric_a.name(), metric_c.name()); let metric_a = RecallMetric::::binary(0.5); let metric_b = RecallMetric::::binary(0.75); assert_ne!(metric_a.name(), metric_b.name()); } } ================================================ FILE: crates/burn-train/src/metric/rl/cum_reward.rs ================================================ use std::sync::Arc; use super::super::{ MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, state::{FormatOptions, NumericMetricState}, }; use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; /// Metric for the cumulative reward of the last completed episode. #[derive(Clone)] pub struct CumulativeRewardMetric { name: MetricName, state: NumericMetricState, } impl CumulativeRewardMetric { /// Creates a new episode length metric. pub fn new() -> Self { Self { name: Arc::new("Cum. Reward".to_string()), state: NumericMetricState::new(), } } } impl Default for CumulativeRewardMetric { fn default() -> Self { Self::new() } } /// The [CumulativeRewardMetric](CumulativeRewardMetric) input type. #[derive(new)] pub struct CumulativeRewardInput { cum_reward: f64, } impl Metric for CumulativeRewardMetric { type Input = CumulativeRewardInput; fn update( &mut self, item: &CumulativeRewardInput, _metadata: &MetricMetadata, ) -> SerializedEntry { self.state.update( item.cum_reward, 1, FormatOptions::new(self.name()).precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: true, } .into() } } impl Numeric for CumulativeRewardMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/rl/ep_len.rs ================================================ use std::sync::Arc; use super::super::{ MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, state::{FormatOptions, NumericMetricState}, }; use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; /// Metric for the length of the last completed episode. #[derive(Clone)] pub struct EpisodeLengthMetric { name: MetricName, state: NumericMetricState, } impl EpisodeLengthMetric { /// Creates a new episode length metric. pub fn new() -> Self { Self { name: Arc::new("Episode length".to_string()), state: NumericMetricState::new(), } } } impl Default for EpisodeLengthMetric { fn default() -> Self { Self::new() } } /// The [EpisodeLengthMetric](EpisodeLengthMetric) input type. #[derive(new)] pub struct EpisodeLengthInput { ep_len: f64, } impl Metric for EpisodeLengthMetric { type Input = EpisodeLengthInput; fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry { self.state .update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0)) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some(String::from("steps")), higher_is_better: true, } .into() } } impl Numeric for EpisodeLengthMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/rl/exploration_rate.rs ================================================ use std::sync::Arc; use super::super::{ MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, state::{FormatOptions, NumericMetricState}, }; use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; /// Metric for the length of the last completed episode. #[derive(Clone)] pub struct ExplorationRateMetric { name: MetricName, state: NumericMetricState, } impl ExplorationRateMetric { /// Creates a new episode length metric. pub fn new() -> Self { Self { name: Arc::new("Exploration rate".to_string()), state: NumericMetricState::new(), } } } impl Default for ExplorationRateMetric { fn default() -> Self { Self::new() } } /// The [ExplorationRateMetric](ExplorationRateMetric) input type. #[derive(new)] pub struct ExplorationRateInput { exploration_rate: f64, } impl Metric for ExplorationRateMetric { type Input = ExplorationRateInput; fn update( &mut self, item: &ExplorationRateInput, _metadata: &MetricMetadata, ) -> SerializedEntry { self.state.update( item.exploration_rate, 1, FormatOptions::new(self.name()).precision(3), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some(String::from("%")), higher_is_better: false, } .into() } } impl Numeric for ExplorationRateMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } ================================================ FILE: crates/burn-train/src/metric/rl/mod.rs ================================================ mod cum_reward; mod ep_len; mod exploration_rate; pub use cum_reward::*; pub use ep_len::*; pub use exploration_rate::*; ================================================ FILE: crates/burn-train/src/metric/state.rs ================================================ use std::sync::Arc; use crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float}; /// Useful utility to implement numeric metrics. /// /// # Notes /// /// The numeric metric store values inside floats. /// Even if some metric are integers, their mean are floats. #[derive(Clone)] pub struct NumericMetricState { sum: f64, count: usize, current: f64, current_count: usize, } /// Formatting options for the [numeric metric state](NumericMetricState). pub struct FormatOptions { name: Arc, unit: Option, precision: Option, } impl FormatOptions { /// Create the [formatting options](FormatOptions) with a name. pub fn new(name: MetricName) -> Self { Self { name: name.clone(), unit: None, precision: None, } } /// Specify the metric unit. pub fn unit(mut self, unit: &str) -> Self { self.unit = Some(unit.to_string()); self } /// Specify the floating point precision. pub fn precision(mut self, precision: usize) -> Self { self.precision = Some(precision); self } /// Get the metric name. pub fn name(&self) -> &Arc { &self.name } /// Get the metric unit. pub fn unit_value(&self) -> &Option { &self.unit } /// Get the precision. pub fn precision_value(&self) -> Option { self.precision } } impl NumericMetricState { /// Create a new [numeric metric state](NumericMetricState). pub fn new() -> Self { Self { sum: 0.0, count: 0, current: f64::NAN, current_count: 0, } } /// Reset the state. pub fn reset(&mut self) { self.sum = 0.0; self.count = 0; self.current = f64::NAN; self.current_count = 0; } /// Update the state. pub fn update( &mut self, value: f64, batch_size: usize, format: FormatOptions, ) -> SerializedEntry { self.sum += value * batch_size as f64; self.count += batch_size; self.current = value; self.current_count = batch_size; let value_current = value; let value_running = self.sum / self.count as f64; // Numeric metric state is an aggregated value let serialized = NumericEntry::Aggregated { aggregated_value: value_current, count: batch_size, } .serialize(); let (formatted_current, formatted_running) = match format.precision { Some(precision) => ( format_float(value_current, precision), format_float(value_running, precision), ), None => (format!("{value_current}"), format!("{value_running}")), }; // TODO: naming inconsistent with RL. let formatted = match format.unit { Some(unit) => { format!("epoch {formatted_running} {unit} - batch {formatted_current} {unit}") } None => format!("epoch {formatted_running} - batch {formatted_current}"), }; SerializedEntry::new(formatted, serialized) } /// Get the numeric value. pub fn current_value(&self) -> NumericEntry { NumericEntry::Aggregated { aggregated_value: self.current, count: self.current_count, } } /// Get the running aggregated value. pub fn running_value(&self) -> NumericEntry { NumericEntry::Aggregated { aggregated_value: self.sum / self.count as f64, count: self.count, } } } impl Default for NumericMetricState { fn default() -> Self { Self::new() } } ================================================ FILE: crates/burn-train/src/metric/store/aggregate.rs ================================================ use crate::{ logger::MetricLogger, metric::{NumericEntry, store::Split}, }; use std::collections::HashMap; use super::{Aggregate, Direction}; /// Type that can be used to fetch and use numeric metric aggregates. #[derive(Default, Debug)] pub(crate) struct NumericMetricsAggregate { value_for_each_epoch: HashMap, } #[derive(new, Hash, PartialEq, Eq, Debug)] struct Key { name: String, epoch: usize, split: Split, aggregate: Aggregate, } impl NumericMetricsAggregate { pub(crate) fn aggregate( &mut self, name: &str, epoch: usize, split: &Split, aggregate: Aggregate, loggers: &mut [Box], ) -> Option { let key = Key::new(name.to_string(), epoch, split.clone(), aggregate); if let Some(value) = self.value_for_each_epoch.get(&key) { return Some(*value); } let points = || { let mut errors = Vec::new(); for logger in loggers { match logger.read_numeric(name, epoch, split) { Ok(points) => return Ok(points), Err(err) => errors.push(err), }; } Err(errors.join(" ")) }; let points = points().expect("Can read values"); if points.is_empty() { return None; } // Accurately compute the aggregated value based on the *actual* number of points // since not all mini-batches are guaranteed to have the specified batch size let (sum, num_points) = points .into_iter() .map(|entry| match entry { NumericEntry::Value(v) => (v, 1), // Right now the mean is the only aggregate available, so we can assume that the sum // of an entry corresponds to (value * number of elements) NumericEntry::Aggregated { aggregated_value, count, } => (aggregated_value * count as f64, count), }) .reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n)) .unwrap(); let value = match aggregate { Aggregate::Mean => sum / num_points as f64, }; self.value_for_each_epoch.insert(key, value); Some(value) } pub(crate) fn find_epoch( &mut self, name: &str, split: &Split, aggregate: Aggregate, direction: Direction, loggers: &mut [Box], ) -> Option { let mut data = Vec::new(); let mut current_epoch = 1; while let Some(value) = self.aggregate(name, current_epoch, split, aggregate, loggers) { data.push(value); current_epoch += 1; } if data.is_empty() { return None; } let mut current_value = match &direction { Direction::Lowest => f64::MAX, Direction::Highest => f64::MIN, }; for (i, value) in data.into_iter().enumerate() { match &direction { Direction::Lowest => { if value < current_value { current_value = value; current_epoch = i + 1; } } Direction::Highest => { if value > current_value { current_value = value; current_epoch = i + 1; } } } } Some(current_epoch) } } #[cfg(test)] mod tests { use std::sync::Arc; use crate::{ logger::{FileMetricLogger, InMemoryMetricLogger}, metric::{MetricDefinition, MetricEntry, MetricId, SerializedEntry, store::MetricsUpdate}, }; use super::*; struct TestLogger { logger: FileMetricLogger, epoch: usize, } const NAME: &str = "test-logger"; impl TestLogger { fn new() -> Self { Self { logger: FileMetricLogger::new("/tmp"), epoch: 1, } } fn log(&mut self, num: f64) { let entry = MetricEntry::new( MetricId::new(Arc::new(NAME.into())), SerializedEntry::new(num.to_string(), num.to_string()), ); let entries = Vec::from([entry]); let metrics_update = MetricsUpdate::new(entries, vec![]); self.logger.log(metrics_update, self.epoch, &Split::Train); } fn log_definition(&mut self) { let definition = MetricDefinition { metric_id: MetricId::new(Arc::new(NAME.into())), name: NAME.into(), attributes: crate::metric::MetricAttributes::None, description: None, }; self.logger.log_metric_definition(definition); } fn new_epoch(&mut self) { self.epoch += 1; } } #[test] fn should_find_epoch() { let mut logger = TestLogger::new(); let mut aggregate = NumericMetricsAggregate::default(); logger.log_definition(); logger.log(500.); // Epoch 1 logger.log(1000.); // Epoch 1 logger.new_epoch(); logger.log(200.); // Epoch 2 logger.log(1000.); // Epoch 2 logger.new_epoch(); logger.log(10000.); // Epoch 3 let value = aggregate .find_epoch( NAME, &Split::Train, Aggregate::Mean, Direction::Lowest, &mut [Box::new(logger.logger)], ) .unwrap(); assert_eq!(value, 2); } #[test] fn should_aggregate_numeric_entry() { let mut logger = InMemoryMetricLogger::default(); let mut aggregate = NumericMetricsAggregate::default(); let metric_name = Arc::new("Loss".to_string()); let metric_id = MetricId::new(metric_name.clone()); let definition = MetricDefinition { metric_id: metric_id.clone(), name: metric_name.to_string(), attributes: crate::metric::MetricAttributes::None, description: None, }; logger.log_metric_definition(definition); // Epoch 1 let loss_1 = 0.5; let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2 let entry = MetricEntry::new( metric_id.clone(), SerializedEntry::new(loss_1.to_string(), NumericEntry::Value(loss_1).serialize()), ); let entries = Vec::from([entry]); let metrics_update = MetricsUpdate::new(entries, vec![]); logger.log(metrics_update, 1, &Split::Train); let entry = MetricEntry::new( metric_id.clone(), SerializedEntry::new( loss_2.to_string(), NumericEntry::Aggregated { aggregated_value: loss_2, count: 2, } .serialize(), ), ); let entries = Vec::from([entry]); let metrics_update = MetricsUpdate::new(entries, vec![]); logger.log(metrics_update, 1, &Split::Train); let value = aggregate .aggregate( &metric_name, 1, &Split::Train, Aggregate::Mean, &mut [Box::new(logger)], ) .unwrap(); // Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875 assert_eq!(value, 1.0); } } ================================================ FILE: crates/burn-train/src/metric/store/base.rs ================================================ use std::sync::Arc; use crate::metric::{MetricDefinition, MetricEntry, NumericEntry}; /// Event happening during the training/validation process. pub enum Event { /// Signal the iniialization of the metrics MetricsInit(Vec), /// Signal that metrics have been updated. MetricsUpdate(MetricsUpdate), /// Signal the end of an epoch. EndEpoch(EpochSummary), } /// Contains all metric information. #[derive(new, Clone, Debug)] pub struct NumericMetricUpdate { /// Generic metric information. pub entry: MetricEntry, /// The numeric information. pub numeric_entry: NumericEntry, /// Numeric value averaged over the global step (epoch). pub running_entry: NumericEntry, } /// Contains all metric information. #[derive(new, Clone, Debug)] pub struct MetricsUpdate { /// Metrics information related to non-numeric metrics. pub entries: Vec, /// Metrics information related to numeric metrics. pub entries_numeric: Vec, } /// Summary information about a given epoch #[derive(new, Clone, Debug)] pub struct EpochSummary { /// Epoch number. pub epoch_number: usize, /// Dataset split (train, valid, test). pub split: Split, } /// Defines how training and validation events are collected and searched. /// /// This trait also exposes methods that uses the collected data to compute useful information. pub trait EventStore: Send { /// Collect a training/validation event. fn add_event(&mut self, event: Event, split: Split); /// Find the epoch following the given criteria from the collected data. fn find_epoch( &mut self, name: &str, aggregate: Aggregate, direction: Direction, split: &Split, ) -> Option; /// Find the metric value for the current epoch following the given criteria. fn find_metric( &mut self, name: &str, epoch: usize, aggregate: Aggregate, split: &Split, ) -> Option; } #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] /// How to aggregate the metric. pub enum Aggregate { /// Compute the average. Mean, } #[derive(Clone, Debug, Hash, PartialEq, Eq)] /// The split to use. pub enum Split { /// The training split. Train, /// The validation split. Valid, /// The testing split, which might be tagged. Test(Option>), } impl std::fmt::Display for Split { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Split::Train => write!(f, "train"), Split::Valid => write!(f, "valid"), Split::Test(_) => write!(f, "test"), } } } #[derive(Copy, Clone)] /// The direction of the query. pub enum Direction { /// Lower is better. Lowest, /// Higher is better. Highest, } ================================================ FILE: crates/burn-train/src/metric/store/client.rs ================================================ use super::EventStore; use super::{Aggregate, Direction, Event, Split}; use std::sync::Arc; use std::{sync::mpsc, thread::JoinHandle}; /// Type that allows to communicate with an [event store](EventStore). pub struct EventStoreClient { sender: mpsc::Sender, handler: Option>, } impl EventStoreClient { /// Create a new [event store](EventStore) client. pub(crate) fn new(store: C) -> Self where C: EventStore + 'static, { let (sender, receiver) = mpsc::channel(); let thread = WorkerThread::new(store, receiver); let handler = std::thread::spawn(move || thread.run()); let handler = Some(handler); Self { sender, handler } } } impl EventStoreClient { /// Add a training event to the [event store](EventStore). pub(crate) fn add_event_train(&self, event: Event) { self.sender .send(Message::OnEventTrain(event)) .expect("Can send event to event store thread."); } /// Add a validation event to the [event store](EventStore). pub(crate) fn add_event_valid(&self, event: Event) { self.sender .send(Message::OnEventValid(event)) .expect("Can send event to event store thread."); } /// Add a testing event to the [event store](EventStore). pub(crate) fn add_event_test(&self, event: Event, tag: Arc) { self.sender .send(Message::OnEventTest(event, tag)) .expect("Can send event to event store thread."); } /// Find the epoch following the given criteria from the collected data. pub fn find_epoch( &self, name: &str, aggregate: Aggregate, direction: Direction, split: &Split, ) -> Option { let (sender, receiver) = mpsc::sync_channel(1); self.sender .send(Message::FindEpoch( name.to_string(), aggregate, direction, split.clone(), sender, )) .expect("Can send event to event store thread."); match receiver.recv() { Ok(value) => value, Err(err) => panic!("Event store thread crashed: {err:?}"), } } /// Find the metric value for the current epoch following the given criteria. pub fn find_metric( &self, name: &str, epoch: usize, aggregate: Aggregate, split: &Split, ) -> Option { let (sender, receiver) = mpsc::sync_channel(1); self.sender .send(Message::FindMetric( name.to_string(), epoch, aggregate, split.clone(), sender, )) .expect("Can send event to event store thread."); match receiver.recv() { Ok(value) => value, Err(err) => panic!("Event store thread crashed: {err:?}"), } } } #[derive(new)] struct WorkerThread { store: S, receiver: mpsc::Receiver, } impl WorkerThread where C: EventStore, { fn run(mut self) { for item in self.receiver.iter() { match item { Message::End => { return; } Message::FindEpoch(name, aggregate, direction, split, callback) => { let response = self.store.find_epoch(&name, aggregate, direction, &split); callback .send(response) .expect("Can send response using callback channel."); } Message::FindMetric(name, epoch, aggregate, split, callback) => { let response = self.store.find_metric(&name, epoch, aggregate, &split); callback .send(response) .expect("Can send response using callback channel."); } Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), Message::OnEventTest(event, tag) => { self.store.add_event(event, Split::Test(Some(tag))) } } } } } enum Message { OnEventTest(Event, Arc), OnEventTrain(Event), OnEventValid(Event), End, FindEpoch( String, Aggregate, Direction, Split, mpsc::SyncSender>, ), FindMetric( String, usize, Aggregate, Split, mpsc::SyncSender>, ), } impl Drop for EventStoreClient { fn drop(&mut self) { self.sender .send(Message::End) .expect("Can send the end message to the event store thread."); let handler = self.handler.take(); if let Some(handler) = handler { handler.join().expect("The event store thread should stop."); } } } ================================================ FILE: crates/burn-train/src/metric/store/log.rs ================================================ use std::collections::HashMap; use super::{Aggregate, Direction, Event, EventStore, Split, aggregate::NumericMetricsAggregate}; use crate::logger::MetricLogger; #[derive(Default)] pub(crate) struct LogEventStore { loggers: Vec>, aggregate: NumericMetricsAggregate, epochs: HashMap, } impl EventStore for LogEventStore { fn add_event(&mut self, event: Event, split: Split) { let epoch = *self.epochs.entry(split.clone()).or_insert(1); match event { Event::MetricsInit(definitions) => { definitions.iter().for_each(|def| { self.loggers .iter_mut() .for_each(|logger| logger.log_metric_definition(def.clone())); }); } Event::MetricsUpdate(update) => { self.loggers .iter_mut() .for_each(|logger| logger.log(update.clone(), epoch, &split)); } Event::EndEpoch(summary) => { self.epochs.insert(split, summary.epoch_number + 1); self.loggers .iter_mut() .for_each(|logger| logger.log_epoch_summary(summary.clone())); } } } fn find_epoch( &mut self, name: &str, aggregate: Aggregate, direction: Direction, split: &Split, ) -> Option { self.aggregate .find_epoch(name, split, aggregate, direction, &mut self.loggers) } fn find_metric( &mut self, name: &str, epoch: usize, aggregate: Aggregate, split: &Split, ) -> Option { self.aggregate .aggregate(name, epoch, split, aggregate, &mut self.loggers) } } impl LogEventStore { /// Register a logger for metrics. pub(crate) fn register_logger(&mut self, logger: ML) { self.loggers.push(Box::new(logger)); } /// Returns whether any loggers are registered. pub(crate) fn has_loggers(&self) -> bool { !self.loggers.is_empty() } } ================================================ FILE: crates/burn-train/src/metric/store/mod.rs ================================================ pub(crate) mod aggregate; mod base; mod client; mod log; pub(crate) use self::log::*; pub use base::*; pub use client::*; ================================================ FILE: crates/burn-train/src/metric/top_k_acc.rs ================================================ use core::marker::PhantomData; use std::sync::Arc; use super::state::{FormatOptions, NumericMetricState}; use super::{MetricMetadata, SerializedEntry}; use crate::metric::{ Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry, }; use burn_core::tensor::backend::Backend; use burn_core::tensor::{ElementConversion, Int, Tensor}; /// The Top-K accuracy metric. /// /// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`). #[derive(Default, Clone)] pub struct TopKAccuracyMetric { name: Arc, k: usize, state: NumericMetricState, /// If specified, targets equal to this value will be considered padding and will not count /// towards the metric pad_token: Option, _b: PhantomData, } /// The [top-k accuracy metric](TopKAccuracyMetric) input type. #[derive(new)] pub struct TopKAccuracyInput { /// The outputs (batch_size, num_classes) outputs: Tensor, /// The labels (batch_size) targets: Tensor, } impl TopKAccuracyMetric { /// Creates the metric. pub fn new(k: usize) -> Self { Self { name: Arc::new(format!("Top-K Accuracy @ TopK({})", k)), k, ..Default::default() } } /// Sets the pad token. pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self } } impl Metric for TopKAccuracyMetric { type Input = TopKAccuracyInput; fn update( &mut self, input: &TopKAccuracyInput, _metadata: &MetricMetadata, ) -> SerializedEntry { let [batch_size, _n_classes] = input.outputs.dims(); let targets = input.targets.clone().to_device(&B::Device::default()); let outputs = input .outputs .clone() .argsort_descending(1) .narrow(1, 0, self.k) .to_device(&B::Device::default()) .reshape([batch_size, self.k]); let (targets, num_pad) = match self.pad_token { Some(pad_token) => { // we ignore the samples where the target is equal to the pad token let mask = targets.clone().equal_elem(pad_token as i64); let num_pad = mask.clone().int().sum().into_scalar().elem::(); (targets.clone().mask_fill(mask, -1_i64), num_pad) } None => (targets.clone(), 0_f64), }; let accuracy = targets .reshape([batch_size, 1]) .repeat_dim(1, self.k) .equal(outputs) .int() .sum() .into_scalar() .elem::() / (batch_size as f64 - num_pad); self.state.update( 100.0 * accuracy, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn clear(&mut self) { self.state.reset() } fn name(&self) -> MetricName { self.name.clone() } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: true, } .into() } } impl Numeric for TopKAccuracyMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; #[test] fn test_accuracy_without_padding() { let device = Default::default(); let mut metric = TopKAccuracyMetric::::new(2); let input = TopKAccuracyInput::new( Tensor::from_data( [ [0.0, 0.2, 0.8], // 2, 1 [1.0, 2.0, 0.5], // 1, 0 [0.4, 0.1, 0.2], // 0, 2 [0.6, 0.7, 0.2], // 1, 0 ], &device, ), Tensor::from_data([2, 2, 1, 1], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } #[test] fn test_accuracy_with_padding() { let device = Default::default(); let mut metric = TopKAccuracyMetric::::new(2).with_pad_token(3); let input = TopKAccuracyInput::new( Tensor::from_data( [ [0.0, 0.2, 0.8, 0.0], // 2, 1 [1.0, 2.0, 0.5, 0.0], // 1, 0 [0.4, 0.1, 0.2, 0.0], // 0, 2 [0.6, 0.7, 0.2, 0.0], // 1, 0 [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count [0.0, 0.1, 0.2, 0.0], // Error on padding should not count [0.6, 0.0, 0.2, 0.0], // Error on padding should not count ], &device, ), Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } #[test] fn test_parameterized_unique_name() { let metric_a = TopKAccuracyMetric::::new(2); let metric_b = TopKAccuracyMetric::::new(1); let metric_c = TopKAccuracyMetric::::new(2); assert_ne!(metric_a.name(), metric_b.name()); assert_eq!(metric_a.name(), metric_c.name()); } } ================================================ FILE: crates/burn-train/src/metric/vision/dice.rs ================================================ use crate::metric::{MetricAttributes, MetricName, SerializedEntry}; use super::super::{ Metric, MetricMetadata, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::{ElementConversion, Int, s}, }; use core::marker::PhantomData; /// Input type for the [DiceMetric]. /// /// # Type Parameters /// - `B`: Backend type. /// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4). pub struct DiceInput { /// Model outputs (predictions), as a tensor. outputs: Tensor, /// Ground truth targets, as a tensor. targets: Tensor, } impl DiceInput { /// Creates a new DiceInput with the given outputs and targets. /// /// Inputs are expected to have the dimensions `[B, C, ...]` /// where `B` is the batch size, `C` is the number of classes, /// and `...` represents additional dimensions (e.g., height, width for images). /// /// If `C` is more than 1, the first class (index 0) is considered the background. /// Additionally, one-hot encoding is the responsibility of the caller. /// /// # Arguments /// - `outputs`: The model outputs as a tensor. /// - `targets`: The ground truth targets as a tensor. /// /// # Returns /// A new instance of `DiceInput`. /// /// # Panics /// - If `D` is less than 3. /// - If `outputs` and `targets` do not have the same dimensions. /// - If `outputs` or `targets` do not have exactly `D` dimensions. /// - If `outputs` and `targets` do not have the same shape. pub fn new(outputs: Tensor, targets: Tensor) -> Self { assert!(D >= 3, "DiceInput requires at least 3 dimensions."); assert!( outputs.dims() == targets.dims(), "Outputs and targets must have the same dimensions. Got {:?} and {:?}", outputs.dims(), targets.dims() ); Self { outputs, targets } } } /// Configuration for the [DiceMetric]. #[derive(Debug, Clone, Copy)] pub struct DiceMetricConfig { /// Epsilon value to avoid division by zero. pub epsilon: f64, /// Whether to include the background class in the metric calculation. /// The background is assumed to be the first class (index 0). /// if `true`, will panic if there are fewer than 2 classes. pub include_background: bool, } impl Default for DiceMetricConfig { fn default() -> Self { Self { epsilon: 1e-7, include_background: false, } } } /// The Dice-Sorenson coefficient (DSC) for evaluating overlap between two binary masks. /// The DSC is defined as: /// `DSC = 2 * (|X ∩ Y|) / (|X| + |Y|)` /// where `X` is the model output and `Y` is the ground truth target. /// /// # Type Parameters /// - `B`: Backend type. /// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4). #[derive(Default, Clone)] pub struct DiceMetric { name: MetricName, /// Internal state for numeric metric aggregation. state: NumericMetricState, /// Marker for backend type. _b: PhantomData, /// Configuration for the metric. config: DiceMetricConfig, } impl DiceMetric { /// Creates a new Dice metric instance with default config. pub fn new() -> Self { Self::with_config(DiceMetricConfig::default()) } /// Creates a new Dice metric with a custom config. pub fn with_config(config: DiceMetricConfig) -> Self { let name = MetricName::new(format!("{D}D Dice Metric")); assert!(D >= 3, "DiceMetric requires at least 3 dimensions."); Self { name, config, ..Default::default() } } } impl Metric for DiceMetric { type Input = DiceInput; fn name(&self) -> MetricName { self.name.clone() } fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { // Dice coefficient: 2 * (|X ∩ Y|) / (|X| + |Y|) if item.outputs.dims() != item.targets.dims() { panic!( "Outputs and targets must have the same dimensions. Got {:?} and {:?}", item.outputs.dims(), item.targets.dims() ); } let dims = item.outputs.dims(); let batch_size = dims[0]; let n_classes = dims[1]; let mut outputs = item.outputs.clone(); let mut targets = item.targets.clone(); if !self.config.include_background && n_classes > 1 { // If not including background, we can ignore the first class outputs = outputs.slice(s![.., 1..]); targets = targets.slice(s![.., 1..]); } else if self.config.include_background && n_classes < 2 { // If including background, we need at least 2 classes panic!("Dice metric requires at least 2 classes when including background."); } let intersection = (outputs.clone() * targets.clone()).sum(); let outputs_sum = outputs.sum(); let targets_sum = targets.sum(); // Convert to f64 let intersection_val = intersection.into_scalar().elem::(); let outputs_sum_val = outputs_sum.into_scalar().elem::(); let targets_sum_val = targets_sum.into_scalar().elem::(); // Use epsilon from config let epsilon = self.config.epsilon; let dice = (2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon); self.state.update( dice, batch_size, FormatOptions::new(self.name()).precision(4), ) } /// Clears the metric state. fn clear(&mut self) { self.state.reset(); } fn attributes(&self) -> MetricAttributes { crate::metric::NumericAttributes { unit: None, higher_is_better: true, } .into() } } impl crate::metric::Numeric for DiceMetric { fn value(&self) -> crate::metric::NumericEntry { self.state.current_value() } fn running_value(&self) -> crate::metric::NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::{TestBackend, metric::Numeric}; use burn_core::tensor::{Shape, Tensor}; #[test] fn test_dice_perfect_overlap() { let device = Default::default(); let mut metric = DiceMetric::::new(); let input = DiceInput::new( Tensor::from_data([[[[1, 0], [1, 0]]]], &device), Tensor::from_data([[[[1, 0], [1, 0]]]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!((metric.value().current() - 1.0).abs() < 1e-6); } #[test] fn test_dice_no_overlap() { let device = Default::default(); let mut metric = DiceMetric::::new(); let input = DiceInput::new( Tensor::from_data([[[[1, 0], [1, 0]]]], &device), Tensor::from_data([[[[0, 1], [0, 1]]]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!(metric.value().current() < 1e-6); } #[test] fn test_dice_partial_overlap() { let device = Default::default(); let mut metric = DiceMetric::::new(); let input = DiceInput::new( Tensor::from_data([[[[1, 1], [0, 0]]]], &device), Tensor::from_data([[[[1, 0], [1, 0]]]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); // intersection = 1, sum = 2+2=4, dice = 2*1/4 = 0.5 assert!((metric.value().current() - 0.5).abs() < 1e-6); } #[test] fn test_dice_empty_masks() { let device = Default::default(); let mut metric = DiceMetric::::new(); let input = DiceInput::new( Tensor::from_data([[[[0, 0], [0, 0]]]], &device), Tensor::from_data([[[[0, 0], [0, 0]]]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!((metric.value().current() - 1.0).abs() < 1e-6); } #[test] fn test_dice_no_background() { let device = Default::default(); let mut metric = DiceMetric::::new(); let input = DiceInput::new( Tensor::ones(Shape::new([1, 1, 2, 2]), &device), Tensor::ones(Shape::new([1, 1, 2, 2]), &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!((metric.value().current() - 1.0).abs() < 1e-6); } #[test] fn test_dice_with_background() { let device = Default::default(); let config = DiceMetricConfig { epsilon: 1e-7, include_background: true, }; let mut metric = DiceMetric::::with_config(config); let input = DiceInput::new( Tensor::ones(Shape::new([1, 2, 2, 2]), &device), Tensor::ones(Shape::new([1, 2, 2, 2]), &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!((metric.value().current() - 1.0).abs() < 1e-6); } #[test] fn test_dice_ignored_background() { let device = Default::default(); let config = DiceMetricConfig { epsilon: 1e-7, include_background: false, }; let mut metric = DiceMetric::::with_config(config); let input = DiceInput::new( Tensor::ones(Shape::new([1, 2, 2, 2]), &device), Tensor::ones(Shape::new([1, 2, 2, 2]), &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert!((metric.value().current() - 1.0).abs() < 1e-6); } #[test] #[should_panic(expected = "DiceInput requires at least 3 dimensions.")] fn test_invalid_input_dimensions() { let device = Default::default(); // D = 2, should panic let _ = DiceInput::::new( Tensor::from_data([[0.0, 0.0]], &device), Tensor::from_data([[0.0, 0.0]], &device), ); } #[test] #[should_panic( expected = "Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]" )] fn test_mismatched_shape() { let device = Default::default(); // shapes differ let _ = DiceInput::::new( Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device), Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device), ); } #[test] #[should_panic(expected = "Dice metric requires at least 2 classes when including background.")] fn test_include_background_panic() { let device = Default::default(); let config = DiceMetricConfig { epsilon: 1e-7, include_background: true, }; let mut metric = DiceMetric::::with_config(config); let input = DiceInput::new( Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device), Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device), ); // n_classes = 2, should not panic let _entry = metric.update(&input, &MetricMetadata::fake()); let config = DiceMetricConfig { epsilon: 1e-7, include_background: true, }; let mut metric = DiceMetric::::with_config(config); let input = DiceInput::new( Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device), Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device), ); // n_classes = 1, should panic let _entry = metric.update(&input, &MetricMetadata::fake()); } } ================================================ FILE: crates/burn-train/src/metric/vision/dists/l2pool.rs ================================================ //! L2 Pooling layer for DISTS. //! //! L2 Pooling applies a Hanning window filter and computes the L2 norm //! across the pooling window. This is used in DISTS instead of MaxPooling. use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_nn::PaddingConfig2d; use burn_nn::conv::{Conv2d, Conv2dConfig}; /// L2 Pooling layer configuration. #[derive(Debug, Clone)] pub struct L2Pool2dConfig { /// Kernel size for pooling pub kernel_size: usize, /// Stride for pooling pub stride: usize, /// Padding for pooling pub padding: usize, } impl Default for L2Pool2dConfig { fn default() -> Self { Self { kernel_size: 5, stride: 2, padding: 2, } } } impl L2Pool2dConfig { /// Create a new L2Pool2d configuration. #[allow(dead_code)] pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self { Self { kernel_size, stride, padding, } } /// Initialize the L2Pool2d layer. pub fn init(&self, channels: usize, device: &B::Device) -> L2Pool2d { L2Pool2d::new( channels, self.kernel_size, self.stride, self.padding, device, ) } } /// L2 Pooling layer. /// /// Applies a 2D Hanning window filter followed by L2 normalization. /// This provides smoother downsampling compared to MaxPooling. #[derive(Module, Debug)] pub struct L2Pool2d { /// Depthwise convolution with Hanning kernel conv: Conv2d, } impl L2Pool2d { /// Create a new L2Pool2d layer with Hanning window kernel. pub fn new( channels: usize, kernel_size: usize, stride: usize, padding: usize, device: &B::Device, ) -> Self { // Create Hanning kernel let kernel = Self::create_hanning_kernel(channels, kernel_size, device); // Create depthwise convolution (groups = channels) let mut conv = Conv2dConfig::new([channels, channels], [kernel_size, kernel_size]) .with_stride([stride, stride]) .with_padding(PaddingConfig2d::Explicit( padding, padding, padding, padding, )) .with_groups(channels) .with_bias(false) .init(device); // Set the kernel weights to Hanning window conv.weight = burn::module::Param::from_tensor(kernel); Self { conv } } /// Create a Hanning kernel for depthwise convolution. /// Output shape: [channels, 1, kernel_size, kernel_size] fn create_hanning_kernel( channels: usize, kernel_size: usize, device: &B2::Device, ) -> Tensor { // Create 1D Hanning window let mut hanning_1d = Vec::with_capacity(kernel_size); for i in 0..kernel_size { let n = i as f32; let n_minus_1 = (kernel_size - 1) as f32; let value = if n_minus_1 == 0.0 { 1.0 } else { 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n / n_minus_1).cos()) }; hanning_1d.push(value); } // Create 2D Hanning window by outer product let mut hanning_2d = Vec::with_capacity(kernel_size * kernel_size); let mut sum = 0.0; for i in 0..kernel_size { for j in 0..kernel_size { let value = hanning_1d[i] * hanning_1d[j]; hanning_2d.push(value); sum += value; } } // Normalize for v in hanning_2d.iter_mut() { *v /= sum; } // Create tensor of shape [1, 1, kernel_size, kernel_size] let kernel_single = Tensor::::from_floats(hanning_2d.as_slice(), device).reshape([ 1, 1, kernel_size, kernel_size, ]); // Expand to [channels, 1, kernel_size, kernel_size] kernel_single.repeat_dim(0, channels) } /// Apply L2 pooling to the input tensor. /// /// # Arguments /// /// * `x` - Input tensor of shape `[batch, channels, height, width]` /// /// # Returns /// /// Pooled tensor with reduced spatial dimensions. pub fn forward(&self, x: Tensor) -> Tensor { // Square the input let x_sq = x.clone().mul(x); // Apply depthwise convolution with Hanning kernel let pooled = self.conv.forward(x_sq); // Take square root for L2 norm // Add small epsilon to avoid sqrt of negative numbers due to numerical errors pooled.clamp_min(1e-10).sqrt() } } ================================================ FILE: crates/burn-train/src/metric/vision/dists/metric.rs ================================================ //! DISTS (Deep Image Structure and Texture Similarity) metric. //! //! DISTS is a full-reference image quality assessment metric that combines //! structure and texture similarity using deep features from VGG16. //! //! Reference: "Image Quality Assessment: Unifying Structure and Texture Similarity" //! https://arxiv.org/abs/2004.07728 use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_nn::loss::Reduction; use super::vgg16_l2pool::Vgg16L2PoolExtractor; /// Channel counts for each stage: [input, stage1, stage2, stage3, stage4, stage5] const CHANNELS: [usize; 6] = [3, 64, 128, 256, 512, 512]; /// Small constant for numerical stability in structure similarity. const C1: f32 = 1e-6; /// Small constant for numerical stability in texture similarity. const C2: f32 = 1e-6; /// ImageNet normalization constants. const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225]; /// Image normalizer with pre-initialized mean and std tensors. /// /// This struct holds the mean and std tensors for normalization, /// avoiding the need to create them on each forward pass. #[derive(Module, Debug)] pub struct Normalizer { /// Mean tensor of shape [1, 3, 1, 1] for broadcasting. pub mean: Tensor, /// Std tensor of shape [1, 3, 1, 1] for broadcasting. pub std: Tensor, } impl Normalizer { /// Create a new ImageNet normalizer. pub fn imagenet(device: &B::Device) -> Self { // Shape: [1, 3, 1, 1] for broadcasting over [batch, channels, height, width] let mean = Tensor::from_floats( [[ [[IMAGENET_MEAN[0]]], [[IMAGENET_MEAN[1]]], [[IMAGENET_MEAN[2]]], ]], device, ); let std = Tensor::from_floats( [[ [[IMAGENET_STD[0]]], [[IMAGENET_STD[1]]], [[IMAGENET_STD[2]]], ]], device, ); Self { mean, std } } /// Normalize a tensor: (x - mean) / std pub fn normalize(&self, x: Tensor) -> Tensor { x.sub(self.mean.clone()).div(self.std.clone()) } } /// Configuration for DISTS metric. #[derive(Config, Debug)] pub struct DistsConfig { /// Whether to apply ImageNet normalization to input images. #[config(default = true)] pub normalize: bool, } impl DistsConfig { /// Initialize a DISTS module with default weights. pub fn init(&self, device: &B::Device) -> Dists { let total_channels: usize = CHANNELS.iter().sum(); // Initialize alpha and beta with constant value 0.1 for all channels let alpha_data: Vec = (0..total_channels).map(|_| 0.1).collect(); let beta_data: Vec = (0..total_channels).map(|_| 0.1).collect(); let normalizer = if self.normalize { Some(Normalizer::imagenet(device)) } else { None }; Dists { extractor: Vgg16L2PoolExtractor::new(device), alpha: Param::from_tensor(Tensor::from_floats(alpha_data.as_slice(), device)), beta: Param::from_tensor(Tensor::from_floats(beta_data.as_slice(), device)), normalizer, } } /// Initialize a DISTS module with pretrained weights. pub fn init_pretrained(&self, device: &B::Device) -> Dists { let dists = self.init(device); super::weights::load_pretrained_weights(dists) } } /// DISTS (Deep Image Structure and Texture Similarity) metric. /// /// Computes perceptual similarity between two images by combining /// structure similarity (based on spatial means) and texture similarity /// (based on variances and covariances) across VGG16 feature maps. /// /// # Example /// /// ```ignore /// use burn_train::metric::vision::{DistsConfig, Reduction}; /// /// let device = Default::default(); /// let dists = DistsConfig::new().init_pretrained(&device); /// /// let img1: Tensor = /* [batch, 3, H, W] */; /// let img2: Tensor = /* [batch, 3, H, W] */; /// /// let distance = dists.forward(img1, img2, Reduction::Mean); /// ``` #[derive(Module, Debug)] #[module(custom_display)] pub struct Dists { /// VGG16 feature extractor with L2 pooling pub(crate) extractor: Vgg16L2PoolExtractor, /// Learned weights for structure similarity (per channel) pub(crate) alpha: Param>, /// Learned weights for texture similarity (per channel) pub(crate) beta: Param>, /// Optional normalizer for input preprocessing pub(crate) normalizer: Option>, } impl ModuleDisplay for Dists { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { content .add("backbone", &"VGG16-L2Pool".to_string()) .add("normalize", &self.normalizer.is_some().to_string()) .optional() } } impl Dists { /// Compute DISTS distance with reduction. /// /// # Arguments /// /// * `input` - First image tensor of shape `[batch, 3, H, W]` /// * `target` - Second image tensor of shape `[batch, 3, H, W]` /// * `reduction` - How to reduce the output (Mean, Sum, or Auto) /// /// # Returns /// /// Scalar tensor of shape `[1]`. pub fn forward( &self, input: Tensor, target: Tensor, reduction: Reduction, ) -> Tensor { let distance = self.forward_no_reduction(input, target); match reduction { Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(), Reduction::Sum => distance.sum(), } } /// Compute DISTS distance without reduction. /// /// # Arguments /// /// * `input` - First image tensor of shape `[batch, 3, H, W]` /// * `target` - Second image tensor of shape `[batch, 3, H, W]` /// /// # Returns /// /// Per-sample distance tensor of shape `[batch]`. pub fn forward_no_reduction(&self, input: Tensor, target: Tensor) -> Tensor { let [batch, _, _, _] = input.dims(); // Preprocess inputs let (input, target) = self.preprocess(input, target); // Extract features from both images let feats_x = self.extractor.forward(input); let feats_y = self.extractor.forward(target); // Get alpha and beta weights let alpha = self.alpha.val(); let beta = self.beta.val(); // Compute weighted sum of alpha and beta for normalization let alpha_sum = alpha.clone().sum(); let beta_sum = beta.clone().sum(); let device = feats_x[0].device(); // Initialize accumulators let mut structure_dist = Tensor::::zeros([batch], &device); let mut texture_dist = Tensor::::zeros([batch], &device); let mut channel_offset = 0; // Compute similarity for each stage for (feat_x, feat_y) in feats_x.iter().zip(feats_y.iter()) { let [_b, c, _h, _w] = feat_x.dims(); // Get alpha and beta for this stage let alpha_stage = alpha.clone().narrow(0, channel_offset, c); let beta_stage = beta.clone().narrow(0, channel_offset, c); // Compute structure and texture similarity for this stage let (s_dist, t_dist) = self.compute_stage_similarity( feat_x.clone(), feat_y.clone(), alpha_stage, beta_stage, ); structure_dist = structure_dist.add(s_dist); texture_dist = texture_dist.add(t_dist); channel_offset += c; } // Normalize by sum of weights structure_dist = structure_dist.div(alpha_sum); texture_dist = texture_dist.div(beta_sum); // DISTS = 1 - (structure_similarity + texture_similarity) // Since we computed distances (1 - similarity), we return the sum structure_dist.add(texture_dist) } /// Compute structure and texture similarity for a single stage. fn compute_stage_similarity( &self, feat_x: Tensor, feat_y: Tensor, alpha: Tensor, beta: Tensor, ) -> (Tensor, Tensor) { let [batch, channels, height, width] = feat_x.dims(); let device = feat_x.device(); // Reshape to [batch, channels, H*W] for easier computation let x = feat_x.reshape([batch, channels, height * width]); let y = feat_y.reshape([batch, channels, height * width]); // Compute means: [batch, channels] (squeeze after mean_dim to remove the reduced dimension) let mean_x = x.clone().mean_dim(2).squeeze_dim::<2>(2); let mean_y = y.clone().mean_dim(2).squeeze_dim::<2>(2); // Compute structure similarity: (2*mean_x*mean_y + c1) / (mean_x^2 + mean_y^2 + c1) let c1 = Tensor::::full([batch, channels], C1, &device); let structure_sim = mean_x .clone() .mul(mean_y.clone()) .mul_scalar(2.0) .add(c1.clone()) .div( mean_x .clone() .mul(mean_x.clone()) .add(mean_y.clone().mul(mean_y.clone())) .add(c1), ); // Compute variances and covariance // var_x = E[x^2] - E[x]^2, clamped at 0 for numerical stability let var_x = x .clone() .mul(x.clone()) .mean_dim(2) .squeeze_dim::<2>(2) .sub(mean_x.clone().mul(mean_x.clone())) .clamp_min(0.0); let var_y = y .clone() .mul(y.clone()) .mean_dim(2) .squeeze_dim::<2>(2) .sub(mean_y.clone().mul(mean_y.clone())) .clamp_min(0.0); // cov_xy = E[xy] - E[x]E[y] let cov_xy = x .mul(y) .mean_dim(2) .squeeze_dim::<2>(2) .sub(mean_x.clone().mul(mean_y.clone())); // Compute texture similarity: (2*cov_xy + c2) / (var_x + var_y + c2) let c2 = Tensor::::full([batch, channels], C2, &device); let texture_sim = cov_xy .mul_scalar(2.0) .add(c2.clone()) .div(var_x.add(var_y).add(c2)); // Convert similarity to distance: 1 - similarity let structure_dist = Tensor::::ones([batch, channels], &device).sub(structure_sim); let texture_dist = Tensor::::ones([batch, channels], &device).sub(texture_sim); // Apply weights: [batch, channels] * [channels] -> [batch, channels] // Then sum over channels -> [batch] let weighted_structure = structure_dist .mul(alpha.unsqueeze_dim::<2>(0)) .sum_dim(1) .squeeze_dim::<1>(1); let weighted_texture = texture_dist .mul(beta.unsqueeze_dim::<2>(0)) .sum_dim(1) .squeeze_dim::<1>(1); (weighted_structure, weighted_texture) } /// Preprocess input images using the configured normalizer. fn preprocess( &self, input: Tensor, target: Tensor, ) -> (Tensor, Tensor) { match &self.normalizer { Some(normalizer) => { let input = normalizer.normalize(input); let target = normalizer.normalize(target); (input, target) } None => (input, target), } } } // ============================================================================= // Tests // ============================================================================= #[cfg(test)] mod tests { use super::*; use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem}; use burn_ndarray::NdArray; type TestBackend = NdArray; type FT = FloatElem; type TestTensor = Tensor; #[test] fn test_dists_identical_images_zero_distance() { let device = Default::default(); // Use random image instead of constant to avoid numerical edge cases let image = TestTensor::<4>::random( [1, 3, 64, 64], burn_core::tensor::Distribution::Uniform(0.0, 1.0), &device, ); let dists: Dists = DistsConfig::new().init(&device); let distance = dists.forward(image.clone(), image, Reduction::Mean); let expected = TensorData::from([0.0]); distance .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } #[test] fn test_dists_different_images_nonzero_distance() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let dists: Dists = DistsConfig::new().init(&device); let distance = dists.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() > 1e-6, "DISTS should be != 0 for different images" ); } #[test] fn test_dists_symmetry() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device); let dists: Dists = DistsConfig::new().init(&device); let distance_forward = dists.forward(image1.clone(), image2.clone(), Reduction::Mean); let distance_reverse = dists.forward(image2, image1, Reduction::Mean); distance_forward .into_data() .assert_approx_eq::(&distance_reverse.into_data(), Tolerance::default()); } #[test] fn test_dists_batch_processing() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device); let dists: Dists = DistsConfig::new().init(&device); let distance = dists.forward(image1, image2, Reduction::Mean); assert_eq!(distance.dims(), [1]); } #[test] fn test_dists_no_reduction() { let device = Default::default(); let batch_size = 4; let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device); let dists: Dists = DistsConfig::new().init(&device); let distance = dists.forward_no_reduction(image1, image2); assert_eq!(distance.dims(), [batch_size]); } #[test] fn display_dists() { let device = Default::default(); let dists: Dists = DistsConfig::new().init(&device); let display_str = format!("{dists}"); assert!(display_str.contains("Dists")); assert!(display_str.contains("VGG16-L2Pool")); } // ========================================================================= // Pretrained Weights Tests (requires network) // ========================================================================= /// Test DISTS pretrained weights download and loading. #[test] fn test_dists_pretrained() { let device = Default::default(); let dists: Dists = DistsConfig::new().init_pretrained(&device); // Test with identical images - should be ~0 // Use random image to avoid numerical edge cases with constant images let image = TestTensor::<4>::random( [1, 3, 64, 64], burn_core::tensor::Distribution::Uniform(0.0, 1.0), &device, ); let distance = dists.forward(image.clone(), image, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() < 1e-5, "Pretrained DISTS should be ~0 for identical images, got {}", distance_value ); // Test with different images - should be positive let image1 = TestTensor::<4>::random( [1, 3, 64, 64], burn_core::tensor::Distribution::Uniform(0.0, 0.3), &device, ); let image2 = TestTensor::<4>::random( [1, 3, 64, 64], burn_core::tensor::Distribution::Uniform(0.7, 1.0), &device, ); let distance = dists.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value > 0.0, "Pretrained DISTS should be > 0 for different images" ); } } ================================================ FILE: crates/burn-train/src/metric/vision/dists/mod.rs ================================================ //! DISTS (Deep Image Structure and Texture Similarity) metric. //! //! This module implements DISTS, a full-reference image quality assessment metric //! that combines structure and texture similarity using deep features. //! //! Reference: "Image Quality Assessment: Unifying Structure and Texture Similarity" //! https://arxiv.org/abs/2004.07728 mod l2pool; mod metric; mod vgg16_l2pool; mod weights; pub use metric::{Dists, DistsConfig}; ================================================ FILE: crates/burn-train/src/metric/vision/dists/vgg16_l2pool.rs ================================================ //! VGG16 feature extractor with L2 Pooling for DISTS. //! //! This module implements the VGG16 backbone used in DISTS, //! with L2Pooling replacing MaxPooling for smoother feature extraction. use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::activation::relu; use burn::tensor::backend::Backend; use burn_nn::PaddingConfig2d; use burn_nn::conv::{Conv2d, Conv2dConfig}; use super::l2pool::{L2Pool2d, L2Pool2dConfig}; /// VGG16 feature extractor with L2 Pooling for DISTS. /// /// Extracts features from 5 stages of VGG16, using L2Pooling /// instead of MaxPooling for smoother downsampling. /// /// Output channels per stage: [64, 128, 256, 512, 512] #[derive(Module, Debug)] pub struct Vgg16L2PoolExtractor { // Stage 1: 2 conv layers, 64 channels pub(crate) conv1_1: Conv2d, pub(crate) conv1_2: Conv2d, pub(crate) pool1: L2Pool2d, // Stage 2: 2 conv layers, 128 channels pub(crate) conv2_1: Conv2d, pub(crate) conv2_2: Conv2d, pub(crate) pool2: L2Pool2d, // Stage 3: 3 conv layers, 256 channels pub(crate) conv3_1: Conv2d, pub(crate) conv3_2: Conv2d, pub(crate) conv3_3: Conv2d, pub(crate) pool3: L2Pool2d, // Stage 4: 3 conv layers, 512 channels pub(crate) conv4_1: Conv2d, pub(crate) conv4_2: Conv2d, pub(crate) conv4_3: Conv2d, pub(crate) pool4: L2Pool2d, // Stage 5: 3 conv layers, 512 channels pub(crate) conv5_1: Conv2d, pub(crate) conv5_2: Conv2d, pub(crate) conv5_3: Conv2d, } impl Vgg16L2PoolExtractor { /// Create a new VGG16 feature extractor with L2 Pooling. pub fn new(device: &B::Device) -> Self { let pool_config = L2Pool2dConfig::default(); Self { // Stage 1 conv1_1: Conv2dConfig::new([3, 64], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv1_2: Conv2dConfig::new([64, 64], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), pool1: pool_config.init(64, device), // Stage 2 conv2_1: Conv2dConfig::new([64, 128], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv2_2: Conv2dConfig::new([128, 128], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), pool2: pool_config.init(128, device), // Stage 3 conv3_1: Conv2dConfig::new([128, 256], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv3_2: Conv2dConfig::new([256, 256], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv3_3: Conv2dConfig::new([256, 256], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), pool3: pool_config.init(256, device), // Stage 4 conv4_1: Conv2dConfig::new([256, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv4_2: Conv2dConfig::new([512, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv4_3: Conv2dConfig::new([512, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), pool4: pool_config.init(512, device), // Stage 5 conv5_1: Conv2dConfig::new([512, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv5_2: Conv2dConfig::new([512, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), conv5_3: Conv2dConfig::new([512, 512], [3, 3]) .with_padding(PaddingConfig2d::Same) .init(device), } } /// Extract features from all 5 stages. /// /// # Arguments /// /// * `x` - Input tensor of shape `[batch, 3, height, width]` /// /// # Returns /// /// Vector of 6 feature tensors: /// - Stage 0: Input image [batch, 3, H, W] /// - Stage 1: After conv1 [batch, 64, H/2, W/2] /// - Stage 2: After conv2 [batch, 128, H/4, W/4] /// - Stage 3: After conv3 [batch, 256, H/8, W/8] /// - Stage 4: After conv4 [batch, 512, H/16, W/16] /// - Stage 5: After conv5 [batch, 512, H/32, W/32] pub fn forward(&self, x: Tensor) -> Vec> { let mut features = Vec::with_capacity(6); // Stage 0: Input image features.push(x.clone()); // Stage 1 let x = relu(self.conv1_1.forward(x)); let x = relu(self.conv1_2.forward(x)); features.push(x.clone()); let x = self.pool1.forward(x); // Stage 2 let x = relu(self.conv2_1.forward(x)); let x = relu(self.conv2_2.forward(x)); features.push(x.clone()); let x = self.pool2.forward(x); // Stage 3 let x = relu(self.conv3_1.forward(x)); let x = relu(self.conv3_2.forward(x)); let x = relu(self.conv3_3.forward(x)); features.push(x.clone()); let x = self.pool3.forward(x); // Stage 4 let x = relu(self.conv4_1.forward(x)); let x = relu(self.conv4_2.forward(x)); let x = relu(self.conv4_3.forward(x)); features.push(x.clone()); let x = self.pool4.forward(x); // Stage 5 let x = relu(self.conv5_1.forward(x)); let x = relu(self.conv5_2.forward(x)); let x = relu(self.conv5_3.forward(x)); features.push(x); features } } ================================================ FILE: crates/burn-train/src/metric/vision/dists/weights.rs ================================================ //! Pretrained weights loading for DISTS. use burn_core as burn; use burn::tensor::backend::Backend; use burn_std::network::downloader::download_file_as_bytes; use burn_store::{ModuleSnapshot, PytorchStore}; use std::fs::{File, create_dir_all}; use std::io::Write; use std::path::PathBuf; use super::metric::Dists; /// URL for pretrained DISTS alpha/beta weights from the official repository. /// Reference: https://github.com/dingkeyan93/DISTS const DISTS_WEIGHTS_URL: &str = "https://github.com/dingkeyan93/DISTS/raw/master/DISTS_pytorch/weights.pt"; /// URL for ImageNet pretrained VGG16 backbone weights from PyTorch. const VGG16_IMAGENET_URL: &str = "https://download.pytorch.org/models/vgg16-397923af.pth"; /// Get the cache directory for DISTS weights. fn get_cache_dir() -> PathBuf { let cache_dir = dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset") .join("dists"); if !cache_dir.exists() { create_dir_all(&cache_dir).expect("Failed to create cache directory"); } cache_dir } /// Download file if not cached. fn download_if_needed(url: &str, cache_path: &PathBuf, message: &str) { if !cache_path.exists() { let bytes = download_file_as_bytes(url, message); let mut file = File::create(cache_path).expect("Failed to create cache file"); file.write_all(&bytes).expect("Failed to write weights"); } } /// Download and load pretrained weights into a DISTS module. /// /// This loads both: /// 1. ImageNet pretrained VGG16 backbone weights /// 2. DISTS trained alpha/beta weights /// /// Weights are cached in the user's cache directory to avoid re-downloading. /// /// # Arguments /// /// * `dists` - The DISTS module to load weights into. /// /// # Returns /// /// The DISTS module with loaded pretrained weights. pub fn load_pretrained_weights(mut dists: Dists) -> Dists { let cache_dir = get_cache_dir(); // Step 1: Download and load VGG16 ImageNet backbone weights let vgg_cache_path = cache_dir.join("vgg16_backbone.pth"); download_if_needed( VGG16_IMAGENET_URL, &vgg_cache_path, "Downloading VGG16 ImageNet weights for DISTS...", ); // Step 2: Download DISTS alpha/beta weights let dists_cache_path = cache_dir.join("dists_weights.pt"); download_if_needed( DISTS_WEIGHTS_URL, &dists_cache_path, "Downloading DISTS alpha/beta weights...", ); // Load VGG16 backbone weights first dists = load_vgg16_backbone_weights(dists, &vgg_cache_path); // Then load DISTS alpha/beta weights dists = load_dists_weights(dists, &dists_cache_path); dists } /// Load VGG16 ImageNet pretrained backbone weights. fn load_vgg16_backbone_weights(mut dists: Dists, cache_path: &PathBuf) -> Dists { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) .skip_enum_variants(true) // VGG16 features.X -> extractor.convY_Z .with_key_remapping(r"^features\.0\.", "extractor.conv1_1.") .with_key_remapping(r"^features\.2\.", "extractor.conv1_2.") .with_key_remapping(r"^features\.5\.", "extractor.conv2_1.") .with_key_remapping(r"^features\.7\.", "extractor.conv2_2.") .with_key_remapping(r"^features\.10\.", "extractor.conv3_1.") .with_key_remapping(r"^features\.12\.", "extractor.conv3_2.") .with_key_remapping(r"^features\.14\.", "extractor.conv3_3.") .with_key_remapping(r"^features\.17\.", "extractor.conv4_1.") .with_key_remapping(r"^features\.19\.", "extractor.conv4_2.") .with_key_remapping(r"^features\.21\.", "extractor.conv4_3.") .with_key_remapping(r"^features\.24\.", "extractor.conv5_1.") .with_key_remapping(r"^features\.26\.", "extractor.conv5_2.") .with_key_remapping(r"^features\.28\.", "extractor.conv5_3."); let result = dists.load_from(&mut store); if let Err(e) = result { log::warn!("Some VGG16 backbone weights could not be loaded: {:?}", e); } dists } /// Load DISTS trained alpha/beta weights. fn load_dists_weights(mut dists: Dists, cache_path: &PathBuf) -> Dists { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) .skip_enum_variants(true); let result = dists.load_from(&mut store); if let Err(e) = result { log::warn!("Some DISTS weights could not be loaded: {:?}", e); } dists } ================================================ FILE: crates/burn-train/src/metric/vision/lpips/alexnet.rs ================================================ //! AlexNet feature extractor for LPIPS. use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::activation::relu; use burn::tensor::backend::Backend; use burn_nn::PaddingConfig2d; use burn_nn::conv::{Conv2d, Conv2dConfig}; /// AlexNet feature extractor for LPIPS. /// /// Extracts features from 5 layers: /// - conv1: 64 channels (after ReLU) /// - conv2: 192 channels (after ReLU) /// - conv3: 384 channels (after ReLU) /// - conv4: 256 channels (after ReLU) /// - conv5: 256 channels (after ReLU) #[derive(Module, Debug)] pub struct AlexFeatureExtractor { /// Conv1: 3 -> 64, kernel 11x11, stride 4, padding 2 conv1: Conv2d, /// Conv2: 64 -> 192, kernel 5x5, stride 1, padding 2 conv2: Conv2d, /// Conv3: 192 -> 384, kernel 3x3, stride 1, padding 1 conv3: Conv2d, /// Conv4: 384 -> 256, kernel 3x3, stride 1, padding 1 conv4: Conv2d, /// Conv5: 256 -> 256, kernel 3x3, stride 1, padding 1 conv5: Conv2d, } impl AlexFeatureExtractor { /// Create a new AlexNet feature extractor. pub fn new(device: &B::Device) -> Self { Self { // Conv1: 3 -> 64, 11x11, stride 4, padding 2 conv1: Conv2dConfig::new([3, 64], [11, 11]) .with_stride([4, 4]) .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)) .with_bias(true) .init(device), // Conv2: 64 -> 192, 5x5, stride 1, padding 2 conv2: Conv2dConfig::new([64, 192], [5, 5]) .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)) .with_bias(true) .init(device), // Conv3: 192 -> 384, 3x3, stride 1, padding 1 conv3: Conv2dConfig::new([192, 384], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .with_bias(true) .init(device), // Conv4: 384 -> 256, 3x3, stride 1, padding 1 conv4: Conv2dConfig::new([384, 256], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .with_bias(true) .init(device), // Conv5: 256 -> 256, 3x3, stride 1, padding 1 conv5: Conv2dConfig::new([256, 256], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .with_bias(true) .init(device), } } /// Extract features from 5 AlexNet layers. pub fn forward(&self, x: Tensor) -> Vec> { let mut features = Vec::with_capacity(5); // Slice 1: Conv1 + ReLU let x = relu(self.conv1.forward(x)); features.push(x.clone()); // Slice 2: MaxPool + Conv2 + ReLU let x = max_pool2d_alex(x); let x = relu(self.conv2.forward(x)); features.push(x.clone()); // Slice 3: MaxPool + Conv3 + ReLU let x = max_pool2d_alex(x); let x = relu(self.conv3.forward(x)); features.push(x.clone()); // Slice 4: Conv4 + ReLU (no pooling) let x = relu(self.conv4.forward(x)); features.push(x.clone()); // Slice 5: Conv5 + ReLU (no pooling) let x = relu(self.conv5.forward(x)); features.push(x); features } } /// 3x3 max pooling with stride 2 (for AlexNet). fn max_pool2d_alex(x: Tensor) -> Tensor { burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], false) } ================================================ FILE: crates/burn-train/src/metric/vision/lpips/metric.rs ================================================ //! LPIPS (Learned Perceptual Image Patch Similarity) metric module. //! //! LPIPS measures perceptual similarity between images using deep features. //! Supports VGG16, AlexNet, and SqueezeNet as backbone networks. //! //! Reference: "The Unreasonable Effectiveness of Deep Features as a Perceptual Metric" //! use burn_core as burn; use burn::config::Config; use burn::module::{Content, DisplaySettings, Module, ModuleDisplay}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn_nn::conv::{Conv2d, Conv2dConfig}; use burn_nn::loss::Reduction; use super::alexnet::AlexFeatureExtractor; use super::squeezenet::SqueezeFeatureExtractor; use super::vgg::VggFeatureExtractor; /// Network type for LPIPS. #[derive(Config, Debug, Copy, PartialEq, Eq)] pub enum LpipsNet { /// VGG16 network (default) Vgg, /// AlexNet network Alex, /// SqueezeNet network Squeeze, } /// Configuration for [Lpips](Lpips) metric module. /// /// # Example /// /// ```ignore /// use burn_train::metric::vision::{LpipsConfig, LpipsNet}; /// /// // VGG (default) /// let lpips_vgg = LpipsConfig::new().init(&device); /// /// // AlexNet /// let lpips_alex = LpipsConfig::new() /// .with_net(LpipsNet::Alex) /// .init(&device); /// /// // SqueezeNet /// let lpips_squeeze = LpipsConfig::new() /// .with_net(LpipsNet::Squeeze) /// .init(&device); /// ``` #[derive(Config, Debug)] pub struct LpipsConfig { /// Network type for feature extraction. #[config(default = "LpipsNet::Vgg")] pub net: LpipsNet, /// Whether to normalize input images to [-1, 1] range. /// Set to true if input is in [0, 1] range. #[config(default = true)] pub normalize: bool, } impl LpipsConfig { /// Initialize a new [Lpips](Lpips) module with pretrained weights. /// /// Downloads and loads official LPIPS pretrained weights from the /// PerceptualSimilarity repository. /// /// # Arguments /// /// * `device` - Device to create the module on. /// /// # Returns /// /// A new LPIPS module with pretrained weights loaded. /// /// # Example /// /// ```ignore /// use burn_train::metric::vision::{LpipsConfig, LpipsNet}; /// /// let lpips = LpipsConfig::new() /// .with_net(LpipsNet::Vgg) /// .init_pretrained(&device); /// ``` pub fn init_pretrained(&self, device: &B::Device) -> Lpips { let lpips = self.init(device); super::weights::load_pretrained_weights(lpips, self.net) } /// Initialize a new [Lpips](Lpips) module with random weights. /// /// # Arguments /// /// * `device` - Device to create the module on. /// /// # Returns /// /// A new LPIPS module with random weights. Use `init_pretrained` for accurate results. pub fn init(&self, device: &B::Device) -> Lpips { match self.net { LpipsNet::Vgg => { // Channel sizes for VGG16: [64, 128, 256, 512, 512] Lpips::Vgg(LpipsVgg { extractor: VggFeatureExtractor::new(device), lin0: Conv2dConfig::new([64, 1], [1, 1]) .with_bias(false) .init(device), lin1: Conv2dConfig::new([128, 1], [1, 1]) .with_bias(false) .init(device), lin2: Conv2dConfig::new([256, 1], [1, 1]) .with_bias(false) .init(device), lin3: Conv2dConfig::new([512, 1], [1, 1]) .with_bias(false) .init(device), lin4: Conv2dConfig::new([512, 1], [1, 1]) .with_bias(false) .init(device), normalize: self.normalize, }) } LpipsNet::Alex => { // Channel sizes for AlexNet: [64, 192, 384, 256, 256] Lpips::Alex(LpipsAlex { extractor: AlexFeatureExtractor::new(device), lin0: Conv2dConfig::new([64, 1], [1, 1]) .with_bias(false) .init(device), lin1: Conv2dConfig::new([192, 1], [1, 1]) .with_bias(false) .init(device), lin2: Conv2dConfig::new([384, 1], [1, 1]) .with_bias(false) .init(device), lin3: Conv2dConfig::new([256, 1], [1, 1]) .with_bias(false) .init(device), lin4: Conv2dConfig::new([256, 1], [1, 1]) .with_bias(false) .init(device), normalize: self.normalize, }) } LpipsNet::Squeeze => { // Channel sizes for SqueezeNet: [64, 128, 256, 384, 384, 512, 512] Lpips::Squeeze(LpipsSqueeze { extractor: SqueezeFeatureExtractor::new(device), lin0: Conv2dConfig::new([64, 1], [1, 1]) .with_bias(false) .init(device), lin1: Conv2dConfig::new([128, 1], [1, 1]) .with_bias(false) .init(device), lin2: Conv2dConfig::new([256, 1], [1, 1]) .with_bias(false) .init(device), lin3: Conv2dConfig::new([384, 1], [1, 1]) .with_bias(false) .init(device), lin4: Conv2dConfig::new([384, 1], [1, 1]) .with_bias(false) .init(device), lin5: Conv2dConfig::new([512, 1], [1, 1]) .with_bias(false) .init(device), lin6: Conv2dConfig::new([512, 1], [1, 1]) .with_bias(false) .init(device), normalize: self.normalize, }) } } } } /// LPIPS (Learned Perceptual Image Patch Similarity) metric module. /// /// Computes perceptual distance between two images using deep features. /// Supports VGG16, AlexNet, and SqueezeNet as backbone networks. /// /// # Example /// /// ```ignore /// use burn_train::metric::vision::{LpipsConfig, LpipsNet, Reduction}; /// /// let device = Default::default(); /// let lpips = LpipsConfig::new().init(&device); /// /// let img1: Tensor = /* [batch, 3, H, W] */; /// let img2: Tensor = /* [batch, 3, H, W] */; /// /// // Compute LPIPS distance /// let distance = lpips.forward(img1, img2, Reduction::Mean); /// ``` #[derive(Module, Debug)] #[allow(clippy::large_enum_variant)] #[module(custom_display)] pub enum Lpips { /// VGG16 backbone (5 feature layers) Vgg(LpipsVgg), /// AlexNet backbone (5 feature layers) Alex(LpipsAlex), /// SqueezeNet backbone (7 feature layers) Squeeze(LpipsSqueeze), } /// LPIPS with VGG16 backbone. #[derive(Module, Debug)] pub struct LpipsVgg { /// VGG feature extractor pub(crate) extractor: VggFeatureExtractor, /// Linear layers for each feature level pub(crate) lin0: Conv2d, pub(crate) lin1: Conv2d, pub(crate) lin2: Conv2d, pub(crate) lin3: Conv2d, pub(crate) lin4: Conv2d, /// Whether to normalize input pub(crate) normalize: bool, } /// LPIPS with AlexNet backbone. #[derive(Module, Debug)] pub struct LpipsAlex { /// AlexNet feature extractor pub(crate) extractor: AlexFeatureExtractor, /// Linear layers for each feature level pub(crate) lin0: Conv2d, pub(crate) lin1: Conv2d, pub(crate) lin2: Conv2d, pub(crate) lin3: Conv2d, pub(crate) lin4: Conv2d, /// Whether to normalize input pub(crate) normalize: bool, } /// LPIPS with SqueezeNet backbone. #[derive(Module, Debug)] pub struct LpipsSqueeze { /// SqueezeNet feature extractor pub(crate) extractor: SqueezeFeatureExtractor, /// Linear layers for each feature level pub(crate) lin0: Conv2d, pub(crate) lin1: Conv2d, pub(crate) lin2: Conv2d, pub(crate) lin3: Conv2d, pub(crate) lin4: Conv2d, pub(crate) lin5: Conv2d, pub(crate) lin6: Conv2d, /// Whether to normalize input pub(crate) normalize: bool, } impl LpipsVgg { /// Compute LPIPS distance without reduction using VGG backbone. pub fn forward_no_reduction(&self, input: Tensor, target: Tensor) -> Tensor { // Preprocess inputs let (input, target) = preprocess_inputs(input, target, self.normalize); // Extract features from both images let feats0 = self.extractor.forward(input); let feats1 = self.extractor.forward(target); // Compute distance for each layer using stack + sum let layer_distances: Vec> = vec![ compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1), compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1), compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1), compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1), compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1), ]; Tensor::cat(layer_distances, 1) .sum_dim(1) .squeeze_dim::<1>(1) } } impl LpipsAlex { /// Compute LPIPS distance without reduction using AlexNet backbone. pub fn forward_no_reduction(&self, input: Tensor, target: Tensor) -> Tensor { // Preprocess inputs let (input, target) = preprocess_inputs(input, target, self.normalize); // Extract features from both images let feats0 = self.extractor.forward(input); let feats1 = self.extractor.forward(target); // Compute distance for each layer using stack + sum let layer_distances: Vec> = vec![ compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1), compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1), compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1), compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1), compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1), ]; Tensor::cat(layer_distances, 1) .sum_dim(1) .squeeze_dim::<1>(1) } } impl LpipsSqueeze { /// Compute LPIPS distance without reduction using SqueezeNet backbone. pub fn forward_no_reduction(&self, input: Tensor, target: Tensor) -> Tensor { // Preprocess inputs let (input, target) = preprocess_inputs(input, target, self.normalize); // Extract features from both images let feats0 = self.extractor.forward(input); let feats1 = self.extractor.forward(target); // Compute distance for each layer using stack + sum (7 layers for SqueezeNet) let layer_distances: Vec> = vec![ compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1), compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1), compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1), compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1), compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1), compute_layer_distance(&feats0[5], &feats1[5], &self.lin5).unsqueeze_dim(1), compute_layer_distance(&feats0[6], &feats1[6], &self.lin6).unsqueeze_dim(1), ]; Tensor::cat(layer_distances, 1) .sum_dim(1) .squeeze_dim::<1>(1) } } impl ModuleDisplay for Lpips { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) .optional() } fn custom_content(&self, content: Content) -> Option { let (net_name, normalize) = match self { Lpips::Vgg(inner) => ("Vgg", inner.normalize), Lpips::Alex(inner) => ("Alex", inner.normalize), Lpips::Squeeze(inner) => ("Squeeze", inner.normalize), }; content .add("net", &net_name.to_string()) .add("normalize", &normalize.to_string()) .optional() } } impl Lpips { /// Compute LPIPS distance with reduction. /// /// # Arguments /// /// * `input` - First image tensor of shape `[batch, 3, H, W]` /// * `target` - Second image tensor of shape `[batch, 3, H, W]` /// * `reduction` - How to reduce the output (Mean, Sum, or Auto) /// /// # Returns /// /// Scalar tensor of shape `[1]`. /// /// # Shapes /// /// - input: `[batch, 3, H, W]` /// - target: `[batch, 3, H, W]` /// - output: `[1]` pub fn forward( &self, input: Tensor, target: Tensor, reduction: Reduction, ) -> Tensor { let distance = self.forward_no_reduction(input, target); match reduction { Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(), Reduction::Sum => distance.sum(), } } /// Compute LPIPS distance without reduction. /// /// # Arguments /// /// * `input` - First image tensor of shape `[batch, 3, H, W]` /// * `target` - Second image tensor of shape `[batch, 3, H, W]` /// /// # Returns /// /// Per-sample distance tensor of shape `[batch]`. /// /// # Shapes /// /// - input: `[batch, 3, H, W]` /// - target: `[batch, 3, H, W]` /// - output: `[batch]` pub fn forward_no_reduction(&self, input: Tensor, target: Tensor) -> Tensor { match self { Lpips::Vgg(inner) => inner.forward_no_reduction(input, target), Lpips::Alex(inner) => inner.forward_no_reduction(input, target), Lpips::Squeeze(inner) => inner.forward_no_reduction(input, target), } } } // ============================================================================= // Helper Functions // ============================================================================= /// Normalize tensor to unit norm along channel dimension. fn normalize_tensor(x: Tensor) -> Tensor { let norm = x.clone().mul(x.clone()).sum_dim(1).sqrt().clamp_min(1e-10); x.div(norm) } /// Apply ImageNet normalization used by PyTorch lpips. /// shift = [-.030, -.088, -.188], scale = [.458, .448, .450] /// output = (input - shift) / scale fn scaling_layer(x: Tensor) -> Tensor { let device = x.device(); let [batch, _, h, w] = x.dims(); // Create shift and scale tensors [1, 3, 1, 1] and broadcast let shift = Tensor::::from_floats([[-0.030], [-0.088], [-0.188]], &device) .reshape([1, 3, 1, 1]) .expand([batch, 3, h, w]); let scale = Tensor::::from_floats([[0.458], [0.448], [0.450]], &device) .reshape([1, 3, 1, 1]) .expand([batch, 3, h, w]); x.sub(shift).div(scale) } /// Compute normalized L2 distance for a single layer. fn compute_layer_distance( feat0: &Tensor, feat1: &Tensor, lin: &Conv2d, ) -> Tensor { // Normalize features (unit norm along channel dimension) let feat0_norm = normalize_tensor(feat0.clone()); let feat1_norm = normalize_tensor(feat1.clone()); // Compute squared difference let diff = feat0_norm.sub(feat1_norm); let diff_sq = diff.clone().mul(diff); // Apply linear layer (learned weights) // Shape: [batch, C, H, W] -> [batch, 1, H, W] let weighted = lin.forward(diff_sq); // Spatial average: compute mean over C, H, W dimensions // Shape: [batch, 1, H, W] -> [batch] let [batch, c, h, w] = weighted.dims(); // Reshape to [batch, c*h*w] then take mean over last dimension weighted .reshape([batch, c * h * w]) .mean_dim(1) .squeeze_dim::<1>(1) } /// Preprocess input images for LPIPS computation. fn preprocess_inputs( input: Tensor, target: Tensor, normalize: bool, ) -> (Tensor, Tensor) { // Normalize to [-1, 1] if needed let (input, target) = if normalize { ( input.mul_scalar(2.0).sub_scalar(1.0), target.mul_scalar(2.0).sub_scalar(1.0), ) } else { (input, target) }; // Apply ImageNet normalization (same as PyTorch lpips scaling_layer) (scaling_layer(input), scaling_layer(target)) } // ============================================================================= // Tests // ============================================================================= #[cfg(test)] mod tests { use super::*; use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem}; use burn_ndarray::NdArray; type TestBackend = NdArray; type FT = FloatElem; type TestTensor = Tensor; // ========================================================================= // Basic Functionality Tests // ========================================================================= /// Identical images should have LPIPS distance of 0. #[test] fn test_lpips_identical_images_zero_distance() { let device = Default::default(); let image = TestTensor::<4>::ones([1, 3, 32, 32], &device); let lpips: Lpips = LpipsConfig::new().init(&device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); // Identical images → distance = 0 let expected = TensorData::from([0.0]); distance .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Different images should have LPIPS distance != 0. /// Note: With random weights, distance can be negative, so we only check != 0. /// Non-negativity is tested with pretrained weights. #[test] fn test_lpips_different_images_nonzero_distance() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device); let lpips: Lpips = LpipsConfig::new().init(&device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() > 1e-6, "LPIPS should be != 0 for different images" ); } /// Test symmetry: LPIPS(a, b) == LPIPS(b, a). #[test] fn test_lpips_symmetry() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device); let lpips: Lpips = LpipsConfig::new().init(&device); let distance_forward = lpips.forward(image1.clone(), image2.clone(), Reduction::Mean); let distance_reverse = lpips.forward(image2, image1, Reduction::Mean); distance_forward .into_data() .assert_approx_eq::(&distance_reverse.into_data(), Tolerance::default()); } // ========================================================================= // Reduction Tests // ========================================================================= #[test] fn test_lpips_forward_mean_reduction() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device); let lpips: Lpips = LpipsConfig::new().init(&device); let distance = lpips.forward(image1, image2, Reduction::Mean); assert_eq!(distance.dims(), [1]); } #[test] fn test_lpips_forward_no_reduction() { let device = Default::default(); let batch_size = 4; let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device); let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device); let lpips: Lpips = LpipsConfig::new().init(&device); let distance = lpips.forward_no_reduction(image1, image2); assert_eq!(distance.dims(), [batch_size]); } // ========================================================================= // AlexNet Tests // ========================================================================= /// Test AlexNet LPIPS with identical images. #[test] fn test_lpips_alex_identical_images_zero_distance() { let device = Default::default(); let image = TestTensor::<4>::ones([1, 3, 64, 64], &device); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); let expected = TensorData::from([0.0]); distance .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Test AlexNet LPIPS with different images produces non-zero distance. #[test] fn test_lpips_alex_different_images_nonzero_distance() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; // Note: With random weights, non-negativity is not guaranteed. // We only check that different images produce a non-zero distance. assert!( distance_value.abs() > 1e-6, "LPIPS (Alex) should be != 0 for different images" ); } // ========================================================================= // SqueezeNet Tests // ========================================================================= /// Test SqueezeNet LPIPS with identical images. #[test] fn test_lpips_squeeze_identical_images_zero_distance() { let device = Default::default(); let image = TestTensor::<4>::ones([1, 3, 64, 64], &device); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); let expected = TensorData::from([0.0]); distance .into_data() .assert_approx_eq::(&expected, Tolerance::default()); } /// Test SqueezeNet LPIPS with different images produces non-zero distance. #[test] fn test_lpips_squeeze_different_images_nonzero_distance() { let device = Default::default(); let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; // Note: With random weights, non-negativity is not guaranteed. // We only check that different images produce a non-zero distance. assert!( distance_value.abs() > 1e-6, "LPIPS (Squeeze) should be != 0 for different images" ); } // ========================================================================= // Display Tests // ========================================================================= #[test] fn display_vgg() { let device = Default::default(); let lpips: Lpips = LpipsConfig::new().init(&device); let display_str = format!("{lpips}"); assert!(display_str.contains("Lpips")); assert!(display_str.contains("Vgg")); } #[test] fn display_alex() { let device = Default::default(); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device); let display_str = format!("{lpips}"); assert!(display_str.contains("Lpips")); assert!(display_str.contains("Alex")); } #[test] fn display_squeeze() { let device = Default::default(); let lpips: Lpips = LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device); let display_str = format!("{lpips}"); assert!(display_str.contains("Lpips")); assert!(display_str.contains("Squeeze")); } // ========================================================================= // Pretrained Weights Tests (requires network) // ========================================================================= /// Test VGG pretrained weights download and loading. #[test] fn test_lpips_pretrained_vgg() { let device = Default::default(); // This will download ~60MB of weights let lpips: Lpips = LpipsConfig::new() .with_net(LpipsNet::Vgg) .init_pretrained(&device); // Test with identical images - should be 0 let image = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() < 1e-5, "Pretrained LPIPS (VGG) should be ~0 for identical images, got {}", distance_value ); // Test with different images - should be positive let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value > 0.0, "Pretrained LPIPS (VGG) should be > 0 for different images, got {}", distance_value ); } /// Test AlexNet pretrained weights download and loading. #[test] fn test_lpips_pretrained_alex() { let device = Default::default(); let lpips: Lpips = LpipsConfig::new() .with_net(LpipsNet::Alex) .init_pretrained(&device); // Test with identical images let image = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() < 1e-5, "Pretrained LPIPS (Alex) should be ~0 for identical images, got {}", distance_value ); // Test with different images let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value > 0.0, "Pretrained LPIPS (Alex) should be > 0 for different images" ); } /// Test SqueezeNet pretrained weights download and loading. #[test] fn test_lpips_pretrained_squeeze() { let device = Default::default(); let lpips: Lpips = LpipsConfig::new() .with_net(LpipsNet::Squeeze) .init_pretrained(&device); // Test with identical images let image = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image.clone(), image, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value.abs() < 1e-5, "Pretrained LPIPS (Squeeze) should be ~0 for identical images, got {}", distance_value ); // Test with different images let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device); let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); let distance = lpips.forward(image1, image2, Reduction::Mean); let distance_value = distance.into_data().to_vec::().unwrap()[0]; assert!( distance_value > 0.0, "Pretrained LPIPS (Squeeze) should be > 0 for different images, got {}", distance_value ); } } ================================================ FILE: crates/burn-train/src/metric/vision/lpips/mod.rs ================================================ //! LPIPS (Learned Perceptual Image Patch Similarity) metric module. //! //! LPIPS measures perceptual similarity between images using deep features. //! Supports VGG16, AlexNet, and SqueezeNet as backbone networks. //! //! Reference: "The Unreasonable Effectiveness of Deep Features as a Perceptual Metric" //! mod alexnet; mod metric; mod squeezenet; mod vgg; mod weights; pub use metric::{Lpips, LpipsAlex, LpipsConfig, LpipsNet, LpipsSqueeze, LpipsVgg}; pub use weights::{get_backbone_weights_url, get_lpips_weights_url, load_pretrained_weights}; // Re-export feature extractors for advanced use cases pub use alexnet::AlexFeatureExtractor; pub use squeezenet::{FireModule, SqueezeFeatureExtractor}; pub use vgg::VggFeatureExtractor; ================================================ FILE: crates/burn-train/src/metric/vision/lpips/squeezenet.rs ================================================ //! SqueezeNet feature extractor for LPIPS. use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::activation::relu; use burn::tensor::backend::Backend; use burn_nn::PaddingConfig2d; use burn_nn::conv::{Conv2d, Conv2dConfig}; /// Fire module for SqueezeNet. /// /// A fire module consists of: /// - Squeeze layer: 1x1 conv to reduce channels /// - Expand layers: parallel 1x1 and 3x3 convs, concatenated #[derive(Module, Debug)] pub struct FireModule { /// Squeeze layer: 1x1 conv squeeze: Conv2d, /// Expand 1x1 conv expand1x1: Conv2d, /// Expand 3x3 conv expand3x3: Conv2d, } impl FireModule { /// Create a new Fire module. pub fn new( in_channels: usize, squeeze_channels: usize, expand1x1_channels: usize, expand3x3_channels: usize, device: &B::Device, ) -> Self { Self { squeeze: Conv2dConfig::new([in_channels, squeeze_channels], [1, 1]) .with_bias(true) .init(device), expand1x1: Conv2dConfig::new([squeeze_channels, expand1x1_channels], [1, 1]) .with_bias(true) .init(device), expand3x3: Conv2dConfig::new([squeeze_channels, expand3x3_channels], [3, 3]) .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .with_bias(true) .init(device), } } /// Forward pass through fire module. pub fn forward(&self, x: Tensor) -> Tensor { let squeezed = relu(self.squeeze.forward(x)); let e1 = relu(self.expand1x1.forward(squeezed.clone())); let e3 = relu(self.expand3x3.forward(squeezed)); // Concatenate along channel dimension Tensor::cat(vec![e1, e3], 1) } } /// SqueezeNet 1.1 feature extractor for LPIPS. /// /// Extracts features from 7 layers: /// - After conv1+relu: 64 channels /// - After fire1+fire2: 128 channels /// - After fire3+fire4: 256 channels /// - After fire5: 384 channels /// - After fire6: 384 channels /// - After fire7: 512 channels /// - After fire8: 512 channels #[derive(Module, Debug)] pub struct SqueezeFeatureExtractor { /// Conv1: 3 -> 64, kernel 3x3, stride 2 conv1: Conv2d, /// Fire1: 64 -> 128 (squeeze=16, expand=64+64) fire1: FireModule, /// Fire2: 128 -> 128 (squeeze=16, expand=64+64) fire2: FireModule, /// Fire3: 128 -> 256 (squeeze=32, expand=128+128) fire3: FireModule, /// Fire4: 256 -> 256 (squeeze=32, expand=128+128) fire4: FireModule, /// Fire5: 256 -> 384 (squeeze=48, expand=192+192) fire5: FireModule, /// Fire6: 384 -> 384 (squeeze=48, expand=192+192) fire6: FireModule, /// Fire7: 384 -> 512 (squeeze=64, expand=256+256) fire7: FireModule, /// Fire8: 512 -> 512 (squeeze=64, expand=256+256) fire8: FireModule, } impl SqueezeFeatureExtractor { /// Create a new SqueezeNet feature extractor. pub fn new(device: &B::Device) -> Self { Self { // Conv1: 3 -> 64, 3x3, stride 2 conv1: Conv2dConfig::new([3, 64], [3, 3]) .with_stride([2, 2]) .with_bias(true) .init(device), // Fire modules (SqueezeNet 1.1 configuration) fire1: FireModule::new(64, 16, 64, 64, device), // -> 128 fire2: FireModule::new(128, 16, 64, 64, device), // -> 128 fire3: FireModule::new(128, 32, 128, 128, device), // -> 256 fire4: FireModule::new(256, 32, 128, 128, device), // -> 256 fire5: FireModule::new(256, 48, 192, 192, device), // -> 384 fire6: FireModule::new(384, 48, 192, 192, device), // -> 384 fire7: FireModule::new(384, 64, 256, 256, device), // -> 512 fire8: FireModule::new(512, 64, 256, 256, device), // -> 512 } } /// Extract features from 7 SqueezeNet layers. pub fn forward(&self, x: Tensor) -> Vec> { let mut features = Vec::with_capacity(7); // Slice 1: Conv1 + ReLU (64 channels) let x = relu(self.conv1.forward(x)); features.push(x.clone()); // Slice 2: MaxPool + Fire1 + Fire2 (128 channels) let x = max_pool2d_squeeze(x); let x = self.fire1.forward(x); let x = self.fire2.forward(x); features.push(x.clone()); // Slice 3: MaxPool + Fire3 + Fire4 (256 channels) let x = max_pool2d_squeeze(x); let x = self.fire3.forward(x); let x = self.fire4.forward(x); features.push(x.clone()); // Slice 4: MaxPool + Fire5 (384 channels) let x = max_pool2d_squeeze(x); let x = self.fire5.forward(x); features.push(x.clone()); // Slice 5: Fire6 (384 channels) let x = self.fire6.forward(x); features.push(x.clone()); // Slice 6: Fire7 (512 channels) let x = self.fire7.forward(x); features.push(x.clone()); // Slice 7: Fire8 (512 channels) let x = self.fire8.forward(x); features.push(x); features } } /// 3x3 max pooling with stride 2, ceil mode (for SqueezeNet). fn max_pool2d_squeeze(x: Tensor) -> Tensor { burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], true) } ================================================ FILE: crates/burn-train/src/metric/vision/lpips/vgg.rs ================================================ //! VGG16 feature extractor for LPIPS. use burn_core as burn; use burn::module::Module; use burn::tensor::Tensor; use burn::tensor::activation::relu; use burn::tensor::backend::Backend; use burn_nn::PaddingConfig2d; use burn_nn::conv::{Conv2d, Conv2dConfig}; /// VGG16 feature extractor for LPIPS. /// /// Extracts features from 5 layers: /// - conv1_2: 64 channels /// - conv2_2: 128 channels /// - conv3_3: 256 channels /// - conv4_3: 512 channels /// - conv5_3: 512 channels #[derive(Module, Debug)] pub struct VggFeatureExtractor { // Block 1 conv1_1: Conv2d, conv1_2: Conv2d, // Block 2 conv2_1: Conv2d, conv2_2: Conv2d, // Block 3 conv3_1: Conv2d, conv3_2: Conv2d, conv3_3: Conv2d, // Block 4 conv4_1: Conv2d, conv4_2: Conv2d, conv4_3: Conv2d, // Block 5 conv5_1: Conv2d, conv5_2: Conv2d, conv5_3: Conv2d, } impl VggFeatureExtractor { /// Create a new VGG16 feature extractor. pub fn new(device: &B::Device) -> Self { let conv_config = |in_ch, out_ch| { Conv2dConfig::new([in_ch, out_ch], [3, 3]) .with_padding(PaddingConfig2d::Same) .with_bias(true) }; Self { // Block 1: 3 -> 64 conv1_1: conv_config(3, 64).init(device), conv1_2: conv_config(64, 64).init(device), // Block 2: 64 -> 128 conv2_1: conv_config(64, 128).init(device), conv2_2: conv_config(128, 128).init(device), // Block 3: 128 -> 256 conv3_1: conv_config(128, 256).init(device), conv3_2: conv_config(256, 256).init(device), conv3_3: conv_config(256, 256).init(device), // Block 4: 256 -> 512 conv4_1: conv_config(256, 512).init(device), conv4_2: conv_config(512, 512).init(device), conv4_3: conv_config(512, 512).init(device), // Block 5: 512 -> 512 conv5_1: conv_config(512, 512).init(device), conv5_2: conv_config(512, 512).init(device), conv5_3: conv_config(512, 512).init(device), } } /// Extract features from 5 VGG layers. pub fn forward(&self, x: Tensor) -> Vec> { let mut features = Vec::with_capacity(5); // Block 1 let x = relu(self.conv1_1.forward(x)); let x = relu(self.conv1_2.forward(x)); features.push(x.clone()); let x = max_pool2d(x); // Block 2 let x = relu(self.conv2_1.forward(x)); let x = relu(self.conv2_2.forward(x)); features.push(x.clone()); let x = max_pool2d(x); // Block 3 let x = relu(self.conv3_1.forward(x)); let x = relu(self.conv3_2.forward(x)); let x = relu(self.conv3_3.forward(x)); features.push(x.clone()); let x = max_pool2d(x); // Block 4 let x = relu(self.conv4_1.forward(x)); let x = relu(self.conv4_2.forward(x)); let x = relu(self.conv4_3.forward(x)); features.push(x.clone()); let x = max_pool2d(x); // Block 5 let x = relu(self.conv5_1.forward(x)); let x = relu(self.conv5_2.forward(x)); let x = relu(self.conv5_3.forward(x)); features.push(x); features } } /// 2x2 max pooling with stride 2. fn max_pool2d(x: Tensor) -> Tensor { burn_core::tensor::module::max_pool2d(x, [2, 2], [2, 2], [0, 0], [1, 1], false) } ================================================ FILE: crates/burn-train/src/metric/vision/lpips/weights.rs ================================================ //! Pretrained weights loading for LPIPS. use burn_core as burn; use burn::tensor::backend::Backend; use burn_std::network::downloader::download_file_as_bytes; use burn_store::{ModuleSnapshot, PytorchStore}; use std::fs::{File, create_dir_all}; use std::io::Write; use std::path::PathBuf; use super::metric::{Lpips, LpipsNet}; /// URLs for pretrained LPIPS linear layer weights from the official repository. /// Reference: https://github.com/richzhang/PerceptualSimilarity const LPIPS_VGG_URL: &str = "https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/vgg.pth"; const LPIPS_ALEX_URL: &str = "https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/alex.pth"; const LPIPS_SQUEEZE_URL: &str = "https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/squeeze.pth"; /// URLs for ImageNet pretrained backbone weights from PyTorch. const VGG16_IMAGENET_URL: &str = "https://download.pytorch.org/models/vgg16-397923af.pth"; const ALEXNET_IMAGENET_URL: &str = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth"; const SQUEEZENET_IMAGENET_URL: &str = "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth"; /// Get the download URL for LPIPS linear layer weights. pub fn get_lpips_weights_url(net: LpipsNet) -> &'static str { match net { LpipsNet::Vgg => LPIPS_VGG_URL, LpipsNet::Alex => LPIPS_ALEX_URL, LpipsNet::Squeeze => LPIPS_SQUEEZE_URL, } } /// Get the download URL for backbone ImageNet weights. pub fn get_backbone_weights_url(net: LpipsNet) -> &'static str { match net { LpipsNet::Vgg => VGG16_IMAGENET_URL, LpipsNet::Alex => ALEXNET_IMAGENET_URL, LpipsNet::Squeeze => SQUEEZENET_IMAGENET_URL, } } /// Get the cache directory for LPIPS weights. fn get_cache_dir() -> PathBuf { let cache_dir = dirs::cache_dir() .expect("Could not get cache directory") .join("burn-dataset") .join("lpips"); if !cache_dir.exists() { create_dir_all(&cache_dir).expect("Failed to create cache directory"); } cache_dir } /// Download file if not cached and return the cache path. fn download_if_needed(url: &str, cache_path: &PathBuf, message: &str) { if !cache_path.exists() { let bytes = download_file_as_bytes(url, message); let mut file = File::create(cache_path).expect("Failed to create cache file"); file.write_all(&bytes).expect("Failed to write weights"); } } /// Download and load pretrained weights into an LPIPS module. /// /// This loads both: /// 1. ImageNet pretrained backbone weights (VGG16/AlexNet/SqueezeNet) /// 2. LPIPS trained linear layer weights /// /// Weights are cached in the user's cache directory to avoid re-downloading. /// /// # Arguments /// /// * `lpips` - The LPIPS module to load weights into. /// * `net` - The network type (determines which weights to download). /// /// # Returns /// /// The LPIPS module with loaded pretrained weights. pub fn load_pretrained_weights(mut lpips: Lpips, net: LpipsNet) -> Lpips { let cache_dir = get_cache_dir(); // Step 1: Load backbone ImageNet weights let backbone_url = get_backbone_weights_url(net); let backbone_cache_path = cache_dir.join(format!("{:?}_backbone.pth", net).to_lowercase()); let backbone_message = match net { LpipsNet::Vgg => "Downloading VGG16 ImageNet weights...", LpipsNet::Alex => "Downloading AlexNet ImageNet weights...", LpipsNet::Squeeze => "Downloading SqueezeNet ImageNet weights...", }; download_if_needed(backbone_url, &backbone_cache_path, backbone_message); // Step 2: Load LPIPS linear layer weights let lpips_url = get_lpips_weights_url(net); let lpips_cache_path = cache_dir.join(format!("{:?}_lpips.pth", net).to_lowercase()); let lpips_message = match net { LpipsNet::Vgg => "Downloading LPIPS VGG weights...", LpipsNet::Alex => "Downloading LPIPS AlexNet weights...", LpipsNet::Squeeze => "Downloading LPIPS SqueezeNet weights...", }; download_if_needed(lpips_url, &lpips_cache_path, lpips_message); // Load backbone weights first lpips = load_backbone_weights(lpips, &backbone_cache_path); // Then load LPIPS linear layer weights lpips = load_lpips_weights(lpips, &lpips_cache_path); lpips } /// Load ImageNet pretrained backbone weights. fn load_backbone_weights(lpips: Lpips, cache_path: &PathBuf) -> Lpips { // Load directly into the inner struct to avoid enum variant issues match lpips { Lpips::Vgg(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) // VGG16 features.X -> extractor.convY_Z .with_key_remapping(r"^features\.0\.", "extractor.conv1_1.") .with_key_remapping(r"^features\.2\.", "extractor.conv1_2.") .with_key_remapping(r"^features\.5\.", "extractor.conv2_1.") .with_key_remapping(r"^features\.7\.", "extractor.conv2_2.") .with_key_remapping(r"^features\.10\.", "extractor.conv3_1.") .with_key_remapping(r"^features\.12\.", "extractor.conv3_2.") .with_key_remapping(r"^features\.14\.", "extractor.conv3_3.") .with_key_remapping(r"^features\.17\.", "extractor.conv4_1.") .with_key_remapping(r"^features\.19\.", "extractor.conv4_2.") .with_key_remapping(r"^features\.21\.", "extractor.conv4_3.") .with_key_remapping(r"^features\.24\.", "extractor.conv5_1.") .with_key_remapping(r"^features\.26\.", "extractor.conv5_2.") .with_key_remapping(r"^features\.28\.", "extractor.conv5_3."); if let Err(e) = inner.load_from(&mut store) { log::warn!("Some VGG backbone weights could not be loaded: {:?}", e); } Lpips::Vgg(inner) } Lpips::Alex(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) // AlexNet features.X -> extractor.convY .with_key_remapping(r"^features\.0\.", "extractor.conv1.") .with_key_remapping(r"^features\.3\.", "extractor.conv2.") .with_key_remapping(r"^features\.6\.", "extractor.conv3.") .with_key_remapping(r"^features\.8\.", "extractor.conv4.") .with_key_remapping(r"^features\.10\.", "extractor.conv5."); if let Err(e) = inner.load_from(&mut store) { log::warn!("Some AlexNet backbone weights could not be loaded: {:?}", e); } Lpips::Alex(inner) } Lpips::Squeeze(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) // SqueezeNet features.X -> extractor.* .with_key_remapping(r"^features\.0\.", "extractor.conv1.") .with_key_remapping(r"^features\.3\.", "extractor.fire1.") .with_key_remapping(r"^features\.4\.", "extractor.fire2.") .with_key_remapping(r"^features\.6\.", "extractor.fire3.") .with_key_remapping(r"^features\.7\.", "extractor.fire4.") .with_key_remapping(r"^features\.9\.", "extractor.fire5.") .with_key_remapping(r"^features\.10\.", "extractor.fire6.") .with_key_remapping(r"^features\.11\.", "extractor.fire7.") .with_key_remapping(r"^features\.12\.", "extractor.fire8."); if let Err(e) = inner.load_from(&mut store) { log::warn!( "Some SqueezeNet backbone weights could not be loaded: {:?}", e ); } Lpips::Squeeze(inner) } } } /// Load LPIPS trained linear layer weights. fn load_lpips_weights(lpips: Lpips, cache_path: &PathBuf) -> Lpips { // Load directly into the inner struct to avoid enum variant issues match lpips { Lpips::Vgg(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) .with_key_remapping(r"^lin0\.model\.1\.", "lin0.") .with_key_remapping(r"^lin1\.model\.1\.", "lin1.") .with_key_remapping(r"^lin2\.model\.1\.", "lin2.") .with_key_remapping(r"^lin3\.model\.1\.", "lin3.") .with_key_remapping(r"^lin4\.model\.1\.", "lin4."); if let Err(e) = inner.load_from(&mut store) { log::warn!("Some VGG LPIPS weights could not be loaded: {:?}", e); } Lpips::Vgg(inner) } Lpips::Alex(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) .with_key_remapping(r"^lin0\.model\.1\.", "lin0.") .with_key_remapping(r"^lin1\.model\.1\.", "lin1.") .with_key_remapping(r"^lin2\.model\.1\.", "lin2.") .with_key_remapping(r"^lin3\.model\.1\.", "lin3.") .with_key_remapping(r"^lin4\.model\.1\.", "lin4."); if let Err(e) = inner.load_from(&mut store) { log::warn!("Some AlexNet LPIPS weights could not be loaded: {:?}", e); } Lpips::Alex(inner) } Lpips::Squeeze(mut inner) => { let mut store = PytorchStore::from_file(cache_path) .allow_partial(true) .with_key_remapping(r"^lin0\.model\.1\.", "lin0.") .with_key_remapping(r"^lin1\.model\.1\.", "lin1.") .with_key_remapping(r"^lin2\.model\.1\.", "lin2.") .with_key_remapping(r"^lin3\.model\.1\.", "lin3.") .with_key_remapping(r"^lin4\.model\.1\.", "lin4.") .with_key_remapping(r"^lin5\.model\.1\.", "lin5.") .with_key_remapping(r"^lin6\.model\.1\.", "lin6."); if let Err(e) = inner.load_from(&mut store) { log::warn!("Some SqueezeNet LPIPS weights could not be loaded: {:?}", e); } Lpips::Squeeze(inner) } } } ================================================ FILE: crates/burn-train/src/metric/vision/mod.rs ================================================ mod dice; mod dists; mod lpips; mod ms_ssim; mod psnr; mod ssim; pub use dice::*; pub use dists::*; pub use lpips::*; pub use ms_ssim::*; pub use psnr::*; pub use ssim::*; ================================================ FILE: crates/burn-train/src/metric/vision/ms_ssim.rs ================================================ use crate::metric::{ Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry, SerializedEntry, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Int, Tensor}, tensor::{ ElementConversion, module::{avg_pool2d, conv2d}, ops::{ConvOptions, PadMode}, }, }; use core::marker::PhantomData; /// Input type for the [MsSsimMetric]. /// /// Both tensors must have shape `[N, C, H, W]`: /// - `N`: Batch size /// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.) /// - `H`: Height /// - `W`: Width /// /// # Important /// The image dimensions must be sufficiently large to accommodate the multi-scale /// computation. Specifically, for the default 5 scales used by Burn, the image dimensions /// should be at least `kernel_size * 2^(scales-1)` (e.g., 11 × 2^4 = 11 * 16 = 176 for default kernel size). /// If your images are smaller, reduce the kernel size or number of scales. /// /// # Example /// ```rust,ignore /// // Create input for RGB images /// let outputs: Tensor = /* tensor */; /// let targets: Tensor = /* tensor */; /// let input = MsSsimInput::new(outputs, targets); /// ``` pub struct MsSsimInput { /// Model outputs with shape [N, C, H, W]. outputs: Tensor, /// Ground truth targets with shape [N, C, H, W]. targets: Tensor, } impl MsSsimInput { /// Creates a new MsSsimInput with the given outputs and targets. /// /// # Arguments /// - `outputs`: The model output images with shape [N, C, H, W]. /// - `targets`: The ground truth images with shape [N, C, H, W]. /// /// # Returns /// A new instance of `MsSsimInput`. /// /// # Panics /// - If `outputs` and `targets` do not have the same shape. pub fn new(outputs: Tensor, targets: Tensor) -> Self { assert!( outputs.dims() == targets.dims(), "Shape mismatch: outputs {:?} targets {:?}", outputs.dims(), targets.dims() ); Self { outputs, targets } } } /// Configuration for the [MsSsimMetric]. #[derive(Debug, Clone)] pub struct MsSsimMetricConfig { /// A parameter of SSIM used to stabilize the luminance comparison. /// Default is 0.01. pub k1: f32, /// A parameter of SSIM used to stabilize the contrast comparison. /// Default is 0.03. pub k2: f32, /// The range of the pixel values in images which can be computed as following: /// `let pixel_range = max_pixel_val - min_pixel_val;` /// where `max_pixel_val` is the maximum possible pixel value and `min_pixel_val` /// is the minimum possible pixel value. /// /// - For normalized images in range [0, 1], it should be set to `1.0 - 0.0 = 1.0` /// - For normalized images in range [-1, 1], it should be set to `1.0 - (-1.0) = 2.0` /// - For 8-bit images in range [0, 255], it should be set to `255.0 - 0.0 = 255.0` pub pixel_range: f32, /// The MS-SSIM metric involves applying convolution to the input tensors using a Gaussian kernel. /// This is the kernel size of the Gaussian kernel. Default is 11. pub kernel_size: usize, /// The MS-SSIM metric involves applying convolution to the input tensors using a Gaussian kernel. /// This is the standard deviation of the Gaussian kernel. Default is 1.5. pub sigma: f32, /// The number of channels in the input images (e.g., 1 for grayscale, 3 for RGB). /// This is used to create the appropriate convolution kernels. Default is 3. pub channels: usize, /// The weights/betas for each scale in the MS-SSIM computation. /// The length of this vector determines the number of scales. /// Default is \[0.0448, 0.2856, 0.3001, 0.2363, 0.1333\] (5 scales). pub betas: Vec, } impl MsSsimMetricConfig { /// Creates a configuration with the specified data range and default parameters. /// /// # Default parameters /// - k1: 0.01 /// - k2: 0.03 /// - kernel_size: 11 /// - sigma: 1.5 /// - channels: 3 /// /// # Panics /// - If `pixel_range` is not positive. /// /// # Example /// ```rust,ignore /// // For normalized RGB images [0, 1] /// let config1 = MsSsimMetricConfig::new(1.0); /// /// // For 8-bit images [0, 255] /// let config2 = MsSsimMetricConfig::new(255.0); /// /// // For grayscale with custom kernel /// let config3 = MsSsimMetricConfig::new(1.0) /// .with_channels(1) /// .with_kernel_size(7); /// ``` pub fn new(pixel_range: f32) -> Self { assert!(pixel_range > 0.0, "pixel_range must be positive"); Self { k1: 0.01, k2: 0.03, pixel_range, kernel_size: 11, sigma: 1.5, channels: 3, betas: vec![0.0448, 0.2856, 0.3001, 0.2363, 0.1333], } } /// Sets custom values for the k1 and k2 parameters of MS-SSIM which are /// used for numerical stability. /// /// # Default values /// - k1: 0.01 /// - k2: 0.03 /// /// # Panics /// - If `k1` or `k2` is not positive. pub fn with_k1_k2(mut self, k1: f32, k2: f32) -> Self { assert!(k1 > 0.0, "k1 must be positive"); assert!(k2 > 0.0, "k2 must be positive"); self.k1 = k1; self.k2 = k2; self } /// Sets a custom kernel size for the Gaussian kernel used in MS-SSIM. The /// kernel size must be a positive odd number. /// /// # Default value /// - kernel_size: 11 /// /// # Panics /// - If `kernel_size` is not a positive odd number. pub fn with_kernel_size(mut self, kernel_size: usize) -> Self { assert!( kernel_size > 0 && kernel_size % 2 == 1, "kernel_size must be positive and an odd number" ); self.kernel_size = kernel_size; self } /// Sets a custom sigma (standard deviation) for the Gaussian kernel used in MS-SSIM. /// /// # Default value /// - sigma: 1.5 /// /// # Panics /// - If `sigma` is not positive. pub fn with_sigma(mut self, sigma: f32) -> Self { assert!(sigma > 0.0, "sigma must be a positive number"); self.sigma = sigma; self } /// Sets the number of channels for the input images. /// /// This affects the shape of the pre-computed convolution kernels. /// Change this if working with grayscale (1) or multispectral images (>3). /// /// # Default value /// - channels: 3 /// /// # Panics /// - If `channels` is 0. pub fn with_channels(mut self, channels: usize) -> Self { assert!(channels > 0, "channels must be a positive number"); self.channels = channels; self } /// Sets custom betas for the scales. The length of the betas vector /// determines the number of scales used in the MS-SSIM computation. /// If you want to make different parameter settings comparable, the betas /// vector should sum to 1 as per the original paper. However, note /// that this is not a strict requirement. /// /// # Default value /// - betas: `[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]` (5 scales) /// /// # Panics /// - If `betas` is empty. /// - If not all values in `betas` are positive. pub fn with_betas(mut self, betas: Vec) -> Self { assert!(!betas.is_empty(), "betas vector cannot be empty"); assert!( betas.iter().all(|&b| b >= 0.0), "All beta values must be non-negative" ); self.betas = betas; self } } /// Multi-Scale Structural Similarity Index (MS-SSIM) metric for image quality assessment. /// /// MS-SSIM extends the single-scale [SSIM](crate::metric::vision::SsimMetric) by computing /// the index at multiple resolutions (scales) and combining them using weighted averaging. /// This approach better correlates with human visual perception, especially for /// high-resolution images where fine details and texture variations are important. /// /// # Algorithm Overview /// /// MS-SSIM computes structural similarity across M scales (M=5 in Burn): /// /// 1. **Contrast** and **Structure** components are computed at every scale /// 2. **Luminance** is computed only at the coarsest (last) scale /// 3. Between scales, images are downsampled by a factor of 2 using average pooling /// /// The final metric is computed as: /// ```text /// MS-SSIM = L_M^{α_M} × ∏_{j=1}^M (C_j^{β_j} × S_j^{γ_j}) /// ``` /// /// Where: /// - `L_M` is luminance at the last scale (M) /// - `C_j` is contrast at scale j: `(2σ_xσ_y + C2) / (σ_x² + σ_y² + C2)` /// - `S_j` is structure at scale j: `(σ_xy + C3) / (σ_xσ_y + C3)` /// - `α_M, β_j, γ_j` are weights from Wang et al. (\[0.0448, 0.2856, 0.3001, 0.2363, 0.1333\]) /// /// # Notes /// /// - This implementation uses separable Gaussian convolution for efficiency (reduces complexity from O(K^2) to O(2K) per pixel) /// - Gaussian kernels are pre-computed during initialization to avoid redundant computation /// - The metric requires images to be large enough to survive the downsampling operations /// /// # Value Range /// /// MS-SSIM values typically range from 0 to 1, where: /// - 1.0 indicates perfect structural similarity (identical images) /// - 0.0 indicates no structural similarity /// - Values are usually positive due to the stability constants (C1, C2, C3) /// /// # References /// /// [Multi-scale Structural Similarity for Image Quality Assessment](https://www.cns.nyu.edu/pub/eero/wang03b.pdf) #[derive(Clone)] pub struct MsSsimMetric { name: MetricName, /// Internal state for numeric metric aggregation. state: NumericMetricState, /// Marker for backend type. _b: PhantomData, /// Configuration for the metric. config: MsSsimMetricConfig, /// Pre-computed horizontal Gaussian kernel with shape [C, 1, 1, K] horizontal_kernel: Tensor, /// Pre-computed vertical Gaussian kernel with shape [C, 1, K, 1] vertical_kernel: Tensor, } impl MsSsimMetric { /// Creates a new MS-SSIM metric with the given configuration. /// /// # Arguments /// - `config`: Configuration for the metric (data range, kernel size, etc.) /// - `device`: Device to place the Gaussian kernels on /// /// # Note /// The default metric name format is "MS-SSIM (pr={}, k={}, σ={})" /// where pr is the pixel range, k is the kernel size, and σ is the /// standard deviation. /// /// # Example /// ```ignore /// let config = MsSsimMetricConfig::new(1.0).with_channels(1); // Grayscale /// let metric = MsSsimMetric::::new(config, &device); /// ``` pub fn new(config: MsSsimMetricConfig, device: &B::Device) -> Self { let kernel = Self::create_1d_gaussian_kernel(&config, device); let size = config.kernel_size; // Create horizontal kernel: shape [C, 1, 1, K] for depthwise conv let horizontal_kernel = kernel .clone() .reshape([1, 1, 1, size]) .repeat_dim(0, config.channels); // Create vertical kernel: shape [C, 1, K, 1] for depthwise conv let vertical_kernel = kernel .reshape([1, 1, size, 1]) .repeat_dim(0, config.channels); Self { name: MetricName::new(format!( "MS-SSIM (pr={}, k={}, σ={})", config.pixel_range, config.kernel_size, config.sigma )), state: NumericMetricState::default(), _b: PhantomData, config, horizontal_kernel, vertical_kernel, } } /// Overrides the default metric name. /// /// # Example /// ```ignore /// let metric = MsSsimMetric::::new(config, &device) /// .with_name("Custom MS-SSIM Name"); /// ``` pub fn with_name(mut self, name: &str) -> Self { self.name = MetricName::new(name.to_string()); self } /// Creates a normalized 1D Gaussian kernel as a tensor where the kernel values sum to 1.0. fn create_1d_gaussian_kernel(config: &MsSsimMetricConfig, device: &B::Device) -> Tensor { let size = config.kernel_size as i64; let sigma = config.sigma; let center = (size / 2) as f32; let one_to_size_tensor = Tensor::::arange(0..size, device).float(); let x_vals = one_to_size_tensor.sub_scalar(center); // Gaussian: exp(-x² / 2σ²) let x_squared = x_vals.clone().mul(x_vals); let x_squared_div_2_sigma_squared = x_squared.div_scalar(2.0 * sigma * sigma); let unnormalized_kernel = x_squared_div_2_sigma_squared.neg().exp(); let kernel_vals_sum = unnormalized_kernel.clone().sum(); unnormalized_kernel.div(kernel_vals_sum) } /// Applies separable Gaussian convolution using pre-computed kernels. /// /// Performs two 1D convolutions (horizontal then vertical) which is /// computationally cheaper than a single 2D convolution. /// /// # Arguments /// - `input`: Tensor of shape [N, C, H, W] fn gaussian_separable_conv(&self, input: Tensor) -> Tensor { let padding = self.config.kernel_size / 2; let h_kernel = self.horizontal_kernel.clone(); let v_kernel = self.vertical_kernel.clone(); // Apply reflect padding to all 4 sides of the input tensor before convolution // Format: (left, right, top, bottom) let padded_input = input.pad((padding, padding, padding, padding), PadMode::Reflect); let h_conv_options = ConvOptions::new([1, 1], [0, 0], [1, 1], self.config.channels); let v_conv_options = ConvOptions::new([1, 1], [0, 0], [1, 1], self.config.channels); let input_after_h_conv = conv2d(padded_input, h_kernel, None, h_conv_options); conv2d(input_after_h_conv, v_kernel, None, v_conv_options) } } impl Metric for MsSsimMetric { type Input = MsSsimInput; fn name(&self) -> MetricName { self.name.clone() } fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let dims = item.outputs.dims(); let scales = self.config.betas.len(); assert_eq!( dims[1], self.config.channels, "Input has {} channels but metric was configured for {}", dims[1], self.config.channels ); // Verify minimum size for the given number of scales // After (scales - 1) downsamples, size is original / 2^(scales-1) // We need kernel_size at that scale let downsample_ops_num = scales.saturating_sub(1) as u32; let min_size = self.config.kernel_size * (2usize.pow(downsample_ops_num)); let h = dims[2]; let w = dims[3]; assert!( h >= min_size && w >= min_size, "Image dimensions (H={}, W={}) must be at least {} to support {} scales of MS-SSIM \ with kernel_size={}. Either increase image size, reduce kernel_size, or reduce the number of scales (betas).", h, w, min_size, scales, self.config.kernel_size ); let mut x = item.outputs.clone(); let mut y = item.targets.clone(); let betas = &self.config.betas; // Compute c1 = (k1 * L)^2 and c2 = (k2 * L)^2, c3 = c2/2 let c1 = (self.config.k1 * self.config.pixel_range).powi(2); let c2 = (self.config.k2 * self.config.pixel_range).powi(2); // Initialize accumulator to 1 for update via multiplication // Shape: [N, C] let batch_size = dims[0]; let channels = dims[1]; let mut ms_ssim_tensor = Tensor::::ones([batch_size, channels], &item.outputs.device()); for (j, beta_j) in betas.iter().enumerate() { // Compute mu_x and mu_y let mu_x = self.gaussian_separable_conv(x.clone()); let mu_y = self.gaussian_separable_conv(y.clone()); let square_of_mu_x = mu_x.clone() * mu_x.clone(); let square_of_mu_y = mu_y.clone() * mu_y.clone(); // Var(X) = E(X^2) - E(X)^2 let mu_of_x_squared = self.gaussian_separable_conv(x.clone() * x.clone()); let mu_of_y_squared = self.gaussian_separable_conv(y.clone() * y.clone()); let var_x = (mu_of_x_squared - square_of_mu_x.clone()).clamp_min(0.0); let var_y = (mu_of_y_squared - square_of_mu_y.clone()).clamp_min(0.0); // Cov(X, Y) = E(XY) - E(X)E(Y) let mu_of_xy = self.gaussian_separable_conv(x.clone() * y.clone()); let cov_xy = mu_of_xy - (mu_x.clone() * mu_y.clone()); // Compute cs_map = (2σxy + C2) / (σx² + σy² + C2) // This is mathematically equivalent to c(x,y) * s(x,y) when C3 = C2 / 2 let contrast_structure = (cov_xy * 2.0 + c2) / (var_x + var_y + c2); // Include luminance at the last scale if j == betas.len() - 1 { // Compute l(x, y) = (2μxμy + C1) / (μx² + μy² + C1) let luminance: Tensor = (2 * mu_x * mu_y + c1) / (square_of_mu_x + square_of_mu_y + c1); let ssim = luminance * contrast_structure; let ssim_spatial_mean = ssim.mean_dims(&[2, 3]).reshape([batch_size, channels]); // Clamp to avoid negative values before raising to power (prevents NaNs) let ssim_mean_clamped = ssim_spatial_mean.clamp_min(0.0); ms_ssim_tensor = ms_ssim_tensor * ssim_mean_clamped.powf_scalar(*beta_j); } else { let contrast_structure_spatial_mean = contrast_structure .mean_dims(&[2, 3]) .reshape([batch_size, channels]); // Clamp to avoid negative values before raising to power (prevents NaNs) let c_s_mean_clamped = contrast_structure_spatial_mean.clamp_min(0.0); ms_ssim_tensor = ms_ssim_tensor * c_s_mean_clamped.powf_scalar(*beta_j); x = avg_pool2d(x, [2, 2], [2, 2], [0, 0], false, false); y = avg_pool2d(y, [2, 2], [2, 2], [0, 0], false, false); } } let ms_ssim_per_image = ms_ssim_tensor.mean_dim(1); let avg_ms_ssim = ms_ssim_per_image.mean().into_scalar().elem::(); self.state.update( avg_ms_ssim, batch_size, FormatOptions::new(self.name()).precision(4), ) } /// Clears the metric state. fn clear(&mut self) { self.state.reset(); } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: true, } .into() } } impl Numeric for MsSsimMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::{TestBackend, metric::Numeric}; use burn_core::tensor::Distribution; fn test_config() -> MsSsimMetricConfig { // Use small kernel and single channel for testing // With kernel_size=3, we need images >= 3*16=48 MsSsimMetricConfig::new(1.0) .with_kernel_size(3) .with_sigma(1.0) .with_channels(1) } #[test] fn test_ms_ssim_perfect_similarity() { // Identical images should give MS-SSIM = 1.0 let device = Default::default(); let outputs = Tensor::::from_data( [[[ [0.5_f32; 64]; 64 // 64x64 constant image ]]], &device, ); let targets = outputs.clone(); let mut metric = MsSsimMetric::::new(test_config(), &device); let input = MsSsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ms_ssim = metric.value().current(); assert!( ms_ssim > 0.99, "MS-SSIM for identical images should be 1.0, got {}", ms_ssim ); } #[test] fn test_ms_ssim_completely_different() { // Black vs white images should give very low MS-SSIM (close to 0.0) let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 256, 256], &device); let targets = Tensor::::ones([1, 1, 256, 256], &device); let mut metric = MsSsimMetric::::new(test_config(), &device); let input = MsSsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ms_ssim = metric.value().current(); assert!( (ms_ssim - 0.3).abs() < 0.01, "MS-SSIM for black vs white should be low (around 0.3), got {}", ms_ssim ); } #[test] fn test_ms_ssim_similar_images() { // Small perturbation should give high MS-SSIM (close to 1.0) let device = Default::default(); let outputs = Tensor::::full([1, 1, 64, 64], 0.5, &device); let targets = Tensor::::full([1, 1, 64, 64], 0.52, &device); let mut metric = MsSsimMetric::::new(test_config(), &device); let input = MsSsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ms_ssim = metric.value().current(); assert!( ms_ssim > 0.95, "MS-SSIM for very similar images should be close to 1.0, got {}", ms_ssim ); } #[test] fn test_ms_ssim_batch_averaging() { let device = Default::default(); // Batch of 2: one identical, one different let outputs = Tensor::::from_data( [ [[[0.5_f32; 64]; 64]], // Image 1: constant 0.5 [[[0.0_f32; 64]; 64]], // Image 2: constant 0.0 (black) ], &device, ); let targets = Tensor::::from_data( [ [[[0.5_f32; 64]; 64]], // Image 1: identical [[[1.0_f32; 64]; 64]], // Image 2: white (opposite) ], &device, ); let mut metric = MsSsimMetric::::new(test_config(), &device); let input = MsSsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ms_ssim = metric.value().current(); // Average of ~1.0 and ~0.292 should be around 0.64 assert!( (ms_ssim - 0.64).abs() < 0.02, "Average MS-SSIM should be around 0.64, got {}", ms_ssim ); } #[test] fn test_ms_ssim_multichannel() { let device = Default::default(); // Test with 3 channels (RGB) let config = MsSsimMetricConfig::new(1.0) .with_kernel_size(3) .with_sigma(1.0) .with_channels(3); let outputs = Tensor::::random( [2, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device, ); let targets = outputs.clone(); let mut metric = MsSsimMetric::::new(config, &device); let input = MsSsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ms_ssim = metric.value().current(); assert!( ms_ssim > 0.99, "MS-SSIM for identical RGB images should be 1.0, got {}", ms_ssim ); } #[test] fn test_ms_ssim_running_average() { let device = Default::default(); let mut metric = MsSsimMetric::::new(test_config(), &device); // First update: identical (1.0) let img1 = Tensor::::full([1, 1, 64, 64], 0.5, &device); let input1 = MsSsimInput::new(img1.clone(), img1); metric.update(&input1, &MetricMetadata::fake()); assert!( metric.value().current() > 0.99, "First update should be approximately 1.0" ); // Second update: different (~0.29) let black = Tensor::::zeros([1, 1, 64, 64], &device); let white = Tensor::::ones([1, 1, 64, 64], &device); let input2 = MsSsimInput::new(black, white); metric.update(&input2, &MetricMetadata::fake()); let running = metric.running_value().current(); assert!( (running - 0.64).abs() < 0.02, "Running average should be approximately 0.64, got {}", running ); } #[test] fn test_ms_ssim_single_scale_small_image() { let device = Default::default(); // Default 5 scales with kernel_size=11 requires a 176x176 image. // With a single scale, the minimum required size drops to // just 11x11 (kernel_size * 2^0). let config = MsSsimMetricConfig::new(1.0) .with_channels(1) .with_betas(vec![1.0]); // 1 scale let mut metric = MsSsimMetric::::new(config, &device); // Create a 16x16 image. This would normally panic with 5 scales, // but should succeed with 1 scale. let outputs = Tensor::::zeros([1, 1, 16, 16], &device); let targets = outputs.clone(); let input = MsSsimInput::new(outputs, targets); // This should not panic let _ = metric.update(&input, &MetricMetadata::fake()); // Identical images should still yield ~1.0 let ms_ssim = metric.value().current(); assert!( ms_ssim > 0.99, "1-scale MS-SSIM for identical images should be 1.0, got {}", ms_ssim ); } #[test] fn test_ssim_symmetry() { // MS-SSIM(x, y) should equal MS-SSIM(y, x) // Symmetry is one of the mathematical properties of MS-SSIM let device = Default::default(); let config = MsSsimMetricConfig::new(1.0) .with_kernel_size(3) .with_sigma(1.0) .with_channels(3); let img1 = Tensor::::random( [2, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device, ); let img2 = Tensor::::random( [2, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device, ); let mut metric1 = MsSsimMetric::::new(config.clone(), &device); let input1 = MsSsimInput::new(img1.clone(), img2.clone()); let _entry = metric1.update(&input1, &MetricMetadata::fake()); let ms_ssim1 = metric1.value().current(); let mut metric2 = MsSsimMetric::::new(config, &device); let input2 = MsSsimInput::new(img2, img1); let _entry = metric2.update(&input2, &MetricMetadata::fake()); let ms_ssim2 = metric2.value().current(); assert!( (ms_ssim1 - ms_ssim2).abs() < 0.001, "MS-SSIM should be symmetric: MS-SSIM(x,y)={} vs MS-SSIM(y,x)={}", ms_ssim1, ms_ssim2 ); } #[test] fn test_ms_ssim_clear() { let device = Default::default(); let mut metric = MsSsimMetric::::new(test_config(), &device); let img = Tensor::::full([1, 1, 64, 64], 0.5, &device); let input = MsSsimInput::new(img.clone(), img); metric.update(&input, &MetricMetadata::fake()); assert!(metric.value().current() > 0.99); metric.clear(); assert!(metric.running_value().current().is_nan()); } #[test] fn test_ms_ssim_custom_name() { let device = Default::default(); let config = MsSsimMetricConfig::new(1.0); let metric = MsSsimMetric::::new(config, &device).with_name("CustomMS-SSIM"); assert_eq!(metric.name().to_string(), "CustomMS-SSIM"); } #[test] fn test_ms_ssim_default_name() { let device = Default::default(); let config = MsSsimMetricConfig::new(255.0); let metric = MsSsimMetric::::new(config, &device); assert_eq!(metric.name().to_string(), "MS-SSIM (pr=255, k=11, σ=1.5)"); } #[test] fn test_ms_ssim_attributes() { let device = Default::default(); let config = MsSsimMetricConfig::new(1.0); let metric = MsSsimMetric::::new(config, &device); match metric.attributes() { MetricAttributes::Numeric(attrs) => { assert!(attrs.higher_is_better); assert_eq!(attrs.unit, None); } _ => panic!("Expected numeric attributes"), } } #[test] #[should_panic(expected = "Shape mismatch")] fn test_ms_ssim_shape_mismatch() { let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 64, 64], &device); let targets = Tensor::::zeros([1, 1, 32, 32], &device); let _ = MsSsimInput::new(outputs, targets); } #[test] #[should_panic(expected = "k1 must be positive")] fn test_ms_ssim_negative_k1() { let _ = MsSsimMetricConfig::new(1.0).with_k1_k2(-0.01, 0.03); } #[test] #[should_panic(expected = "k2 must be positive")] fn test_ms_ssim_negative_k2() { let _ = MsSsimMetricConfig::new(1.0).with_k1_k2(0.01, -0.03); } #[test] #[should_panic(expected = "pixel_range must be positive")] fn test_ms_ssim_negative_data_range() { let _ = MsSsimMetricConfig::new(-1.0); } #[test] #[should_panic(expected = "pixel_range must be positive")] fn test_ms_ssim_zero_data_range() { let _ = MsSsimMetricConfig::new(0.0); } #[test] #[should_panic(expected = "kernel_size must be positive and an odd number")] fn test_ms_ssim_even_kernel_size() { let _ = MsSsimMetricConfig::new(1.0).with_kernel_size(10); } #[test] #[should_panic(expected = "kernel_size must be positive and an odd number")] fn test_ms_ssim_zero_kernel_size() { let _ = MsSsimMetricConfig::new(1.0).with_kernel_size(0); } #[test] #[should_panic(expected = "sigma must be a positive number")] fn test_ms_ssim_negative_sigma() { let _ = MsSsimMetricConfig::new(1.0).with_sigma(-1.5); } #[test] #[should_panic(expected = "sigma must be a positive number")] fn test_ms_ssim_zero_sigma() { let _ = MsSsimMetricConfig::new(1.0).with_sigma(0.0); } #[test] #[should_panic(expected = "channels must be a positive number")] fn test_ms_ssim_zero_channels() { let _ = MsSsimMetricConfig::new(1.0).with_channels(0); } #[test] #[should_panic(expected = "betas vector cannot be empty")] fn test_ms_ssim_empty_betas() { let _ = MsSsimMetricConfig::new(1.0).with_betas(vec![]); } #[test] #[should_panic(expected = "All beta values must be non-negative")] fn test_ms_ssim_negative_betas() { let _ = MsSsimMetricConfig::new(1.0).with_betas(vec![0.3, 0.3, -0.1, 0.5]); } #[test] #[should_panic(expected = "Image dimensions")] fn test_ms_ssim_image_too_small() { let device = Default::default(); // 3 scales with kernel_size=11 requires 44x44 minimum (11 * 2^2) let config = MsSsimMetricConfig::new(1.0).with_betas(vec![0.5, 0.3, 0.2]); let mut metric = MsSsimMetric::::new(config, &device); let outputs = Tensor::::zeros([1, 3, 32, 32], &device); // Too small (32 < 44) let targets = outputs.clone(); let input = MsSsimInput::new(outputs, targets); let _ = metric.update(&input, &MetricMetadata::fake()); } } ================================================ FILE: crates/burn-train/src/metric/vision/psnr.rs ================================================ use crate::metric::{ Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry, SerializedEntry, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::ElementConversion, }; use core::marker::PhantomData; use std::f64::consts::LN_10; /// Input type for the [PsnrMetric]. /// /// Both tensors must have shape `[N, C, H, W]`: /// - `N`: Batch size /// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.) /// - `H`: Height /// - `W`: Width pub struct PsnrInput { /// Model output (predictions/reconstructions) images with shape `[N, C, H, W]`. outputs: Tensor, /// Ground truth images with shape `[N, C, H, W]`. targets: Tensor, } impl PsnrInput { /// Creates a new PsnrInput with the given outputs and targets. /// /// Inputs are expected to have the dimensions `[N, C, H, W]` /// where `N` is the batch size, `C` is the number of channels, /// `H` is the height of the image, and `W` is the width of the image. /// /// # Arguments /// - `outputs`: The model output images with shape `[N, C, H, W]`. /// - `targets`: The ground truth images with shape `[N, C, H, W]`. /// /// # Returns /// A new instance of `PsnrInput`. /// /// # Panics /// - If `outputs` and `targets` do not have the same shape. pub fn new(outputs: Tensor, targets: Tensor) -> Self { assert!( outputs.dims() == targets.dims(), "Shape mismatch: outputs {:?}, targets {:?}", outputs.dims(), targets.dims() ); Self { outputs, targets } } } /// Configuration for the [PsnrMetric]. #[derive(Debug, Clone, Copy)] pub struct PsnrMetricConfig { /// Maximum possible pixel value. /// - Use `1.0` for normalized images in range \[0, 1\] /// - Use `255.0` for 8-bit images in range \[0, 255\] pub max_pixel_val: f64, /// Epsilon value for numerical stability when MSE is very small or zero. /// /// When MSE falls below this threshold, it is clamped to `epsilon`, /// resulting in a maximum PSNR of approximately `10 * log10(max_pixel_val² / epsilon)` dB. /// /// Default is `1e-10`, which yields ~100 dB for perfect reconstruction with `max_pixel_val = 1.0`. pub epsilon: f64, } impl PsnrMetricConfig { /// Creates a configuration with the specified maximum pixel value. /// /// # Example /// ```ignore /// // Normalized images [0, 1] /// let config = PsnrMetricConfig::new(1.0); /// /// // 8-bit images [0, 255] /// let config = PsnrMetricConfig::new(255.0); /// // Also set a custom epsilon value /// let config = PsnrMetricConfig::new(255.0).with_epsilon(1e-8); /// ``` pub fn new(max_pixel_val: f64) -> Self { assert!(max_pixel_val > 0.0, "max_pixel_val must be positive"); Self { max_pixel_val, epsilon: 1e-10, } } /// Sets a custom epsilon for numerical stability near zero MSE pub fn with_epsilon(mut self, epsilon: f64) -> Self { assert!(epsilon > 0.0, "epsilon must be positive"); self.epsilon = epsilon; self } } /// The peak signal-to-noise ratio (PSNR) metric for image quality assessment. /// /// PSNR is commonly used to measure the quality of reconstructed images /// compared to the original. Higher values (in dB) indicate better quality. /// /// # Formula /// ```text /// PSNR = 10 * log10(MAX^2 / MSE) /// ``` /// where MAX is the maximum possible pixel value and MSE is the mean squared error. /// /// # Note /// - PSNR is computed for each image first, and then it is averaged across all the images in the batch. /// - For perfect reconstruction (MSE = 0), the MSE is clamped to `epsilon` to avoid division by zero, /// yielding a maximum PSNR of `10 * log10(MAX^2 / epsilon)` dB. #[derive(Clone)] pub struct PsnrMetric { name: MetricName, /// Internal state for numeric metric aggregation. state: NumericMetricState, /// Marker for backend type. _b: PhantomData, /// Configuration for the metric. config: PsnrMetricConfig, } impl PsnrMetric { /// Creates a new PSNR metric with the given configuration. /// /// # Example /// ```ignore /// let config = PsnrMetricConfig::new(1.0); /// let metric = PsnrMetric::::new(config); /// ``` pub fn new(config: PsnrMetricConfig) -> Self { Self { name: MetricName::new(format!("PSNR@{}", config.max_pixel_val)), state: NumericMetricState::default(), config, _b: PhantomData, } } /// Overrides the default metric name which is `PSNR@{max_pixel_val}`. /// /// Examples names: /// - `PSNR@1.0` /// - `PSNR@255.0` /// /// Use this method to provide a custom name. pub fn with_name(mut self, name: &str) -> Self { self.name = MetricName::new(name.to_string()); self } } impl Metric for PsnrMetric { type Input = PsnrInput; fn name(&self) -> MetricName { self.name.clone() } fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let dims = item.outputs.dims(); let batch_size = dims[0]; let outputs = item.outputs.clone(); let targets = item.targets.clone(); // Compute per-image MSE by reducing over all dimensions except batch (dims 1, 2, 3) // Resulting shape: [N, 1, 1, 1] let diff = outputs.sub(targets); let mse_per_image = diff.powi_scalar(2).mean_dims(&[1, 2, 3]); // Flatten to shape: [N] let mse_flat = mse_per_image.flatten::<1>(0, 3); // Clamp MSE to avoid division by 0 in the expression (MAX^2 / MSE) let mse_clamped = mse_flat.clamp_min(self.config.epsilon); let max_squared = self.config.max_pixel_val * self.config.max_pixel_val; // Compute PSNR for each image and accumulate // PSNR value in dB (using the change of base formula): // 10 * log10(MAX^2 / MSE) = 10 * ln(MAX^2 / MSE) / ln(10) // = ln(MAX^2 / MSE) * (10 / ln(10)) let psnr_per_image = mse_clamped .recip() .mul_scalar(max_squared) .log() .mul_scalar(10.0 / LN_10); let avg_psnr = psnr_per_image.mean().into_scalar().elem::(); self.state.update( avg_psnr, batch_size, FormatOptions::new(self.name()).unit("dB").precision(2), ) } /// Clears the metric state. fn clear(&mut self) { self.state.reset(); } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("dB".to_string()), higher_is_better: true, } .into() } } impl Numeric for PsnrMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::{TestBackend, metric::Numeric}; use burn_core::tensor::TensorData; #[test] fn test_psnr_perfect_reconstruction() { // When outputs exactly match targets, PSNR should be very high // (limited by epsilon clamping to ~100 dB with default epsilon=1e-10) let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[[1.0_f32, 0.5], [0.25, 0.75]]]]), &device, ); let targets = outputs.clone(); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); // With epsilon = 1e-10 and max=1.0: // PSNR = 10 * log10(1.0 / 1e-10) = 100 dB let psnr = metric.value().current(); assert!( psnr >= 99.0, "PSNR for perfect reconstruction should be ~100 dB, got {} dB", psnr ); } #[test] fn test_psnr_constant_error() { // Constant error of 0.1 across all pixels // MSE = 0.01, PSNR = 10 * log10(1.0 / 0.01) = 20 dB let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]), &device, ); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); assert!( (psnr - 20.0).abs() < 0.01, "Expected PSNR ~20 dB, got {} dB", psnr ); } #[test] fn test_psnr_varying_error() { // Errors: 0.1, 0.2, 0.3, 0.4 → squared: 0.01, 0.04, 0.09, 0.16 // MSE = 0.075, PSNR = 10 * log10(1.0 / 0.075) ≈ 11.249 dB let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[[0.1_f32, 0.2], [0.3, 0.4]]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]), &device, ); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 10.0 * (1.0_f64 / 0.075).log10(); assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{:.3} dB, got {} dB", expected_psnr, psnr ); } #[test] fn test_psnr_max_pixel_255() { // Test with 8-bit image range [0, 255] // Error = 10 everywhere, MSE = 100 // PSNR = 10 * log10(255^2 / 100) ≈ 28.13 dB let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[[10.0_f32, 10.0], [10.0, 10.0]]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]), &device, ); let config = PsnrMetricConfig::new(255.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 10.0 * (255.0_f64 * 255.0 / 100.0).log10(); assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{:.3} dB, got {} dB", expected_psnr, psnr ); } #[test] fn test_psnr_batch_averaging() { // Batch of 2 images with different MSEs // Image 1: error 0.1 → MSE = 0.01 → PSNR = 20 dB // Image 2: error 0.01 → MSE = 0.0001 → PSNR = 40 dB // Average PSNR = 30 dB let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([ [[[0.1_f32, 0.1], [0.1, 0.1]]], [[[0.01_f32, 0.01], [0.01, 0.01]]], ]), &device, ); let targets = Tensor::::from_data( TensorData::from([ [[[0.0_f32, 0.0], [0.0, 0.0]]], [[[0.0_f32, 0.0], [0.0, 0.0]]], ]), &device, ); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 30.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected average PSNR ~{} dB, got {} dB", expected_psnr, psnr ); } #[test] fn test_psnr_multichannel() { // Test with 3 channels (RGB-like) // All channels have constant error 0.1 → MSE = 0.01 → PSNR = 20 dB let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[ [[0.1_f32, 0.1], [0.1, 0.1]], [[0.1_f32, 0.1], [0.1, 0.1]], [[0.1_f32, 0.1], [0.1, 0.1]], ]]), &device, ); let targets = Tensor::::zeros([1, 3, 2, 2], &device); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 20.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{} dB, got {} dB", expected_psnr, psnr ); } #[test] fn test_psnr_running_average() { // Test running average across multiple updates let device = Default::default(); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); // First update: error 0.1 → MSE = 0.01 → PSNR = 20 dB let outputs1 = Tensor::::from_data( TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]), &device, ); let targets1 = Tensor::::zeros([1, 1, 2, 2], &device); let input1 = PsnrInput::new(outputs1, targets1); let _entry = metric.update(&input1, &MetricMetadata::fake()); let psnr1 = metric.value().current(); let expected_psnr1 = 20.0; assert!( (psnr1 - expected_psnr1).abs() < 0.01, "First update PSNR should be ~{} dB, got {} dB", expected_psnr1, psnr1 ); // Second update: error 0.01 → MSE = 0.0001 → PSNR = 40 dB let outputs2 = Tensor::::from_data( TensorData::from([[[[0.01_f32, 0.01], [0.01, 0.01]]]]), &device, ); let targets2 = Tensor::::zeros([1, 1, 2, 2], &device); let input2 = PsnrInput::new(outputs2, targets2); let _entry = metric.update(&input2, &MetricMetadata::fake()); // Running average: (20 + 40) / 2 = 30 dB let running_avg_psnr = metric.running_value().current(); let expected_running_avg_psnr = 30.0; assert!( (running_avg_psnr - expected_running_avg_psnr).abs() < 0.01, "Running average should be ~{} dB, got {} dB", expected_running_avg_psnr, running_avg_psnr ); } #[test] fn test_psnr_clear() { // Error 0.1 → MSE = 0.01 → PSNR = 20 dB let device = Default::default(); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let outputs = Tensor::::from_data( TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]), &device, ); let targets = Tensor::::zeros([1, 1, 2, 2], &device); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 20.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{} dB, got {} dB", expected_psnr, psnr ); // Clear and verify reset metric.clear(); let psnr = metric.running_value().current(); assert!(psnr.is_nan(), "Expected NaN after clear, got {} dB", psnr) } #[test] fn test_psnr_custom_name() { let config = PsnrMetricConfig::new(1.0); let metric = PsnrMetric::::new(config).with_name("CustomPSNR"); assert_eq!(metric.name().to_string(), "CustomPSNR"); } #[test] fn test_psnr_custom_epsilon() { let device = Default::default(); // With a larger epsilon, perfect reconstruction gives lower PSNR let config = PsnrMetricConfig::new(1.0).with_epsilon(0.01); let mut metric = PsnrMetric::::new(config); let outputs = Tensor::::from_data( TensorData::from([[[[0.5_f32, 0.5], [0.5, 0.5]]]]), &device, ); let targets = outputs.clone(); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); // With epsilon = 0.01, PSNR = 10 * log10(1.0 / 0.01) = 20 dB let psnr = metric.value().current(); let expected_psnr = 20.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{} dB with epsilon=0.01, got {}", expected_psnr, psnr ); } #[test] fn test_psnr_negative_errors() { // Test that negative differences (target > output) work correctly let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]), &device, ); let targets = Tensor::::from_data( TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]), &device, ); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); // Same MSE as positive errors (0.01), so PSNR = 20 dB let psnr = metric.value().current(); let expected_psnr = 20.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{} dB, got {}", expected_psnr, psnr ); } #[test] fn test_psnr_large_batch() { // Test with a larger batch to verify batch dimension handling let device = Default::default(); let batch_size = 8; // All images have constant error 0.1 → MSE = 0.01 → PSNR = 20 dB let outputs = Tensor::::full([batch_size, 3, 4, 4], 0.1, &device); let targets = Tensor::::zeros([batch_size, 3, 4, 4], &device); let config = PsnrMetricConfig::new(1.0); let mut metric = PsnrMetric::::new(config); let input = PsnrInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let psnr = metric.value().current(); let expected_psnr = 20.0; assert!( (psnr - expected_psnr).abs() < 0.01, "Expected PSNR ~{} dB, got {}", expected_psnr, psnr ); } #[test] fn test_psnr_attributes() { let config = PsnrMetricConfig::new(1.0); let metric = PsnrMetric::::new(config); let attrs = metric.attributes(); match attrs { MetricAttributes::Numeric(numeric_attrs) => { assert_eq!(numeric_attrs.unit, Some("dB".to_string())); assert!(numeric_attrs.higher_is_better); } _ => panic!("Expected numeric attributes"), } } #[test] #[should_panic(expected = "Shape mismatch")] fn test_psnr_shape_mismatch() { let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 2, 2], &device); let targets = Tensor::::zeros([1, 1, 3, 3], &device); let _ = PsnrInput::new(outputs, targets); } #[test] #[should_panic(expected = "max_pixel_val must be positive")] fn test_psnr_negative_max_pixel_val() { let _ = PsnrMetricConfig::new(-1.0); } #[test] #[should_panic(expected = "max_pixel_val must be positive")] fn test_psnr_zero_max_pixel_val() { let _ = PsnrMetricConfig::new(0.0); } #[test] #[should_panic(expected = "epsilon must be positive")] fn test_psnr_negative_epsilon() { let _ = PsnrMetricConfig::new(1.0).with_epsilon(-1e-10); } #[test] #[should_panic(expected = "epsilon must be positive")] fn test_psnr_zero_epsilon() { let _ = PsnrMetricConfig::new(1.0).with_epsilon(0.0); } } ================================================ FILE: crates/burn-train/src/metric/vision/ssim.rs ================================================ use crate::metric::{ Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry, SerializedEntry, state::{FormatOptions, NumericMetricState}, }; use burn_core::{ prelude::{Backend, Tensor}, tensor::{ElementConversion, module::conv2d, ops::ConvOptions}, }; use core::marker::PhantomData; /// Input type for the [SsimMetric]. /// /// Both tensors must have shape `[N, C, H, W]`: /// - `N`: Batch size /// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.) /// - `H`: Height /// - `W`: Width pub struct SsimInput { /// Model output (predictions/reconstructions) images with shape [N, C, H, W]. outputs: Tensor, /// Ground truth images with shape [N, C, H, W]. targets: Tensor, } impl SsimInput { /// Creates a new SsimInput with the given outputs and targets. /// /// Inputs are expected to have the dimensions `[N, C, H, W]` /// where `N` is the batch size, `C` is the number of channels, /// `H` is the height of the image, and `W` is the width of the image. /// /// # Arguments /// - `outputs`: The model output images with shape [N, C, H, W]. /// - `targets`: The ground truth images with shape [N, C, H, W]. /// /// # Returns /// A new instance of `SsimInput`. /// /// # Panics /// - If `outputs` and `targets` do not have the same shape. pub fn new(outputs: Tensor, targets: Tensor) -> Self { assert!( outputs.dims() == targets.dims(), "Shape mismatch: outputs {:?}, targets {:?}", outputs.dims(), targets.dims() ); Self { outputs, targets } } } /// Configuration for the [SsimMetric]. #[derive(Debug, Clone, Copy)] pub struct SsimMetricConfig { /// The range of the pixel values in images which can be computed as following: /// `let pixel_range = max_pixel_val - min_pixel_val;` /// where `max_pixel_val` is the maximum possible pixel value and `min_pixel_val` /// is the minimum possible pixel value. /// /// - For normalized images in range [0, 1], it should be set to `1.0 - 0.0 = 1.0` /// - For normalized images in range [-1, 1], it should be set to `1.0 - (-1.0) = 2.0` /// - For 8-bit images in range [0, 255], it should be set to `255.0 - 0.0 = 255.0` pub pixel_range: f32, /// A parameter of SSIM used to stabilize the luminance comparison. /// Default is 0.01. pub k1: f32, /// A parameter of SSIM used to stabilize the contrast comparison. /// Default is 0.03. pub k2: f32, /// The SSIM metric involves applying convolution to the input tensors using a Gaussian kernel. /// This is the kernel size of the Gaussian kernel. Default is 11. pub kernel_size: usize, /// The SSIM metric involves applying convolution to the input tensors using a Gaussian kernel. /// This is the standard deviation of the Gaussian kernel. Default is 1.5. pub sigma: f32, } impl SsimMetricConfig { /// Creates a configuration with the specified data range and default parameters. /// /// # Default parameters /// - k1: 0.01 /// - k2: 0.03 /// - kernel_size: 11 /// - sigma: 1.5 /// /// # Panics /// - If `pixel_range` is not positive. /// /// # Example /// ```ignore /// // Normalized images [0, 1] /// let config1 = SsimMetricConfig::new(1.0); /// /// // 8-bit images [0, 255] /// let config2 = SsimMetricConfig::new(255.0); /// /// // Also set custom values for k1 and k2 /// let config3 = SsimMetricConfig::new(1.0).with_k1_k2(0.015, 0.025); /// /// // Also set a custom value for window size /// config3.with_kernel_size(13); /// ``` pub fn new(pixel_range: f32) -> Self { assert!(pixel_range > 0.0, "pixel_range must be positive"); Self { pixel_range: pixel_range, k1: 0.01, k2: 0.03, kernel_size: 11, sigma: 1.5, } } /// Sets a custom value for the k1 and k2 parameters of SSIM which are /// used for numerical stability. /// /// # Default values /// - k1: 0.01 /// - k2: 0.03 /// /// # Panics /// - If `k1` or `k2` is not positive. pub fn with_k1_k2(mut self, k1: f32, k2: f32) -> Self { assert!(k1 > 0.0, "k1 must be positive"); assert!(k2 > 0.0, "k2 must be positive"); self.k1 = k1; self.k2 = k2; self } /// Sets a custom window size for the Gaussian kernel used in SSIM. The /// window size must be a positive odd number. /// /// # Default value /// - kernel_size: 11 /// /// # Panics /// - If `kernel_size` is not a positive odd number. pub fn with_kernel_size(mut self, kernel_size: usize) -> Self { assert!( kernel_size > 0 && kernel_size % 2 == 1, "kernel_size must be positive and an odd number" ); self.kernel_size = kernel_size; self } /// Sets a custom sigma (standard deviation) for the Gaussian kernel used in SSIM. /// /// # Default value /// - sigma: 1.5 /// /// # Panics /// - If `sigma` is not positive. pub fn with_sigma(mut self, sigma: f32) -> Self { assert!(sigma > 0.0, "sigma must be positive"); self.sigma = sigma; self } } /// The SSIM (structural similarity index measure) metric for image quality assessment. /// /// SSIM measures the perceived quality of images by comparing luminance, /// contrast, and structure. Values range from -1 to 1, where 1 indicates /// perfect structural similarity. /// /// # Formula /// ```text /// SSIM(x, y) = (2μxμy + C1)(2σxy + C2) / (μx² + μy² + C1)(σx² + σy² + C2) /// ``` /// /// # Note /// - This implementation uses separable Gaussian convolution for efficiency. Instead of a /// single 2D convolution with a K by K kernel, it applies two 1D convolutions (horizontal /// then vertical). This reduces the computational complexity from O(K^2) to O(2K) per pixel. /// - SSIM is computed for each image first, and then it is averaged across all the images in the batch. #[derive(Clone)] pub struct SsimMetric { name: MetricName, /// Internal state for numeric metric aggregation. state: NumericMetricState, /// Marker for backend type. _b: PhantomData, /// Configuration for the metric. config: SsimMetricConfig, } impl SsimMetric { /// Creates a new SSIM metric with the given configuration. /// /// # Note /// The metric name format is "SSIM (dr={}, w={}, σ={})" /// where dr is the data range, w is the window size, sigma is the /// standard deviation. For example, the metric name might be /// "SSIM (dr=1.0, w=11, σ=1.5)". /// /// # Example /// ```ignore /// let ssim_config = SsimMetricConfig::new(1.0); /// let ssim_metric = SsimMetric::::new(ssim_config); /// ``` pub fn new(config: SsimMetricConfig) -> Self { Self { name: MetricName::new(format!( "SSIM (dr={}, w={}, σ={})", config.pixel_range, config.kernel_size, config.sigma, )), state: NumericMetricState::default(), config, _b: PhantomData, } } /// Overrides the default metric name which is "SSIM". pub fn with_name(mut self, name: &str) -> Self { self.name = MetricName::new(name.to_string()); self } /// Creates a 1D Gaussian kernel as a tensor. /// /// Returns a normalized kernel where all values sum to 1. /// The returned kernel will be reshaped by the `gaussian_conv_separable` /// associated function later. fn create_1d_gaussian_kernel(&self) -> Vec { let size = self.config.kernel_size; let sigma = self.config.sigma; let center = (size / 2) as f32; let mut kernel = vec![0.0f32; size]; let mut sum = 0.0f32; for (i, v) in kernel.iter_mut().enumerate() { let x = i as f32 - center; let value = (-(x * x) / (2.0 * sigma * sigma)).exp(); *v = value; sum += value; } // Normalize so values sum to 1 for v in kernel.iter_mut() { *v /= sum; } kernel } /// Applies separable convolution using two 1D Gaussian kernels. /// /// # Arguments /// - `inputs`: Tensor of shape [N, C, H, W] /// - `kernel_1d`: The 1D Gaussian kernel values /// - `channels`: Number of channels for depthwise convolution. fn gaussian_conv_separable( &self, input: Tensor, kernel_1d: &[f32], channels: usize, device: &B::Device, ) -> Tensor { let size = self.config.kernel_size; let padding = size / 2; // Create horizontal kernel: shape [C, 1, 1, K] let horizontal_kernel = Tensor::::from_floats(kernel_1d, device) .reshape([1, 1, 1, size]) // [1, 1, 1, K] .repeat_dim(0, channels); // [C, 1, 1, K] let vertical_kernel = Tensor::::from_floats(kernel_1d, device) .reshape([1, 1, size, 1]) // [1, 1, K, 1] .repeat_dim(0, channels); // [C, 1, K, 1] // Apply horizontal convolution let horizontal_conv_options = ConvOptions::new([1, 1], [0, padding], [1, 1], channels); let input_after_horizontal_conv = conv2d(input, horizontal_kernel, None, horizontal_conv_options); // Apply vertical convolution let vertical_conv_options = ConvOptions::new([1, 1], [padding, 0], [1, 1], channels); conv2d( input_after_horizontal_conv, vertical_kernel, None, vertical_conv_options, ) } } impl Metric for SsimMetric { type Input = SsimInput; fn name(&self) -> MetricName { self.name.clone() } fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry { let dims = item.outputs.dims(); let batch_size = dims[0]; let channels = dims[1]; let device = item.outputs.device(); let img_height = dims[2]; let img_width = dims[3]; assert!( img_height >= self.config.kernel_size && img_width >= self.config.kernel_size, "Image dimensions (H={}, W={}) must be >= kernel_size ({})", img_height, img_width, self.config.kernel_size ); // Constants in SSIM formula used for numerical stability let c1 = (self.config.k1 * self.config.pixel_range).powi(2); let c2 = (self.config.k2 * self.config.pixel_range).powi(2); // Create 1D Gaussian kernel to apply separable convolutions twice (horizontally and vertically) let kernel_1d = self.create_1d_gaussian_kernel(); // Compute mu_x and mu_y, their product and squares let x = item.outputs.clone(); let y = item.targets.clone(); let mu_x = self.gaussian_conv_separable(x.clone(), &kernel_1d, channels, &device); let mu_y = self.gaussian_conv_separable(y.clone(), &kernel_1d, channels, &device); let mu_x_mu_y = mu_x.clone() * mu_y.clone(); let square_of_mu_x = mu_x.clone() * mu_x.clone(); let square_of_mu_y = mu_y.clone() * mu_y.clone(); // Compute var_x, var_y (which are the same as (sigma_x)^2 and (sigma_y)^2): // Var(X) = E[X^2] - E[X]^2 // var_x = mu_of_x_squared - (mu_x * mu_x) let mu_of_x_squared = self.gaussian_conv_separable(x.clone() * x.clone(), &kernel_1d, channels, &device); let mu_of_y_squared = self.gaussian_conv_separable(y.clone() * y.clone(), &kernel_1d, channels, &device); let var_x = (mu_of_x_squared - square_of_mu_x.clone()).clamp_min(0.0); let var_y = (mu_of_y_squared - square_of_mu_y.clone()).clamp_min(0.0); // Compute the sample covariance of x and y: sigma_xy // Cov(X, Y) = E[XY] - E[X]E[Y] // sigma_xy = mu_xy - (mu_x * mu_y) let mu_xy = self.gaussian_conv_separable(x * y, &kernel_1d, channels, &device); let sigma_xy = mu_xy - mu_x_mu_y.clone(); // Compute SSIM: // SSIM(x, y) = (2μxμy + C1)(2σxy + C2) / (μx² + μy² + C1)(σx² + σy² + C2) let numerator = (mu_x_mu_y.mul_scalar(2.0_f32) + c1) * (sigma_xy.mul_scalar(2.0_f32) + c2); let denominator = (square_of_mu_x + square_of_mu_y + c1) * (var_x + var_y + c2); let ssim_tensor = numerator / denominator; // Average SSIM across all dimensions to get a single scalar value let ssim_per_image = ssim_tensor.mean_dims(&[1, 2, 3]); let avg_ssim = ssim_per_image.mean().into_scalar().elem::(); self.state.update( avg_ssim, batch_size, FormatOptions::new(self.name()).precision(4), ) } /// Clears the metric state. fn clear(&mut self) { self.state.reset(); } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: None, higher_is_better: true, } .into() } } impl Numeric for SsimMetric { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] #[allow(clippy::manual_range_contains)] mod tests { use super::*; use crate::{TestBackend, metric::Numeric}; use burn_core::tensor::{Distribution, Shape, TensorData}; fn test_config() -> SsimMetricConfig { SsimMetricConfig::new(1.0) .with_kernel_size(3) .with_sigma(1.0) } #[test] fn test_ssim_perfect_similarity() { // When outputs exactly match targets, SSIM should be 1.0 let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[[ [0.1_f32, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9], ]]]), &device, ); let targets = outputs.clone(); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "SSIM for identical images should be 1.0, got {}", ssim ); } #[test] fn test_ssim_completely_different() { // Constant black vs constant white // With constant images: SSIM = (2*mu_x*mu_y + C1) / (mu_x^2 + mu_y^2 + C1) // For x=0, y=1 with C1=(0.01)^2=0.0001: SSIM ≈ 0.0001 / (1 + 0.00001) = 0.00009999 let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 4, 4], &device); let targets = Tensor::::ones([1, 1, 4, 4], &device); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( ssim < 0.0001, "SSIM for black vs white images should be very low, got {}", ssim ); } #[test] fn test_ssim_similar_images() { // Small perturbation should give high SSIM let device = Default::default(); let outputs = Tensor::::full([1, 1, 4, 4], 0.5, &device); let targets = Tensor::::full([1, 1, 4, 4], 0.51, &device); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( ssim > 0.99, "SSIM for very similar images should be close to 1.0, got {}", ssim ); } #[test] fn test_ssim_batch_averaging() { // Batch of 2 images: // Image 1: identical (SSIM = 1.0) // Image 2: black vs white (SSIM ≈ 0) let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([ [[ [0.5_f32, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], ]], [[ [0.0_f32, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], ]], ]), &device, ); let targets = Tensor::::from_data( TensorData::from([ [[ [0.5_f32, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5], ]], [[ [1.0_f32, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], ]], ]), &device, ); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); // Average of ~1.0 and ~0.0 should be around 0.5 assert!( ssim > 0.49 && ssim < 0.51, "Average SSIM should be around 0.5, got {}", ssim ); } #[test] fn test_ssim_multichannel() { // Test with 3 channels (e.g., RGB) let device = Default::default(); let outputs = Tensor::::from_data( TensorData::from([[ [ [0.5_f32, 0.6, 0.7, 0.8], [0.4, 0.5, 0.6, 0.7], [0.3, 0.4, 0.5, 0.6], [0.2, 0.3, 0.4, 0.5], ], [ [0.3_f32, 0.4, 0.5, 0.6], [0.2, 0.3, 0.4, 0.5], [0.1, 0.2, 0.3, 0.4], [0.0, 0.1, 0.2, 0.3], ], [ [0.7_f32, 0.8, 0.9, 1.0], [0.6, 0.7, 0.8, 0.9], [0.5, 0.6, 0.7, 0.8], [0.4, 0.5, 0.6, 0.7], ], ]]), &device, ); let targets = outputs.clone(); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "SSIM for identical RGB images should be 1.0, got {}", ssim ); } #[test] fn test_ssim_symmetry() { // SSIM(x, y) should equal SSIM(y, x) // Symmetry is one of the mathematical properties of SSIM let device = Default::default(); let img1 = Tensor::::from_data( TensorData::from([[[ [0.1_f32, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9], ]]]), &device, ); let img2 = Tensor::::from_data( TensorData::from([[[ [0.2_f32, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9], [0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0], ]]]), &device, ); let config = test_config(); let mut metric1 = SsimMetric::::new(config); let input1 = SsimInput::new(img1.clone(), img2.clone()); let _entry = metric1.update(&input1, &MetricMetadata::fake()); let ssim1 = metric1.value().current(); let mut metric2 = SsimMetric::::new(config); let input2 = SsimInput::new(img2, img1); let _entry = metric2.update(&input2, &MetricMetadata::fake()); let ssim2 = metric2.value().current(); assert!( (ssim1 - ssim2).abs() < 0.001, "SSIM should be symmetric: SSIM(x,y)={} vs SSIM(y,x)={}", ssim1, ssim2 ); } #[test] fn test_ssim_range() { // SSIM values should be in [-1, 1] range let device = Default::default(); let shape = Shape::new([1, 1, 11, 11]); let distribution = Distribution::Uniform(0.0, 1.0); let outputs = Tensor::::random(shape.clone(), distribution, &device); let targets = Tensor::::random(shape, distribution, &device); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( ssim >= -1.0 && ssim <= 1.0, "SSIM should be in range [-1, 1], got {}", ssim ); } #[test] fn test_ssim_running_average() { let device = Default::default(); let mut metric = SsimMetric::::new(test_config()); // First update: identical images (SSIM = 1.0) let outputs1 = Tensor::::from_data( TensorData::from([[[ [0.5_f32, 0.6, 0.7, 0.8], [0.4, 0.5, 0.6, 0.7], [0.3, 0.4, 0.5, 0.6], [0.2, 0.3, 0.4, 0.5], ]]]), &device, ); let targets1 = outputs1.clone(); let input1 = SsimInput::new(outputs1, targets1); let _entry = metric.update(&input1, &MetricMetadata::fake()); let ssim1 = metric.value().current(); assert!( (ssim1 - 1.0).abs() < 0.001, "First update SSIM should be ~1.0, got {}", ssim1 ); // Second update: very different images (SSIM close to 0) let outputs2 = Tensor::::zeros([1, 1, 4, 4], &device); let targets2 = Tensor::::ones([1, 1, 4, 4], &device); let input2 = SsimInput::new(outputs2, targets2); let _entry = metric.update(&input2, &MetricMetadata::fake()); // Running average should be around 0.5 let running_avg = metric.running_value().current(); assert!( running_avg > 0.49 && running_avg < 0.51, "Running average should be around 0.5, got {}", running_avg ); } #[test] fn test_ssim_clear() { let device = Default::default(); let mut metric = SsimMetric::::new(test_config()); let outputs = Tensor::::from_data( TensorData::from([[[ [0.5_f32, 0.6, 0.7, 0.8], [0.4, 0.5, 0.6, 0.7], [0.3, 0.4, 0.5, 0.6], [0.2, 0.3, 0.4, 0.5], ]]]), &device, ); let targets = outputs.clone(); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "Expected SSIM ~1.0, got {}", ssim ); // Clear and verify reset metric.clear(); let ssim = metric.running_value().current(); assert!(ssim.is_nan(), "Expected NaN after clear, got {}", ssim); } #[test] fn test_ssim_custom_name() { let config = SsimMetricConfig::new(1.0); let metric = SsimMetric::::new(config).with_name("CustomSSIM"); assert_eq!(metric.name().to_string(), "CustomSSIM"); let metric = SsimMetric::::new(test_config()); assert_eq!(metric.name().to_string(), "SSIM (dr=1, w=3, σ=1)"); let config = SsimMetricConfig::new(255.0); let metric = SsimMetric::::new(config); assert_eq!(metric.name().to_string(), "SSIM (dr=255, w=11, σ=1.5)"); } #[test] fn test_ssim_pixel_range_255() { // Test with 8-bit image range [0, 255] let device = Default::default(); let shape = Shape::new([1, 1, 10, 10]); let distribution = Distribution::Uniform(0.0, 255.0); let outputs = Tensor::::random(shape.clone(), distribution, &device); let targets = outputs.clone(); let config = SsimMetricConfig::new(255.0).with_kernel_size(3); let mut metric = SsimMetric::::new(config); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "SSIM for identical 8-bit images should be 1.0, got {}", ssim ); } #[test] fn test_ssim_large_batch() { let device = Default::default(); let shape = Shape::new([20, 3, 30, 30]); let distribution = Distribution::Uniform(0.0, 1.0); let outputs = Tensor::::random(shape, distribution, &device); let targets = outputs.clone(); let mut metric = SsimMetric::::new(test_config()); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "SSIM for identical batch should be 1.0, got {}", ssim ); } #[test] fn test_ssim_default_kernel_size() { // Test with default kernel_size=11, need images >= 11x11 let device = Default::default(); let shape = Shape::new([1, 1, 1080, 1920]); let distribution = Distribution::Uniform(0.0, 1.0); let outputs = Tensor::::random(shape, distribution, &device); let targets = outputs.clone(); let config = SsimMetricConfig::new(1.0); // default kernel_size=11 let mut metric = SsimMetric::::new(config); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); let ssim = metric.value().current(); assert!( (ssim - 1.0).abs() < 0.001, "SSIM with default window size should work and SSIM should be ~0.0, got {}", ssim ); } #[test] fn test_ssim_attributes() { let config = SsimMetricConfig::new(1.0); let metric = SsimMetric::::new(config); let attrs = metric.attributes(); match attrs { MetricAttributes::Numeric(numeric_attrs) => { assert_eq!(numeric_attrs.unit, None); assert!(numeric_attrs.higher_is_better); } _ => panic!("Expected numeric attributes"), } } #[test] #[should_panic(expected = "Shape mismatch")] fn test_ssim_shape_mismatch() { let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 4, 4], &device); let targets = Tensor::::zeros([1, 1, 5, 5], &device); let _ = SsimInput::new(outputs, targets); } #[test] #[should_panic(expected = "Image dimensions (H=4, W=4) must be >= kernel_size (11)")] fn test_ssim_image_too_small() { let device = Default::default(); let outputs = Tensor::::zeros([1, 1, 4, 4], &device); let targets = outputs.clone(); // Default kernel_size=11, but image is only 4x4 let config = SsimMetricConfig::new(1.0); let mut metric = SsimMetric::::new(config); let input = SsimInput::new(outputs, targets); let _entry = metric.update(&input, &MetricMetadata::fake()); } #[test] fn test_ssim_valid_k1_k2() { let config = SsimMetricConfig::new(1.0).with_k1_k2(0.015, 0.035); assert!( config.k1 == 0.015 && config.k2 == 0.035, "Expected k1=0.015 and k2=0.035, got k1={} and k2={}", config.k1, config.k2 ); } #[test] #[should_panic(expected = "pixel_range must be positive")] fn test_ssim_negative_pixel_range() { let _ = SsimMetricConfig::new(-1.0); } #[test] #[should_panic(expected = "pixel_range must be positive")] fn test_ssim_zero_pixel_range() { let _ = SsimMetricConfig::new(0.0); } #[test] #[should_panic(expected = "k1 must be positive")] fn test_ssim_negative_k1() { let _ = SsimMetricConfig::new(1.0).with_k1_k2(-0.01, 0.03); } #[test] #[should_panic(expected = "k2 must be positive")] fn test_ssim_negative_k2() { let _ = SsimMetricConfig::new(1.0).with_k1_k2(0.01, -0.03); } #[test] #[should_panic(expected = "kernel_size must be positive and an odd number")] fn test_ssim_even_kernel_size() { let _ = SsimMetricConfig::new(1.0).with_kernel_size(10); } #[test] #[should_panic(expected = "kernel_size must be positive and an odd number")] fn test_ssim_zero_kernel_size() { let _ = SsimMetricConfig::new(1.0).with_kernel_size(0); } #[test] #[should_panic(expected = "sigma must be positive")] fn test_ssim_negative_sigma() { let _ = SsimMetricConfig::new(1.0).with_sigma(-1.5); } #[test] #[should_panic(expected = "sigma must be positive")] fn test_ssim_zero_sigma() { let _ = SsimMetricConfig::new(1.0).with_sigma(0.0); } } ================================================ FILE: crates/burn-train/src/metric/wer.rs ================================================ use super::cer::edit_distance; use super::state::{FormatOptions, NumericMetricState}; use super::{MetricMetadata, SerializedEntry}; use crate::metric::{ Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry, }; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; use core::marker::PhantomData; use std::sync::Arc; // The edit_distance function remains the same as it calculates the Levenshtein distance // between two sequences. The "units" within the sequences will now be treated as words. /// The word error rate (WER) metric, similar to the CER, is defined as the edit distance (e.g. Levenshtein distance) between the predicted /// and reference word sequences, divided by the total number of words in the reference. Here, the "units" within the sequences are words. /// #[derive(Clone)] pub struct WordErrorRate { name: MetricName, state: NumericMetricState, pad_token: Option, _b: PhantomData, } /// The [word error rate metric](WordErrorRate) input type. #[derive(new)] pub struct WerInput { /// The predicted token sequences (as a 2-D tensor of token indices). pub outputs: Tensor, /// The target token sequences (as a 2-D tensor of token indices). pub targets: Tensor, } impl Default for WordErrorRate { fn default() -> Self { Self::new() } } impl WordErrorRate { /// Creates the metric. pub fn new() -> Self { Self { name: Arc::new("WER".to_string()), state: NumericMetricState::default(), pad_token: None, _b: PhantomData, } } /// Sets the pad token. pub fn with_pad_token(mut self, index: usize) -> Self { self.pad_token = Some(index); self } } impl Metric for WordErrorRate { type Input = WerInput; fn update(&mut self, input: &WerInput, _metadata: &MetricMetadata) -> SerializedEntry { let outputs = input.outputs.clone(); let targets = input.targets.clone(); let [batch_size, seq_len] = targets.dims(); let outputs_data = outputs .to_data() .to_vec::() .expect("Failed to convert outputs to Vec"); let targets_data = targets .to_data() .to_vec::() .expect("Failed to convert targets to Vec"); let pad_token = self.pad_token; let mut total_edit_distance = 0.0; let mut total_target_length = 0.0; // Process each sequence in the batch for i in 0..batch_size { let start = i * seq_len; let end = (i + 1) * seq_len; let output_seq = &outputs_data[start..end]; let target_seq = &targets_data[start..end]; // Handle padding and map elements to i32. // These sequences now represent "words" (token IDs). let output_seq_no_pad = match pad_token { Some(pad) => output_seq .iter() .take_while(|&&x| x != pad as i64) .map(|&x| x as i32) .collect::>(), None => output_seq.iter().map(|&x| x as i32).collect(), }; let target_seq_no_pad = match pad_token { Some(pad) => target_seq .iter() .take_while(|&&x| x != pad as i64) .map(|&x| x as i32) .collect::>(), None => target_seq.iter().map(|&x| x as i32).collect(), }; let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad); total_edit_distance += ed as f64; total_target_length += target_seq_no_pad.len() as f64; } // Compute current WER value as a percentage let value = if total_target_length > 0.0 { 100.0 * total_edit_distance / total_target_length } else { 0.0 }; self.state.update( value, batch_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } fn name(&self) -> MetricName { self.name.clone() } fn clear(&mut self) { self.state.reset(); } fn attributes(&self) -> MetricAttributes { NumericAttributes { unit: Some("%".to_string()), higher_is_better: false, } .into() } } impl Numeric for WordErrorRate { fn value(&self) -> NumericEntry { self.state.current_value() } fn running_value(&self) -> NumericEntry { self.state.running_value() } } #[cfg(test)] mod tests { use super::*; use crate::TestBackend; /// Perfect match => WER = 0 %. #[test] fn test_wer_without_padding() { let device = Default::default(); let mut metric = WordErrorRate::::new(); // Batch size = 2, sequence length = 2 let preds = Tensor::from_data([[1, 2], [3, 4]], &device); let tgts = Tensor::from_data([[1, 2], [3, 4]], &device); metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake()); assert_eq!(0.0, metric.value().current()); } /// Two word edits in four target words => 50 %. #[test] fn test_wer_without_padding_two_errors() { let device = Default::default(); let mut metric = WordErrorRate::::new(); // One substitution in each sequence. // Sequence 1: target [1, 3], pred [1, 2] -> 1 error (3 vs 2) // Sequence 2: target [3, 4], pred [3, 5] -> 1 error (4 vs 5) let preds = Tensor::from_data([[1, 2], [3, 5]], &device); let tgts = Tensor::from_data([[1, 3], [3, 4]], &device); metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake()); // Total errors = 2, Total target words = 4. WER = (2/4) * 100 = 50 % assert_eq!(50.0, metric.value().current()); } /// Same scenario as above, but with right-padding (token 9) ignored. #[test] fn test_wer_with_padding() { let device = Default::default(); let pad = 9_i64; let mut metric = WordErrorRate::::new().with_pad_token(pad as usize); // Each row has three columns, last one is the pad token. // Target sequences after removing pad: [1, 3] and [3, 4] (total length 4) // Predicted sequences after removing pad: [1, 2] and [3, 5] let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device); let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device); metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake()); assert_eq!(50.0, metric.value().current()); } /// `clear()` must reset the running statistics to NaN. #[test] fn test_clear_resets_state() { let device = Default::default(); let mut metric = WordErrorRate::::new(); let preds = Tensor::from_data([[1, 2]], &device); let tgts = Tensor::from_data([[1, 3]], &device); // one error metric.update( &WerInput::new(preds.clone(), tgts.clone()), &MetricMetadata::fake(), ); assert!(metric.value().current() > 0.0); metric.clear(); assert!(metric.value().current().is_nan()); } } ================================================ FILE: crates/burn-train/src/renderer/base.rs ================================================ use std::sync::Arc; use crate::{ LearnerSummary, metric::{MetricDefinition, MetricEntry, NumericEntry}, }; use burn_core::data::dataloader::Progress; /// Trait for rendering metrics. pub trait MetricsRendererTraining: Send + Sync { /// Updates the training metric state. /// /// # Arguments /// /// * `state` - The metric state. fn update_train(&mut self, state: MetricState); /// Updates the validation metric state. /// /// # Arguments /// /// * `state` - The metric state. fn update_valid(&mut self, state: MetricState); /// Renders the training progress. /// /// # Arguments /// /// * `item` - The training progress. fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec); /// Renders the validation progress. /// /// # Arguments /// /// * `item` - The validation progress. fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec); /// Callback method invoked when training ends, whether it /// completed successfully or was interrupted. /// /// # Returns /// /// A result indicating whether the end-of-training actions were successful. fn on_train_end( &mut self, summary: Option, ) -> Result<(), Box> { default_summary_action(summary); Ok(()) } } /// A renderer that can be used for both training and evaluation. pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining { /// Keep the renderer from automatically closing, requiring manual action to close it. fn manual_close(&mut self); /// Register a new metric. fn register_metric(&mut self, definition: MetricDefinition); } #[derive(Clone)] /// The name of an evaluation. /// /// This is going to group metrics together for easier analysis. pub struct EvaluationName { pub(crate) name: Arc, } impl EvaluationName { /// Creates a new evaluation name. pub fn new(s: S) -> Self { Self { name: Arc::new(format!("{s}")), } } /// Returns the evaluation name. pub fn as_str(&self) -> &str { &self.name } } impl core::fmt::Display for EvaluationName { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str(&self.name) } } /// Trait for rendering metrics. pub trait MetricsRendererEvaluation: Send + Sync { /// Updates the testing metric state. /// /// # Arguments /// /// * `state` - The metric state. fn update_test(&mut self, name: EvaluationName, state: MetricState); /// Renders the testing progress. /// /// # Arguments /// /// * `item` - The training progress. fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec); /// Callback method invoked when testing ends, whether it /// completed successfully or was interrupted. /// /// # Returns /// /// A result indicating whether the end-of-testing actions were successful. fn on_test_end( &mut self, summary: Option, ) -> Result<(), Box> { default_summary_action(summary); Ok(()) } } /// The state of a metric. #[derive(Debug)] pub enum MetricState { /// A generic metric. Generic(MetricEntry), /// A numeric metric. Numeric(MetricEntry, NumericEntry), } /// Training progress. #[derive(Debug)] pub struct TrainingProgress { /// The progress. pub progress: Option, /// The progress of the whole training. pub global_progress: Progress, /// The iteration, if it differs from the items processed. pub iteration: Option, } /// Evaluation progress. #[derive(Debug)] pub struct EvaluationProgress { /// The progress. pub progress: Progress, /// The iteration, if it is different from the processed items. pub iteration: Option, } impl From<&EvaluationProgress> for TrainingProgress { fn from(value: &EvaluationProgress) -> Self { TrainingProgress { progress: None, global_progress: value.progress.clone(), iteration: value.iteration, } } } impl TrainingProgress { /// Creates a new empty training progress. pub fn none() -> Self { Self { progress: None, global_progress: Progress { items_processed: 0, items_total: 0, }, iteration: None, } } } /// Type of progress indicators. pub enum ProgressType { /// Detailed progress. Detailed { /// The tag. tag: String, /// The progress. progress: Progress, }, /// Simple value. Value { /// The tag. tag: String, /// The value. value: usize, }, } fn default_summary_action(summary: Option) { if let Some(summary) = summary { println!("{summary}"); } } ================================================ FILE: crates/burn-train/src/renderer/cli.rs ================================================ use crate::renderer::{ EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation, MetricsRendererTraining, ProgressType, TrainingProgress, }; /// A simple renderer for when the cli feature is not enabled. pub struct CliMetricsRenderer; #[allow(clippy::new_without_default)] impl CliMetricsRenderer { /// Create a new instance. pub fn new() -> Self { Self {} } } impl MetricsRendererTraining for CliMetricsRenderer { fn update_train(&mut self, _state: MetricState) {} fn update_valid(&mut self, _state: MetricState) {} fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec) { println!("{item:?}"); } fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec) { println!("{item:?}"); } } impl MetricsRendererEvaluation for CliMetricsRenderer { fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec) { println!("{item:?}"); } fn update_test(&mut self, _name: super::EvaluationName, _state: MetricState) {} } impl MetricsRenderer for CliMetricsRenderer { fn manual_close(&mut self) { // Nothing to do. } fn register_metric(&mut self, _definition: crate::metric::MetricDefinition) {} } ================================================ FILE: crates/burn-train/src/renderer/mod.rs ================================================ #[cfg(feature = "tui")] use std::io::IsTerminal; mod base; pub use base::*; pub(crate) mod cli; pub use cli::*; /// The tui renderer #[cfg(feature = "tui")] pub mod tui; use crate::Interrupter; /// Return the default metrics renderer. /// /// This can be either: /// - `TuiMetricsRenderer`, when the `tui` feature is enabled and `stdout` is /// a terminal, or /// - `CliMetricsRenderer`, when the `tui` feature is not enabled, or `stdout` /// is not a terminal. #[allow(unused_variables)] pub(crate) fn default_renderer( interuptor: Interrupter, checkpoint: Option, ) -> Box { #[cfg(feature = "tui")] if std::io::stdout().is_terminal() { return Box::new(tui::TuiMetricsRendererWrapper::new(interuptor, checkpoint)); } Box::new(CliMetricsRenderer::new()) } ================================================ FILE: crates/burn-train/src/renderer/tui/base.rs ================================================ use std::sync::Arc; use super::{ ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView, }; use ratatui::{ prelude::{Constraint, Direction, Layout, Rect}, style::Color, }; #[derive(new)] pub(crate) struct MetricsView<'a> { metric_numeric: NumericMetricView<'a>, metric_text: TextMetricView, progress: ProgressBarView, controls: ControlsView, status: StatusView, } impl MetricsView<'_> { pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { let chunks = Layout::default() .direction(Direction::Vertical) .constraints([Constraint::Min(16), Constraint::Max(4)].as_ref()) .split(size); let size_other = chunks[0]; let size_progress = chunks[1]; let chunks = Layout::default() .direction(Direction::Horizontal) .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref()) .split(size_other); let size_other = chunks[0]; let size_metric_numeric = chunks[1]; let chunks = Layout::default() .direction(Direction::Vertical) .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref()) .split(size_other); let size_controls = chunks[0]; let size_metric_text = chunks[1]; let size_status = chunks[2]; self.metric_numeric.render(frame, size_metric_numeric); self.metric_text.render(frame, size_metric_text); self.controls.render(frame, size_controls); self.progress.render(frame, size_progress); self.status.render(frame, size_status); } } #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub(crate) enum TuiSplit { Train, Valid, Test, } #[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord)] pub(crate) enum TuiGroup { Default, Named(Arc), } #[derive(new, Hash, Clone, PartialEq, Eq, PartialOrd, Ord)] pub(crate) struct TuiTag { pub(crate) split: TuiSplit, pub(crate) group: TuiGroup, } impl core::fmt::Display for TuiTag { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.group { TuiGroup::Default => f.write_fmt(format_args!("{}", self.split)), TuiGroup::Named(group) => f.write_fmt(format_args!("{} - {}", self.split, group)), } } } impl core::fmt::Display for TuiGroup { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TuiGroup::Default => f.write_str(""), TuiGroup::Named(group) => f.write_fmt(format_args!("{group} ")), } } } impl core::fmt::Display for TuiSplit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TuiSplit::Train => f.write_str("Train"), TuiSplit::Valid => f.write_str("Valid"), TuiSplit::Test => f.write_str("Test"), } } } impl TuiSplit { pub(crate) fn color(&self) -> Color { match self { TuiSplit::Train => Color::LightRed, TuiSplit::Valid => Color::LightBlue, TuiSplit::Test => Color::LightGreen, } } } ================================================ FILE: crates/burn-train/src/renderer/tui/controls.rs ================================================ use super::TerminalFrame; use ratatui::{ prelude::{Alignment, Rect}, style::{Color, Style, Stylize}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, }; /// Controls view. pub(crate) struct ControlsView; impl ControlsView { /// Render the view. pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { let lines = vec![ vec![ Span::from(" Quit : ").yellow().bold(), Span::from("q ").bold(), Span::from(" Stop the training.").italic(), ], vec![ Span::from(" Plots Metrics : ").yellow().bold(), Span::from("⬅ ➡").bold(), Span::from(" Switch between metrics.").italic(), ], vec![ Span::from(" Plots Type : ").yellow().bold(), Span::from("⬆ ⬇").bold(), Span::from(" Switch between types.").italic(), ], ]; let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::>()) .alignment(Alignment::Left) .wrap(Wrap { trim: false }) .style(Style::default().fg(Color::Gray)) .block( Block::default() .borders(Borders::ALL) .style(Style::default().fg(Color::Gray)) .title_alignment(Alignment::Left) .title("Controls"), ); frame.render_widget(paragraph, size); } } ================================================ FILE: crates/burn-train/src/renderer/tui/full_history.rs ================================================ use super::PlotAxes; use crate::{ metric::NumericEntry, renderer::tui::{TuiSplit, TuiTag}, }; use ratatui::{ style::{Color, Style}, symbols, widgets::{Bar, Dataset, GraphType}, }; use std::collections::BTreeMap; /// A plot that shows the full history at a reduced resolution. pub(crate) struct FullHistoryPlot { pub(crate) axes: PlotAxes, points: BTreeMap, max_samples: usize, max_samples_ratio: BTreeMap, next_x_state: usize, } struct FullHistoryPoints { min_x: f64, max_x: f64, min_y: f64, max_y: f64, avg_sum: f64, avg_counter: f64, points: Vec<(f64, f64)>, max_samples: usize, step_size: usize, } impl FullHistoryPlot { /// Create a new history plot. pub(crate) fn new(max_samples: usize) -> Self { Self { points: BTreeMap::default(), axes: PlotAxes::default(), max_samples, max_samples_ratio: BTreeMap::default(), next_x_state: 0, } } /// Update the maximum amount of sample to display for the validation points. /// /// This is necessary if we want the validation line to have the same point density as the /// training line. pub(crate) fn update_max_sample(&mut self, split: TuiSplit, ratio: f64) { self.max_samples_ratio.insert(split, ratio); self.points .iter_mut() .filter(|(tag, _)| tag.split == split) .for_each(|(_, points)| { points.max_samples = (self.max_samples as f64 * ratio) as usize; }); } /// Register a training data point. pub(crate) fn push(&mut self, tag: TuiTag, data: NumericEntry) { let x_current = self.next_x(); let points = match self.points.get_mut(&tag) { Some(val) => val, None => { let max_samples = self .max_samples_ratio .get(&tag.split) .map(|ratio| (*ratio * self.max_samples as f64) as usize) .unwrap_or(self.max_samples); self.points .insert(tag.clone(), FullHistoryPoints::new(max_samples)); self.points.get_mut(&tag).unwrap() } }; points.push((x_current, data)); self.update_bounds(); } pub(crate) fn datasets(&self) -> Vec> { let mut datasets = Vec::with_capacity(2); for (tag, points) in self.points.iter() { datasets.push(points.dataset(format!("{tag}"), tag.split.color())); } datasets } pub(crate) fn bars(&self, max: u64, bar_width: &mut usize) -> Vec> { let mut bars = Vec::new(); for (tag, points) in self.points.iter() { if let Some((bar, width)) = points.bar(tag, max) { *bar_width = usize::max(*bar_width, width); bars.push(bar); } } bars } fn next_x(&mut self) -> f64 { let value = self.next_x_state; self.next_x_state += 1; value as f64 } fn update_bounds(&mut self) { let (mut x_min, mut x_max) = (f64::MAX, f64::MIN); let (mut y_min, mut y_max) = (f64::MAX, f64::MIN); for points in self.points.values() { x_min = f64::min(x_min, points.min_x); x_max = f64::max(x_max, points.max_x); y_min = f64::min(y_min, points.min_y); y_max = f64::max(y_max, points.max_y); } self.axes.update_bounds((x_min, x_max), (y_min, y_max)); } } impl FullHistoryPoints { fn new(max_samples: usize) -> Self { Self { min_x: 0., max_x: 0., min_y: f64::MAX, max_y: f64::MIN, avg_sum: 0.0, avg_counter: 0.0, points: Vec::with_capacity(max_samples), max_samples, step_size: 1, } } fn push(&mut self, (x, y): (f64, NumericEntry)) { if !(x as usize).is_multiple_of(self.step_size) { return; } let y = match y { NumericEntry::Value(val) => { self.avg_sum += val; self.avg_counter += 1.0; val } NumericEntry::Aggregated { aggregated_value, count, } => { self.avg_sum += aggregated_value * count as f64; self.avg_counter += count as f64; aggregated_value } }; if x > self.max_x { self.max_x = x; } if x < self.min_x { self.min_x = x; } if y > self.max_y { self.max_y = y; } if y < self.min_y { self.min_y = y } self.points.push((x, y)); if self.points.len() > self.max_samples { self.resize(); } } /// We keep only half the points and we double the step size. /// /// This ensure that we have the same amount of points across the X axis. fn resize(&mut self) { let mut points = Vec::with_capacity(self.max_samples / 2); let mut max_x = f64::MIN; let mut max_y = f64::MIN; let mut min_x = f64::MAX; let mut min_y = f64::MAX; for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() { if i % 2 == 0 { if x > max_x { max_x = x; } if x < min_x { min_x = x; } if y > max_y { max_y = y; } if y < min_y { min_y = y; } points.push((x, y)); } } self.points = points; self.step_size *= 2; self.min_x = min_x; self.max_x = max_x; self.min_y = min_y; self.max_y = max_y; } fn dataset<'a>(&'a self, name: String, color: Color) -> Dataset<'a> { Dataset::default() .name(name) .marker(symbols::Marker::Braille) .style(Style::default().fg(color).bold()) .graph_type(GraphType::Line) .data(&self.points) } fn bar<'a>(&'a self, tag: &TuiTag, max: u64) -> Option<(Bar<'a>, usize)> { if self.avg_sum == 0.0 { return None; } let label = format!("{tag}"); let width = usize::max(label.len(), 7); // 7 min width let factor = max as f64; let avg = self.avg_sum / self.avg_counter; Some(( Bar::default() .value((avg * factor) as u64) .style(tag.split.color()) .text_value(format!("{:.2}", avg)) .label(label), width, )) } } #[cfg(test)] mod tests { use super::*; use crate::renderer::tui::{TuiGroup, TuiSplit}; #[test] fn test_points() { let mut chart = FullHistoryPlot::new(10); let tag_train = TuiTag::new(TuiSplit::Train, TuiGroup::Default); let tag_valid = TuiTag::new(TuiSplit::Valid, TuiGroup::Default); chart.update_max_sample(tag_valid.split, 0.6); for i in 0..100 { chart.push(tag_train.clone(), NumericEntry::Value(i as f64)); } for i in 0..60 { chart.push(tag_valid.clone(), NumericEntry::Value(i as f64)); } let expected_train = vec![ (0.0, 0.0), (16.0, 16.0), (32.0, 32.0), (48.0, 48.0), (64.0, 64.0), (80.0, 80.0), (96.0, 96.0), ]; let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)]; assert_eq!( chart.points.get(&tag_train).unwrap().points, expected_train, "Expected train data points" ); assert_eq!( chart.points.get(&tag_valid).unwrap().points, expected_valid, "Expected valid data points" ); } } ================================================ FILE: crates/burn-train/src/renderer/tui/metric_numeric.rs ================================================ use crate::{ metric::{MetricName, NumericEntry}, renderer::{EvaluationProgress, TrainingProgress, tui::TuiTag}, }; use super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame, TuiSplit}; use ratatui::{ crossterm::event::{Event, KeyCode, KeyEventKind}, prelude::{Alignment, Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style, Stylize}, text::Line, widgets::{ Axis, BarChart, BarGroup, Block, Borders, Chart, LegendPosition, Padding, Paragraph, Tabs, }, }; use std::collections::BTreeMap; /// 1000 seems to be required to see some improvement. const MAX_NUM_SAMPLES_RECENT: usize = 1000; /// 250 seems to be the right resolution when plotting all history. /// Otherwise, there is too much points and the lines arent't smooth enough. const MAX_NUM_SAMPLES_FULL: usize = 250; /// Numeric metrics state that handles creating plots. #[derive(Default)] pub(crate) struct NumericMetricsState { data: BTreeMap, names: Vec, selected: usize, kind: PlotKind, num_samples_train: Option, num_samples_valid: Option, num_samples_test: Option, epoch: usize, } /// The kind of plot to display. #[derive(Default, Clone, Copy)] pub(crate) enum PlotKind { /// Display the full history of the metric with reduced resolution. #[default] Full, /// Display only the recent history of the metric, but with more resolution. Recent, Summary, } impl NumericMetricsState { /// Register a new training value for the metric with the given name. pub(crate) fn push(&mut self, tag: TuiTag, name: MetricName, data: NumericEntry) { if let Some((recent, full)) = self.data.get_mut(name.as_ref()) { recent.push(tag.clone(), data.current()); full.push(tag, data); } else { let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT); let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL); recent.push(tag.clone(), data.current()); full.push(tag, data); self.names.push(name.clone()); self.data.insert(name, (recent, full)); } } /// Update the state with the training progress. pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) { self.epoch = progress.global_progress.items_processed; if self.num_samples_train.is_some() { return; } // If the training only has the notion of global progress, num_samples_train remains None. self.num_samples_train = progress.progress.as_ref().map(|p| p.items_total); } /// Update the state with the validation progress. pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) { if self.num_samples_valid.is_some() { return; } // If num_samples_train is None, keep the default max_samples for validation. if let Some(num_sample_train) = self.num_samples_train { for (_, (_recent, full)) in self.data.iter_mut() { let ratio = match &progress.progress { Some(p) => p.items_total as f64 / num_sample_train as f64, None => progress.global_progress.items_total as f64 / num_sample_train as f64, }; full.update_max_sample(TuiSplit::Valid, ratio); } } self.epoch = progress.global_progress.items_processed; self.num_samples_valid = progress.progress.as_ref().map(|p| p.items_total); } /// Update the state with the testing progress. pub(crate) fn update_progress_test(&mut self, progress: &EvaluationProgress) { if self.num_samples_test.is_some() { return; } if let Some(num_sample_train) = self.num_samples_train { for (_, (_recent, full)) in self.data.iter_mut() { let ratio = progress.progress.items_total as f64 / num_sample_train as f64; full.update_max_sample(TuiSplit::Test, ratio); } } self.num_samples_test = Some(progress.progress.items_total); } /// Create a view to display the numeric metrics. pub(crate) fn view(&self) -> NumericMetricView<'_> { match self.names.is_empty() { true => NumericMetricView::None, false => match self.kind { PlotKind::Summary => { NumericMetricView::BarPlots(&self.names, self.selected, self.bar_chart()) } _ => NumericMetricView::LinePlots( &self.names, self.selected, self.line_chart(), self.kind, ), }, } } /// Handle the current event. pub(crate) fn on_event(&mut self, event: &Event) { if let Event::Key(key) = event { match key.kind { KeyEventKind::Release | KeyEventKind::Repeat => (), #[cfg(target_os = "windows")] // Fix the double toggle on Windows. KeyEventKind::Press => return, #[cfg(not(target_os = "windows"))] KeyEventKind::Press => (), } match key.code { KeyCode::Right => self.next_metric(), KeyCode::Left => self.previous_metric(), KeyCode::Up => self.switch_kind(), KeyCode::Down => self.switch_kind(), _ => {} } } } fn switch_kind(&mut self) { self.kind = match self.kind { PlotKind::Full => PlotKind::Recent, PlotKind::Recent => PlotKind::Summary, PlotKind::Summary => PlotKind::Full, }; } fn next_metric(&mut self) { self.selected = (self.selected + 1) % { let this = &self; this.data.len() }; } fn previous_metric(&mut self) { if self.selected > 0 { self.selected -= 1; } else { self.selected = ({ let this = &self; this.data.len() }) - 1; } } fn line_chart<'a>(&'a self) -> Chart<'a> { let name = self.names.get(self.selected).unwrap(); let (recent, full) = self.data.get(name).unwrap(); let (datasets, axes) = match self.kind { PlotKind::Full => (full.datasets(), &full.axes), PlotKind::Recent => (recent.datasets(), &recent.axes), _ => unreachable!(), }; Chart::<'a>::new(datasets) .block(Block::default()) .x_axis( Axis::default() .style(Style::default().fg(Color::DarkGray)) .title("Iteration") .labels(axes.labels_x.clone().into_iter().map(|s| s.bold())) .bounds(axes.bounds_x), ) .y_axis( Axis::default() .style(Style::default().fg(Color::DarkGray)) .labels(axes.labels_y.clone().into_iter().map(|s| s.bold())) .bounds(axes.bounds_y), ) .legend_position(Some(LegendPosition::Right)) } fn bar_chart<'a>(&'a self) -> BarChart<'a> { let name = self.names.get(self.selected).unwrap(); let (_recent, full) = self.data.get(name).unwrap(); let mut bar_width = 0; let bars = full.bars(100, &mut bar_width); let data = BarGroup::default().bars(&bars); BarChart::default() .block(Block::default().padding(Padding::new(2, 2, 2, 0))) .bar_width(bar_width as u16) .bar_gap(2) .data(data) } } #[allow(clippy::large_enum_variant)] #[derive(new)] pub(crate) enum NumericMetricView<'a> { LinePlots(&'a [MetricName], usize, Chart<'a>, PlotKind), BarPlots(&'a [MetricName], usize, BarChart<'a>), None, } impl NumericMetricView<'_> { pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { match self { Self::LinePlots(titles, selected, chart, kind) => { let block = Block::default() .borders(Borders::ALL) .title("Plots") .title_alignment(Alignment::Left); let size_new = block.inner(size); frame.render_widget(block, size); let size = size_new; let chunks = Layout::default() .direction(Direction::Vertical) .constraints( [ Constraint::Length(2), Constraint::Length(1), Constraint::Min(0), ] .as_ref(), ) .split(size); let tabs = Tabs::new( titles .iter() .map(|i| Line::from(vec![i.to_string().yellow()])), ) .select(selected) .style(Style::default()) .highlight_style( Style::default() .add_modifier(Modifier::BOLD) .add_modifier(Modifier::UNDERLINED) .fg(Color::LightYellow), ); let title = match kind { PlotKind::Full => "Full History", PlotKind::Recent => "Recent History", _ => unreachable!(), }; let plot_type = Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); frame.render_widget(tabs, chunks[0]); frame.render_widget(plot_type, chunks[1]); frame.render_widget(chart, chunks[2]); } Self::BarPlots(titles, selected, chart) => { let block = Block::default() .borders(Borders::ALL) .title("Summary") .title_alignment(Alignment::Left); let size_new = block.inner(size); frame.render_widget(block, size); let size = size_new; let chunks = Layout::default() .direction(Direction::Vertical) .constraints([ Constraint::Length(2), Constraint::Length(1), Constraint::Min(0), ]) .split(size); let tabs = Tabs::new( titles .iter() .map(|i| Line::from(vec![i.to_string().yellow()])), ) .select(selected) .style(Style::default()) .highlight_style( Style::default() .add_modifier(Modifier::BOLD) .add_modifier(Modifier::UNDERLINED) .fg(Color::LightYellow), ); let title = "Summary"; let plot_type = Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center); frame.render_widget(tabs, chunks[0]); frame.render_widget(plot_type, chunks[1]); frame.render_widget(chart, chunks[2]); } Self::None => {} }; } } ================================================ FILE: crates/burn-train/src/renderer/tui/metric_text.rs ================================================ use super::TerminalFrame; use crate::{ metric::{MetricEntry, MetricName}, renderer::tui::{TuiGroup, TuiSplit}, }; use ratatui::{ prelude::{Alignment, Rect}, style::{Color, Style, Stylize}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, }; use std::{collections::BTreeMap, sync::Arc}; #[derive(Default)] pub(crate) struct TextMetricsState { data: BTreeMap, names: Vec, } struct MetricGroup { groups: BTreeMap, } impl MetricGroup { fn new(group: TuiGroup, metric: MetricSplits) -> Self { Self { groups: BTreeMap::from_iter(Some((group, metric))), } } fn update(&mut self, split: TuiSplit, group: TuiGroup, metric: MetricEntry) { match self.groups.get_mut(&group) { Some(value) => value.update(split, metric), None => { let value = MetricSplits::new(split, metric); self.groups.insert(group, value); } } } } struct MetricSplits { splits: BTreeMap, } impl MetricSplits { fn new(split: TuiSplit, metric: MetricEntry) -> Self { Self { splits: BTreeMap::from_iter(Some((split, metric))), } } fn update(&mut self, split: TuiSplit, metric: MetricEntry) { self.splits.insert(split, metric); } } impl TextMetricsState { pub(crate) fn update( &mut self, split: TuiSplit, group: TuiGroup, metric: MetricEntry, name: Arc, ) { if let Some(existing) = self.data.get_mut(name.as_ref()) { existing.update(split, group, metric); } else { let key = name.clone(); let value = MetricSplits::new(split, metric); self.names.push(key.clone()); self.data .insert(key.to_string(), MetricGroup::new(group, value)); } } pub(crate) fn view(&self) -> TextMetricView { TextMetricView::new(&self.names, &self.data) } } pub(crate) struct TextMetricView { lines: Vec>>, } impl TextMetricView { fn new(names: &[MetricName], data: &BTreeMap) -> Self { let mut lines = Vec::with_capacity(names.len() * 4); let start_line = |title: &str| vec![Span::from(format!(" {title} ")).bold().yellow()]; let format_line = |group: &TuiGroup, split: &TuiSplit, formatted: &str| { vec![ Span::from(format!(" {group}{split} ")).bold(), Span::from(formatted.to_string()).italic(), ] }; for name in names { lines.push(start_line(name)); let entry = data.get(name.as_ref()).unwrap(); for (name, group) in entry.groups.iter() { for (split, entry) in group.splits.iter() { lines.push(format_line(name, split, &entry.serialized_entry.formatted)); } } lines.push(vec![Span::from("")]); } Self { lines } } pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) .alignment(Alignment::Left) .wrap(Wrap { trim: false }) .block(Block::default().borders(Borders::ALL).title("Metrics")) .style(Style::default().fg(Color::Gray)); frame.render_widget(paragraph, size); } } ================================================ FILE: crates/burn-train/src/renderer/tui/mod.rs ================================================ mod base; mod controls; mod full_history; mod metric_numeric; mod metric_text; mod plot_utils; mod popup; mod progress; mod recent_history; mod renderer; mod status; pub(crate) use base::*; pub(crate) use controls::*; pub(crate) use full_history::*; pub(crate) use metric_numeric::*; pub(crate) use metric_text::*; pub(crate) use plot_utils::*; pub(crate) use popup::*; pub(crate) use progress::*; pub(crate) use recent_history::*; pub use renderer::*; pub(crate) use status::*; ================================================ FILE: crates/burn-train/src/renderer/tui/plot_utils.rs ================================================ use crate::metric::format_float; const AXIS_TITLE_PRECISION: usize = 2; /// The data describing both X and Y axes. pub(crate) struct PlotAxes { pub(crate) labels_x: Vec, pub(crate) labels_y: Vec, pub(crate) bounds_x: [f64; 2], pub(crate) bounds_y: [f64; 2], } impl Default for PlotAxes { fn default() -> Self { Self { bounds_x: [f64::MAX, f64::MIN], bounds_y: [f64::MAX, f64::MIN], labels_x: Vec::new(), labels_y: Vec::new(), } } } impl PlotAxes { /// Update the bounds based on the min max of each X and Y axes with both train and valid data. pub(crate) fn update_bounds(&mut self, (x_min, x_max): (f64, f64), (y_min, y_max): (f64, f64)) { self.bounds_x = [x_min, x_max]; self.bounds_y = [y_min, y_max]; // We know x are integers. self.labels_x = vec![format!("{x_min}"), format!("{x_max}")]; self.labels_y = vec![ format_float(y_min, AXIS_TITLE_PRECISION), format_float(y_max, AXIS_TITLE_PRECISION), ]; } } ================================================ FILE: crates/burn-train/src/renderer/tui/popup.rs ================================================ use ratatui::{ crossterm::event::{Event, KeyCode}, prelude::{Alignment, Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style, Stylize}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, }; use super::TerminalFrame; /// Popup callback function. pub(crate) trait CallbackFn: Send + Sync { /// Call the function and return if the popup state should be reset. fn call(&self) -> bool; } /// Popup callback. pub(crate) struct Callback { title: String, description: String, trigger: char, callback: Box, } impl Callback { /// Create a new popup. pub(crate) fn new(title: T, description: D, trigger: char, callback: C) -> Self where T: Into, D: Into, C: CallbackFn + 'static, { Self { title: title.into(), description: description.into(), trigger, callback: Box::new(callback), } } } /// Popup state. pub(crate) enum PopupState { Empty, Full(String, Vec), } impl PopupState { /// If the popup is empty. pub(crate) fn is_empty(&self) -> bool { matches!(&self, PopupState::Empty) } /// Handle popup events. pub(crate) fn on_event(&mut self, event: &Event) { let mut reset = false; match self { PopupState::Empty => {} PopupState::Full(_, callbacks) => { for callback in callbacks.iter() { if let Event::Key(key) = event && let KeyCode::Char(key) = &key.code && &callback.trigger == key && callback.callback.call() { reset = true; } } } }; if reset { *self = Self::Empty; } } /// Create the popup view. pub(crate) fn view(&self) -> Option> { match self { PopupState::Empty => None, PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)), } } } #[derive(new)] pub(crate) struct PopupView<'a> { title: &'a String, callbacks: &'a [Callback], } impl<'a> PopupView<'a> { /// Render the view. pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) { let lines = self .callbacks .iter() .flat_map(|callback| { vec![ Line::from(vec![ Span::from(format!("[{}] ", callback.trigger)).bold(), Span::from(format!("{} ", callback.title)).yellow().bold(), ]), Line::from(Span::from("")), Line::from(Span::from(callback.description.to_string()).italic()), Line::from(Span::from("")), ] }) .collect::>(); let paragraph = Paragraph::new(lines) .alignment(Alignment::Left) .wrap(Wrap { trim: false }) .style(Style::default().fg(Color::Gray)) .block( Block::default() .borders(Borders::ALL) .title_alignment(Alignment::Center) .style(Style::default().fg(Color::Gray)) .title(Span::styled( self.title, Style::default().add_modifier(Modifier::BOLD), )), ); let area = centered_percent(20, size, Direction::Horizontal); let area = centered_percent(20, area, Direction::Vertical); frame.render_widget(paragraph, area); } } /// The percent represents the amount of space that will be taken by each side. fn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect { let center = 100 - (percent * 2); Layout::default() .direction(direction) .constraints([ Constraint::Percentage(percent), Constraint::Percentage(center), Constraint::Percentage(percent), ]) .split(size)[1] } ================================================ FILE: crates/burn-train/src/renderer/tui/progress.rs ================================================ use super::TerminalFrame; use crate::renderer::{EvaluationProgress, TrainingProgress, tui::TuiSplit}; use ratatui::{ prelude::{Alignment, Constraint, Direction, Layout, Rect}, style::{Color, Style, Stylize}, text::{Line, Span}, widgets::{Block, Borders, Gauge, Paragraph}, }; use std::time::{Duration, Instant}; /// Simple progress bar for the training. /// /// We currently ignore the time taken for the validation part. pub(crate) struct ProgressBarState { progress_total: f64, // Progress for total execution. progress_task: f64, // Progress for current task. split: TuiSplit, starting_epoch: usize, estimate: ProgressEstimate, } const MINUTE: u64 = 60; const HOUR: u64 = 60 * 60; const DAY: u64 = 24 * 60 * 60; impl ProgressBarState { pub fn new(checkpoint: Option) -> Self { Self { progress_total: 0.0, progress_task: 0.0, split: TuiSplit::Train, estimate: ProgressEstimate::new(), starting_epoch: checkpoint.unwrap_or(0), } } /// Update the training progress. pub(crate) fn update_train(&mut self, progress: &TrainingProgress) { self.progress_total = calculate_progress(progress, 0, 0); let local_progress = progress .progress .as_ref() .unwrap_or(&progress.global_progress); self.progress_task = local_progress.items_processed as f64 / local_progress.items_total as f64; self.estimate.update(progress, self.starting_epoch); self.split = TuiSplit::Train; } /// Update the validation progress. pub(crate) fn update_valid(&mut self, progress: &TrainingProgress) { // We don't use the validation for the total progress yet. let local_progress = progress .progress .as_ref() .unwrap_or(&progress.global_progress); self.progress_task = local_progress.items_processed as f64 / local_progress.items_total as f64; self.split = TuiSplit::Valid; } /// Update the testing progress. pub(crate) fn update_test(&mut self, progress: &EvaluationProgress) { // We don't use the testing for the total progress yet. self.progress_task = progress.progress.items_processed as f64 / progress.progress.items_total as f64; self.split = TuiSplit::Test; } /// Create a view for the current progress. pub(crate) fn view(&self) -> ProgressBarView { const NO_ETA: &str = "---"; let eta = match self.estimate.secs() { Some(eta) => format_eta(eta), None => NO_ETA.to_string(), }; ProgressBarView::new( self.progress_total, self.progress_task, self.split.color(), eta, ) } } #[derive(new)] pub(crate) struct ProgressBarView { progress: f64, progress_task: f64, color_task: Color, eta: String, } impl ProgressBarView { /// Render the view. pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) { let block = Block::default() .borders(Borders::ALL) .title("Progress") .title_alignment(Alignment::Left); let size_new = block.inner(size); frame.render_widget(block, size); let size = size_new; let chunks = Layout::default() .direction(Direction::Vertical) .constraints([Constraint::Ratio(1, 2), Constraint::Ratio(1, 2)].as_ref()) .split(size); let size_task = chunks[0]; let size_total = chunks[1]; let calculate_size = |size: Rect| { Layout::default() .direction(Direction::Horizontal) .constraints( [ Constraint::Length(1), // Empty space Constraint::Min(0), Constraint::Length(self.eta.len() as u16 + 4), ] .as_ref(), ) .split(size) }; let chunks = calculate_size(size_total); let size_gauge_total = chunks[1]; let size_eta = chunks[2]; let chunks = calculate_size(size_task); let size_gauge_task = chunks[1]; let progress_total = Gauge::default() .gauge_style(Style::default().fg(Color::Yellow)) .ratio(self.progress.min(1.0)); let progress_task = Gauge::default() .gauge_style(Style::default().fg(self.color_task)) .ratio(self.progress_task.min(1.0)); let eta = Paragraph::new(Line::from(vec![ Span::from(" ("), Span::from(self.eta).italic(), Span::from(") "), ])); frame.render_widget(progress_task, size_gauge_task); frame.render_widget(progress_total, size_gauge_total); frame.render_widget(eta, size_eta); } } struct ProgressEstimate { started: Instant, started_after_warmup: Option, warmup_num_items: usize, progress: f64, } impl ProgressEstimate { fn new() -> Self { Self { started: Instant::now(), started_after_warmup: None, warmup_num_items: 0, progress: 0.0, } } fn secs(&self) -> Option { let eta = self.started_after_warmup?.elapsed(); let total_estimated = (eta.as_secs() as f64) / self.progress; if total_estimated.is_normal() { let remaining = 1.0 - self.progress; let eta = (total_estimated * remaining) as u64; Some(eta) } else { None } } fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) { if self.started_after_warmup.is_some() { self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); return; } const WARMUP_NUM_ITERATION: usize = 10; // When the training has started since 30 seconds. if self.started.elapsed() > Duration::from_secs(30) { self.init(progress, starting_epoch); return; } // When the training has started since at least 10 seconds and completed 10 iterations. if progress.iteration >= Some(WARMUP_NUM_ITERATION) && self.started.elapsed() > Duration::from_secs(10) { self.init(progress, starting_epoch); } } fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) { let epoch = progress.global_progress.items_processed - starting_epoch; self.warmup_num_items = match &progress.progress { Some(local_progress) => { let epoch_items = (epoch - 1) * local_progress.items_total; let iteration_items = local_progress.items_processed; epoch_items + iteration_items } None => epoch, }; self.started_after_warmup = Some(Instant::now()); self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items); } } fn calculate_progress( progress: &TrainingProgress, starting_epoch: usize, ignore_num_items: usize, ) -> f64 { let epoch_total = progress.global_progress.items_total - starting_epoch; let epoch = progress.global_progress.items_processed - starting_epoch; match &progress.progress { Some(local_progress) => { let total_items = local_progress.items_total * epoch_total; let epoch_items = (epoch - 1) * local_progress.items_total; let iteration_items = local_progress.items_processed; let num_items = epoch_items + iteration_items - ignore_num_items; num_items as f64 / total_items as f64 } None => epoch as f64 / epoch_total as f64, } } fn format_eta(eta_secs: u64) -> String { let seconds = eta_secs % 60; let minutes = eta_secs / MINUTE % 60; let hours = eta_secs / HOUR % 24; let days = eta_secs / DAY; if days > 1 { format!("{days} days") } else if days == 1 { "1 day".to_string() } else if hours > 1 { format!("{hours} hours") } else if hours == 1 { "1 hour".to_string() } else if minutes > 1 { format!("{minutes} mins") } else if minutes == 1 { "1 min".to_string() } else if seconds > 1 { format!("{seconds} secs") } else { "1 sec".to_string() } } #[cfg(test)] mod tests { use super::*; use burn_core::data::dataloader::Progress; #[test] fn test_format_eta() { assert_eq!("55 secs", format_eta(55), "Less than 1 minutes"); assert_eq!("1 min", format_eta(61), "More than 1 minutes"); assert_eq!("2 mins", format_eta(2 * 61), "More than 2 minutes"); assert_eq!("1 hour", format_eta(3601), "More than 1 hour"); assert_eq!("2 hours", format_eta(2 * 3601), "More than 2 hour"); assert_eq!("1 day", format_eta(24 * 3601), "More than 1 day"); assert_eq!("2 days", format_eta(48 * 3601), "More than 2 day"); } #[test] fn calculate_progress_for_eta() { let half = Progress { items_processed: 5, items_total: 10, }; let global_progress = Progress { items_processed: 9, items_total: 10, }; let progress = TrainingProgress { progress: Some(half), global_progress, iteration: Some(500), }; let starting_epoch = 8; let progress = calculate_progress(&progress, starting_epoch, 0); // Two epochs remaining while the first is half done. assert_eq!(0.25, progress); } #[test] fn calculate_progress_for_eta_with_warmup() { let half = Progress { items_processed: 110, items_total: 1000, }; let global_progress = Progress { items_processed: 9, items_total: 10, }; let progress = TrainingProgress { progress: Some(half), global_progress, iteration: Some(500), }; let starting_epoch = 8; let progress = calculate_progress(&progress, starting_epoch, 10); // Two epochs remaining while the first is half done. assert_eq!(0.05, progress); } } ================================================ FILE: crates/burn-train/src/renderer/tui/recent_history.rs ================================================ use super::PlotAxes; use crate::renderer::tui::TuiTag; use ratatui::{ style::{Color, Style}, symbols, widgets::{Dataset, GraphType}, }; use std::collections::BTreeMap; const FACTOR_BEFORE_RESIZE: usize = 2; /// A plot that shows the recent history at full resolution. pub(crate) struct RecentHistoryPlot { pub(crate) axes: PlotAxes, points: BTreeMap, max_samples: usize, } struct RecentHistoryPoints { min_x: f64, max_x: f64, min_y: f64, max_y: f64, cursor: usize, points: Vec<(f64, f64)>, max_samples: usize, factor_before_resize: usize, } impl RecentHistoryPlot { pub(crate) fn new(max_samples: usize) -> Self { Self { axes: PlotAxes::default(), points: BTreeMap::default(), max_samples, } } pub(crate) fn push(&mut self, tag: TuiTag, data: f64) { if !self.points.contains_key(&tag) { self.points .insert(tag.clone(), RecentHistoryPoints::new(self.max_samples)); } let (x_min, x_current) = self.point_x(); for (s, entry) in self.points.iter_mut() { if s == &tag { entry.push((x_current, data)); } entry.update_cursor(x_min); } self.update_bounds(); } pub(crate) fn datasets(&self) -> Vec> { let mut datasets = Vec::new(); for (tag, points) in self.points.iter() { datasets.push(points.dataset(format!("{tag}"), tag.split.color())); } datasets } fn point_x(&mut self) -> (f64, f64) { let mut x_current = f64::MIN; let mut x_min = f64::MAX; for point in self.points.values() { x_current = f64::max(x_current, point.max_x); x_min = f64::min(x_min, point.min_x); } if x_current - x_min >= self.max_samples as f64 { x_min += 1.0; } (x_min, x_current + 1.0) } fn update_bounds(&mut self) { let (mut x_min, mut x_max) = (f64::MAX, f64::MIN); let (mut y_min, mut y_max) = (f64::MAX, f64::MIN); for points in self.points.values() { x_min = f64::min(x_min, points.min_x); x_max = f64::max(x_max, points.max_x); y_min = f64::min(y_min, points.min_y); y_max = f64::max(y_max, points.max_y); } self.axes.update_bounds((x_min, x_max), (y_min, y_max)); } } impl RecentHistoryPoints { fn new(max_samples: usize) -> Self { let factor_before_resize = FACTOR_BEFORE_RESIZE; Self { min_x: 0., max_x: 0., min_y: f64::MAX, max_y: f64::MIN, points: Vec::with_capacity(factor_before_resize * max_samples), cursor: 0, max_samples, factor_before_resize, } } fn push(&mut self, (x, y): (f64, f64)) { if x > self.max_x { self.max_x = x; } if x < self.min_x { self.min_x = x; } if y > self.max_y { self.max_y = y; } if y < self.min_y { self.min_y = y } self.points.push((x, y)); } fn update_cursor(&mut self, min_x: f64) { if self.min_x >= min_x { return; } self.min_x = min_x; let mut update_y_max = false; let mut update_y_min = false; while let Some((x, y)) = self.points.get(self.cursor) { if *x >= self.min_x { break; } if *y == self.max_y { update_y_max = true } if *y == self.min_y { update_y_min = true; } self.cursor += 1; } if update_y_max { self.max_y = self.calculate_max_y(); } if update_y_min { self.min_y = self.calculate_min_y(); } if self.points.len() >= self.max_samples * self.factor_before_resize { self.resize(); } } fn slice(&self) -> &[(f64, f64)] { &self.points[self.cursor..self.points.len()] } fn calculate_max_y(&self) -> f64 { let mut max_y = f64::MIN; for (_x, y) in self.slice() { max_y = f64::max(max_y, *y); } max_y } fn calculate_min_y(&self) -> f64 { let mut min_y = f64::MAX; for (_x, y) in self.slice() { if *y < min_y { min_y = *y; } } min_y } fn resize(&mut self) { let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize); for i in self.cursor..self.points.len() { points.push(self.points[i]); } self.points = points; self.cursor = 0; } fn dataset<'a>(&'a self, name: String, color: Color) -> Dataset<'a> { let data = &self.points[self.cursor..self.points.len()]; Dataset::default() .name(name) .marker(symbols::Marker::Braille) .style(Style::default().fg(color).bold()) .graph_type(GraphType::Scatter) .data(data) } } #[cfg(test)] mod tests { use crate::renderer::tui::{TuiGroup, TuiSplit}; use super::*; #[test] fn test_push_update_bounds_max_y() { let mut chart = RecentHistoryPlot::new(2); let tag = TuiTag::new(TuiSplit::Train, TuiGroup::Default); chart.push(tag.clone(), 15.0); chart.push(tag.clone(), 10.0); chart.push(tag.clone(), 14.0); assert_eq!(chart.axes.bounds_y[1], 15.); chart.push(tag, 10.0); assert_eq!(chart.axes.bounds_y[1], 14.); } #[test] fn test_push_update_bounds_min_y() { let mut chart = RecentHistoryPlot::new(2); let tag = TuiTag::new(TuiSplit::Train, TuiGroup::Default); chart.push(tag.clone(), 5.0); chart.push(tag.clone(), 10.0); chart.push(tag.clone(), 14.0); assert_eq!(chart.axes.bounds_y[0], 5.); chart.push(tag, 10.0); assert_eq!(chart.axes.bounds_y[0], 10.); } } ================================================ FILE: crates/burn-train/src/renderer/tui/renderer.rs ================================================ use crate::metric::{MetricDefinition, MetricId}; use crate::renderer::tui::TuiSplit; use crate::renderer::{ EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation, ProgressType, TrainingProgress, }; use crate::renderer::{MetricsRendererTraining, tui::NumericMetricsState}; use crate::{Interrupter, LearnerSummary}; use ratatui::{ Terminal, crossterm::{ event::{self, Event, KeyCode}, execute, terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode}, }, prelude::*, }; use std::collections::HashMap; use std::panic::{set_hook, take_hook}; use std::sync::mpsc::{Receiver, Sender}; use std::sync::{Arc, Mutex, mpsc}; use std::thread::JoinHandle; use std::{ error::Error, io::{self, Stdout}, time::{Duration, Instant}, }; use super::{ Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState, TextMetricsState, TuiGroup, TuiTag, }; /// The current terminal backend. pub(crate) type TerminalBackend = CrosstermBackend; /// The current terminal frame. pub(crate) type TerminalFrame<'a> = ratatui::Frame<'a>; type PanicHook = Box) + 'static + Sync + Send>; const MAX_REFRESH_RATE_MILLIS: u64 = 100; enum TuiRendererEvent { MetricRegistration(MetricDefinition), MetricsUpdate((TuiSplit, TuiGroup, MetricState)), StatusUpdateTrain((TuiSplit, TrainingProgress, Vec)), StatusUpdateTest((EvaluationProgress, Vec)), ProcessEnd { summary: Option, /// Interrupter reset. reset: bool, }, ManualClose, Close, Persistent, } /// The terminal UI metrics renderer. pub struct TuiMetricsRendererWrapper { sender: mpsc::Sender, interrupter: Interrupter, handle_join: Option>, kill_signal: Arc>>, } impl TuiMetricsRendererWrapper { /// Create a new terminal UI renderer. pub fn new(interrupter: Interrupter, checkpoint: Option) -> Self { let (sender, receiver) = mpsc::channel(); let (kill_signal_sender, kill_signal_receiver) = mpsc::channel(); let interrupter_clone = interrupter.clone(); let handle_join = std::thread::Builder::new() .name("train-renderer".into()) .spawn(move || { let mut renderer = TuiMetricsRenderer::new(interrupter_clone, checkpoint, kill_signal_sender); let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); loop { match receiver.try_recv() { Ok(event) => renderer.handle_event(event), Err(mpsc::TryRecvError::Empty) => (), Err(mpsc::TryRecvError::Disconnected) => { log::error!("Renderer thread disconnected."); break; } } // Render if renderer.last_update.elapsed() >= tick_rate && let Err(err) = renderer.render() { log::error!("Render error: {err}"); break; } if (renderer.manual_close && renderer.interrupter.should_stop()) || renderer.close { break; } } }) .unwrap(); Self { sender, interrupter, handle_join: Some(handle_join), kill_signal: Arc::new(Mutex::new(kill_signal_receiver)), } } fn send_event(&self, event: TuiRendererEvent) { if self.kill_signal.lock().unwrap().try_recv().is_ok() { panic!("Killing training from user input.") } if let Err(e) = self.sender.send(event) { log::warn!("Failed to send TUI event: {e}"); } } /// Set the renderer to persistent mode. pub fn persistent(self) -> Self { self.send_event(TuiRendererEvent::Persistent); self } } struct TuiMetricsRenderer { terminal: Terminal, last_update: std::time::Instant, progress: ProgressBarState, metric_definitions: HashMap, metrics_numeric: NumericMetricsState, metrics_text: TextMetricsState, status: StatusState, interrupter: Interrupter, popup: PopupState, previous_panic_hook: Option>, persistent: bool, manual_close: bool, close: bool, summary: Option, kill_signal: Sender<()>, } impl MetricsRendererEvaluation for TuiMetricsRendererWrapper { fn update_test(&mut self, name: EvaluationName, state: MetricState) { self.send_event(TuiRendererEvent::MetricsUpdate(( TuiSplit::Test, TuiGroup::Named(name.name), state, ))); } fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec) { self.send_event(TuiRendererEvent::StatusUpdateTest(( item, progress_indicators, ))); } fn on_test_end(&mut self, summary: Option) -> Result<(), Box> { // Update the summary self.send_event(TuiRendererEvent::ProcessEnd { summary, reset: false, }); Ok(()) } } impl MetricsRenderer for TuiMetricsRendererWrapper { fn manual_close(&mut self) { self.send_event(TuiRendererEvent::ManualClose); let _ = self.handle_join.take().unwrap().join(); } fn register_metric(&mut self, definition: MetricDefinition) { self.send_event(TuiRendererEvent::MetricRegistration(definition)); } } impl MetricsRendererTraining for TuiMetricsRendererWrapper { fn update_train(&mut self, state: MetricState) { self.send_event(TuiRendererEvent::MetricsUpdate(( TuiSplit::Train, TuiGroup::Default, state, ))); } fn update_valid(&mut self, state: MetricState) { self.send_event(TuiRendererEvent::MetricsUpdate(( TuiSplit::Valid, TuiGroup::Default, state, ))); } fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec) { self.send_event(TuiRendererEvent::StatusUpdateTrain(( TuiSplit::Train, item, progress_indicators, ))); } fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec) { self.send_event(TuiRendererEvent::StatusUpdateTrain(( TuiSplit::Valid, item, progress_indicators, ))); } fn on_train_end(&mut self, summary: Option) -> Result<(), Box> { // Reset for following steps. self.interrupter.reset(); // Update the summary self.send_event(TuiRendererEvent::ProcessEnd { summary, reset: true, }); Ok(()) } } impl Drop for TuiMetricsRendererWrapper { fn drop(&mut self) { if !std::thread::panicking() { self.send_event(TuiRendererEvent::Close); let _ = self.handle_join.take().unwrap().join(); } } } impl TuiMetricsRenderer { fn update_metric(&mut self, split: TuiSplit, group: TuiGroup, state: MetricState) { match state { MetricState::Generic(entry) => { let name = self .metric_definitions .get(&entry.metric_id) .unwrap() .name .clone() .into(); self.metrics_text.update(split, group, entry, name); } MetricState::Numeric(entry, value) => { let name: Arc = self .metric_definitions .get(&entry.metric_id) .unwrap() .name .clone() .into(); self.metrics_numeric .push(TuiTag::new(split, group.clone()), name.clone(), value); self.metrics_text.update(split, group, entry, name); } }; } pub fn new( interrupter: Interrupter, checkpoint: Option, kill_signal: Sender<()>, ) -> Self { let mut stdout = io::stdout(); execute!(stdout, EnterAlternateScreen).unwrap(); enable_raw_mode().unwrap(); let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap(); // Reset the terminal to raw mode on panic before running the panic handler // This prevents that the panic message is not visible for the user. let previous_panic_hook = Arc::new(take_hook()); set_hook(Box::new({ let previous_panic_hook = previous_panic_hook.clone(); move |panic_info| { let _ = disable_raw_mode(); let _ = execute!(io::stdout(), LeaveAlternateScreen); previous_panic_hook(panic_info); } })); Self { terminal, last_update: Instant::now(), progress: ProgressBarState::new(checkpoint), metric_definitions: HashMap::default(), metrics_numeric: NumericMetricsState::default(), metrics_text: TextMetricsState::default(), status: StatusState::default(), interrupter, popup: PopupState::Empty, previous_panic_hook: Some(previous_panic_hook), persistent: false, manual_close: false, close: false, summary: None, kill_signal, } } fn handle_event(&mut self, event: TuiRendererEvent) { match event { TuiRendererEvent::MetricRegistration(definition) => { self.metric_definitions .insert(definition.metric_id.clone(), definition); } TuiRendererEvent::MetricsUpdate((split, group, state)) => { self.update_metric(split, group, state); } TuiRendererEvent::StatusUpdateTrain((split, item, status)) => match split { TuiSplit::Train => { self.progress.update_train(&item); self.metrics_numeric.update_progress_train(&item); self.status.update_train(status); } TuiSplit::Valid => { self.progress.update_valid(&item); self.metrics_numeric.update_progress_valid(&item); self.status.update_valid(status); } _ => (), }, TuiRendererEvent::StatusUpdateTest((item, status)) => { self.progress.update_test(&item); self.metrics_numeric.update_progress_test(&item); self.status.update_test(status); } TuiRendererEvent::ProcessEnd { summary, reset } => { match (self.summary.take(), summary) { (None, Some(summary)) => { self.summary = Some(summary); } (Some(current), Some(other)) => self.summary = Some(current.merge(other)), (_, _) => { /* nothing to update */ } } if reset { self.interrupter.reset(); } } TuiRendererEvent::ManualClose => self.manual_close = true, TuiRendererEvent::Persistent => self.persistent = true, TuiRendererEvent::Close => self.close = true, } } fn render(&mut self) -> Result<(), Box> { self.draw()?; self.handle_user_input()?; self.last_update = Instant::now(); Ok(()) } fn draw(&mut self) -> Result<(), Box> { self.terminal.draw(|frame| { let size = frame.area(); match self.popup.view() { Some(view) => view.render(frame, size), None => { let view = MetricsView::new( self.metrics_numeric.view(), self.metrics_text.view(), self.progress.view(), ControlsView, self.status.view(), ); view.render(frame, size); } }; })?; Ok(()) } fn handle_user_input(&mut self) -> Result<(), Box> { while event::poll(Duration::from_secs(0))? { let event = event::read()?; self.popup.on_event(&event); if self.popup.is_empty() { self.metrics_numeric.on_event(&event); if let Event::Key(key) = event && let KeyCode::Char('q') = key.code { self.popup = PopupState::Full( "Quit".to_string(), vec![ Callback::new( "Stop the training.", "Stop the training immediately. This will break from the \ training loop, but any remaining code after the loop will be \ executed.", 's', QuitPopupAccept(self.interrupter.clone()), ), Callback::new( "Stop the training immediately.", "Kill the program. This will create a panic! which will make \ the current training fails. Any code following the training \ won't be executed.", 'k', KillPopupAccept(self.kill_signal.clone()), ), Callback::new( "Cancel", "Cancel the action, continue the training.", 'c', PopupCancel, ), ], ); } } } Ok(()) } fn handle_post_training(&mut self) -> Result<(), Box> { self.popup = PopupState::Full( "Training is done".to_string(), vec![Callback::new( "Training Done", "Press 'x' to close this popup. Press 'q' to exit the application after the \ popup is closed.", 'x', PopupCancel, )], ); self.draw().ok(); loop { if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) { match event::read() { Ok(event @ Event::Key(key)) => { if self.popup.is_empty() { self.metrics_numeric.on_event(&event); if let KeyCode::Char('q') = key.code { break; } } else { self.popup.on_event(&event); } self.draw().ok(); } Ok(Event::Resize(..)) => { self.draw().ok(); } Err(err) => { eprintln!("Error reading event: {err}"); break; } _ => continue, } } } Ok(()) } // Reset the terminal back to raw mode. fn reset(&mut self) -> Result<(), Box> { // If previous panic hook has already been re-instated, then the terminal was already reset. if self.previous_panic_hook.is_some() { if self.persistent && let Err(err) = self.handle_post_training() { eprintln!("Error in post-training handling: {err}"); } disable_raw_mode()?; execute!(self.terminal.backend_mut(), LeaveAlternateScreen)?; self.terminal.show_cursor()?; // Reinstall the previous panic hook let _ = take_hook(); if let Some(previous_panic_hook) = Arc::into_inner(self.previous_panic_hook.take().unwrap()) { set_hook(previous_panic_hook); } } Ok(()) } } struct QuitPopupAccept(Interrupter); struct KillPopupAccept(Sender<()>); struct PopupCancel; impl CallbackFn for KillPopupAccept { fn call(&self) -> bool { self.0.send(()).unwrap(); panic!("Killing training from user input."); } } impl CallbackFn for QuitPopupAccept { fn call(&self) -> bool { self.0.stop(Some("Stopping training from user input.")); true } } impl CallbackFn for PopupCancel { fn call(&self) -> bool { true } } impl Drop for TuiMetricsRenderer { fn drop(&mut self) { // Reset the terminal back to raw mode. This can be skipped during // panicking because the panic hook has already reset the terminal if !std::thread::panicking() { self.reset().unwrap(); if let Some(summary) = &self.summary { println!("{summary}"); log::info!("{summary}"); } } } } ================================================ FILE: crates/burn-train/src/renderer/tui/status.rs ================================================ use crate::renderer::ProgressType; use super::TerminalFrame; use ratatui::{ prelude::{Alignment, Rect}, style::{Color, Style, Stylize}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, }; /// Show the training status with various information. pub(crate) struct StatusState { progress_indicators: Vec, mode: Mode, } enum Mode { Valid, Train, Evaluation, } impl Default for StatusState { fn default() -> Self { Self { progress_indicators: vec![], mode: Mode::Train, } } } impl StatusState { /// Update the training information. pub(crate) fn update_train(&mut self, progress_indicators: Vec) { self.progress_indicators = progress_indicators; self.mode = Mode::Train; } /// Update the validation information. pub(crate) fn update_valid(&mut self, progress_indicators: Vec) { self.progress_indicators = progress_indicators; self.mode = Mode::Valid; } /// Update the testing information. pub(crate) fn update_test(&mut self, progress_indicators: Vec) { self.progress_indicators = progress_indicators; self.mode = Mode::Evaluation; } /// Create a view. pub(crate) fn view(&self) -> StatusView { StatusView::new(&self.progress_indicators, &self.mode) } } pub(crate) struct StatusView { lines: Vec>>, } impl StatusView { fn new(progress_indicators: &[ProgressType], mode: &Mode) -> Self { let title = |title: &str| Span::from(format!(" {title} ")).bold().yellow(); let value = |value: String| Span::from(value).italic(); let mode = match mode { Mode::Valid => "Validating", Mode::Train => "Training", Mode::Evaluation => "Evaluation", }; let width = progress_indicators .iter() .map(|p| match p { ProgressType::Detailed { tag, .. } => tag.len(), ProgressType::Value { tag, .. } => tag.len(), }) .max() .unwrap_or(4); let mut lines = vec![vec![ title(&format!("{: lines.push(vec![ title(&format!("{: lines.push(vec![ title(&format!("{: , size: Rect) { let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::>()) .alignment(Alignment::Left) .block(Block::default().borders(Borders::ALL).title("Status")) .wrap(Wrap { trim: false }) .style(Style::default().fg(Color::Gray)); frame.render_widget(paragraph, size); } } ================================================ FILE: crates/burn-vision/Cargo.toml ================================================ [package] authors = [ "nathanielsimard ", "wingertge ", ] categories = ["science"] description = "Vision processing operations for burn tensors" documentation = "https://docs.rs/burn-vision" edition.workspace = true keywords = ["deep-learning", "machine-learning", "gpu"] license.workspace = true name = "burn-vision" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision" version.workspace = true [lints] workspace = true [features] default = ["ndarray", "cubecl-backend", "fusion", "std"] std = ["aligned-vec/std"] tracing = [ "burn-cubecl?/tracing", "burn-fusion?/tracing", "burn-ir/tracing", "burn-ndarray?/tracing", "burn-tch?/tracing", "burn-tensor/tracing", "cubecl/tracing", ] cubecl-backend = ["cubecl", "burn-cubecl"] fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] # Test features test-cpu = [] test-cuda = ["cubecl-backend", ] test-wgpu = ["cubecl-backend", ] test-vulkan = ["burn-wgpu/vulkan", "test-wgpu"] test-metal = ["burn-wgpu/metal", "test-wgpu"] [dependencies] aligned-vec = { version = "0.6", default-features = false } bon = { workspace = true } burn-cubecl = { path = "../burn-cubecl", version = "=0.21.0-pre.2", optional = true } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } burn-ir = { path = "../burn-ir", version = "=0.21.0-pre.2" } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2", optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-tensor = { path = "../burn-tensor", version = "=0.21.0-pre.2" } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "=0.21.0-pre.2", optional = true } bytemuck = { workspace = true } cubecl = { workspace = true, optional = true } derive-new = { workspace = true } half = { workspace = true } image = { version = "0.25" } macerator = { workspace = true } ndarray = { workspace = true } num-traits = { workspace = true } paste = { workspace = true } serde = { workspace = true } [dev-dependencies] burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", default-features = false } cubecl = { workspace = true } ================================================ FILE: crates/burn-vision/src/backends/cpu/base.rs ================================================ pub trait MinMax { fn min(self, other: Self) -> Self; fn max(self, other: Self) -> Self; } macro_rules! impl_minmax { ($ty: ty) => { impl MinMax for $ty { fn min(self, other: Self) -> Self { Ord::min(self, other) } fn max(self, other: Self) -> Self { Ord::max(self, other) } } }; ($($ty: ty),*) => { $(impl_minmax!($ty);)* } } impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64); impl MinMax for f32 { fn min(self, other: Self) -> Self { self.min(other) } fn max(self, other: Self) -> Self { self.max(other) } } impl MinMax for f64 { fn min(self, other: Self) -> Self { self.min(other) } fn max(self, other: Self) -> Self { self.max(other) } } ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs ================================================ no_analyze!{{ use centerLabels::*;let mut label = entry; while let Some(next) = (|label| -> Option { match label { NODE_1=> { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_2); } else { return Some(NODE_3); } } NODE_3=> { if (*img_row01.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_2); } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(cl_tree_1); } } NODE_4=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_5); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_3); } } } NODE_6=> { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_2); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_7); } } else { return Some(NODE_1); } } NODE_2=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { return Some(NODE_4); } } NODE_7=> { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_5); } } NODE_5=> { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } } NODE_8=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } else { return Some(NODE_9); } } NODE_10=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_11); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } } else { if (*img_row11.add((c - 1) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_8); } else { return Some(NODE_11); } } } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_12); } else { return Some(NODE_12); } } } } NODE_11=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_4); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_3); } } NODE_13=> { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(cl_tree_11); } } NODE_9=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } } else { return Some(NODE_11); } } NODE_12=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_10); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_9); } } NODE_14=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_4); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_10); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_9); } } } } NODE_15=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_13); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_5); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_7); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(cl_tree_3); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_3); } } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_10); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(cl_tree_9); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_9); } } } } } NODE_16=> { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_8); } else { return Some(NODE_2); } } else { return Some(NODE_17); } } else { return Some(NODE_1); } } NODE_18=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } NODE_19=> { if (*img_row11.add((c + 2) as usize)).to_bool() { return Some(NODE_20); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_8); } } NODE_21=> { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { if (*img_row11.add((c + 2) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_3); } } } else { return Some(NODE_3); } } NODE_22=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } } NODE_23=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } } NODE_24=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } NODE_17=> { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_7); } } NODE_25=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_18); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_8); } } NODE_20=> { if (*img_row12.add((c + 1) as usize)).to_bool() { return Some(NODE_26); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } NODE_27=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } NODE_28=> { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_22); } else { return Some(NODE_19); } } NODE_26=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } } NODE_29=> { if (*img_row11.add((c + 2) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_8); } } NODE_30=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_8); } } NODE_31=> { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_23); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_19); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_12); } } } NODE_32=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_33); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { if (*img_row11.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_34); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } } else { if (*img_row11.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_35); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_4); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_3); } } } NODE_36=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { return Some(NODE_33); } else { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_34); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } } else { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_35); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_3); } } } NODE_37=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_18); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_8); } } NODE_33=> { if (*img_row12.add((c - 1) as usize)).to_bool() { return Some(NODE_26); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(cl_tree_5); } } NODE_38=> { if (*img_row12.add((c - 1) as usize)).to_bool() { return Some(NODE_22); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } NODE_39=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_10); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_10); } } NODE_35=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_4); } } NODE_40=> { if (*img_row12.add((c - 1) as usize)).to_bool() { return Some(NODE_23); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } NODE_34=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_5); } } cl_tree_0 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_0); } else { return Some(cl_break_1_0); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_14); } else { return Some(NODE_6); } } cl_tree_1 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_1); } else { return Some(cl_break_1_1); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_15); } else { return Some(NODE_6); } } cl_tree_2 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_2); } else { return Some(cl_break_1_2); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_10); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_8); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { return Some(NODE_1); } } } cl_tree_3 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_3); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_29); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_12); } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } else { return Some(NODE_29); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { return Some(NODE_21); } } } cl_tree_4 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_4); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_27); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_25); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_12); } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_24); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } else { return Some(NODE_25); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { return Some(NODE_21); } } } cl_tree_5 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_5); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_30); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_12); } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { return Some(NODE_30); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_6); } else { if (*img_row11.add((c + 2) as usize)).to_bool() { return Some(NODE_5); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } } } else { return Some(NODE_3); } } } } cl_tree_6 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_6); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_31); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_28); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { return Some(NODE_1); } } } cl_tree_7 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_4); } else { return Some(cl_break_1_7); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_10); } else { return Some(NODE_15); } } else { return Some(NODE_16); } } cl_tree_8 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_8); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_27); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_37); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_12); } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_24); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } else { return Some(NODE_37); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } } else { return Some(NODE_21); } } } cl_tree_9 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_5); } else { return Some(cl_break_1_9); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_9); } else { return Some(NODE_12); } } } else { return Some(NODE_14); } } else { return Some(NODE_16); } } cl_tree_10 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_6); } else { return Some(cl_break_1_10); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_40); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_36); } else { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_39); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_9); } } } } else { return Some(NODE_14); } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_38); } else { return Some(NODE_36); } } else { return Some(NODE_2); } } else { return Some(NODE_17); } } else { return Some(NODE_1); } } } cl_tree_11 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_7); } else { return Some(cl_break_1_11); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_31); } else { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_31); } else { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_11); } else { return Some(NODE_13); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_5); } else { return Some(NODE_7); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(cl_tree_3); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_10); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(cl_tree_9); } } } } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_28); } else { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_22); } else { if (*img_row11.add((c + 2) as usize)).to_bool() { return Some(NODE_20); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(cl_tree_4); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_3); } } } } else { return Some(NODE_2); } } } else { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } else { if (*img_row00.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_7); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(cl_tree_7); } } } } else { return Some(NODE_1); } } } cl_tree_12 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_8); } else { return Some(cl_break_1_12); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_40); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_11); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_32); } else { if (*img_row11.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_39); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_10); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(cl_tree_9); } } } } else { return Some(NODE_14); } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_38); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(cl_tree_6); } } else { return Some(NODE_32); } } else { return Some(NODE_2); } } else { return Some(NODE_17); } } else { return Some(NODE_1); } } } NODE_41=> { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); } else { return Some(NODE_42); } } NODE_43=> { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } NODE_42=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_44=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } NODE_45=> { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } else { *img_labels_row00.add(c as usize) = 0.elem(); } } NODE_46=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } NODE_47=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } NODE_48=> { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } cl_break_0_0 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_47); } else { return Some(NODE_48); } return None;} cl_break_0_1 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_44); } else { return Some(NODE_48); } return None;} cl_break_0_2 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_41); } else { return Some(NODE_43); } return None;} cl_break_0_3 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { return Some(NODE_43); } return None;} cl_break_0_4 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_41); } else { return Some(NODE_44); } } else { return Some(NODE_45); } return None;} cl_break_0_5 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_42); } else { return Some(NODE_47); } } else { return Some(NODE_45); } return None;} cl_break_0_6 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_46); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_47); } } else { return Some(NODE_45); } return None;} cl_break_0_7 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row00.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } else { *img_labels_row00.add(c as usize) = 0.elem(); } } return None;} cl_break_0_8 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_46); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_47); } } else { return Some(NODE_45); } return None;} NODE_49=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } NODE_50=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_49); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_51=> { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_52); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_52=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } } NODE_53=> { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_54); } else { return Some(NODE_55); } } else { return Some(NODE_56); } } NODE_55=> { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } NODE_54=> { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_57); } else { return Some(NODE_58); } } NODE_59=> { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } else { return Some(NODE_60); } } NODE_61=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_62=> { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { return Some(NODE_63); } else { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_49); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } } else { return Some(NODE_58); } } NODE_63=> { if (*img_row12.add((c - 1) as usize)).to_bool() { return Some(NODE_52); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } NODE_64=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_65); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_65=> { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_49); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } NODE_66=> { if (*img_row01.add((c - 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_63); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_65); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } } else { return Some(NODE_58); } } NODE_67=> { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_58); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } else { return Some(NODE_56); } } NODE_56=> { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_58); } else { return Some(NODE_60); } } NODE_58=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } NODE_68=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); } else { return Some(NODE_69); } } } NODE_70=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_71); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } } NODE_57=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { return Some(NODE_69); } } NODE_60=> { if (*img_row01.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } NODE_71=> { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); } } NODE_69=> { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } cl_break_1_0 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_58); } else { return Some(NODE_67); } return None;} cl_break_1_1 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_70); } else { return Some(NODE_67); } return None;} cl_break_1_2 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_68); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_57); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_56); } } return None;} cl_break_1_3 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_61); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_61); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_59); } } return None;} cl_break_1_4 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_50); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_50); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_59); } } return None;} cl_break_1_5 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { return Some(NODE_60); } } } return None;} cl_break_1_6 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_51); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_51); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_56); } } return None;} cl_break_1_7 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_68); } else { return Some(NODE_70); } } else { return Some(NODE_53); } return None;} cl_break_1_8 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_64); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_64); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_59); } } return None;} cl_break_1_9 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_54); } else { return Some(NODE_53); } return None;} cl_break_1_10 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_62); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_62); } else { return Some(NODE_55); } } else { return Some(NODE_56); } } return None;} cl_break_1_11 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_51); } else { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_51); } else { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { return Some(NODE_71); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } } } } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_51); } else { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_51); } else { return Some(NODE_58); } } } else { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row00.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } } else { return Some(NODE_56); } } return None;} cl_break_1_12 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_66); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_66); } else { return Some(NODE_55); } } else { return Some(NODE_56); } } return None;} }; None})(label) { label = next; } }} ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs ================================================ no_analyze!{{ use firstLabels::*;let mut label = entry; while let Some(next) = (|label| -> Option { match label { NODE_72=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_1); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_2); } } NODE_73=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(fl_tree_1); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(fl_tree_2); } } NODE_74=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_1); } else { if (*img_row01.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_1); } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(fl_tree_0); } } } fl_tree_0 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_72); } else { if (*img_row01.add((c) as usize)).to_bool() { return Some(NODE_72); } else { return Some(NODE_74); } } } fl_tree_1 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_73); } else { if (*img_row01.add((c) as usize)).to_bool() { return Some(NODE_73); } else { return Some(NODE_74); } } } fl_tree_2 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { return Some(NODE_73); } else { return Some(NODE_72); } } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(fl_tree_1); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_1); } } else { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(fl_tree_2); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(fl_tree_2); } } } else { return Some(NODE_74); } } } NODE_75=> { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } fl_break_0_0 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } return None;} fl_break_0_1 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } return None;} fl_break_0_2 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_75); } else { if (*img_row01.add((c) as usize)).to_bool() { return Some(NODE_75); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } return None;} NODE_76=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { if (*img_row01.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } } NODE_77=> { if (*img_row01.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } fl_break_1_0 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { return Some(NODE_76); } } return None;} fl_break_1_1 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row01.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { return Some(NODE_76); } } return None;} fl_break_1_2 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_77); } else { if (*img_row01.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_77); } else { return Some(NODE_77); } } else { return Some(NODE_76); } } return None;} fl_ => {}, }; None})(label) { label = next; } }} ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs ================================================ /// Workaround for rust-analyzer bug that causes invalid errors on the `include!`. macro_rules! no_analyze { ($tokens:tt) => { $tokens }; } pub(crate) use no_analyze; #[allow(non_snake_case, non_camel_case_types, unused)] pub enum centerLabels { NODE_1, NODE_2, NODE_3, NODE_4, NODE_5, NODE_6, NODE_7, NODE_8, NODE_9, NODE_10, NODE_11, NODE_12, NODE_13, NODE_14, NODE_15, NODE_16, NODE_17, NODE_18, NODE_19, NODE_20, NODE_21, NODE_22, NODE_23, NODE_24, NODE_25, NODE_26, NODE_27, NODE_28, NODE_29, NODE_30, NODE_31, NODE_32, NODE_33, NODE_34, NODE_35, NODE_36, NODE_37, NODE_38, NODE_39, NODE_40, NODE_41, NODE_42, NODE_43, NODE_44, NODE_45, NODE_46, NODE_47, NODE_48, NODE_49, NODE_50, NODE_51, NODE_52, NODE_53, NODE_54, NODE_55, NODE_56, NODE_57, NODE_58, NODE_59, NODE_60, NODE_61, NODE_62, NODE_63, NODE_64, NODE_65, NODE_66, NODE_67, NODE_68, NODE_69, NODE_70, NODE_71, cl_tree_0, cl_tree_1, cl_tree_2, cl_tree_3, cl_tree_4, cl_tree_5, cl_tree_6, cl_tree_7, cl_tree_8, cl_tree_9, cl_tree_10, cl_tree_11, cl_tree_12, cl_break_0_0, cl_break_0_1, cl_break_0_2, cl_break_0_3, cl_break_0_4, cl_break_0_5, cl_break_0_6, cl_break_0_7, cl_break_0_8, cl_break_1_0, cl_break_1_1, cl_break_1_2, cl_break_1_3, cl_break_1_4, cl_break_1_5, cl_break_1_6, cl_break_1_7, cl_break_1_8, cl_break_1_9, cl_break_1_10, cl_break_1_11, cl_break_1_12, } #[allow(non_snake_case, non_camel_case_types, unused)] pub enum firstLabels { NODE_72, NODE_73, NODE_74, NODE_75, NODE_76, NODE_77, fl_tree_0, fl_tree_1, fl_tree_2, fl_break_0_0, fl_break_0_1, fl_break_0_2, fl_break_1_0, fl_break_1_1, fl_break_1_2, fl_, } #[allow(non_snake_case, non_camel_case_types, unused)] pub enum lastLabels { NODE_78, NODE_79, NODE_80, NODE_81, NODE_82, NODE_83, NODE_84, NODE_85, NODE_86, NODE_87, NODE_88, NODE_89, NODE_90, NODE_91, NODE_92, ll_tree_0, ll_tree_1, ll_tree_2, ll_tree_3, ll_tree_4, ll_tree_5, ll_tree_6, ll_tree_7, ll_break_0_0, ll_break_0_1, ll_break_0_2, ll_break_0_3, ll_break_1_0, ll_break_1_1, ll_break_1_2, ll_break_1_3, ll_break_1_4, ll_break_1_5, ll_break_1_6, ll_break_1_7, ll_, } #[allow(non_snake_case, non_camel_case_types, unused)] pub enum singleLabels { NODE_93, NODE_94, sl_tree_0, sl_tree_1, sl_break_0_0, sl_break_0_1, sl_break_1_0, sl_break_1_1, sl_, } ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs ================================================ no_analyze!{{ use lastLabels::*;let mut label = entry; while let Some(next) = (|label| -> Option { match label { NODE_78=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(ll_tree_4); } } NODE_79=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(ll_tree_6); } } NODE_80=> { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(ll_tree_6); } } NODE_81=> { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_82); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_3); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(ll_tree_2); } } } NODE_83=> { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_5); } else { if (*img_row11.add((c + 2) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(ll_tree_2); } } } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(ll_tree_1); } } NODE_84=> { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_5); } else { return Some(NODE_81); } } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(ll_tree_1); } } NODE_82=> { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_4); } } NODE_85=> { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver); return Some(ll_tree_4); } } NODE_86=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_6); } } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_7); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_0); } } } ll_tree_0 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_81); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_0); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(ll_tree_0); } } } } else { return Some(NODE_84); } } ll_tree_1 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_80); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_82); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_85); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_3); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(ll_tree_2); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(ll_tree_2); } } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_0); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(ll_tree_0); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(ll_tree_0); } } } } } else { return Some(NODE_84); } } ll_tree_2 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_6); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_7); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_0); } } } else { return Some(NODE_83); } } ll_tree_3 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_79); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_6); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_78); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_7); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_0); } } } else { return Some(NODE_83); } } ll_tree_4 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize); return Some(ll_tree_4); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_7); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_0); } } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_5); } else { if (*img_row11.add((c + 2) as usize)).to_bool() { return Some(NODE_82); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_3); } } } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(ll_tree_1); } } } ll_tree_5 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } } if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_86); } else { return Some(NODE_84); } } ll_tree_6 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_86); } else { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_6); } else { return Some(NODE_80); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { return Some(NODE_82); } else { return Some(NODE_85); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_3); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(ll_tree_2); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); return Some(ll_tree_0); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); return Some(ll_tree_0); } } } } } else { return Some(NODE_84); } } ll_tree_7 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_79); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_6); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); return Some(ll_tree_6); } } else { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 2) as usize)).to_bool() { if (*img_row12.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_78); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver); return Some(ll_tree_4); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_7); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(ll_tree_0); } } } else { return Some(NODE_83); } } ll_break_0_0 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} ll_break_0_1 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} ll_break_0_2 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} ll_break_0_3 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } } } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} NODE_87=> { if (*img_row00.add((c + 1) as usize)).to_bool() { return Some(NODE_88); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } NODE_88=> { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } NODE_89=> { if (*img_row12.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); } } NODE_90=> { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } NODE_91=> { if (*img_row00.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } else { *img_labels_row00.add(c as usize) = 0.elem(); } } NODE_92=> { if (*img_row12.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver); } } ll_break_1_0 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_88); } else { return Some(NODE_87); } return None;} ll_break_1_1 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { return Some(NODE_92); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { if (*img_row11.add((c - 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = solver.new_label(); } } } } else { return Some(NODE_87); } return None;} ll_break_1_2 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_91); } return None;} ll_break_1_3 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { return Some(NODE_89); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_91); } return None;} ll_break_1_4 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } return None;} ll_break_1_5 => { if (*img_row00.add((c) as usize)).to_bool() { return Some(NODE_90); } else { return Some(NODE_87); } return None;} ll_break_1_6 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c - 1) as usize)).to_bool() { return Some(NODE_90); } else { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { return Some(NODE_92); } } else { if (*img_row11.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize); } else { *img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize); } } } } else { return Some(NODE_87); } return None;} ll_break_1_7 => { if (*img_row00.add((c) as usize)).to_bool() { if (*img_row11.add((c + 1) as usize)).to_bool() { if (*img_row12.add((c) as usize)).to_bool() { if (*img_row11.add((c - 2) as usize)).to_bool() { return Some(NODE_89); } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver); } } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } } else { return Some(NODE_91); } return None;} ll_ => {}, }; None})(label) { label = next; } }} ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs ================================================ no_analyze!{{ use singleLabels::*;let mut label = entry; while let Some(next) = (|label| -> Option { match label { NODE_93=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(sl_tree_1); } else { *img_labels_row00.add(c as usize) = 0.elem(); return Some(sl_tree_0); } } sl_tree_0 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(sl_tree_1); } else { *img_labels_row00.add(c as usize) = solver.new_label(); return Some(sl_tree_0); } } else { return Some(NODE_93); } } sl_tree_1 => { if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } } if (*img_row00.add((c) as usize)).to_bool() { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(sl_tree_1); } else { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); return Some(sl_tree_0); } } else { return Some(NODE_93); } } sl_break_0_0 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} sl_break_0_1 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { *img_labels_row00.add(c as usize) = 0.elem(); } return None;} NODE_94=> { if (*img_row00.add((c + 1) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { *img_labels_row00.add(c as usize) = 0.elem(); } } sl_break_1_0 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = solver.new_label(); } else { return Some(NODE_94); } return None;} sl_break_1_1 => { if (*img_row00.add((c) as usize)).to_bool() { *img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize); } else { return Some(NODE_94); } return None;} sl_ => {}, }; None})(label) { label = next; } }} ================================================ FILE: crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs ================================================ //! Spaghetti algorithm for connected component labeling //! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana, //! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling," //! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019. //! //! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN) //! as described in //! //! F. Bolelli, S. Allegretti, C. Grana. //! "One DAG to Rule Them All." //! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021 #![allow( unreachable_code, clippy::collapsible_else_if, clippy::if_same_then_else )] use std::cmp::Ordering; use burn_tensor::{Element, ElementComparison, ElementConversion, cast::ToElement}; use ndarray::{Array2, Axis, s}; #[allow(non_snake_case)] mod Spaghetti_forest_labels; pub(crate) use Spaghetti_forest_labels::*; use crate::Connectivity; use super::{Solver, StatsOp, max_labels}; pub fn process( img_arr: Array2, stats: &mut impl StatsOp

Burn MNIST Inference Demo

Draw a digit here Cropped and scaled Probability result
================================================ FILE: examples/mnist-inference-web/index.js ================================================ /** * * This demo is part of Burn project: https://github.com/tracel-ai/burn * * Released under a dual license: * https://github.com/tracel-ai/burn/blob/main/LICENSE-MIT * https://github.com/tracel-ai/burn/blob/main/LICENSE-APACHE * */ /** * Auto crops the image, scales to 28x28 pixel image, and returns as grayscale image. * @param {object} mainContext - The 2d context of the source canvas. * @param {object} cropContext - The 2d context of an intermediate hidden canvas. * @param {object} scaledContext - The 2d context of the destination 28x28 canvas. */ export function cropScaleGetImageData(mainContext, cropContext, scaledContext) { const cropEl = cropContext.canvas; // Get the auto-cropped image data and put into the intermediate/hidden canvas cropContext.fillStyle = "rgba(255, 255, 255, 255)"; // white non-transparent color cropContext.fillRect(0, 0, cropEl.width, cropEl.height); cropContext.save(); const [w, h, croppedImage] = cropImageFromCanvas(mainContext); cropEl.width = Math.max(w, h) * 1.2; cropEl.height = Math.max(w, h) * 1.2; const leftPadding = (cropEl.width - w) / 2; const topPadding = (cropEl.height - h) / 2; cropContext.putImageData(croppedImage, leftPadding, topPadding); // Copy image data to scale 28x28 canvas scaledContext.save(); scaledContext.clearRect(0, 0, scaledContext.canvas.height, scaledContext.canvas.width); scaledContext.fillStyle = "rgba(255, 255, 255, 255)"; // white non-transparent color scaledContext.fillRect(0, 0, cropEl.width, cropEl.height); scaledContext.scale(28.0 / cropContext.canvas.width, 28.0 / cropContext.canvas.height); scaledContext.drawImage(cropEl, 0, 0); // Extract image data and convert into single value (greyscale) array const data = rgba2gray(scaledContext.getImageData(0, 0, 28, 28).data); scaledContext.restore(); return data; } /** * Converts RGBA image data from canvas to grayscale (0 is white & 255 is black). * @param {int[]} - Image data. */ export function rgba2gray(data) { let converted = new Float32Array(data.length / 4); // Data is stored as [r0,g0,b0,a0, ... r[n],g[n],b[n],a[n]] where n is number of pixels. for (let i = 0; i < data.length; i += 4) { let r = 255 - data[i]; // red let g = 255 - data[i + 1]; // green let b = 255 - data[i + 2]; // blue let a = 255 - data[i + 3]; // alpha // Use RGB grayscale coefficients (https://imagej.nih.gov/ij/docs/menus/image.html) let y = 0.299 * r + 0.587 * g + 0.114 * b; converted[i / 4] = y; // 4 times fewer data points but the same number of pixels. } return converted; } /** * Auto crops a canvas images and returns its image data. * @param {object} ctx - canvas 2d context. * src: https://stackoverflow.com/a/22267731 */ export function cropImageFromCanvas(ctx) { let canvas = ctx.canvas, w = canvas.width, h = canvas.height, pix = { x: [], y: [] }, imageData = ctx.getImageData(0, 0, canvas.width, canvas.height), x, y, index; for (y = 0; y < h; y++) { for (x = 0; x < w; x++) { index = (y * w + x) * 4; let r = imageData.data[index]; let g = imageData.data[index + 1]; let b = imageData.data[index + 2]; // On some browsers the canvas has a grey border which prevents cropping if we do min != 255 if (Math.min(r, g, b) < 240) { pix.x.push(x); pix.y.push(y); } } } pix.x.sort(function (a, b) { return a - b; }); pix.y.sort(function (a, b) { return a - b; }); let n = pix.x.length - 1; w = 1 + pix.x[n] - pix.x[0]; h = 1 + pix.y[n] - pix.y[0]; return [w, h, ctx.getImageData(pix.x[0], pix.y[0], w, h, { willReadFrequently: true })]; } /** * Truncates number to a given decimal position * @param {number} num - Number to truncate. * @param {number} fixed - Decimal positions. * src: https://stackoverflow.com/a/11818658 */ export function toFixed(num, fixed) { const re = new RegExp('^-?\\d+(?:\.\\d{0,' + (fixed || -1) + '})?'); return num.toString().match(re)[0]; } /** * Looks up element by an id. * @param {string} - Element id. */ export function $(id) { return document.getElementById(id); } /** * Helper function that builds a chart using Chart.js library. * @param {object} chartEl - Chart canvas element. * * NOTE: Assumes chart.js is loaded into the global. */ export function chartConfigBuilder(chartEl) { Chart.register(ChartDataLabels); return new Chart(chartEl, { plugins: [ChartDataLabels], type: "bar", data: { labels: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], datasets: [ { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], borderWidth: 0, fill: true, backgroundColor: "#247ABF", }, ], }, options: { responsive: false, maintainAspectRatio: false, animation: true, plugins: { legend: { display: false, }, tooltip: { enabled: true, }, datalabels: { color: "white", formatter: function (value, context) { return toFixed(value, 2); }, }, }, scales: { y: { beginAtZero: true, max: 1.0, }, }, }, }); } ================================================ FILE: examples/mnist-inference-web/run-server.sh ================================================ #!/usr/bin/env bash # Opening index.html file directly by a browser does not work because of # the security restrictions by the browser. Viewing the HTML file will fail with # this error message: # ``` # Access to script at # 'file:///Users/user/Projects/burn-mac/examples/mnist-inference-web/pkg/mnist_inference_web.js' # from origin 'null' has been blocked by CORS policy: # Cross origin requests are only supported for protocol schemes: # http, data, isolated-app, chrome-extension, chrome, https, chrome-untrusted. # ``` # So that's why running a local HTTP server is needed. if ! command -v python3 &> /dev/null then echo "python3 could not be found. Running server requires python3." exit fi echo "Running local python HTTP server on port 8000 ..." python3 -m http.server 8000 ================================================ FILE: examples/mnist-inference-web/src/lib.rs ================================================ #![cfg_attr(not(test), no_std)] pub mod model; pub mod state; pub mod web; extern crate alloc; ================================================ FILE: examples/mnist-inference-web/src/model.rs ================================================ use burn::{ nn::{ BatchNorm, PaddingConfig2d, pool::{MaxPool2d, MaxPool2dConfig}, }, prelude::*, }; #[derive(Module, Debug)] pub struct Model { conv1: ConvBlock, conv2: ConvBlock, dropout: nn::Dropout, fc1: nn::Linear, fc2: nn::Linear, fc3: nn::Linear, activation: nn::Gelu, } const NUM_CLASSES: usize = 10; impl Model { pub fn new(device: &B::Device) -> Self { let conv1 = ConvBlock::new([1, 64], [3, 3], device, true); // out: max_pool -> [Batch,32,13,13] let conv2 = ConvBlock::new([64, 64], [3, 3], device, true); // out: max_pool -> [Batch,64,5,5] let hidden_size = 64 * 5 * 5; let fc1 = nn::LinearConfig::new(hidden_size, 128).init(device); let fc2 = nn::LinearConfig::new(128, 128).init(device); let fc3 = nn::LinearConfig::new(128, NUM_CLASSES).init(device); let dropout = nn::DropoutConfig::new(0.25).init(); Self { conv1, conv2, dropout, fc1, fc2, fc3, activation: nn::Gelu::new(), } } pub fn forward(&self, input: Tensor) -> Tensor { let [batch_size, height, width] = input.dims(); let x = input.reshape([batch_size, 1, height, width]).detach(); let x = self.conv1.forward(x); let x = self.conv2.forward(x); let [batch_size, channels, height, width] = x.dims(); let x = x.reshape([batch_size, channels * height * width]); let x = self.fc1.forward(x); let x = self.activation.forward(x); let x = self.dropout.forward(x); let x = self.fc2.forward(x); let x = self.activation.forward(x); let x = self.dropout.forward(x); self.fc3.forward(x) } } #[derive(Module, Debug)] pub struct ConvBlock { conv: nn::conv::Conv2d, norm: BatchNorm, pool: Option, activation: nn::Relu, } impl ConvBlock { pub fn new( channels: [usize; 2], kernel_size: [usize; 2], device: &B::Device, pool: bool, ) -> Self { let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) .with_padding(PaddingConfig2d::Valid) .init(device); let norm = nn::BatchNormConfig::new(channels[1]).init(device); let pool = if pool { Some(MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init()) } else { None }; Self { conv, norm, pool, activation: nn::Relu::new(), } } pub fn forward(&self, input: Tensor) -> Tensor { let x = self.conv.forward(input); let x = self.norm.forward(x); let x = self.activation.forward(x); if let Some(pool) = &self.pool { pool.forward(x) } else { x } } } ================================================ FILE: examples/mnist-inference-web/src/state.rs ================================================ use crate::model::Model; use burn::{ module::Module, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, }; #[cfg(feature = "wgpu")] use burn::backend::wgpu::{Wgpu, WgpuDevice, graphics::AutoGraphicsApi, init_setup_async}; #[cfg(feature = "wgpu")] pub type Backend = Wgpu; #[cfg(all(feature = "ndarray", not(feature = "wgpu")))] pub type Backend = burn::backend::ndarray::NdArray; static STATE_ENCODED: &[u8] = include_bytes!("../model.bin"); /// Builds and loads trained parameters into the model. pub async fn build_and_load_model() -> Model { #[cfg(feature = "wgpu")] init_setup_async::(&WgpuDevice::default(), Default::default()).await; let model: Model = Model::new(&Default::default()); let record = BinBytesRecorder::::default() .load(STATE_ENCODED, &Default::default()) .expect("Failed to decode state"); model.load_record(record) } ================================================ FILE: examples/mnist-inference-web/src/web.rs ================================================ #![allow(clippy::new_without_default)] use alloc::string::String; use js_sys::Array; #[cfg(target_family = "wasm")] use wasm_bindgen::prelude::*; use crate::model::Model; use crate::state::{Backend, build_and_load_model}; use burn::tensor::Tensor; #[cfg_attr(target_family = "wasm", wasm_bindgen(start))] pub fn start() { console_error_panic_hook::set_once(); } /// Mnist structure that corresponds to JavaScript class. /// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html) #[cfg_attr(target_family = "wasm", wasm_bindgen)] pub struct Mnist { model: Option>, } #[cfg_attr(target_family = "wasm", wasm_bindgen)] impl Mnist { /// Constructor called by JavaScripts with the new keyword. #[cfg_attr(target_family = "wasm", wasm_bindgen(constructor))] pub fn new() -> Self { console_error_panic_hook::set_once(); Self { model: None } } /// Returns the inference results. /// /// This method is called from JavaScript via generated wrapper code by wasm-bindgen. /// /// # Arguments /// /// * `input` - A f32 slice of input 28x28 image /// /// See bindgen support types for passing and returning arrays: /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html) /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html) /// pub async fn inference(&mut self, input: &[f32]) -> Result { if self.model.is_none() { self.model = Some(build_and_load_model().await); } let model = self.model.as_ref().unwrap(); let device = Default::default(); // Reshape from the 1D array to 3d tensor [batch, height, width] let input = Tensor::::from_floats(input, &device).reshape([1, 28, 28]); // Normalize input: make between [0,1] and make the mean=0 and std=1 // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 let input = ((input / 255) - 0.1307) / 0.3081; // Run the tensor input through the model let output: Tensor = model.forward(input); // Convert the model output into probability distribution using softmax formula let output = burn::tensor::activation::softmax(output, 1); // Flatten output tensor with [1, 10] shape into boxed slice of [f32] let output = output.into_data_async().await.unwrap(); let array = Array::new(); for value in output.iter::() { array.push(&value.into()); } Ok(array) } } ================================================ FILE: examples/modern-lstm/Cargo.toml ================================================ [package] edition.workspace = true name = "modern-lstm" version = "0.5.0" [lints] workspace = true [features] cuda = ["burn/cuda"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] [dependencies] burn = { path = "../../crates/burn", features = ["train"] } # Random number generator rand = { workspace = true, features = ["thread_rng"] } rand_distr = { workspace = true } # Serialization serde = { workspace = true, features = ["std", "derive"] } # Organise the results in dataframe planus = { workspace = true } polars = { workspace = true } ================================================ FILE: examples/modern-lstm/README.md ================================================ # Advanced LSTM Implementation with Burn A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined weight matrices for the input and hidden states, based on the [PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch). `LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM variants differ by `bidirectional` and `num_layers` settings: - LSTM: `num_layers = 1` and `bidirectional = false` - Stacked LSTM: `num_layers > 1` and `bidirectional = false` - Bidirectional LSTM: `num_layers = 1` and `bidirectional = true` - Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true` This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs. ## Usage ## Training ```sh # Cuda backend cargo run --example lstm-train --release --features cuda # Wgpu backend cargo run --example lstm-train --release --features wgpu # Tch GPU backend export TORCH_CUDA_VERSION=cu128 # Set the cuda version cargo run --example lstm-train --release --features tch-gpu # Tch CPU backend cargo run --example lstm-train --release --features tch-cpu # NdArray backend (CPU) cargo run --example lstm-train --release --features ndarray cargo run --example lstm-train --release --features ndarray-blas-openblas cargo run --example lstm-train --release --features ndarray-blas-netlib ``` ### Inference ```sh cargo run --example lstm-infer --release --features cuda ``` ================================================ FILE: examples/modern-lstm/examples/lstm-infer.rs ================================================ use burn::tensor::backend::Backend; pub fn launch(device: B::Device) { modern_lstm::inference::infer::("/tmp/modern-lstm", device); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::launch; pub fn run() { launch::(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { launch::(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; use burn::backend::wgpu::Wgpu; pub fn run() { launch::(Default::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::launch; use burn::backend::Cuda; pub fn run() { launch::(Default::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); } ================================================ FILE: examples/modern-lstm/examples/lstm-train.rs ================================================ use burn::{ grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend, }; use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig}; pub fn launch(device: B::Device) { let config = TrainingConfig::new( LstmNetworkConfig::new(), // Gradient clipping via optimizer config AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))), ); modern_lstm::training::train::("/tmp/modern-lstm", config, device); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::{ Autodiff, ndarray::{NdArray, NdArrayDevice}, }; use crate::launch; pub fn run() { launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::launch; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::launch; pub fn run() { launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; use burn::backend::{Autodiff, wgpu::Wgpu}; pub fn run() { launch::>(Default::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::launch; use burn::backend::{Autodiff, Cuda, cuda::CudaDevice}; pub fn run() { launch::>(CudaDevice::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); } ================================================ FILE: examples/modern-lstm/src/dataset.rs ================================================ use burn::{ data::{ dataloader::batcher::Batcher, dataset::{Dataset, InMemDataset}, }, prelude::*, }; use rand::RngExt; use rand_distr::{Distribution, Normal}; use serde::{Deserialize, Serialize}; // Dataset parameters pub const NUM_SEQUENCES: usize = 1000; pub const SEQ_LENGTH: usize = 10; pub const NOISE_LEVEL: f32 = 0.1; pub const RANDOM_SEED: u64 = 5; // Generate a sequence where each number is the sum of previous two numbers plus noise #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SequenceDatasetItem { pub sequence: Vec, pub target: f32, } impl SequenceDatasetItem { pub fn new(seq_length: usize, noise_level: f32) -> Self { // Start with two random numbers between 0 and 1 let mut seq = vec![rand::rng().random(), rand::rng().random()]; // Generate sequence for _i in 0..seq_length { // Next number is sum of previous two plus noise let normal = Normal::new(0.0, noise_level).unwrap(); let next_val = seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::rng()); seq.push(next_val); } Self { // Convert to sequence and target sequence: seq[0..seq.len() - 1].to_vec(), // All but last target: seq[seq.len() - 1], // Last value } } } // Custom Dataset for Sequence Data pub struct SequenceDataset { dataset: InMemDataset, } impl SequenceDataset { pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { let mut items = vec![]; for _i in 0..num_sequences { items.push(SequenceDatasetItem::new(seq_length, noise_level)); } let dataset = InMemDataset::new(items); Self { dataset } } } impl Dataset for SequenceDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } #[derive(Clone, Debug, Default)] pub struct SequenceBatcher {} #[derive(Clone, Debug)] pub struct SequenceBatch { pub sequences: Tensor, // [batch_size, seq_length, input_size] pub targets: Tensor, // [batch_size, 1] } impl Batcher> for SequenceBatcher { fn batch(&self, items: Vec, device: &B::Device) -> SequenceBatch { let mut sequences: Vec> = Vec::new(); for item in items.iter() { let seq_tensor = Tensor::::from_floats(item.sequence.as_slice(), device); // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations sequences.push(seq_tensor.unsqueeze_dims(&[-1])); } let sequences = Tensor::stack(sequences, 0); let targets = items .iter() .map(|item| Tensor::::from_floats([item.target], device)) .collect(); let targets = Tensor::stack(targets, 0); SequenceBatch { sequences, targets } } } ================================================ FILE: examples/modern-lstm/src/inference.rs ================================================ use crate::{ dataset::{ NOISE_LEVEL, NUM_SEQUENCES, SEQ_LENGTH, SequenceBatcher, SequenceDataset, SequenceDatasetItem, }, model::LstmNetwork, training::TrainingConfig, }; use burn::{ data::{dataloader::batcher::Batcher, dataset::Dataset}, prelude::*, record::{CompactRecorder, Recorder}, }; use polars::prelude::*; pub fn infer(artifact_dir: &str, device: B::Device) { // Loading model let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) .expect("Config should exist for the model; run train first"); let record = CompactRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) .expect("Trained model should exist; run train first"); let model: LstmNetwork = config.model.init(&device).load_record(record); let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL); let items: Vec = dataset.iter().collect(); let batcher = SequenceBatcher::default(); // Put all items in one batch let batch = batcher.batch(items, &device); let predicted = model.forward(batch.sequences, None); let targets = batch.targets; let predicted = predicted.squeeze_dim::<1>(1).into_data(); let expected = targets.squeeze_dim::<1>(1).into_data(); // Display the predicted vs expected values let results = df![ "predicted" => &predicted.to_vec::().unwrap(), "expected" => &expected.to_vec::().unwrap(), ] .unwrap(); println!("{}", &results.head(Some(10))); } ================================================ FILE: examples/modern-lstm/src/lib.rs ================================================ pub mod dataset; pub mod inference; pub mod model; pub mod training; ================================================ FILE: examples/modern-lstm/src/model.rs ================================================ use burn::{ nn::{ Dropout, DropoutConfig, Initializer, LayerNorm, LayerNormConfig, Linear, LinearConfig, LstmState, Sigmoid, Tanh, }, prelude::*, }; /// LSTM Cell implementation with layer normalization. /// /// Mathematical formulation of LSTM: /// f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate /// i_t = σ(W_i · [h_{t-1}, x_t] + b_i] # Input gate /// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g] # Candidate cell state /// o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate /// /// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state /// h_t = o_t ⊙ tanh(c_t) # New hidden state /// /// where: /// - σ is the sigmoid function /// - ⊙ is the element-wise multiplication /// - [h_{t-1}, x_t] represents concatenation #[derive(Module, Debug)] pub struct LstmCell { pub hidden_size: usize, // Combined weight matrices for efficiency // weight_ih layer uses combined weights for [i_t, f_t, g_t, o_t] for input x_t // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1} pub weight_ih: Linear, pub weight_hh: Linear, // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. pub norm_x: LayerNorm, // Normalize gate pre-activations pub norm_h: LayerNorm, // Normalize hidden state pub norm_c: LayerNorm, // Normalize cell state pub dropout: Dropout, } /// Configuration to create a Lstm module using the init function. #[derive(Config, Debug)] pub struct LstmCellConfig { // The size of the input features pub input_size: usize, // The size of the hidden state pub hidden_size: usize, // The number of hidden layers pub dropout: f64, } impl LstmCellConfig { // Initialize parameters using best practices: // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn) // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training #[allow(clippy::single_range_in_vec_init)] pub fn init(&self, device: &B::Device) -> LstmCell { let initializer = Initializer::XavierNormal { gain: 1.0 }; let init_bias = Tensor::::ones([self.hidden_size], device); let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size) .with_initializer(initializer.clone()) .init(device); // Set forget gate bias to 1.0 (helps with learning long sequences) let bias = weight_ih .bias .clone() .unwrap() .val() .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone()); weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias)); let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size) .with_initializer(initializer) .init(device); let bias = weight_hh .bias .clone() .unwrap() .val() .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias); weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias)); LstmCell { hidden_size: self.hidden_size, weight_ih, weight_hh, norm_x: LayerNormConfig::new(4 * self.hidden_size).init(device), norm_h: LayerNormConfig::new(self.hidden_size).init(device), norm_c: LayerNormConfig::new(self.hidden_size).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } } impl LstmCell { /// Forward pass of LSTM cell. /// Args: /// x: Input tensor of shape (batch_size, input_size) /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size) /// Returns: /// Tuple of (h_t, c_t) representing new hidden and cell states pub fn forward(&self, x: Tensor, state: LstmState) -> LstmState { let (h_prev, c_prev) = (state.hidden, state.cell); // Combined matrix multiplication for all gates // Shape: (batch_size, 4 * hidden_size) let gates_x = self.weight_ih.forward(x); // Transform input let gates_h = self.weight_hh.forward(h_prev); // Transform previous hidden state // Apply layer normalization let gates_x = self.norm_x.forward(gates_x); // Combined gate pre-activations let gates = gates_x + gates_h; // Split into individual gates // Each gate shape: (batch_size, hidden_size) let gates = gates.chunk(4, 1); let i_gate = gates[0].clone(); let f_gate = gates[1].clone(); let g_gate = gates[2].clone(); let o_gate = gates[3].clone(); // Apply gate non-linearities let i_t = Sigmoid::new().forward(i_gate); let f_t = Sigmoid::new().forward(f_gate); let g_t = Tanh::new().forward(g_gate); let o_t = Sigmoid::new().forward(o_gate); // Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t let c_t = f_t * c_prev + i_t * g_t; let c_t = self.norm_c.forward(c_t); // Update cell state: h_t = o_t ⊙ tanh(c_t) let h_t = o_t * Tanh::new().forward(c_t.clone()); let h_t = self.norm_h.forward(h_t); let h_t = self.dropout.forward(h_t); LstmState::new(h_t, c_t) } // Initialize cell state and hidden state if provided or with zeros pub fn init_state(&self, batch_size: usize, device: &B::Device) -> LstmState { let cell = Tensor::zeros([batch_size, self.hidden_size], device); let hidden = Tensor::zeros([batch_size, self.hidden_size], device); LstmState::new(cell, hidden) } } /// Stacked LSTM implementation supporting multiple layers /// Each layer processes the output of the previous layer #[derive(Module, Debug)] pub struct StackedLstm { pub layers: Vec>, } #[derive(Config, Debug)] pub struct StackedLstmConfig { pub input_size: usize, pub hidden_size: usize, pub num_layers: usize, pub dropout: f64, } impl StackedLstmConfig { pub fn init(&self, device: &B::Device) -> StackedLstm { let mut layers: Vec> = vec![]; // Create list of LSTM cells, one for each layer for i in 0..self.num_layers { if i == 0 { if i < self.num_layers - 1 { layers.push( LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout) .init(device), ); } else { // No dropout on last layer layers.push( LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device), ); } } else if i < self.num_layers - 1 { layers.push( LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout) .init(device), ); } else { // No dropout on last layer layers.push( LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device), ); } } StackedLstm { layers } } } impl StackedLstm { /// Process input sequence through stacked LSTM layers. /// /// Args: /// x: Input tensor of shape (batch_size, seq_length, input_size) /// states: Optional initial states for each layer /// /// Returns: /// Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size) /// and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size) pub fn forward( &self, x: Tensor, states: Option>>, ) -> (Tensor, Vec>) { let [batch_size, seq_length, _] = x.dims(); let device = x.device(); let mut states = match states { None => { let mut temp: Vec> = vec![]; for layer in self.layers.iter() { temp.push(layer.init_state(batch_size, &device)); } temp } _ => states.unwrap(), }; let mut layer_outputs = vec![]; for t in 0..seq_length { let mut input_t = x.clone().slice(s![.., t..t + 1, ..]).squeeze_dim::<2>(1); for (i, lstm_cell) in self.layers.iter().enumerate() { let mut state: LstmState = LstmState::new(states[i].cell.clone(), states[i].hidden.clone()); state = lstm_cell.forward(input_t, state); input_t = state.hidden.clone(); states[i] = state; } layer_outputs.push(input_t); } // Stack output along sequence dimension let output = Tensor::stack(layer_outputs, 1); (output, states) } } /// Complete LSTM network with bidirectional support. /// /// In bidirectional mode: /// - Forward LSTM processes sequence from left to right /// - Backward LSTM processes sequence from right to left /// - Outputs are concatenated for final prediction #[derive(Module, Debug)] pub struct LstmNetwork { // Forward direction LSTM pub stacked_lstm: StackedLstm, // Optional backward direction LSTM for bidirectional processing pub reverse_lstm: Option>, pub dropout: Dropout, pub fc: Linear, } #[derive(Config, Debug)] pub struct LstmNetworkConfig { #[config(default = 1)] pub input_size: usize, // Single feature (number sequence) #[config(default = 32)] pub hidden_size: usize, // Size of LSTM hidden state #[config(default = 2)] pub num_layers: usize, // Number of LSTM layers #[config(default = 1)] pub output_size: usize, // Predict one number #[config(default = 0.1)] pub dropout: f64, #[config(default = true)] pub bidirectional: bool, // Use bidirectional LSTM } impl LstmNetworkConfig { pub fn init(&self, device: &B::Device) -> LstmNetwork { // Forward direction LSTM let stacked_lstm = StackedLstmConfig::new( self.input_size, self.hidden_size, self.num_layers, self.dropout, ) .init(device); // Optional backward direction LSTM for bidirectional processing let (reverse_lstm, hidden_size) = if self.bidirectional { let lstm = StackedLstmConfig::new( self.input_size, self.hidden_size, self.num_layers, self.dropout, ) .init(device); (Some(lstm), 2 * self.hidden_size) } else { (None, self.hidden_size) }; let fc = LinearConfig::new(hidden_size, self.output_size).init(device); let dropout = DropoutConfig::new(self.dropout).init(); LstmNetwork { stacked_lstm, reverse_lstm, dropout, fc, } } } impl LstmNetwork { /// Forward pass of the network. /// /// For bidirectional processing: /// 1. Process sequence normally with forward LSTM /// 2. Process reversed sequence with backward LSTM /// 3. Concatenate both outputs /// 4. Apply final linear transformation /// /// Args: /// x: Input tensor of shape (batch_size, seq_length, input_size) /// states: Optional initial states /// /// Returns: /// Output tensor of shape (batch_size, output_size) pub fn forward(&self, x: Tensor, states: Option>>) -> Tensor { let seq_length = x.dims()[1]; // Forward direction let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states); output = match &self.reverse_lstm { Some(reverse_lstm) => { //Process sequence in reverse direction let (mut reverse_output, _states) = reverse_lstm.forward(x.flip([1]), None); // Flip back to align with forward sequence reverse_output = reverse_output.flip([1]); // Concatenate forward and backward outputs along the feature dimension output = Tensor::cat(vec![output, reverse_output], 2); output } None => output, }; // Apply dropout before final layer output = self.dropout.forward(output); // Use final timestep output for prediction self.fc.forward( output .slice(s![.., seq_length - 1..seq_length, ..]) .squeeze_dim::<2>(1), ) } } ================================================ FILE: examples/modern-lstm/src/training.rs ================================================ use crate::dataset::{ NOISE_LEVEL, NUM_SEQUENCES, RANDOM_SEED, SEQ_LENGTH, SequenceBatcher, SequenceDataset, }; use crate::model::{LstmNetwork, LstmNetworkConfig}; use burn::{ data::dataloader::DataLoaderBuilder, module::AutodiffModule, nn::loss::{MseLoss, Reduction::Mean}, optim::{AdamConfig, GradientsParams, Optimizer}, prelude::*, record::CompactRecorder, tensor::backend::AutodiffBackend, }; #[derive(Config, Debug)] pub struct TrainingConfig { pub model: LstmNetworkConfig, pub optimizer: AdamConfig, #[config(default = 30)] pub num_epochs: usize, #[config(default = 32)] pub batch_size: usize, #[config(default = 2)] pub num_workers: usize, #[config(default = 1e-3)] pub lr: f64, } // Create the directory to save the model and model config fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { create_artifact_dir(artifact_dir); // Save training config config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); B::seed(&device, RANDOM_SEED); // Create the model and optimizer let mut model = config.model.init::(&device); let mut optim = config.optimizer.init::>(); // Create the batcher let batcher = SequenceBatcher::default(); // Create the dataloaders let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .shuffle(RANDOM_SEED) .num_workers(config.num_workers) .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL)); let dataloader_valid = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .shuffle(RANDOM_SEED) .num_workers(config.num_workers) // 20% size of training .build(SequenceDataset::new( NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL, )); let train_num_items = dataloader_train.num_items(); let valid_num_items = dataloader_valid.num_items(); println!("Starting training..."); // Iterate over our training for X epochs for epoch in 1..config.num_epochs + 1 { // Initialize the training and validation metrics at the start of each epoch let mut train_losses = vec![]; let mut train_loss = 0.0; let mut valid_losses = vec![]; let mut valid_loss = 0.0; // Implement our training loop for batch in dataloader_train.iter() { let output = model.forward(batch.sequences, None); let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; // Gradients for the current backward pass let grads = loss.backward(); // Gradients linked to each parameter of the model let grads = GradientsParams::from_grads(grads, &model); // Update the model using the optimizer model = optim.step(config.lr, model, grads); } // The averaged train loss per epoch let avg_train_loss = train_loss / train_num_items as f32; train_losses.push(avg_train_loss); // Get the model without autodiff let valid_model = model.valid(); // Implement our validation loop for batch in dataloader_valid.iter() { let output = valid_model.forward(batch.sequences, None); let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); valid_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; } // The averaged train loss per epoch let avg_valid_loss = valid_loss / valid_num_items as f32; valid_losses.push(avg_valid_loss); // Display the averaged training and validation metrics every 10 epochs if (epoch + 1) % 5 == 0 { println!( "Epoch {}/{}, Avg Loss {:.4}, Avg Val Loss: {:.4}", epoch + 1, config.num_epochs, avg_train_loss, avg_valid_loss, ); } } // Save the trained model model .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) .expect("Trained model should be saved successfully"); } ================================================ FILE: examples/multi-gpus/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] edition.workspace = true license.workspace = true name = "multi-gpus" publish = false version.workspace = true [lints] workspace = true [features] default = [] f16 = [] flex32 = [] tch-gpu = ["burn/tch"] cuda = ["burn/cuda"] rocm = ["burn/rocm"] [dependencies] # Burn burn = { path = "../../crates/burn", features = [ "autotune", "fusion", "collective", "train", "std", ], default-features = false } text-classification = { path = "../text-classification" } ================================================ FILE: examples/multi-gpus/examples/multi-gpus.rs ================================================ fn main() { #[cfg(feature = "cuda")] multi_gpus::run::(); #[cfg(feature = "rocm")] multi_gpus::run::(); #[cfg(feature = "tch-gpu")] multi_gpus::run::(); } ================================================ FILE: examples/multi-gpus/src/lib.rs ================================================ use burn::{ backend::Autodiff, collective::{self, CollectiveConfig, PeerId, ReduceOperation}, data::{dataloader::DataLoaderBuilder, dataset::transform::PartialDataset}, nn::transformer::TransformerEncoderConfig, optim::{GradientsParams, Optimizer, SgdConfig}, prelude::*, tensor::{ TensorPrimitive, backend::{AutodiffBackend, DeviceId}, }, }; use std::{ sync::{Arc, mpsc::SyncSender}, time::Instant, }; use text_classification::{ AgNewsDataset, TextClassificationDataset, data::{TextClassificationBatcher, Tokenizer}, model::TextClassificationModel, }; pub fn run() { let type_id = 0; let num_devices = B::Device::device_count(type_id); let devices = (0..num_devices) .map(|i| B::Device::from_id(DeviceId::new(type_id, i as u32))) .collect(); run_with::(devices); } fn run_with(devices: Vec) { for strategy in [ collective::AllReduceStrategy::Tree(2), collective::AllReduceStrategy::Ring, collective::AllReduceStrategy::Centralized, ] { println!("[Gradient Update - {strategy:?}] starting ..."); let start = Instant::now(); task_grad_all_reduce::>(devices.clone(), strategy); println!( "[Gradient Update - {strategy:?}] took {:?}", start.elapsed() ); } for strategy in [ collective::AllReduceStrategy::Centralized, collective::AllReduceStrategy::Ring, collective::AllReduceStrategy::Tree(2), ] { println!("[All Reduce - {strategy:?}] starting ..."); let start = Instant::now(); task_all_reduce::(devices.clone(), 420, strategy); println!("[All Reduce - {strategy:?}] took {:?}", start.elapsed()); } task_naive_aggregation::(devices.clone(), 100); } fn task_naive_aggregation(mut devices: Vec, num_iterations: usize) { let aggregation_device = devices.pop().unwrap(); let shape = [8, 4096, 4096]; let (sender, receiver) = std::sync::mpsc::sync_channel(devices.len()); fn compute(input: Tensor) -> Tensor { let log = input.clone() + 1.0; input.matmul(log) } let mut handles = devices .into_iter() .map(|device| { let sender = sender.clone(); std::thread::spawn(move || { let input = Tensor::::random(shape, burn::tensor::Distribution::Default, &device); for _ in 0..num_iterations { let new = compute(input.clone()); sender.send(new.clone()).unwrap(); } }) }) .collect::>(); handles.push(std::thread::spawn(move || { let mut input = Tensor::::random( shape, burn::tensor::Distribution::Default, &aggregation_device, ); while let Ok(tensor) = receiver.recv() { let main = tensor.to_device(&aggregation_device); let value = main.clone().sum().into_scalar().elem::(); input = input + main / 2; println!("{value:?}"); assert_ne!(value, 0.0); } })); for handle in handles { handle.join().unwrap(); } } fn task_all_reduce( devices: Vec, num_iterations: usize, strategy: collective::AllReduceStrategy, ) { let num_devices = devices.len(); let batch = 32; let shape_signal = [batch, 2048, 2048]; let shape_weights = [1, 2048, 2048]; fn compute(weights: Tensor, signal: Tensor) -> Tensor { weights.matmul(signal) } let handles = devices .into_iter() .enumerate() .map(|(id, device)| { std::thread::spawn(move || { let mut weights = Tensor::::random( shape_weights, burn::tensor::Distribution::Default, &device, ) - 0.5; let id = PeerId::from(id); let config = CollectiveConfig::default() .with_num_devices(num_devices) .with_local_all_reduce_strategy(strategy); collective::register::(id, device.clone(), config).unwrap(); for i in 0..num_iterations { let signal = Tensor::::random( shape_signal, burn::tensor::Distribution::Default, &device, ) - 0.5; let signal = compute(weights, signal); let weights_update = signal.mean_dim(0); let result = collective::all_reduce::( id, weights_update.into_primitive().tensor(), ReduceOperation::Mean, ) .unwrap(); weights = Tensor::from_primitive(TensorPrimitive::Float(result)); let val = weights.clone().sum().into_scalar().elem::(); if id == PeerId::from(0) { println!("Iter {i} => {val}"); } } collective::finish_collective::(id).unwrap(); }) }) .collect::>(); for handle in handles { handle.join().unwrap(); } } fn task_grad_all_reduce( devices: Vec, strategy: collective::AllReduceStrategy, ) { let num_devices = devices.len(); let seq_length = nn::attention::SeqLengthOption::Fixed(512); let batch_size = 32; let config = TransformerEncoderConfig::new(256, 1024, 8, 4); let dataset = text_classification::AgNewsDataset::train(); let tokenizer = Arc::new(text_classification::data::BertCasedTokenizer::default()); let model_config = text_classification::model::TextClassificationModelConfig::new( config, AgNewsDataset::num_classes(), tokenizer.vocab_size(), seq_length, ); let datasets = PartialDataset::split(dataset, devices.len()); let model_main = model_config.init(&devices[0]); let handles = devices .into_iter() .zip(datasets) .enumerate() .map(|(id, (device, dataset))| { let model_main = model_main.clone(); let tokenizer = tokenizer.clone(); std::thread::spawn(move || { println!("[{id}] Running on device {device:?}"); let mut model = model_main.fork(&device); let batcher = TextClassificationBatcher::new(tokenizer, seq_length); let dataloader_train = DataLoaderBuilder::new(batcher) .batch_size(batch_size) .set_device(device.clone()) .build(dataset); let syncher = GradSyncer::start::( CollectiveConfig::default() .with_num_devices(num_devices) .with_local_all_reduce_strategy(strategy), device.clone(), PeerId::from(id), ); let mut optim = SgdConfig::new().init::>(); for (i, batch) in dataloader_train.iter().enumerate() { let output = model.forward(batch); let loss: Tensor = output.loss.clone(); let grads = loss.backward(); let loss = loss.into_scalar().elem::(); let grads = GradientsParams::from_grads(grads, &model); let grads = syncher.sync(grads); if let Some(grads) = grads { model = optim.step(1.0e-5, model, grads); } println!("[{id}] Iter {i} => {loss}"); } }) }) .collect::>(); for handle in handles { handle.join().unwrap(); } } struct GradSyncer { sender: SyncSender, } struct Message { callback: SyncSender>, grads: GradientsParams, } impl GradSyncer { fn start(config: CollectiveConfig, device: Device, id: PeerId) -> Self { let (sender, receiver) = std::sync::mpsc::sync_channel::(8); std::thread::spawn(move || { println!("[{id}] Register collective operation {config:?}"); collective::register::(id, device, config).unwrap(); let num_stages = 4; let mut buffers: Vec = Vec::new(); while let Ok(msg) = receiver.recv() { let grads = msg .grads .all_reduce::(id, ReduceOperation::Mean) .unwrap(); buffers.push(grads); let result = if buffers.len() >= num_stages { Some(buffers.remove(0)) } else { None }; msg.callback.send(result).unwrap(); } collective::finish_collective::(id).unwrap(); }); Self { sender } } fn sync(&self, grads: GradientsParams) -> Option { let (sender, receiver) = std::sync::mpsc::sync_channel(1); let msg = Message { callback: sender, grads, }; self.sender.send(msg).unwrap(); receiver.recv().unwrap() } } ================================================ FILE: examples/notebook/README.md ================================================ # Jupyter Notebook Examples with Burn This directory includes Jupyter Notebook examples showcasing the usage of the Burn deep learning framework in Rust through [Evcxr Jupyter](https://github.com/evcxr/evcxr/blob/main/evcxr_jupyter/README.md). The examples are systematically organized based on the specific Burn features they illustrate. ## Viewing Options You can explore the examples in different ways: - **Notebook Viewer:** If you prefer not to set up the entire crate package, you can view the examples in a notebook viewer or run them to see images and other media outputs. - **Visual Studio Code (vscode):** If you're using vscode, you already have access to a built-in notebook viewer, enabling you to open and interact with the notebook files directly. For other editors, you can utilize the [Jupyter Notebook Viewer](https://nbviewer.jupyter.org/). ## Getting Started with Rust and Evcxr To execute the Rust code within the notebooks, you must install the Evcxr kernel. Here's how to get started: ### Install Evcxr Kernel 1. **Build Evcxr Kernel:** Install the required package with the following command: ```shell cargo install evcxr_jupyter ``` 2. **Install and Register the Kernel to Jupyter:** ```shell evcxr_jupyter --install ``` ### Open and Run Notebooks Once the kernel is installed, you can open the notebook files in your preferred editor and run the code. Ensure that the kernel is set to `Rust` within the notebook for proper execution. ## Additional Reading Resources - [Notebook Special Commands for Evcxr](https://github.com/evcxr/evcxr/blob/main/COMMON.md): Learn about the unique commands and functionalities offered by Evcxr for a more efficient workflow with Jupyter Notebooks. ================================================ FILE: examples/notebook/autodiff.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Autodifferentiation and Gradient Descent in Burn\n", "\n", "This notebook demonstrates how to use automatic differentiation in Burn to compute gradients and implement gradient descent." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [], "source": [ "// Dependency declarations\n", ":dep burn = {path = \"../../crates/burn\"}\n", ":dep burn-ndarray = {path = \"../../crates/burn-ndarray\"}\n", ":dep burn-autodiff = {path = \"../../crates/burn-autodiff\"}\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [], "source": [ "// Import packages\n", "use burn::prelude::*;\n", "use burn_autodiff::Autodiff;\n", "use burn_ndarray::NdArray;\n", "\n", "// Type alias: Autodiff enables automatic differentiation\n", "type B = Autodiff>;\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Understanding require_grad()\n", "\n", "In Burn, tensors can be marked for gradient tracking using `.require_grad()`. This tells the framework to track operations on this tensor so gradients can be computed later." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Regular tensor x: Tensor {\n", " data:\n", "[1.0, 2.0, 3.0, 4.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Tensor y with require_grad: Tensor {\n", " data:\n", "[1.0, 2.0, 3.0, 4.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "result = sum(y * 2) = Tensor {\n", " data:\n", "[20.0],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "let device = ::Device::default();\n", "\n", "// Create a regular tensor - no gradient tracking\n", "let x: Tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);\n", "println!(\"Regular tensor x: {}\", x);\n", "\n", "// Create a tensor that requires gradient computation\n", "let y: Tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\n", "println!(\"Tensor y with require_grad: {}\", y);\n", "\n", "// Now let's do some operations on y\n", "let z = y.clone() * 2.0;\n", "let result = z.sum();\n", "println!(\"result = sum(y * 2) = {}\", result);\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Computing Gradients with backward()\n", "\n", "The `.backward()` method computes the gradients of all tensors that have `require_grad()` set. It returns a gradients object that holds the computed gradients." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = Tensor {\n", " data:\n", "[1.0, 2.0, 3.0, 4.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "dy/dx = Tensor {\n", " data:\n", "[2.0, 2.0, 2.0, 2.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Example: y = [1, 2, 3, 4]\n", "// z = y * 2 = [2, 4, 6, 8]\n", "// result = sum(z) = 20\n", "//\n", "// d(result)/d(y) = d(result)/dz * dz/dy = 1 * 2 = [2, 2, 2, 2]\n", "\n", "let device = ::Device::default();\n", "let y: Tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\n", "let z = y.clone() * 2.0;\n", "let result = z.sum();\n", "\n", "// Compute gradients\n", "let grads = result.backward();\n", "\n", "// Get gradient for y\n", "let y_grad = y.grad(&grads).unwrap();\n", "println!(\"y = {}\", y);\n", "println!(\"d(result)/dy = {}\", y_grad);\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. More Complex Example: Quadratic Function\n", "Let's compute the gradient of a more complex function: f(x) = x²\n", "\n", "The derivative is: f'(x) = 2x" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x = Tensor {\n", " data:\n", "[1.0, 2.0, 3.0, 4.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "x^2 = Tensor {\n", " data:\n", "[1.0, 4.0, 9.0, 16.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "d(x^2)/dx = Tensor {\n", " data:\n", "[2.0, 4.0, 6.0, 8.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Expected: [2, 4, 6, 8]\n" ] } ], "source": [ "// f(x) = x^2\n", "// f'(x) = 2x\n", "\n", "let device = ::Device::default();\n", "let x: Tensor = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\n", "let y = x.clone().powf_scalar(2.0);\n", "let result = y.clone().sum();\n", "\n", "let grads = result.backward();\n", "let x_grad = x.grad(&grads).unwrap();\n", "\n", "println!(\"x = {}\", x);\n", "println!(\"x^2 = {}\", y);\n", "println!(\"d(x^2)/dx = {}\", x_grad);\n", "\n", "// Verify: d(x^2)/dx should be [2, 4, 6, 8]\n", "println!(\"Expected: [2, 4, 6, 8]\");\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Chain Rule Example\n", "\n", "Let's verify the chain rule: f(g(x))' = f'(g(x)) * g'(x)\n", "\n", "Example: y = sin(x²), we want dy/dx\n", "\n", "Let u = x², y = sin(u)\n", "dy/du = cos(u), du/dx = 2x\n", "dy/dx = cos(x²) * 2x" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x = Tensor {\n", " data:\n", "[0.0, 1.0, 2.0, 3.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "y = sin(x^2) = Tensor {\n", " data:\n", "[0.0, 0.84147096, -0.7568025, 0.4121185],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "dy/dx = Tensor {\n", " data:\n", "[0.0, 1.0806046, -2.6145744, -5.4667816],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Expected (cos(x^2) * 2x): Tensor {\n", " data:\n", "[0.0, 1.0806046, -2.6145744, -5.4667816],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"autodiff\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// y = sin(x^2)\n", "// dy/dx = cos(x^2) * 2x\n", "\n", "let device = ::Device::default();\n", "let x: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device).require_grad();\n", "\n", "// Forward pass\n", "let x_squared = x.clone().powf_scalar(2.0);\n", "let y = x_squared.sin();\n", "let result = y.clone().sum();\n", "\n", "// Backward pass\n", "let grads = result.backward();\n", "let x_grad = x.grad(&grads).unwrap();\n", "\n", "println!(\"x = {}\", x);\n", "println!(\"y = sin(x^2) = {}\", y);\n", "println!(\"dy/dx = {}\", x_grad);\n", "\n", "// Verify manually: cos(x^2) * 2x\n", "let expected_grad = x.clone().powf_scalar(2.0).cos() * (x.clone() * 2.0);\n", "println!(\"Expected (cos(x^2) * 2x): {}\", expected_grad);\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Gradient Descent from Scratch\n", "\n", "Now let's implement the classic gradient descent algorithm to find the minimum of a function.\n", "\n", "We'll minimize: f(x) = (x - 3)²\n", "\n", "The minimum is at x = 3, where f(x) = 0" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting gradient descent to minimize (x - 3)^2\n", "Expected minimum: x = 3\n", "---\n", "Iteration 0: x = 0.0000, loss = 9.0000\n", "Iteration 1: x = 0.6000, loss = 5.7600\n", "Iteration 2: x = 1.0800, loss = 3.6864\n", "Iteration 3: x = 1.4640, loss = 2.3593\n", "Iteration 4: x = 1.7712, loss = 1.5099\n", "Iteration 5: x = 2.0170, loss = 0.9664\n", "Iteration 6: x = 2.2136, loss = 0.6185\n", "Iteration 7: x = 2.3709, loss = 0.3958\n", "Iteration 8: x = 2.4967, loss = 0.2533\n", "Iteration 9: x = 2.5973, loss = 0.1621\n", "Iteration 10: x = 2.6779, loss = 0.1038\n", "Iteration 11: x = 2.7423, loss = 0.0664\n", "Iteration 12: x = 2.7938, loss = 0.0425\n", "Iteration 13: x = 2.8351, loss = 0.0272\n", "Iteration 14: x = 2.8681, loss = 0.0174\n", "Iteration 15: x = 2.8944, loss = 0.0111\n", "Iteration 16: x = 2.9156, loss = 0.0071\n", "Iteration 17: x = 2.9324, loss = 0.0046\n", "Iteration 18: x = 2.9460, loss = 0.0029\n", "Iteration 19: x = 2.9568, loss = 0.0019\n", "---\n", "Final x = 2.9654\n" ] } ], "source": [ "// Target: minimize f(x) = (x - 3)^2\n", "// This has minimum at x = 3\n", "\n", "fn loss(x: &Tensor) -> Tensor {\n", " // f(x) = (x - 3)^2\n", " (x.clone() - 3.0).powf_scalar(2.0)\n", "}\n", "\n", "let device = ::Device::default();\n", "// Start from x = 0\n", "let mut x_val: f32 = 0.0;\n", "\n", "let learning_rate: f32 = 0.1;\n", "\n", "println!(\"Starting gradient descent to minimize (x - 3)^2\");\n", "println!(\"Expected minimum: x = 3\");\n", "println!(\"---\");\n", "\n", "for i in 0..20 {\n", " // Create a new tensor with current x value and require gradients\n", " let x = Tensor::::from_floats([x_val], &device).require_grad();\n", " \n", " // Forward pass\n", " let loss_value = loss(&x);\n", " \n", " // Get loss as f32 for printing\n", " let loss_scalar: f32 = loss_value.clone().into_scalar().elem::();\n", " \n", " println!(\"Iteration {}: x = {:.4}, loss = {:.4}\", i, x_val, loss_scalar);\n", "\n", " // Backward pass\n", " let grads = loss_value.backward();\n", " let grad = x.grad(&grads).unwrap();\n", " \n", " // Update: x = x - learning_rate * gradient\n", " let grad_val: f32 = grad.into_scalar().elem::();\n", " x_val = x_val - grad_val * learning_rate;\n", "}\n", "\n", "println!(\"---\");\n", "println!(\"Final x = {:.4}\", x_val);\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Linear Regression with Gradient Descent\n", "\n", "Let's use gradient descent to fit a simple linear regression model: y = wx + b\n", "\n", "We'll generate synthetic data where the true relationship is y = 2x + 1" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated 100 data points\n", "True relationship: y = 2x + 1\n", "First 5 x values: [0.0, 0.1, 0.2, 0.3, 0.4]\n", "First 5 y values: [0.87993187, 0.98804677, 1.5366085, 1.7324162, 1.653858]\n" ] } ], "source": [ "use burn::tensor::{Distribution, TensorData};\n", "\n", "let device = ::Device::default();\n", "// Generate synthetic data: y = 2x + 1 + noise\n", "let num_samples = 100;\n", "let x_data = TensorData::new((0..num_samples).map(|i| i as f32 / 10.0).collect(), [num_samples, 1]);\n", "// Generate noise using Burn's random tensor\n", "let noise = Tensor::::random([num_samples, 1], Distribution::Uniform(-0.25, 0.25), &device);\n", "\n", "let x = Tensor::::from(x_data);\n", "let y: Tensor = 2 * x.clone() + 1 + noise;\n", "\n", "println!(\"Generated {} data points\", num_samples);\n", "println!(\"True relationship: y = 2x + 1\");\n", "println!(\"First 5 x values: {}\", x.clone().slice([0..5, 0..1]));\n", "println!(\"First 5 y values: {}\", y.clone().slice([0..5, 0..1]));\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "vscode": { "languageId": "rust" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training linear regression with gradient descent...\n", "Initial w = 0.5000, b = 0.5000\n", "Epoch 0: loss = 81.7705, w = 1.5358, b = 0.6586\n", "Epoch 20: loss = 0.0365, w = 2.0384, b = 0.7594\n", "Epoch 40: loss = 0.0341, w = 2.0351, b = 0.7810\n", "Epoch 60: loss = 0.0321, w = 2.0322, b = 0.8006\n", "Epoch 80: loss = 0.0305, w = 2.0295, b = 0.8184\n", "---\n", "Final: w = 2.0272, b = 0.8336\n", "True: w = 2.0, b = 1.0\n" ] } ], "source": [ "// Initialize weights randomly\n", "let device = ::Device::default();\n", "let mut w_val: f32 = 0.5; // Start with reasonable initial values\n", "let mut b_val: f32 = 0.5;\n", "\n", "let learning_rate: f32 = 0.01;\n", "let num_epochs = 100;\n", "\n", "println!(\"Training linear regression with gradient descent...\");\n", "println!(\"Initial w = {:.4}, b = {:.4}\", w_val, b_val);\n", "\n", "for epoch in 0..num_epochs {\n", " // Create tensors with current parameter values\n", " let w = Tensor::::from_floats([[w_val]], &device).require_grad();\n", " let b = Tensor::::from_floats([[b_val]], &device).require_grad();\n", " \n", " // Forward pass: y_pred = w * x + b\n", " let y_pred = x.clone().matmul(w.clone()) + b.clone();\n", " \n", " // Compute loss: MSE = (1/n) * sum((y_pred - y)^2)\n", " let loss = (y_pred.clone() - y.clone()).powf_scalar(2.0).mean();\n", " \n", " // Backward pass\n", " let grads = loss.backward();\n", " let w_grad = w.grad(&grads).unwrap();\n", " let b_grad = b.grad(&grads).unwrap();\n", " \n", " // Update weights\n", " let w_grad_val: f32 = w_grad.into_scalar().elem::();\n", " let b_grad_val: f32 = b_grad.into_scalar().elem::();\n", " w_val = w_val - w_grad_val * learning_rate;\n", " b_val = b_val - b_grad_val * learning_rate;\n", " \n", " if epoch % 20 == 0 {\n", " let loss_val: f32 = loss.clone().into_scalar().elem::();\n", " println!(\"Epoch {:3}: loss = {:.4}, w = {:.4}, b = {:.4}\", epoch, loss_val, w_val, b_val);\n", " }\n", "}\n", "\n", "println!(\"---\");\n", "println!(\"Final: w = {:.4}, b = {:.4}\", w_val, b_val);\n", "println!(\"True: w = 2.0, b = 1.0\");\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "In this notebook, we covered:\n", "\n", "- **require_grad()**: Mark tensors for gradient tracking\n", "- **backward()**: Compute gradients automatically using reverse-mode autodiff\n", "- **grad()**: Retrieve computed gradients\n", "- **Gradient Descent**: Implemented from scratch to minimize a quadratic function\n", "- **Linear Regression**: Used gradient descent to fit a linear model to data\n", "\n", "These concepts are the foundation of neural network training in Burn!" ] } ], "metadata": { "kernelspec": { "display_name": "Rust", "language": "rust", "name": "rust" }, "language_info": { "codemirror_mode": "rust", "file_extension": ".rs", "mimetype": "text/rust", "name": "Rust", "pygment_lexer": "rust", "version": "" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/notebook/basic-tensor-op.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensor Operations in Burn\n", "\n", "This notebook demonstrates basic tensor operations in Burn, a deep learning framework written in Rust." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "\n", "// Dependency declarations for the notebook.\n", "// The syntax is similar to Cargo.toml. Just prefix with :dep\n", "\n", ":dep burn = {path = \"../../crates/burn\"}\n", ":dep burn-ndarray = {path = \"../../crates/burn-ndarray\"}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// Import packages\n", "use burn::prelude::*;\n", "use burn_ndarray::NdArray;\n", "\n", "// Type alias for the backend (using CPU/NdArray)\n", "type B = NdArray;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Tensor Creation" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Empty tensor shape: Shape { dims: [2, 3, 4] }\n", "Zeros tensor: Tensor {\n", " data:\n", "[[0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0]],\n", " shape: [3, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Ones tensor: Tensor {\n", " data:\n", "[[1.0, 1.0, 1.0, 1.0],\n", " [1.0, 1.0, 1.0, 1.0]],\n", " shape: [2, 4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Full tensor (7.0): Tensor {\n", " data:\n", "[[7.0, 7.0, 7.0],\n", " [7.0, 7.0, 7.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "let device = ::Device::default();\n", "\n", "// Create an empty tensor (uninitialized values)\n", "let empty: Tensor = Tensor::empty([2, 3, 4], &device);\n", "println!(\"Empty tensor shape: {:?}\", empty.shape());\n", "\n", "// Create a tensor filled with zeros\n", "let zeros: Tensor = Tensor::zeros([3, 3], &device);\n", "println!(\"Zeros tensor: {}\", zeros);\n", "\n", "// Create a tensor filled with ones\n", "let ones: Tensor = Tensor::ones([2, 4], &device);\n", "println!(\"Ones tensor: {}\", ones);\n", "\n", "// Create a tensor filled with a specific value\n", "let full: Tensor = Tensor::full([2, 3], 7.0, &device);\n", "println!(\"Full tensor (7.0): {}\", full);" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From slice:\n", "Tensor {\n", " data:\n", "[[1.0, 2.0, 3.0],\n", " [4.0, 5.0, 6.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Random tensor: Tensor {\n", " data:\n", "[0.32371014, 0.41100568, 0.94457513, 0.8408601, 0.42262083],\n", " shape: [5],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Normal distribution: Tensor {\n", " data:\n", "[-0.22402725, 1.8367178, -1.1049407, -0.6302627, 1.1106112],\n", " shape: [5],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Uniform [0, 10): Tensor {\n", " data:\n", "[8.110331, 7.335061, 9.858947, 6.0834813, 3.6619747],\n", " shape: [5],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Create a tensor from a slice of values\n", "let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];\n", "let from_slice = Tensor::::from_floats(data, &device).reshape([2, 3]);\n", "println!(\"From slice:\\n{}\", from_slice);\n", "\n", "// Create a random tensor\n", "use burn::tensor::Distribution;\n", "let random: Tensor = Tensor::random([5], Distribution::Default, &device);\n", "println!(\"Random tensor: {}\", random);\n", "\n", "// Create a tensor with normal distribution\n", "let normal: Tensor = Tensor::random([5], Distribution::Normal(0.0, 1.0), &device);\n", "println!(\"Normal distribution: {}\", normal);\n", "\n", "// Create a tensor with uniform distribution in range [0, 10)\n", "let uniform: Tensor = Tensor::random([5], Distribution::Uniform(0.0, 10.0), &device);\n", "println!(\"Uniform [0, 10): {}\", uniform);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Shape Operations" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original (2x3):\n", "Tensor {\n", " data:\n", "[[1.0, 2.0, 3.0],\n", " [4.0, 5.0, 6.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Reshaped (1x2x3): Tensor {\n", " data:\n", "[[[1.0, 2.0, 3.0],\n", " [4.0, 5.0, 6.0]]],\n", " shape: [1, 2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Flattened: Tensor {\n", " data:\n", "[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],\n", " shape: [6],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Reshape tensor - change the dimensions without changing the data\n", "let tensor = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\n", "println!(\"Original (2x3):\\n{}\", tensor);\n", "\n", "let reshaped: Tensor = tensor.clone().reshape([1, 2, 3]);\n", "println!(\"Reshaped (1x2x3): {}\", reshaped);\n", "\n", "// Flatten - reshape to 1D\n", "let flat: Tensor = tensor.flatten(0, 1);\n", "println!(\"Flattened: {}\", flat);" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original:\n", "Tensor {\n", " data:\n", "[[1.0, 2.0],\n", " [3.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Transposed:\n", "Tensor {\n", " data:\n", "[[1.0, 3.0],\n", " [2.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Using .t():\n", "Tensor {\n", " data:\n", "[[1.0, 3.0],\n", " [2.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Transpose - swap dimensions\n", "let tensor = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n", "println!(\"Original:\\n{}\", tensor);\n", "\n", "let transposed = tensor.clone().transpose();\n", "println!(\"Transposed:\\n{}\", transposed);\n", "\n", "// Also .t() works for 2D tensors\n", "let t = tensor.t();\n", "println!(\"Using .t():\\n{}\", t);" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before squeeze [1,1,2]: shape = Shape { dims: [1, 1, 2] }\n", "After squeeze: shape = Shape { dims: [2] }\n", "Before unsqueeze [2,2]: shape = Shape { dims: [2, 2] }\n", "After unsqueeze: shape = Shape { dims: [1, 2, 2] }\n" ] } ], "source": [ "// Squeeze - remove dimensions of size 1\n", "let tensor = Tensor::::from_floats([1.0, 2.0], &device).reshape([1, 1, 2]);\n", "println!(\"Before squeeze [1,1,2]: shape = {:?}\", tensor.shape());\n", "\n", "let squeezed = tensor.squeeze::<1>();\n", "println!(\"After squeeze: shape = {:?}\", squeezed.shape());\n", "\n", "// Unsqueeze - add a dimension of size 1 at specified position\n", "let tensor = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n", "println!(\"Before unsqueeze [2,2]: shape = {:?}\", tensor.shape());\n", "\n", "let unsqueezed = tensor.unsqueeze::<3>();\n", "println!(\"After unsqueeze: shape = {:?}\", unsqueezed.shape());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Indexing and Slicing" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original tensor:\n", "Tensor {\n", " data:\n", "[[1.0, 2.0, 3.0, 4.0],\n", " [5.0, 6.0, 7.0, 8.0],\n", " [9.0, 10.0, 11.0, 12.0]],\n", " shape: [3, 4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Create a tensor for indexing examples\n", "let tensor = Tensor::::from_floats(\n", " [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],\n", "&device\n", ").reshape([3, 4]);\n", "println!(\"Original tensor:\\n{}\", tensor);" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sliced [1..3, 1..4]:\n", "Tensor {\n", " data:\n", "[[6.0, 7.0, 8.0],\n", " [10.0, 11.0, 12.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Row 1: Tensor {\n", " data:\n", "[[5.0, 6.0, 7.0, 8.0]],\n", " shape: [1, 4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Column 2: Tensor {\n", " data:\n", "[[3.0],\n", " [7.0],\n", " [11.0]],\n", " shape: [3, 1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Slice tensor - select a portion using ranges\n", "// Get rows 1-2 (index 1 to end), columns 1-3 (index 1 to 3)\n", "let sliced = tensor.clone().slice([1..3, 1..4]);\n", "println!(\"Sliced [1..3, 1..4]:\\n{}\", sliced);\n", "\n", "// Get single row\n", "let row = tensor.clone().slice([1..2, 0..4]);\n", "println!(\"Row 1: {}\", row);\n", "\n", "// Get single column\n", "let col = tensor.slice([0..3, 2..3]);\n", "println!(\"Column 2: {}\", col);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Basic Math Operations" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = Tensor {\n", " data:\n", "[[1.0, 2.0],\n", " [3.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "b = Tensor {\n", " data:\n", "[[5.0, 6.0],\n", " [7.0, 8.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a + b = Tensor {\n", " data:\n", "[[6.0, 8.0],\n", " [10.0, 12.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a - b = Tensor {\n", " data:\n", "[[-4.0, -4.0],\n", " [-4.0, -4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a * b = Tensor {\n", " data:\n", "[[5.0, 12.0],\n", " [21.0, 32.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a / b = Tensor {\n", " data:\n", "[[0.2, 0.33333334],\n", " [0.42857143, 0.5]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "let a = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n", "let b = Tensor::::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);\n", "\n", "println!(\"a = {}\", a);\n", "println!(\"b = {}\", b);\n", "\n", "// Addition\n", "let c = a.clone() + b.clone();\n", "println!(\"a + b = {}\", c);\n", "\n", "// Subtraction\n", "let c = a.clone() - b.clone();\n", "println!(\"a - b = {}\", c);\n", "\n", "// Multiplication (element-wise)\n", "let c = a.clone() * b.clone();\n", "println!(\"a * b = {}\", c);\n", "\n", "// Division (element-wise)\n", "let c = a.clone() / b.clone();\n", "println!(\"a / b = {}\", c);" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = Tensor {\n", " data:\n", "[[1.0, 2.0],\n", " [3.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a + 10 = Tensor {\n", " data:\n", "[[11.0, 12.0],\n", " [13.0, 14.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a * 2 = Tensor {\n", " data:\n", "[[2.0, 4.0],\n", " [6.0, 8.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Scalar operations\n", "let a = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n", "\n", "println!(\"a = {}\", a);\n", "\n", "// Add scalar\n", "let c = a.clone() + 10.0;\n", "println!(\"a + 10 = {}\", c);\n", "\n", "// Multiply scalar\n", "let c = a.clone() * 2.0;\n", "println!(\"a * 2 = {}\", c);" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = Tensor {\n", " data:\n", "[[1.0, 2.0],\n", " [3.0, 4.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "b = Tensor {\n", " data:\n", "[[5.0, 6.0],\n", " [7.0, 8.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a @ b (matmul) = Tensor {\n", " data:\n", "[[19.0, 22.0],\n", " [43.0, 50.0]],\n", " shape: [2, 2],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Matrix multiplication\n", "let a = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\n", "let b = Tensor::::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);\n", "\n", "println!(\"a = {}\", a);\n", "println!(\"b = {}\", b);\n", "\n", "let result = a.matmul(b);\n", "println!(\"a @ b (matmul) = {}\", result);\n", "\n", "// Verify (rows of a · columns of b): row1 [1,2] · col1 [5,7] = 1*5+2*7 = 19, row1 [1,2] · col2 [6,8] = 1*6+2*8 = 22\n", "// row2 [3,4] · col1 [5,7] = 3*5+4*7 = 43, row2 [3,4] · col2 [6,8] = 3*6+4*8 = 50" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Element-wise Math Functions" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = Tensor {\n", " data:\n", "[0.0, 1.0, 2.0],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "exp(a) = Tensor {\n", " data:\n", "[1.0, 2.7182817, 7.389056],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "log(a + 1) = Tensor {\n", " data:\n", "[0.0, 0.6931472, 1.0986123],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a.powf(2) = Tensor {\n", " data:\n", "[0.0, 1.0, 4.0],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a.powf(0.5) = Tensor {\n", " data:\n", "[0.0, 1.0, 1.4142135],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "let a: Tensor = Tensor::from_floats([0.0, 1.0, 2.0], &device);\n", "\n", "println!(\"a = {}\", a);\n", "\n", "// Exponential\n", "println!(\"exp(a) = {}\", a.clone().exp());\n", "\n", "// Natural logarithm\n", "println!(\"log(a + 1) = {}\", (a.clone() + 1.0).log());\n", "\n", "// Power\n", "println!(\"a.powf(2) = {}\", a.clone().powf_scalar(2.0));\n", "println!(\"a.powf(0.5) = {}\", a.clone().powf_scalar(0.5));" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "angles = Tensor {\n", " data:\n", "[0.0, 0.7853982, 1.5707964],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "sin(angles) = Tensor {\n", " data:\n", "[0.0, 0.70710677, 1.0],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "cos(angles) = Tensor {\n", " data:\n", "[1.0, 0.70710677, -4.371139e-8],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "tan(angles) = Tensor {\n", " data:\n", "[0.0, 1.0, -22877332.0],\n", " shape: [3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Trigonometric functions\n", "let angles: Tensor = Tensor::from_floats([0.0, std::f32::consts::PI / 4.0, std::f32::consts::PI / 2.0], &device);\n", "\n", "println!(\"angles = {}\", angles);\n", "println!(\"sin(angles) = {}\", angles.clone().sin());\n", "println!(\"cos(angles) = {}\", angles.clone().cos());\n", "println!(\"tan(angles) = {}\", angles.clone().tan());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Reduction Operations" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor:\n", "Tensor {\n", " data:\n", "[[1.0, 2.0, 3.0],\n", " [4.0, 5.0, 6.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Sum: Tensor {\n", " data:\n", "[21.0],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Mean: Tensor {\n", " data:\n", "[3.5],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Product: Tensor {\n", " data:\n", "[720.0],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Max: Tensor {\n", " data:\n", "[6.0],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Min: Tensor {\n", " data:\n", "[1.0],\n", " shape: [1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "let tensor = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\n", "println!(\"Tensor:\\n{}\", tensor);\n", "\n", "// Sum all elements\n", "println!(\"Sum: {}\", tensor.clone().sum());\n", "\n", "// Mean of all elements\n", "println!(\"Mean: {}\", tensor.clone().mean());\n", "\n", "// Product of all elements\n", "println!(\"Product: {}\", tensor.clone().prod());\n", "\n", "// Maximum and minimum\n", "println!(\"Max: {}\", tensor.clone().max());\n", "println!(\"Min: {}\", tensor.clone().min());" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor:\n", "Tensor {\n", " data:\n", "[[1.0, 2.0, 3.0],\n", " [4.0, 5.0, 6.0]],\n", " shape: [2, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Sum dim 0: Tensor {\n", " data:\n", "[[5.0, 7.0, 9.0]],\n", " shape: [1, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Sum dim 1: Tensor {\n", " data:\n", "[[6.0],\n", " [15.0]],\n", " shape: [2, 1],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Mean dim 0: Tensor {\n", " data:\n", "[[2.5, 3.5, 4.5]],\n", " shape: [1, 3],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Reduce along specific dimensions\n", "let tensor = Tensor::::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\n", "println!(\"Tensor:\\n{}\", tensor);\n", "\n", "// Sum along dimension 0 (columns)\n", "println!(\"Sum dim 0: {}\", tensor.clone().sum_dim(0));\n", "\n", "// Sum along dimension 1 (rows)\n", "println!(\"Sum dim 1: {}\", tensor.clone().sum_dim(1));\n", "\n", "// Mean along dimension 0\n", "println!(\"Mean dim 0: {}\", tensor.clone().mean_dim(0));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Comparison and Selection" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = Tensor {\n", " data:\n", "[1.0, 5.0, 3.0, 8.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "b = Tensor {\n", " data:\n", "[4.0, 2.0, 6.0, 7.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "a > b: Tensor {\n", " data:\n", "[false, true, false, true],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Bool\",\n", " dtype: \"bool\",\n", "}\n", "a < b: Tensor {\n", " data:\n", "[true, false, true, false],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Bool\",\n", " dtype: \"bool\",\n", "}\n", "a == b: Tensor {\n", " data:\n", "[false, false, false, false],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Bool\",\n", " dtype: \"bool\",\n", "}\n" ] } ], "source": [ "let a: Tensor = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);\n", "let b: Tensor = Tensor::from_floats([4.0, 2.0, 6.0, 7.0], &device);\n", "\n", "println!(\"a = {}\", a);\n", "println!(\"b = {}\", b);\n", "\n", "// Element-wise comparison returns a boolean tensor\n", "let greater = a.clone().greater(b.clone());\n", "println!(\"a > b: {}\", greater);\n", "\n", "let less = a.clone().lower(b.clone());\n", "println!(\"a < b: {}\", less);\n", "\n", "let equal = a.clone().equal(b.clone());\n", "println!(\"a == b: {}\", equal);" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original: Tensor {\n", " data:\n", "[1.0, 5.0, 3.0, 8.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Where > 4, replace with 0: Tensor {\n", " data:\n", "[1.0, 0.0, 3.0, 0.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n", "Where > 4, replace with -1: Tensor {\n", " data:\n", "[1.0, -1.0, 3.0, -1.0],\n", " shape: [4],\n", " device: Cpu,\n", " backend: \"ndarray\",\n", " kind: \"Float\",\n", " dtype: \"f32\",\n", "}\n" ] } ], "source": [ "// Conditional selection\n", "let a: Tensor = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);\n", "\n", "// mask_where: where condition is true, use replacement value, else keep original value\n", "let condition = a.clone().greater_elem(4.0);\n", "let result = a.clone().mask_where(condition, Tensor::zeros([4], &device));\n", "println!(\"Original: {}\", a);\n", "println!(\"Where > 4, replace with 0: {}\", result);\n", "\n", "// mask_fill: simpler - just replace values matching condition\n", "let result = a.clone().mask_fill(a.clone().greater_elem(4.0), -1.0);\n", "println!(\"Where > 4, replace with -1: {}\", result);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "In this notebook, we covered:\n", "- **Tensor Creation**: empty, zeros, ones, full, from_floats, random\n", "- **Shape Operations**: reshape, transpose, flatten, squeeze, unsqueeze\n", "- **Indexing and Slicing**: slice operation with ranges\n", "- **Math Operations**: add, sub, mul, div, matmul\n", "- **Element-wise Functions**: exp, log, powf_scalar, sin, cos, tan\n", "- **Reduction Operations**: sum, mean, prod, max, min\n", "- **Comparison**: greater, lower, equal, mask_where, mask_fill\n" ] } ], "metadata": { "kernelspec": { "display_name": "Rust", "language": "rust", "name": "rust" }, "language_info": { "codemirror_mode": "rust", "file_extension": ".rs", "mimetype": "text/rust", "name": "rust", "pygment_lexer": "rust", "version": "" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: examples/server/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] edition.workspace = true license.workspace = true name = "server" publish = false version.workspace = true [lints] workspace = true [features] default = ["webgpu"] cuda = ["burn/cuda"] webgpu = ["burn/webgpu"] vulkan = ["burn/vulkan"] ndarray = ["burn/ndarray"] [dependencies] cfg-if = { workspace = true } burn = { path = "../../crates/burn", version = "=0.21.0-pre.2", features = ["server"] } cubecl = { workspace = true } ================================================ FILE: examples/server/cubecl.toml ================================================ [profiling] logger = { log = "info", level = "disabled" } [autotune] logger = { log = "info", level = "disabled" } # logger = { log = "info", level = "full" } [compilation] logger = { log = "info", level = "disabled" } # logger = { log = "info", level = "full" } cache = "target" ================================================ FILE: examples/server/examples/server.rs ================================================ fn main() { server::start(); } ================================================ FILE: examples/server/src/lib.rs ================================================ #![recursion_limit = "141"] pub fn start() { let port = std::env::var("REMOTE_BACKEND_PORT") .map(|port| match port.parse::() { Ok(val) => val, Err(err) => panic!("Invalid port, got {port} with error {err}"), }) .unwrap_or(3000); cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start_websocket::(Default::default(), port); } else if #[cfg(feature = "cuda")]{ burn::server::start_websocket::(Default::default(), port); } else if #[cfg(feature = "webgpu")] { burn::server::start_websocket::(Default::default(), port); } else if #[cfg(feature = "vulkan")] { burn::server::start_websocket::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); } } } ================================================ FILE: examples/simple-regression/Cargo.toml ================================================ [package] authors = ["aasheeshsingh "] edition.workspace = true license.workspace = true name = "simple-regression" publish = false version.workspace = true [lints] workspace = true [features] default = ["burn/dataset", "burn/sqlite-bundled"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] remote = ["burn/remote"] [dependencies] burn = {path = "../../crates/burn", features=["train"]} # Serialization log = {workspace = true} serde = {workspace = true, features = ["std", "derive"]} # Displaying results textplots = "0.8.7" rgb = "0.8.52" ================================================ FILE: examples/simple-regression/README.md ================================================ # Regression The example shows you how to: - Define a custom dataset for regression problems. We implement the [California Housing Dataset](https://huggingface.co/datasets/gvlassis/california_housing) from HuggingFace hub. The dataset is also available as part of toy regression datasets in sklearn[datasets](https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset). - Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature scaling. - Define a Simple NN model for regression using Burn Modules. > **Note** > This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index) > library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/) > installed on your computer. The example can be run like so: ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. echo "Using ndarray backend" cargo run --example regression --release --features ndarray # CPU NdArray Backend - f32 - single thread cargo run --example regression --release --features ndarray-blas-openblas # CPU NdArray Backend - f32 - blas with openblas cargo run --example regression --release --features ndarray-blas-netlib # CPU NdArray Backend - f32 - blas with netlib echo "Using tch backend" export TORCH_CUDA_VERSION=cu128 # Set the cuda version cargo run --example regression --release --features tch-gpu # GPU Tch Backend - f32 cargo run --example regression --release --features tch-cpu # CPU Tch Backend - f32 echo "Using wgpu backend" cargo run --example regression --release --features wgpu ``` ================================================ FILE: examples/simple-regression/examples/regression.rs ================================================ use burn::{backend::Autodiff, tensor::backend::Backend}; use simple_regression::{inference, training}; static ARTIFACT_DIR: &str = "/tmp/burn-example-regression"; #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::ndarray::{NdArray, NdArrayDevice}; pub fn run() { let device = NdArrayDevice::Cpu; super::run::(device.clone()); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; super::run::(device); } } #[cfg(feature = "wgpu")] mod wgpu { use burn::backend::wgpu::{Wgpu, WgpuDevice}; pub fn run() { let device = WgpuDevice::default(); super::run::(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use simple_regression::training; pub fn run() { let device = LibTorchDevice::Cpu; super::run::(device); } } #[cfg(feature = "remote")] mod remote { use burn::backend::{RemoteBackend, remote::RemoteDevice}; pub fn run() { let device = RemoteDevice::default(); super::run::(device); } } /// Train a regression model and predict results on a number of samples. pub fn run(device: B::Device) { training::run::>(ARTIFACT_DIR, device.clone()); inference::infer::(ARTIFACT_DIR, device) } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "remote")] remote::run(); } ================================================ FILE: examples/simple-regression/src/dataset.rs ================================================ use burn::{ data::{ dataloader::batcher::Batcher, dataset::{Dataset, HuggingfaceDatasetLoader, SqliteDataset}, }, prelude::*, }; pub const NUM_FEATURES: usize = 8; // Pre-computed statistics for the housing dataset features const FEATURES_MIN: [f32; NUM_FEATURES] = [0.4999, 1., 0.8461, 0.375, 3., 0.6923, 32.54, -124.35]; const FEATURES_MAX: [f32; NUM_FEATURES] = [ 15., 52., 141.9091, 34.0667, 35682., 1243.3333, 41.95, -114.31, ]; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct HousingDistrictItem { /// Median income #[serde(rename = "MedInc")] pub median_income: f32, /// Median house age #[serde(rename = "HouseAge")] pub house_age: f32, /// Average number of rooms per household #[serde(rename = "AveRooms")] pub avg_rooms: f32, /// Average number of bedrooms per household #[serde(rename = "AveBedrms")] pub avg_bedrooms: f32, /// Block group population #[serde(rename = "Population")] pub population: f32, /// Average number of household members #[serde(rename = "AveOccup")] pub avg_occupancy: f32, /// Block group latitude #[serde(rename = "Latitude")] pub latitude: f32, /// Block group longitude #[serde(rename = "Longitude")] pub longitude: f32, /// Median house value (in 100 000$) #[serde(rename = "MedHouseVal")] pub median_house_value: f32, } pub struct HousingDataset { dataset: SqliteDataset, } impl Dataset for HousingDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index) } fn len(&self) -> usize { self.dataset.len() } } impl HousingDataset { pub fn train() -> Self { Self::new("train") } pub fn validation() -> Self { Self::new("validation") } pub fn test() -> Self { Self::new("test") } pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("gvlassis/california_housing") .dataset(split) .unwrap(); Self { dataset } } } /// Normalizer for the housing dataset. #[derive(Clone, Debug)] pub struct Normalizer { pub min: Tensor, pub max: Tensor, } impl Normalizer { /// Creates a new normalizer. pub fn new(device: &B::Device, min: &[f32], max: &[f32]) -> Self { let min = Tensor::::from_floats(min, device).unsqueeze(); let max = Tensor::::from_floats(max, device).unsqueeze(); Self { min, max } } /// Normalizes the input image according to the housing dataset min/max. pub fn normalize(&self, input: Tensor) -> Tensor { (input - self.min.clone()) / (self.max.clone() - self.min.clone()) } /// Returns a new normalizer on the given device. pub fn to_device(&self, device: &B::Device) -> Self { Self { min: self.min.clone().to_device(device), max: self.max.clone().to_device(device), } } } #[derive(Clone, Debug)] pub struct HousingBatcher { normalizer: Normalizer, } #[derive(Clone, Debug)] pub struct HousingBatch { pub inputs: Tensor, pub targets: Tensor, } impl HousingBatcher { pub fn new(device: B::Device) -> Self { Self { normalizer: Normalizer::new(&device, &FEATURES_MIN, &FEATURES_MAX), } } } impl Batcher> for HousingBatcher { fn batch(&self, items: Vec, device: &B::Device) -> HousingBatch { let mut inputs: Vec> = Vec::new(); for item in items.iter() { let input_tensor = Tensor::::from_floats( [ item.median_income, item.house_age, item.avg_rooms, item.avg_bedrooms, item.population, item.avg_occupancy, item.latitude, item.longitude, ], device, ); inputs.push(input_tensor.unsqueeze()); } let inputs = Tensor::cat(inputs, 0); let inputs = self.normalizer.to_device(device).normalize(inputs); let targets = items .iter() .map(|item| Tensor::::from_floats([item.median_house_value], device)) .collect(); let targets = Tensor::cat(targets, 0); HousingBatch { inputs, targets } } } ================================================ FILE: examples/simple-regression/src/inference.rs ================================================ use burn::{ data::{dataloader::batcher::Batcher, dataset::Dataset}, module::Module, record::{NoStdTrainingRecorder, Recorder}, tensor::backend::Backend, }; use rgb::RGB8; use textplots::{Chart, ColorPlot, Shape}; use crate::{ dataset::{HousingBatcher, HousingDataset, HousingDistrictItem}, model::{RegressionModelConfig, RegressionModelRecord}, }; pub fn infer(artifact_dir: &str, device: B::Device) { let record: RegressionModelRecord = NoStdTrainingRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) .expect("Trained model should exist; run train first"); let model = RegressionModelConfig::new() .init(&device) .load_record(record); // Use a sample of 1000 items from the test split let dataset = HousingDataset::test(); let items: Vec = dataset.iter().take(1000).collect(); let batcher = HousingBatcher::new(device.clone()); let batch = batcher.batch(items.clone(), &device); let predicted = model.forward(batch.inputs); let targets = batch.targets; // Display the predicted vs expected values let predicted = predicted.squeeze_dim::<1>(1).into_data(); let expected = targets.into_data(); let points = predicted .iter::() .zip(expected.iter::()) .collect::>(); println!("Predicted vs. Expected Median House Value (in 100,000$)"); Chart::new_with_y_range(120, 60, 0., 5., 0., 5.) .linecolorplot( &Shape::Points(&points), RGB8 { r: 255, g: 85, b: 85, }, ) .display(); // Print a single numeric value as an example println!("Predicted {} Expected {}", points[0].0, points[0].1); } ================================================ FILE: examples/simple-regression/src/lib.rs ================================================ pub mod dataset; pub mod inference; pub mod model; pub mod training; ================================================ FILE: examples/simple-regression/src/model.rs ================================================ use crate::dataset::{HousingBatch, NUM_FEATURES}; use burn::{ nn::{ Linear, LinearConfig, Relu, loss::{MseLoss, Reduction::Mean}, }, prelude::*, tensor::backend::AutodiffBackend, train::{InferenceStep, RegressionOutput, TrainOutput, TrainStep}, }; #[derive(Module, Debug)] pub struct RegressionModel { input_layer: Linear, output_layer: Linear, activation: Relu, } #[derive(Config, Debug)] pub struct RegressionModelConfig { #[config(default = 64)] pub hidden_size: usize, } impl RegressionModelConfig { pub fn init(&self, device: &B::Device) -> RegressionModel { let input_layer = LinearConfig::new(NUM_FEATURES, self.hidden_size) .with_bias(true) .init(device); let output_layer = LinearConfig::new(self.hidden_size, 1) .with_bias(true) .init(device); RegressionModel { input_layer, output_layer, activation: Relu::new(), } } } impl RegressionModel { pub fn forward(&self, input: Tensor) -> Tensor { let x = self.input_layer.forward(input); let x = self.activation.forward(x); self.output_layer.forward(x) } pub fn forward_step(&self, item: HousingBatch) -> RegressionOutput { let targets: Tensor = item.targets.unsqueeze_dim(1); let output: Tensor = self.forward(item.inputs); let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean); RegressionOutput { loss, output, targets, } } } impl TrainStep for RegressionModel { type Input = HousingBatch; type Output = RegressionOutput; fn step(&self, item: HousingBatch) -> TrainOutput> { let item = self.forward_step(item); TrainOutput::new(self, item.loss.backward(), item) } } impl InferenceStep for RegressionModel { type Input = HousingBatch; type Output = RegressionOutput; fn step(&self, item: HousingBatch) -> RegressionOutput { self.forward_step(item) } } ================================================ FILE: examples/simple-regression/src/training.rs ================================================ use crate::dataset::{HousingBatcher, HousingDataset}; use crate::model::RegressionModelConfig; use burn::optim::AdamConfig; use burn::train::{Learner, SupervisedTraining}; use burn::{ data::{dataloader::DataLoaderBuilder, dataset::Dataset}, prelude::*, record::{CompactRecorder, NoStdTrainingRecorder}, tensor::backend::AutodiffBackend, train::metric::LossMetric, }; #[derive(Config, Debug)] pub struct ExpConfig { #[config(default = 100)] pub num_epochs: usize, #[config(default = 2)] pub num_workers: usize, #[config(default = 1337)] pub seed: u64, pub optimizer: AdamConfig, #[config(default = 256)] pub batch_size: usize, } fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts before to get an accurate learner summary std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } pub fn run(artifact_dir: &str, device: B::Device) { create_artifact_dir(artifact_dir); // Config let optimizer = AdamConfig::new(); let config = ExpConfig::new(optimizer); let model = RegressionModelConfig::new().init(&device); B::seed(&device, config.seed); // Define train/valid datasets and dataloaders let train_dataset = HousingDataset::train(); let valid_dataset = HousingDataset::validation(); println!("Train Dataset Size: {}", train_dataset.len()); println!("Valid Dataset Size: {}", valid_dataset.len()); let batcher_train = HousingBatcher::::new(device.clone()); let batcher_test = HousingBatcher::::new(device.clone()); let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(train_dataset); let dataloader_test = DataLoaderBuilder::new(batcher_test) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(valid_dataset); // Model let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .num_epochs(config.num_epochs) .summary(); let result = training.launch(Learner::new(model, config.optimizer.init(), 1e-3)); config .save(format!("{artifact_dir}/config.json").as_str()) .unwrap(); result .model .save_file( format!("{artifact_dir}/model"), &NoStdTrainingRecorder::new(), ) .expect("Failed to save trained model"); } ================================================ FILE: examples/text-classification/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] edition.workspace = true license.workspace = true name = "text-classification" publish = false version.workspace = true [lints] workspace = true [features] default = [] f16 = [] flex32 = [] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] vulkan = ["burn/vulkan"] remote = ["burn/remote"] cuda = ["burn/cuda"] rocm = ["burn/rocm"] metal = ["burn/metal"] ddp = ["burn/collective"] [dependencies] # Burn burn = { path = "../../crates/burn", features = [ "train", "tui", "sqlite-bundled", "metrics", "ndarray", "autotune", # "fusion", "std", ], default-features = false } log = { workspace = true } # Tokenizer tokenizers = { version = "0.22.2", default-features = false, features = [ "onig", "http", ] } # Utils derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } ================================================ FILE: examples/text-classification/README.md ================================================ # Text Classification This project provides an example implementation for training and inferencing text classification models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library. > **Note** > This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index) > library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/) > installed on your computer. ## Dataset Details - AG News: The AG News dataset is a collection of news articles from more than 2000 news sources. This library helps you load and process this dataset, categorizing articles into four classes: "World", "Sports", "Business", and "Technology". - DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from Wikipedia. This library helps you load and process this dataset, categorizing articles into 14 classes including "Company", "Educational Institution", "Artist", among others. # Usage ## Torch GPU backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device. export TORCH_CUDA_VERSION=cu128 # Set the cuda version (CUDA users) # AG News cargo run --example ag-news-train --release --features tch-gpu # Train on the ag news dataset cargo run --example ag-news-infer --release --features tch-gpu # Run inference on the ag news dataset # DbPedia cargo run --example db-pedia-train --release --features tch-gpu # Train on the db pedia dataset cargo run --example db-pedia-infer --release --features tch-gpu # Run inference db pedia dataset ``` ## Torch CPU backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # AG News cargo run --example ag-news-train --release --features tch-cpu # Train on the ag news dataset cargo run --example ag-news-infer --release --features tch-cpu # Run inference on the ag news dataset # DbPedia cargo run --example db-pedia-train --release --features tch-cpu # Train on the db pedia dataset cargo run --example db-pedia-infer --release --features tch-cpu # Run inference db pedia dataset ``` ## ndarray backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques # AG News cargo run --example ag-news-train --release --features ndarray # Train on the ag news dataset cargo run --example ag-news-infer --release --features ndarray # Run inference on the ag news dataset # DbPedia cargo run --example db-pedia-train --release --features ndarray # Train on the db pedia dataset cargo run --example db-pedia-infer --release --features ndarray # Run inference db pedia dataset ``` ## WGPU backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # AG News cargo run --example ag-news-train --release --features wgpu # Train on the ag news dataset cargo run --example ag-news-infer --release --features wgpu # Run inference on the ag news dataset # DbPedia cargo run --example db-pedia-train --release --features wgpu # Train on the db pedia dataset cargo run --example db-pedia-infer --release --features wgpu # Run inference db pedia dataset ``` ## CUDA backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # Add the f16 feature to run in f16. # AG News cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset ``` ## Metal backend ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. # Add the f16 feature to run in f16. # AG News cargo run --example ag-news-train --release --features metal # Train on the ag news dataset cargo run --example ag-news-infer --release --features metal # Run inference on the ag news dataset ``` ================================================ FILE: examples/text-classification/cubecl.toml ================================================ [profiling] logger = { log = "info", level = "disabled" } [autotune] level = "balanced" cache = "target" logger = { info = true, level = "full" } [compilation] logger = { level = "disabled" } cache = "target" [memory] logger = { level = "disabled", file = "/tmp/memory.log" } persistent_memory = "enabled" [streaming] max_streams = 8 ================================================ FILE: examples/text-classification/examples/ag-news-infer.rs ================================================ #![recursion_limit = "256"] use burn::tensor::backend::Backend; use text_classification::AgNewsDataset; #[cfg(not(feature = "f16"))] #[allow(dead_code)] type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-ag-news", // Samples from the test dataset, but you are free to test with your own text. vec![ "Jays power up to take finale Contrary to popular belief, the power never really \ snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \ took some extra time for the batting orders to provide some extra wattage." .to_string(), "Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \ man to death and 14 others to prison terms for a series of attacks and terrorist \ plots in 2002, including the bombing of a French oil tanker." .to_string(), "IBM puts grids to work at U.S. Open IBM will put a collection of its On \ Demand-related products and technologies to this test next week at the U.S. Open \ tennis championships, implementing a grid-based infrastructure capable of running \ multiple workloads including two not associated with the tournament." .to_string(), ], ); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{ElemType, launch}; pub fn run() { launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use crate::{ElemType, launch}; use burn::backend::libtorch::{LibTorch, LibTorchDevice}; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use crate::{ElemType, launch}; use burn::backend::libtorch::{LibTorch, LibTorchDevice}; pub fn run() { launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::{ElemType, launch}; use burn::backend::wgpu::{Wgpu, WgpuDevice}; pub fn run() { launch::>(WgpuDevice::default()); } } #[cfg(feature = "metal")] mod metal { use crate::{ElemType, launch}; use burn::backend::metal::{Metal, MetalDevice}; pub fn run() { launch::>(MetalDevice::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::{ElemType, launch}; use burn::backend::{Cuda, cuda::CudaDevice}; pub fn run() { launch::>(CudaDevice::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); } ================================================ FILE: examples/text-classification/examples/ag-news-train.rs ================================================ #![recursion_limit = "256"] use burn::{ nn::transformer::TransformerEncoderConfig, optim::{AdamConfig, decay::WeightDecayConfig}, prelude::*, tensor::backend::{AutodiffBackend, DeviceId}, }; use text_classification::{AgNewsDataset, training::ExperimentConfig}; #[cfg(not(any(feature = "f16", feature = "flex32")))] #[allow(unused)] type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; #[cfg(feature = "flex32")] type ElemType = burn::tensor::flex32; pub fn launch_multi() { let type_id = 0; let num_devices = B::Device::device_count(type_id); let devices = (0..num_devices) .map(|i| B::Device::from_id(DeviceId::new(type_id, i as u32))) .collect(); launch::(devices) } pub fn launch(devices: Vec) { let config = ExperimentConfig::new( TransformerEncoderConfig::new(256, 1024, 8, 4) .with_norm_first(true) .with_quiet_softmax(true), AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), ); text_classification::training::train::( devices, AgNewsDataset::train(), AgNewsDataset::test(), config, "/tmp/text-classification-ag-news", ); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::{ Autodiff, ndarray::{NdArray, NdArrayDevice}, }; use crate::{ElemType, launch}; pub fn run() { launch::>>(vec![NdArrayDevice::Cpu]); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use crate::{ElemType, launch}; use burn::backend::autodiff::checkpoint::strategy::BalancedCheckpointing; use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>>(vec![device]); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::{ElemType, launch}; pub fn run() { launch::>>(vec![LibTorchDevice::Cpu]); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::{ElemType, launch}; use burn::backend::{Autodiff, Wgpu}; pub fn run() { launch::>>(vec![Default::default()]); } } #[cfg(feature = "vulkan")] mod vulkan { use crate::{ElemType, launch}; use burn::backend::{Autodiff, Vulkan, autodiff::checkpoint::strategy::BalancedCheckpointing}; pub fn run() { type B = Autodiff, BalancedCheckpointing>; launch::(vec![Default::default()]); } } #[cfg(feature = "metal")] mod metal { use crate::{ElemType, launch}; use burn::backend::{Autodiff, Metal}; pub fn run() { launch::>>(vec![Default::default()]); } } #[cfg(feature = "remote")] mod remote { use crate::{ElemType, launch}; use burn::backend::{Autodiff, RemoteBackend}; pub fn run() { launch::>(vec![Default::default()]); } } #[cfg(feature = "cuda")] mod cuda { use crate::{ElemType, launch_multi}; use burn::backend::{Autodiff, Cuda, autodiff::checkpoint::strategy::BalancedCheckpointing}; pub fn run() { launch_multi::, BalancedCheckpointing>>(); } } #[cfg(feature = "rocm")] mod rocm { use crate::{ElemType, launch}; use burn::backend::{Autodiff, Rocm, autodiff::checkpoint::strategy::BalancedCheckpointing}; pub fn run() { launch::, BalancedCheckpointing>>(vec![Default::default()]); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); #[cfg(feature = "rocm")] rocm::run(); #[cfg(feature = "remote")] remote::run(); #[cfg(feature = "vulkan")] vulkan::run(); #[cfg(feature = "metal")] metal::run(); } ================================================ FILE: examples/text-classification/examples/db-pedia-infer.rs ================================================ use text_classification::DbPediaDataset; use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", // Samples from the test dataset, but you are free to test with your own text. vec![ " Magnus Eriksson is a Swedish former footballer who played as a forward.".to_string(), "Crossbeam Systems is headquartered in Boxborough Massachusetts and has offices in \ Europe Latin America and Asia Pacific. Crossbeam Systems was acquired by Blue Coat \ Systems in December 2012 and the Crossbeam brand has been fully absorbed into Blue \ Coat." .to_string(), " Zia is the sequel to the award-winning Island of the Blue Dolphins by Scott O'Dell. \ It was published in 1976 sixteen years after the publication of the first novel." .to_string(), ], ); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{ElemType, launch}; pub fn run() { launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{ElemType, launch}; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{ElemType, launch}; pub fn run() { launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{ElemType, launch}; pub fn run() { launch::>(WgpuDevice::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); } ================================================ FILE: examples/text-classification/examples/db-pedia-train.rs ================================================ use burn::{ nn::transformer::TransformerEncoderConfig, optim::{AdamConfig, decay::WeightDecayConfig}, tensor::backend::AutodiffBackend, }; use text_classification::{DbPediaDataset, training::ExperimentConfig}; #[cfg(not(feature = "f16"))] #[allow(dead_code)] type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; pub fn launch(devices: Vec) { let config = ExperimentConfig::new( TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true), AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))), ); text_classification::training::train::( devices, DbPediaDataset::train(), DbPediaDataset::test(), config, "/tmp/text-classification-db-pedia", ); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use crate::{ElemType, launch}; use burn::backend::{ Autodiff, ndarray::{NdArray, NdArrayDevice}, }; pub fn run() { launch::>>(vec![NdArrayDevice::Cpu]); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::{ElemType, launch}; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>>(vec![device]); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::{ElemType, launch}; pub fn run() { launch::>>(vec![LibTorchDevice::Cpu]); } } #[cfg(feature = "wgpu")] mod wgpu { use burn::backend::{ Autodiff, wgpu::{Wgpu, WgpuDevice}, }; use crate::{ElemType, launch}; pub fn run() { launch::>>(vec![WgpuDevice::default()]); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); } ================================================ FILE: examples/text-classification/src/data/batcher.rs ================================================ // The module defines two structs TextClassificationTrainingBatch and TextClassificationInferenceBatch // to handle batches of data during training and inference respectively. The TextClassificationBatcher // struct is implemented for creating these batches. It is parameterized on the type B: Backend to // support different computation backends (e.g., CPU, CUDA). // Two implementations of the Batcher trait are provided for TextClassificationBatcher, one for creating // training batches and one for creating inference batches. In each implementation, the batch function is // defined to convert a vector of items into a batch. For training, the items are instances of // TextClassificationItem and include both the text and the corresponding label. // For inference, the items are simply strings without labels. The function tokenizes the text, // generates a padding mask, and returns a batch object. use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; use burn::{ data::dataloader::batcher::Batcher, nn::attention::{SeqLengthOption, generate_padding_mask}, prelude::*, }; use std::sync::Arc; /// Struct for batching text classification items #[derive(Clone, new)] pub struct TextClassificationBatcher { tokenizer: Arc, // Tokenizer for converting text to token IDs seq_length: SeqLengthOption, // Sequence length option for tokenized text } /// Struct for training batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationTrainingBatch { pub tokens: Tensor, // Tokenized text pub labels: Tensor, // Labels of the text pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Struct for inference batch in text classification task #[derive(Debug, Clone, new)] pub struct TextClassificationInferenceBatch { pub tokens: Tensor, // Tokenized text pub mask_pad: Tensor, // Padding mask for the tokenized text } /// Implement Batcher trait for TextClassificationBatcher struct for training impl Batcher> for TextClassificationBatcher { /// Batches a vector of text classification items into a training batch fn batch( &self, items: Vec, device: &B::Device, ) -> TextClassificationTrainingBatch { let mut tokens_list = Vec::with_capacity(items.len()); let mut labels_list = Vec::with_capacity(items.len()); // Tokenize text and create label tensor for each item for item in items { tokens_list.push(self.tokenizer.encode(&item.text)); labels_list.push(Tensor::from_data( TensorData::from([(item.label as i64).elem::()]), device, )); } // Generate padding mask for tokenized text let mask = generate_padding_mask( self.tokenizer.pad_token(), tokens_list, self.seq_length, device, ); // Create and return training batch TextClassificationTrainingBatch { tokens: mask.tensor, labels: Tensor::cat(labels_list, 0), mask_pad: mask.mask, } } } /// Implement Batcher trait for TextClassificationBatcher struct for inference impl Batcher> for TextClassificationBatcher { /// Batches a vector of strings into an inference batch fn batch(&self, items: Vec, device: &B::Device) -> TextClassificationInferenceBatch { let mut tokens_list = Vec::with_capacity(items.len()); // Tokenize each string for item in items { tokens_list.push(self.tokenizer.encode(&item)); } // Generate padding mask for tokenized text let mask = generate_padding_mask( self.tokenizer.pad_token(), tokens_list, self.seq_length, device, ); // Create and return inference batch TextClassificationInferenceBatch { tokens: mask.tensor.to_device(device), mask_pad: mask.mask.to_device(device), } } } ================================================ FILE: examples/text-classification/src/data/dataset.rs ================================================ // The AgNewsDataset and DbPediaDataset structs are examples of specific text // classification datasets. Each dataset struct has a field for the underlying // SQLite dataset and implements methods for accessing and processing the data. // Each dataset is also provided with specific information about its classes via // the TextClassificationDataset trait. These implementations are designed to be used // with a machine learning framework for tasks such as training a text classification model. use burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader}; // Define a struct for text classification items #[derive(new, Clone, Debug)] pub struct TextClassificationItem { pub text: String, // The text for classification pub label: usize, // The label of the text (classification category) } // Trait for text classification datasets pub trait TextClassificationDataset: Dataset { fn num_classes() -> usize; // Returns the number of unique classes in the dataset fn class_name(label: usize) -> String; // Returns the name of the class given its label } // Struct for items in the AG News dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct AgNewsItem { pub text: String, // The text for classification pub label: usize, // The label of the text (classification category) } // Struct for the AG News dataset pub struct AgNewsDataset { dataset: SqliteDataset, // Underlying SQLite dataset } // Implement the Dataset trait for the AG News dataset impl Dataset for AgNewsDataset { /// Returns a specific item from the dataset fn get(&self, index: usize) -> Option { self.dataset .get(index) .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems } /// Returns the length of the dataset fn len(&self) -> usize { self.dataset.len() } } // Implement methods for constructing the AG News dataset impl AgNewsDataset { /// Returns the training portion of the dataset pub fn train() -> Self { Self::new("train") } /// Returns the testing portion of the dataset pub fn test() -> Self { Self::new("test") } /// Constructs the dataset from a split (either "train" or "test") pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("ag_news") .dataset(split) .unwrap(); Self { dataset } } } /// Implements the TextClassificationDataset trait for the AG News dataset impl TextClassificationDataset for AgNewsDataset { /// Returns the number of unique classes in the dataset fn num_classes() -> usize { 4 } /// Returns the name of a class given its label fn class_name(label: usize) -> String { match label { 0 => "World", 1 => "Sports", 2 => "Business", 3 => "Technology", _ => panic!("invalid class"), } .to_string() } } /// Struct for items in the DbPedia dataset #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { pub title: String, // The title of the item pub content: String, // The content of the item pub label: usize, // The label of the item (classification category) } /// Struct for the DbPedia dataset pub struct DbPediaDataset { dataset: SqliteDataset, // Underlying SQLite dataset } /// Implements the Dataset trait for the DbPedia dataset impl Dataset for DbPediaDataset { /// Returns a specific item from the dataset fn get(&self, index: usize) -> Option { self.dataset.get(index).map(|item| { TextClassificationItem::new( format!("Title: {} - Content: {}", item.title, item.content), item.label, ) }) } /// Returns the length of the dataset fn len(&self) -> usize { self.dataset.len() } } /// Implement methods for constructing the DbPedia dataset impl DbPediaDataset { /// Returns the training portion of the dataset pub fn train() -> Self { Self::new("train") } /// Returns the testing portion of the dataset pub fn test() -> Self { Self::new("test") } /// Constructs the dataset from a split (either "train" or "test") pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset(split) .unwrap(); Self { dataset } } } /// Implement the TextClassificationDataset trait for the DbPedia dataset impl TextClassificationDataset for DbPediaDataset { /// Returns the number of unique classes in the dataset fn num_classes() -> usize { 14 } /// Returns the name of a class given its label fn class_name(label: usize) -> String { match label { 0 => "Company", 1 => "EducationalInstitution", 2 => "Artist", 3 => "Athlete", 4 => "OfficeHolder", 5 => "MeanOfTransportation", 6 => "Building", 7 => "NaturalPlace", 8 => "Village", 9 => "Animal", 10 => "Plant", 11 => "Album", 12 => "Film", 13 => "WrittenWork", _ => panic!("invalid class"), } .to_string() } } ================================================ FILE: examples/text-classification/src/data/mod.rs ================================================ mod batcher; mod dataset; mod tokenizer; pub use batcher::*; pub use dataset::*; pub use tokenizer::*; ================================================ FILE: examples/text-classification/src/data/tokenizer.rs ================================================ // This module defines a trait `Tokenizer` that represents a common interface for all tokenizer // types used in the text classification library. A specific implementation of this trait, // `BertCasedTokenizer`, uses the BERT cased tokenization strategy provided by the `tokenizers` library. // This trait represents the common interface for all tokenizer types. // The `Send + Sync` bounds are necessary for allowing these operations // to work across thread boundaries. #[allow(dead_code)] pub trait Tokenizer: Send + Sync { /// Converts a text string into a sequence of tokens. fn encode(&self, value: &str) -> Vec; /// Converts a sequence of tokens back into a text string. fn decode(&self, tokens: &[usize]) -> String; /// Gets the size of the tokenizer's vocabulary. fn vocab_size(&self) -> usize; /// Gets the token used for padding sequences to a consistent length. fn pad_token(&self) -> usize; /// Gets the string representation of the padding token. /// The default implementation uses `decode` on the padding token. fn pad_token_value(&self) -> String { self.decode(&[self.pad_token()]) } } /// Struct represents a specific tokenizer using the BERT cased tokenization strategy. pub struct BertCasedTokenizer { // The underlying tokenizer from the `tokenizers` library. tokenizer: tokenizers::Tokenizer, } // Default implementation for creating a new BertCasedTokenizer. // This uses a pretrained BERT cased tokenizer model. impl Default for BertCasedTokenizer { fn default() -> Self { Self { tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), } } } // Implementation of the Tokenizer trait for BertCasedTokenizer. impl Tokenizer for BertCasedTokenizer { // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy. fn encode(&self, value: &str) -> Vec { let tokens = self.tokenizer.encode(value, true).unwrap(); tokens.get_ids().iter().map(|t| *t as usize).collect() } /// Converts a sequence of tokens back into a text string. fn decode(&self, tokens: &[usize]) -> String { let tokens = tokens.iter().map(|t| *t as u32).collect::>(); self.tokenizer.decode(&tokens, false).unwrap() } /// Gets the size of the BERT cased tokenizer's vocabulary. fn vocab_size(&self) -> usize { self.tokenizer.get_vocab_size(true) } /// Gets the token used for padding sequences to a consistent length. fn pad_token(&self) -> usize { self.tokenizer.token_to_id("[PAD]").unwrap() as usize } } ================================================ FILE: examples/text-classification/src/inference.rs ================================================ // This module defines the inference process for a text classification model. // It loads a model and its configuration from a directory, and uses a tokenizer // and a batcher to prepare the input data. The model is then used to make predictions // on the input samples, and the results are printed out for each sample. // Import required modules and types use crate::{ data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, model::TextClassificationModelConfig, training::ExperimentConfig, }; use burn::{ data::dataloader::batcher::Batcher, prelude::*, record::{CompactRecorder, Recorder}, }; use std::sync::Arc; // Define inference function pub fn infer( device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device) artifact_dir: &str, // Directory containing model and config files samples: Vec, // Text samples for inference ) { // Load experiment configuration let config = ExperimentConfig::load(format!("{artifact_dir}/config.json").as_str()) .expect("Config file present"); // Initialize tokenizer let tokenizer = Arc::new(BertCasedTokenizer::default()); // Get number of classes from dataset let n_classes = D::num_classes(); // Initialize batcher for batching samples let batcher = Arc::new(TextClassificationBatcher::new( tokenizer.clone(), config.seq_length, )); // Load pre-trained model weights println!("Loading weights ..."); let record = CompactRecorder::new() .load(format!("{artifact_dir}/model").into(), &device) .expect("Trained model weights tb"); // Create model using loaded weights println!("Creating model ..."); let model = TextClassificationModelConfig::new( config.transformer, n_classes, tokenizer.vocab_size(), config.seq_length, ) .init::(&device) .load_record(record); // Initialize model with loaded weights // Run inference on the given text samples println!("Running inference ..."); let item = batcher.batch(samples.clone(), &device); // Batch samples using the batcher let predictions = model.infer(item); // Get model predictions // Print out predictions for each sample for (i, text) in samples.into_iter().enumerate() { #[allow(clippy::single_range_in_vec_init)] let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample let logits = prediction.to_data(); // Convert prediction tensor to data let class_index = prediction.argmax(1).squeeze_dim::<1>(1).into_scalar(); // Get class index with the highest value let class = D::class_name(class_index.elem::() as usize); // Get class name // Print sample text, predicted logits and predicted class println!( "\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: \ {class}\n================" ); } } ================================================ FILE: examples/text-classification/src/lib.rs ================================================ #[macro_use] extern crate derive_new; pub mod data; pub mod inference; pub mod model; pub mod training; pub use data::{AgNewsDataset, DbPediaDataset, TextClassificationDataset}; ================================================ FILE: examples/text-classification/src/model.rs ================================================ // This is a basic text classification model implemented in Rust using the Burn framework. // It uses a Transformer as the base model and applies Linear and Embedding layers. // The model is then trained using Cross-Entropy loss. It contains methods for model initialization // (both with and without pre-trained weights), forward pass, inference, training, and validation. use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch}; use burn::{ nn::{ Embedding, EmbeddingConfig, Linear, LinearConfig, attention::SeqLengthOption, loss::CrossEntropyLossConfig, transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, }, prelude::*, tensor::{activation::softmax, backend::AutodiffBackend}, train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep}, }; // Define the model configuration #[derive(Config, Debug)] pub struct TextClassificationModelConfig { transformer: TransformerEncoderConfig, n_classes: usize, vocab_size: usize, seq_length: SeqLengthOption, } // Define the model structure #[derive(Module, Debug)] pub struct TextClassificationModel { transformer: TransformerEncoder, embedding_token: Embedding, embedding_pos: Embedding, output: Linear, n_classes: usize, } // Define functions for model initialization impl TextClassificationModelConfig { /// Initializes a model with default weights pub fn init(&self, device: &B::Device) -> TextClassificationModel { let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(device); let transformer = self.transformer.init(device); let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device); let max_seq_length = match self.seq_length { SeqLengthOption::Fixed(max) | SeqLengthOption::Max(max) => max, SeqLengthOption::NoMax => panic!( "Text classification requires a max sequence length because of the embedding strategy." ), }; let embedding_pos = EmbeddingConfig::new(max_seq_length, self.transformer.d_model).init(device); TextClassificationModel { transformer, embedding_token, embedding_pos, output, n_classes: self.n_classes, } } } /// Define model behavior impl TextClassificationModel { // Defines forward pass for training pub fn forward(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { // Get batch and sequence length, and the device let [batch_size, seq_length] = item.tokens.dims(); let device = &self.embedding_token.devices()[0]; // Move tensors to the correct device let tokens = item.tokens.to_device(device); let labels = item.labels.to_device(device); let mask_pad = item.mask_pad.to_device(device); // Calculate token and position embeddings, and combine them let index_positions = Tensor::arange(0..seq_length as i64, device) .reshape([1, seq_length]) .repeat_dim(0, batch_size); let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_tokens = self.embedding_token.forward(tokens); let embedding = (embedding_positions + embedding_tokens) / 2; // Perform transformer encoding, calculate output and loss let encoded = self .transformer .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); let output = self.output.forward(encoded); let output_classification = output .slice([0..batch_size, 0..1]) .reshape([batch_size, self.n_classes]); let loss = CrossEntropyLossConfig::new() .init(&output_classification.device()) .forward(output_classification.clone(), labels.clone()); // Return the output and loss ClassificationOutput { loss, output: output_classification, targets: labels, } } /// Defines forward pass for inference pub fn infer(&self, item: TextClassificationInferenceBatch) -> Tensor { // Get batch and sequence length, and the device let [batch_size, seq_length] = item.tokens.dims(); let device = &self.embedding_token.devices()[0]; // Move tensors to the correct device let tokens = item.tokens.to_device(device); let mask_pad = item.mask_pad.to_device(device); // Calculate token and position embeddings, and combine them let index_positions = Tensor::arange(0..seq_length as i64, device) .reshape([1, seq_length]) .repeat_dim(0, batch_size); let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_tokens = self.embedding_token.forward(tokens); let embedding = (embedding_positions + embedding_tokens) / 2; // Perform transformer encoding, calculate output and apply softmax for prediction let encoded = self .transformer .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad)); let output = self.output.forward(encoded); let output = output .slice([0..batch_size, 0..1]) .reshape([batch_size, self.n_classes]); softmax(output, 1) } } /// Define training step impl TrainStep for TextClassificationModel { type Input = TextClassificationTrainingBatch; type Output = ClassificationOutput; fn step( &self, item: TextClassificationTrainingBatch, ) -> TrainOutput> { // Run forward pass, calculate gradients and return them along with the output let item = self.forward(item); let grads = item.loss.backward(); TrainOutput::new(self, grads, item) } } /// Define validation step impl InferenceStep for TextClassificationModel { type Input = TextClassificationTrainingBatch; type Output = ClassificationOutput; fn step(&self, item: TextClassificationTrainingBatch) -> ClassificationOutput { // Run forward pass and return the output self.forward(item) } } ================================================ FILE: examples/text-classification/src/training.rs ================================================ // This module trains a text classification model using the provided training and testing datasets, // as well as the provided configuration. It first initializes a tokenizer and batchers for the datasets, // then initializes the model and data loaders for the datasets. The function then initializes // an optimizer and a learning rate scheduler, and uses them along with the model and datasets // to build a learner, which is used to train the model. The trained model and the configuration are // then saved to the specified directory. use crate::{ data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer}, model::TextClassificationModelConfig, }; #[cfg(feature = "ddp")] use burn::collective::{AllReduceStrategy, CollectiveConfig}; use burn::train::{Learner, SupervisedTraining}; #[cfg(not(feature = "ddp"))] use burn::{ data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset}, lr_scheduler::noam::NoamLrSchedulerConfig, nn::{attention::SeqLengthOption, transformer::TransformerEncoderConfig}, optim::AdamConfig, prelude::*, record::{CompactRecorder, Recorder}, tensor::backend::AutodiffBackend, train::{ MultiDeviceOptim, metric::{ AccuracyMetric, CudaMetric, IterationSpeedMetric, LearningRateMetric, LossMetric, }, }, }; use std::sync::Arc; // Define configuration struct for the experiment #[derive(Config, Debug)] pub struct ExperimentConfig { pub transformer: TransformerEncoderConfig, pub optimizer: AdamConfig, #[config(default = "SeqLengthOption::Fixed(256)")] pub seq_length: SeqLengthOption, #[config(default = 16)] pub batch_size: usize, #[config(default = 5)] pub num_epochs: usize, } // Define train function pub fn train( devices: Vec, // Device on which to perform computation (e.g., CPU or CUDA device) dataset_train: D, // Training dataset dataset_test: D, // Testing dataset config: ExperimentConfig, // Experiment configuration artifact_dir: &str, // Directory to save model and config files ) { // Initialize tokenizer let tokenizer = Arc::new(BertCasedTokenizer::default()); // Initialize batcher let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.seq_length); // Initialize model let model = TextClassificationModelConfig::new( config.transformer.clone(), D::num_classes(), tokenizer.vocab_size(), config.seq_length, ) .init::(&devices[0]); // Initialize data loaders for training and testing data let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .num_workers(1) .build(SamplerDataset::new(dataset_train, 50_000)); let dataloader_test = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .num_workers(1) .build(SamplerDataset::new(dataset_test, 5_000)); // Initialize optimizer let optim = config.optimizer.init(); // Initialize learning rate scheduler let lr_scheduler = NoamLrSchedulerConfig::new(1e-2) .with_warmup_steps(1000) .with_model_size(config.transformer.d_model) .init() .unwrap(); // Initialize learner #[cfg(not(feature = "ddp"))] let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metric_train(CudaMetric::new()) .metric_valid(CudaMetric::new()) .metric_train(IterationSpeedMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .metric_train_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .num_epochs(config.num_epochs) .summary() .with_training_strategy(burn::train::TrainingStrategy::MultiDevice( devices, MultiDeviceOptim::OptimSharded, )); #[cfg(feature = "ddp")] let collective_config = CollectiveConfig::default().with_local_all_reduce_strategy(AllReduceStrategy::Tree(2)); #[cfg(feature = "ddp")] let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metric_train(CudaMetric::new()) .metric_valid(CudaMetric::new()) .metric_train(IterationSpeedMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .metric_train_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .with_training_strategy(burn::train::ddp(devices, collective_config)) .num_epochs(config.num_epochs) .summary(); // Train the model let result = training.launch(Learner::new(model, optim, lr_scheduler)); // Save the configuration and the trained model config.save(format!("{artifact_dir}/config.json")).unwrap(); CompactRecorder::new() .record( result.model.into_record(), format!("{artifact_dir}/model").into(), ) .unwrap(); } ================================================ FILE: examples/text-generation/Cargo.toml ================================================ [package] authors = ["nathanielsimard "] edition.workspace = true license.workspace = true name = "text-generation" publish = false version.workspace = true [lints] workspace = true [features] default = ["burn/dataset", "burn/sqlite-bundled"] f16 = [] [dependencies] # Burn burn = {path = "../../crates/burn", features=["train", "tch"]} # Tokenizer tokenizers = {version = "0.22.2", default-features = false, features = [ "onig", "http", ]} # Utils derive-new = {workspace = true} log = {workspace = true} serde = {workspace = true, features = ["std", "derive"]} ================================================ FILE: examples/text-generation/README.md ================================================ # Text Generation > **Note** > This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index) > library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/) > installed on your computer. The example can be run like so: ## CUDA users ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. export TORCH_CUDA_VERSION=cu128 cargo run --example text-generation --release ``` ## Mac users ```bash git clone https://github.com/tracel-ai/burn.git cd burn # Use the --release flag to really speed up training. cargo run --example text-generation --release ``` ================================================ FILE: examples/text-generation/examples/text-generation.rs ================================================ use burn::optim::decay::WeightDecayConfig; use text_generation::{DbPediaDataset, training::ExperimentConfig}; #[cfg(feature = "f16")] type Elem = burn::tensor::f16; #[cfg(not(feature = "f16"))] type Elem = f32; type Backend = burn::backend::Autodiff>; fn main() { let config = ExperimentConfig::new( burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6) .with_norm_first(true), burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))), ); text_generation::training::train::( if cfg!(target_os = "macos") { burn::tensor::Device::::Mps } else { burn::tensor::Device::::Cuda(0) }, DbPediaDataset::train(), DbPediaDataset::test(), config, "/tmp/text-generation", ); } ================================================ FILE: examples/text-generation/src/data/batcher.rs ================================================ use super::{dataset::TextGenerationItem, tokenizer::Tokenizer}; use burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*}; use std::sync::Arc; #[derive(Clone, new)] pub struct TextGenerationBatcher { tokenizer: Arc, max_seq_length: usize, } #[derive(Debug, Clone, new)] pub struct TextGenerationBatch { pub tokens: Tensor, pub mask_pad: Tensor, } #[derive(Debug, Clone, new)] pub struct TrainingTextGenerationBatch { pub tokens_inputs: Tensor, pub targets: Tensor, pub mask_pad: Tensor, } impl Batcher> for TextGenerationBatcher { fn batch(&self, items: Vec, device: &B::Device) -> TextGenerationBatch { let mut tokens_list = Vec::with_capacity(items.len()); for item in items { tokens_list.push(self.tokenizer.encode(&item.text, true)); } let mask = generate_padding_mask( self.tokenizer.pad_token(), tokens_list, Some(self.max_seq_length), device, ); TextGenerationBatch { tokens: mask.tensor, mask_pad: mask.mask, } } } impl Batcher> for TextGenerationBatcher { fn batch( &self, items: Vec, device: &B::Device, ) -> TrainingTextGenerationBatch { let item: TextGenerationBatch = self.batch(items, device); let [batch_size, seq_length] = item.tokens.dims(); let inputs = item .tokens .clone() .slice([0..batch_size, 0..seq_length - 1]); let targets = item.tokens.slice([0..batch_size, 1..seq_length]); let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]); TrainingTextGenerationBatch::new(inputs, targets, mask_pad) } } ================================================ FILE: examples/text-generation/src/data/dataset.rs ================================================ use burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader}; #[derive(new, Clone, Debug)] pub struct TextGenerationItem { pub text: String, } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct DbPediaItem { pub content: String, } pub struct DbPediaDataset { dataset: SqliteDataset, } impl Dataset for DbPediaDataset { fn get(&self, index: usize) -> Option { self.dataset .get(index) .map(|item| TextGenerationItem::new(item.content)) } fn len(&self) -> usize { self.dataset.len() } } impl DbPediaDataset { pub fn train() -> Self { Self::new("train") } pub fn test() -> Self { Self::new("test") } pub fn new(split: &str) -> Self { let dataset: SqliteDataset = HuggingfaceDatasetLoader::new("dbpedia_14") .dataset(split) .unwrap(); Self { dataset } } } ================================================ FILE: examples/text-generation/src/data/mod.rs ================================================ mod batcher; mod dataset; mod tokenizer; pub use batcher::*; pub use dataset::*; pub use tokenizer::*; ================================================ FILE: examples/text-generation/src/data/tokenizer.rs ================================================ #[allow(dead_code)] pub trait Tokenizer: Send + Sync { fn encode(&self, value: &str, special_tokens: bool) -> Vec; fn decode(&self, tokens: &[usize]) -> String; fn vocab_size(&self) -> usize; fn pad_token(&self) -> usize; fn start_token(&self) -> usize; fn end_token(&self) -> usize; fn pad_token_value(&self) -> String { self.decode(&[self.pad_token()]) } fn start_token_value(&self) -> String { self.decode(&[self.start_token()]) } fn end_token_value(&self) -> String { self.decode(&[self.end_token()]) } } pub struct Gpt2Tokenizer { tokenizer: tokenizers::Tokenizer, } impl Default for Gpt2Tokenizer { fn default() -> Self { let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None).unwrap(); tokenizer.add_special_tokens(&[ tokenizers::AddedToken::from("[START]", true), tokenizers::AddedToken::from("[END]", true), tokenizers::AddedToken::from("[PAD]", true), ]); Self { tokenizer } } } impl Tokenizer for Gpt2Tokenizer { fn encode(&self, value: &str, special_tokens: bool) -> Vec { let text = match special_tokens { true => "[START]".to_owned() + value + "[END]", false => value.to_string(), }; let tokens = self.tokenizer.encode(text, true).unwrap(); tokens.get_ids().iter().map(|t| *t as usize).collect() } fn decode(&self, tokens: &[usize]) -> String { let tokens = tokens.iter().map(|t| *t as u32).collect::>(); self.tokenizer.decode(&tokens, false).unwrap() } fn vocab_size(&self) -> usize { self.tokenizer.get_vocab_size(true) } fn pad_token(&self) -> usize { self.tokenizer.token_to_id("[PAD]").unwrap() as usize } fn start_token(&self) -> usize { self.tokenizer.token_to_id("[START]").unwrap() as usize } fn end_token(&self) -> usize { self.tokenizer.token_to_id("[END]").unwrap() as usize } } #[cfg(test)] mod tests { use super::*; #[test] fn test_encode_decode() { let tokenizer = Gpt2Tokenizer::default(); let text = "A sentence"; let tokens = tokenizer.encode(text, false); let decoded = tokenizer.decode(&tokens); assert_eq!(decoded, text); } #[test] fn test_add_start_end_token() { let tokenizer = Gpt2Tokenizer::default(); let text = "A sentence"; let tokens_without = tokenizer.encode(text, false); let tokens_with = tokenizer.encode(text, true); assert_eq!(tokens_with.len() - 2, tokens_without.len()); } } ================================================ FILE: examples/text-generation/src/lib.rs ================================================ #[macro_use] extern crate derive_new; mod data; mod model; pub mod training; pub use data::DbPediaDataset; ================================================ FILE: examples/text-generation/src/model.rs ================================================ use crate::data::TrainingTextGenerationBatch; use burn::{ nn::{ Embedding, EmbeddingConfig, Linear, LinearConfig, attention::generate_autoregressive_mask, loss::CrossEntropyLossConfig, transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput}, }, prelude::*, tensor::backend::AutodiffBackend, train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep}, }; #[derive(Config, Debug)] pub struct TextGenerationModelConfig { transformer: TransformerEncoderConfig, vocab_size: usize, pad_token: usize, max_seq_length: usize, } #[derive(Module, Debug)] pub struct TextGenerationModel { transformer: TransformerEncoder, embedding_token: Embedding, embedding_pos: Embedding, output: Linear, vocab_size: usize, pad_token: usize, max_seq_length: usize, } impl TextGenerationModelConfig { pub fn init(&self, device: &B::Device) -> TextGenerationModel { let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(device); let transformer = self.transformer.init(device); let embedding_token = EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device); let embedding_pos = EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device); TextGenerationModel { transformer, embedding_token, embedding_pos, output, vocab_size: self.vocab_size, pad_token: self.pad_token, max_seq_length: self.max_seq_length, } } } impl TextGenerationModel { pub fn forward_training( &self, item: TrainingTextGenerationBatch, ) -> ClassificationOutput { let [batch_size, seq_length] = item.tokens_inputs.dims(); let device = &self.devices()[0]; let inputs = item.tokens_inputs.to_device(device); let targets = item.targets.to_device(device); let mask_pad = item.mask_pad.to_device(device); let index_positions = Tensor::arange(0..seq_length as i64, device) .reshape([1, seq_length]) .repeat_dim(0, batch_size); let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_tokens = self.embedding_token.forward(inputs); let embedding = (embedding_positions + embedding_tokens) / 2; let mask_attn = generate_autoregressive_mask::(batch_size, seq_length, device); let encoded = self.transformer.forward( TransformerEncoderInput::new(embedding) .mask_pad(mask_pad) .mask_attn(mask_attn), ); let output = self.output.forward(encoded); let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]); let targets_flatten = targets.reshape([batch_size * seq_length]); let loss = CrossEntropyLossConfig::new() .with_pad_tokens(Some(vec![self.pad_token])) .init(&output_flatten.device()); let loss = loss.forward(output_flatten.clone(), targets_flatten.clone()); ClassificationOutput { loss, output: output_flatten, targets: targets_flatten, } } } impl TrainStep for TextGenerationModel { type Input = TrainingTextGenerationBatch; type Output = ClassificationOutput; fn step(&self, item: TrainingTextGenerationBatch) -> TrainOutput> { let item = self.forward_training(item); let grads = item.loss.backward(); TrainOutput::new(self, grads, item) } } impl InferenceStep for TextGenerationModel { type Input = TrainingTextGenerationBatch; type Output = ClassificationOutput; fn step(&self, item: TrainingTextGenerationBatch) -> ClassificationOutput { self.forward_training(item) } } ================================================ FILE: examples/text-generation/src/training.rs ================================================ use crate::{ data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer}, model::TextGenerationModelConfig, }; use burn::{ data::{ dataloader::DataLoaderBuilder, dataset::{Dataset, transform::SamplerDataset}, }, lr_scheduler::noam::NoamLrSchedulerConfig, nn::transformer::TransformerEncoderConfig, optim::AdamConfig, prelude::*, record::{CompactRecorder, DefaultRecorder, Recorder}, tensor::backend::AutodiffBackend, train::{ Learner, SupervisedTraining, metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric, PerplexityMetric}, }, }; use std::sync::Arc; #[derive(Config, Debug)] pub struct ExperimentConfig { transformer: TransformerEncoderConfig, optimizer: AdamConfig, #[config(default = 512)] max_seq_length: usize, #[config(default = 6)] batch_size: usize, #[config(default = 50)] num_epochs: usize, } pub fn train + 'static>( device: B::Device, dataset_train: D, dataset_test: D, config: ExperimentConfig, artifact_dir: &str, ) { let tokenizer = Arc::new(Gpt2Tokenizer::default()); let batcher = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); let model = TextGenerationModelConfig::new( config.transformer.clone(), tokenizer.vocab_size(), tokenizer.pad_token(), config.max_seq_length, ) .init::(&device); let dataloader_train = DataLoaderBuilder::new(batcher.clone()) .batch_size(config.batch_size) .num_workers(4) .build(SamplerDataset::new(dataset_train, 10_000)); let dataloader_test = DataLoaderBuilder::new(batcher) .batch_size(config.batch_size) .num_workers(4) .build(SamplerDataset::new(dataset_test, 1000)); let accum = 6; // Effective batch size = 6 * 6 = 32. let optim = config.optimizer.init(); let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64) .with_warmup_steps(6000) .with_model_size(config.transformer.d_model) .init() .unwrap(); let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test) .metric_train(CudaMetric::new()) .metric_valid(CudaMetric::new()) .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token())) .metric_train_numeric(PerplexityMetric::new().with_pad_token(tokenizer.pad_token())) .metric_valid_numeric(PerplexityMetric::new().with_pad_token(tokenizer.pad_token())) .metric_train(LossMetric::new()) .metric_valid(LossMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .grads_accumulation(accum) .num_epochs(config.num_epochs) .summary(); let result = training.launch(Learner::new(model, optim, lr_scheduler)); config.save(format!("{artifact_dir}/config.json")).unwrap(); DefaultRecorder::new() .record( result.model.into_record(), format!("{artifact_dir}/model").into(), ) .unwrap(); } ================================================ FILE: examples/wgan/Cargo.toml ================================================ [package] name = "wgan" version = "0.5.0" edition.workspace = true [lints] workspace = true [features] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] cuda = ["burn/cuda"] [dependencies] burn = { path = "../../crates/burn", features=["train", "vision"] } image = { workspace = true } ================================================ FILE: examples/wgan/README.md ================================================ # Wasserstein Generative Adversarial Network A burn implementation of an example WGAN model to generate MNIST digits inspired by [the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). Please note that better performance maybe gained by adopting a convolution layer in [some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). ## Usage ## Training ```sh # Cuda backend cargo run --example wgan-mnist --release --features cuda # Wgpu backend cargo run --example wgan-mnist --release --features wgpu # Tch GPU backend export TORCH_CUDA_VERSION=cu128 # Set the cuda version cargo run --example wgan-mnist --release --features tch-gpu # Tch CPU backend cargo run --example wgan-mnist --release --features tch-cpu # NdArray backend (CPU) cargo run --example wgan-mnist --release --features ndarray # f32 - single thread cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib ``` ### Generating To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. ```sh cargo run --example wgan-generate --release --features cuda ``` ================================================ FILE: examples/wgan/examples/wgan-generate.rs ================================================ use burn::tensor::backend::Backend; pub fn launch(device: B::Device) { wgan::infer::generate::("/tmp/wgan-mnist", device); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::launch; pub fn run() { launch::(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { launch::(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; use burn::backend::wgpu::Wgpu; pub fn run() { launch::(Default::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::launch; use burn::backend::Cuda; pub fn run() { launch::(Default::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); } ================================================ FILE: examples/wgan/examples/wgan-mnist.rs ================================================ use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; use wgan::{model::ModelConfig, training::TrainingConfig}; pub fn launch(device: B::Device) { let config = TrainingConfig::new( ModelConfig::new(), RmsPropConfig::new() .with_alpha(0.99) .with_momentum(0.0) .with_epsilon(0.00000008) .with_weight_decay(None) .with_centered(false), ); wgan::training::train::("/tmp/wgan-mnist", config, device); } #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] mod ndarray { use burn::backend::{ Autodiff, ndarray::{NdArray, NdArrayDevice}, }; use crate::launch; pub fn run() { launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::launch; pub fn run() { #[cfg(not(target_os = "macos"))] let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { use burn::backend::{ Autodiff, libtorch::{LibTorch, LibTorchDevice}, }; use crate::launch; pub fn run() { launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; use burn::backend::{Autodiff, wgpu::Wgpu}; pub fn run() { launch::>(Default::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::launch; use burn::backend::{Autodiff, Cuda, cuda::CudaDevice}; pub fn run() { launch::>(CudaDevice::default()); } } fn main() { #[cfg(any( feature = "ndarray", feature = "ndarray-blas-netlib", feature = "ndarray-blas-openblas", feature = "ndarray-blas-accelerate", ))] ndarray::run(); #[cfg(feature = "tch-gpu")] tch_gpu::run(); #[cfg(feature = "tch-cpu")] tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); #[cfg(feature = "cuda")] cuda::run(); } ================================================ FILE: examples/wgan/src/dataset.rs ================================================ use burn::{ data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, prelude::*, }; #[derive(Clone, Debug, Default)] pub struct MnistBatcher {} #[derive(Clone, Debug)] pub struct MnistBatch { pub images: Tensor, pub targets: Tensor, } impl Batcher> for MnistBatcher { fn batch(&self, items: Vec, device: &B::Device) -> MnistBatch { let images = items .iter() .map(|item| TensorData::from(item.image)) .map(|data| Tensor::::from_data(data.convert::(), device)) .map(|tensor| tensor.reshape([1, 28, 28])) // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example .map(|tensor| ((tensor / 255) - 0.5) / 0.5) .collect(); let targets = items .iter() .map(|item| { Tensor::::from_data( TensorData::from([(item.label as i64).elem::()]), device, ) }) .collect(); let images = Tensor::stack(images, 0); let targets = Tensor::cat(targets, 0); MnistBatch { images, targets } } } ================================================ FILE: examples/wgan/src/infer.rs ================================================ use crate::training::{TrainingConfig, save_image}; use burn::{ prelude::*, record::{CompactRecorder, Recorder}, tensor::Distribution, }; pub fn generate(artifact_dir: &str, device: B::Device) { // Loading model let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) .expect("Config should exist for the model; run train first"); let record = CompactRecorder::new() .load(format!("{artifact_dir}/generator").into(), &device) .expect("Trained model should exist; run train first"); let (mut generator, _) = config.model.init::(&device); generator = generator.load_record(record); // Get a batch of noise let noise = Tensor::::random( [config.batch_size, config.model.latent_dim], Distribution::Normal(0.0, 1.0), &device, ); let fake_images = generator.forward(noise); // [batch_size, channesl*height*width] let fake_images = fake_images.reshape([ config.batch_size, config.model.channels, config.model.image_size, config.model.image_size, ]); // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); // Normalize the images. The Rgb32 images should be in range 0.0-1.0 let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) / (fake_images.clone().max().reshape([1, 1, 1, 1]) - fake_images.clone().min().reshape([1, 1, 1, 1])); // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); // Save images in artifact directory save_image::(fake_images, 5, format!("{artifact_dir}/fake_image.png")).unwrap(); } ================================================ FILE: examples/wgan/src/lib.rs ================================================ pub mod dataset; pub mod infer; pub mod model; pub mod training; ================================================ FILE: examples/wgan/src/model.rs ================================================ use burn::{ module::{Module, ModuleMapper, Param}, prelude::*, tensor::backend::AutodiffBackend, }; /// Layer block of generator model #[derive(Module, Debug)] pub struct LayerBlock { fc: nn::Linear, bn: nn::BatchNorm, leakyrelu: nn::LeakyRelu, } impl LayerBlock { pub fn new(input: usize, output: usize, device: &B::Device) -> Self { let fc = nn::LinearConfig::new(input, output) .with_bias(true) .init(device); let bn: nn::BatchNorm = nn::BatchNormConfig::new(output) .with_epsilon(0.8) .init(device); let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); Self { fc, bn, leakyrelu } } pub fn forward(&self, input: Tensor) -> Tensor { let output = self.fc.forward(input); // output: [Batch, x] let output = self.bn.forward(output); // output: [Batch, x] self.leakyrelu.forward(output) // output: [Batch, x] } } /// Generator model #[derive(Module, Debug)] pub struct Generator { layer1: LayerBlock, layer2: LayerBlock, layer3: LayerBlock, layer4: LayerBlock, fc: nn::Linear, tanh: nn::Tanh, } impl Generator { /// Applies the forward pass on the input tensor by specified order pub fn forward(&self, noise: Tensor) -> Tensor { let output = self.layer1.forward(noise); let output = self.layer2.forward(output); let output = self.layer3.forward(output); let output = self.layer4.forward(output); let output = self.fc.forward(output); self.tanh.forward(output) // [batch_size, channels*height*width] } } /// Discriminator model #[derive(Module, Debug)] pub struct Discriminator { fc1: nn::Linear, leakyrelu1: nn::LeakyRelu, fc2: nn::Linear, leakyrelu2: nn::LeakyRelu, fc3: nn::Linear, } impl Discriminator { /// Applies the forward pass on the input tensor by specified order. /// The input image shape is [batch, channels, height, width] pub fn forward(&self, images: Tensor) -> Tensor { // Full connection for each batch let output = images.flatten(1, 3); // output: [batch, channels*height*width] let output = self.fc1.forward(output); // output: [batch, 512] let output = self.leakyrelu1.forward(output); // output: [batch, 512] let output = self.fc2.forward(output); // output: [batch, 256] let output = self.leakyrelu2.forward(output); // output: [batch, 256] self.fc3.forward(output) // output: [batch, 1] } } // Use model config to construct a generative and adversarial model #[derive(Config, Debug)] pub struct ModelConfig { /// Dimensionality of the latent space #[config(default = 100)] pub latent_dim: usize, #[config(default = 28)] pub image_size: usize, #[config(default = 1)] pub channels: usize, } impl ModelConfig { /// Initialize the generator and discriminator models based on the config. pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { // Construct the initialized generator let layer1 = LayerBlock::new(self.latent_dim, 128, device); let layer2 = LayerBlock::new(128, 256, device); let layer3 = LayerBlock::new(256, 512, device); let layer4 = LayerBlock::new(512, 1024, device); let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size) .with_bias(true) .init(device); let generator = Generator { layer1, layer2, layer3, layer4, fc, tanh: nn::Tanh::new(), }; // Construct the initialized discriminator let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512) .init(device); let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); let fc2 = nn::LinearConfig::new(512, 256).init(device); let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); let fc3 = nn::LinearConfig::new(256, 1).init(device); let discriminator = Discriminator { fc1, leakyrelu1, fc2, leakyrelu2, fc3, }; (generator, discriminator) } } /// Clip module mapper to clip all module parameters between a range of values #[derive(Module, Clone, Debug)] pub struct Clip { pub min: f32, pub max: f32, } impl ModuleMapper for Clip { fn map_float(&mut self, param: Param>) -> Param> { let (id, tensor, mapper) = param.consume(); let is_require_grad = tensor.is_require_grad(); let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); if is_require_grad { tensor = tensor.require_grad(); } Param::from_mapped_value(id, tensor, mapper) } } ================================================ FILE: examples/wgan/src/training.rs ================================================ use crate::dataset::MnistBatcher; use crate::model::{Clip, ModelConfig}; use burn::optim::{GradientsParams, Optimizer, RmsPropConfig}; use burn::{ data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, prelude::*, record::CompactRecorder, tensor::{Distribution, backend::AutodiffBackend}, }; use image::{Rgb32FImage, RgbImage, buffer::ConvertBuffer, error::ImageResult}; use std::path::Path; #[derive(Config, Debug)] pub struct TrainingConfig { pub model: ModelConfig, pub optimizer: RmsPropConfig, #[config(default = 200)] pub num_epochs: usize, #[config(default = 512)] pub batch_size: usize, #[config(default = 8)] pub num_workers: usize, #[config(default = 5)] pub seed: u64, #[config(default = 3e-4)] pub lr: f64, /// Number of training steps for discriminator before generator is trained per iteration #[config(default = 5)] pub num_critic: usize, /// Lower and upper clip value for disc. weights #[config(default = 0.01)] pub clip_value: f32, /// Save a sample of images every `sample_interval` epochs #[config(default = 10)] pub sample_interval: usize, } // Create the directory to save the model and model config fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } /// Save the generated images // The images format is [B, H, W, C] pub fn save_image>( images: Tensor, nrow: u32, path: Q, ) -> ImageResult<()> { let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32; let width = images.dims()[2] as u32; let height = images.dims()[1] as u32; // Supports both 1 and 3 channels image let channels = match images.dims()[3] { 1 => 3, 3 => 1, _ => panic!("Wrong channels number"), }; let mut imgbuf = RgbImage::new(nrow * width, ncol * height); // Write images into a nrow*ncol grid layout for row in 0..nrow { for col in 0..ncol { let image: Tensor = images .clone() .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize) .squeeze_dim(0); // The Rgb32 should be in range 0.0-1.0 let image = image.into_data().iter::().collect::>(); // Supports both 1 and 3 channels image let image = image .into_iter() .flat_map(|n| std::iter::repeat_n(n, channels)) .collect(); let image = Rgb32FImage::from_vec(width, height, image).unwrap(); let image: RgbImage = image.convert(); for (x, y, pixel) in image.enumerate_pixels() { imgbuf.put_pixel(row * width + x, col * height + y, *pixel); } } } imgbuf.save(path) } pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { create_artifact_dir(artifact_dir); // Create the Clip module mapper let mut clip = Clip { min: -config.clip_value, max: config.clip_value, }; // Save training config config .save(format!("{artifact_dir}/config.json")) .expect("Config should be saved successfully"); B::seed(&device, config.seed); // Create the model and optimizer let (mut generator, mut discriminator) = config.model.init::(&device); let mut optimizer_g = config.optimizer.init(); let mut optimizer_d = config.optimizer.init(); // Create the dataset batcher let batcher_train = MnistBatcher::default(); // Create the dataloaders let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); // Iterate over our training for X epochs for epoch in 0..config.num_epochs { // Implement our training loop for (iteration, batch) in dataloader_train.iter().enumerate() { // Generate a batch of fake images from noise (standarded normal distribution) let noise = Tensor::::random( [config.batch_size, config.model.latent_dim], Distribution::Normal(0.0, 1.0), &device, ); // datach: do not update generator, only discriminator is updated let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width] let fake_images = fake_images.reshape([ config.batch_size, config.model.channels, config.model.image_size, config.model.image_size, ]); // Adversarial loss let loss_d = -discriminator.forward(batch.images).mean() + discriminator.forward(fake_images.clone()).mean(); // Gradients for the current backward pass let grads = loss_d.backward(); // Gradients linked to each parameter of the discriminator let grads = GradientsParams::from_grads(grads, &discriminator); // Update the discriminator using the optimizer discriminator = optimizer_d.step(config.lr, discriminator, grads); // Clip parameters (weights) of discriminator discriminator = discriminator.map(&mut clip); // Train the generator every num_critic iterations if iteration % config.num_critic == 0 { // Generate a batch of images again without detaching let critic_fake_images = generator.forward(noise.clone()); let critic_fake_images = critic_fake_images.reshape([ config.batch_size, config.model.channels, config.model.image_size, config.model.image_size, ]); // Adversarial loss. Minimize it to make the fake images as truth let loss_g = -discriminator.forward(critic_fake_images).mean(); let grads = loss_g.backward(); let grads = GradientsParams::from_grads(grads, &generator); generator = optimizer_g.step(config.lr, generator, grads); // Print the progression let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32) .ceil() as usize; println!( "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]", epoch + 1, config.num_epochs, iteration, batch_num, loss_d.into_scalar(), loss_g.into_scalar() ); } // If at save interval => save the first 25 generated images if epoch % config.sample_interval == 0 && iteration == 0 { // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); // Normalize the images. The Rgb32 images should be in range 0.0-1.0 let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) / (fake_images.clone().max().reshape([1, 1, 1, 1]) - fake_images.clone().min().reshape([1, 1, 1, 1])); // Add 0.5/255.0 to the images, refer to pytorch save_image source let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); // Save images in artifact directory let path = format!("{artifact_dir}/image-{epoch}.png"); save_image::(fake_images, 5, path).unwrap(); } } } // Save the trained models generator .save_file(format!("{artifact_dir}/generator"), &CompactRecorder::new()) .expect("Generator should be saved successfully"); discriminator .save_file( format!("{artifact_dir}/discriminator"), &CompactRecorder::new(), ) .expect("Discriminator should be saved successfully"); } ================================================ FILE: rustfmt.toml ================================================ max_width = 100 # uncomment and run `cargo +nightly fmt --all` to find and fix lines that are too long (and therefore break autoformatting) # error_on_line_overflow = true # format_strings = true ================================================ FILE: xtask/Cargo.toml ================================================ [package] name = "xtask" version = "4.10.0" edition.workspace = true license = "MIT OR Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lints] workspace = true [dependencies] log = { workspace = true } strum = { workspace = true } tracel-xtask = { workspace = true } [dev-dependencies] rstest = { workspace = true } ================================================ FILE: xtask/src/commands/books.rs ================================================ use std::path::Path; use tracel_xtask::prelude::*; #[derive(clap::Args)] pub struct BooksArgs { #[command(subcommand)] book: BookKind, } #[derive(clap::Subcommand)] pub(crate) enum BookKind { /// Burn Book, a.k.a. the guide, made for the Burn users. Burn(BookKindArgs), /// Contributor book, made for people willing to get all the technical understanding and advice to contribute actively to the project. Contributor(BookKindArgs), } #[derive(clap::Args)] pub(crate) struct BookKindArgs { #[command(subcommand)] command: BookSubCommand, } #[derive(clap::Subcommand, strum::Display)] pub(crate) enum BookSubCommand { /// Build the book Build, /// Open the book on the specified port or random port and rebuild it automatically upon changes Open(OpenArgs), } #[derive(clap::Args)] pub(crate) struct OpenArgs { /// Specify the port to open the book on (defaults to a random port if not specified) #[clap(long, default_value_t = random_port())] port: u16, } /// Book information pub(crate) struct Book { name: &'static str, path: &'static Path, } impl BooksArgs { pub(crate) fn parse(&self) -> anyhow::Result<()> { Book::run(&self.book) } } impl Book { const BURN_BOOK_NAME: &'static str = "Burn Book"; const BURN_BOOK_PATH: &'static str = "./burn-book"; const CONTRIBUTOR_BOOK_NAME: &'static str = "Contributor Book"; const CONTRIBUTOR_BOOK_PATH: &'static str = "./contributor-book"; pub(crate) fn run(book_arg: &BookKind) -> anyhow::Result<()> { let (book, command) = match book_arg { BookKind::Burn(args) => ( Self { name: Self::BURN_BOOK_NAME, path: Path::new(Self::BURN_BOOK_PATH), }, &args.command, ), BookKind::Contributor(args) => ( Self { name: Self::CONTRIBUTOR_BOOK_NAME, path: Path::new(Self::CONTRIBUTOR_BOOK_PATH), }, &args.command, ), }; book.execute(command) } fn execute(&self, command: &BookSubCommand) -> anyhow::Result<()> { ensure_cargo_crate_is_installed("mdbook", None, None, false)?; group!("{}: {}", self.name, command); match command { BookSubCommand::Build => self.build(), BookSubCommand::Open(args) => self.open(args), }?; endgroup!(); Ok(()) } fn build(&self) -> anyhow::Result<()> { run_process( "mdbook", &["build"], None, Some(self.path), "mdbook should build the book successfully", ) } fn open(&self, args: &OpenArgs) -> anyhow::Result<()> { run_process( "mdbook", &["serve", "--open", "--port", &args.port.to_string()], None, Some(self.path), "mdbook should open the book successfully", ) } } ================================================ FILE: xtask/src/commands/build.rs ================================================ use std::collections::HashMap; use tracel_xtask::prelude::{clap::ValueEnum, *}; use crate::{ARM_NO_ATOMIC_PTR_TARGET, ARM_TARGET, NO_STD_CRATES, WASM32_TARGET}; #[macros::extend_command_args(BuildCmdArgs, Target, None)] pub struct BurnBuildCmdArgs { /// Build in CI mode which excludes unsupported crates. #[arg(long)] pub ci: bool, } pub(crate) fn handle_command( mut args: BurnBuildCmdArgs, env: Environment, context: Context, ) -> anyhow::Result<()> { match context { Context::NoStd => { [ "Default", WASM32_TARGET, ARM_TARGET, ARM_NO_ATOMIC_PTR_TARGET, ] .iter() .try_for_each(|build_target| { let mut build_args = vec!["--no-default-features"]; let mut env_vars = HashMap::new(); if *build_target != "Default" { build_args.extend(vec!["--target", *build_target]); } let mut crates = NO_STD_CRATES.to_vec(); if *build_target == ARM_NO_ATOMIC_PTR_TARGET { // Temporarily remove `burn-autodiff` from building with the // target `thumbv6m-none-eabi` as it requires enabling the // `arbitrary_self_types` feature for the // `clone_if_require_grad` method of // `burn-autodiff::graph::Node` crates.retain(|&v| v != "burn-autodiff"); env_vars.insert( "RUSTFLAGS", "--cfg portable_atomic_unsafe_assume_single_core", ); } helpers::custom_crates_build( crates, build_args, Some(env_vars), None, &format!("no-std with target {}", *build_target), ) })?; Ok(()) } Context::Std => { if args.ci { // Exclude crates that are not supported on CI args.exclude.extend(vec![ "burn-cuda".to_string(), "burn-rocm".to_string(), "burn-tch".to_string(), ]); if std::env::var("DISABLE_WGPU").is_ok() { args.exclude.extend(vec!["burn-wgpu".to_string()]); }; } // Build workspace base_commands::build::handle_command(args.try_into().unwrap(), env, context)?; // Specific additional commands to test specific features // burn-dataset helpers::custom_crates_build( vec!["burn-dataset"], vec!["--all-features"], None, None, "std with all features", )?; Ok(()) } Context::All => Context::value_variants() .iter() .filter(|ctx| **ctx != Context::All) .try_for_each(|ctx| { handle_command( BurnBuildCmdArgs { target: args.target.clone(), exclude: args.exclude.clone(), only: args.only.clone(), ci: args.ci, release: args.release, features: args.features.clone(), no_default_features: args.no_default_features, }, env.clone(), ctx.clone(), ) }), } } ================================================ FILE: xtask/src/commands/doc.rs ================================================ use tracel_xtask::prelude::*; pub(crate) fn handle_command( mut args: DocCmdArgs, env: Environment, ctx: Context, ) -> anyhow::Result<()> { if args.get_command() == DocSubCommand::Build { args.exclude .extend(vec!["burn-cuda".to_string(), "burn-rocm".to_string()]); } // Execute documentation command on workspace base_commands::doc::handle_command(args.clone(), env, ctx)?; // Specific additional commands to build other docs if args.get_command() == DocSubCommand::Build { // burn-dataset helpers::custom_crates_doc_build( vec!["burn-dataset"], vec!["--all-features"], None, None, "All features", )?; } Ok(()) } ================================================ FILE: xtask/src/commands/mod.rs ================================================ pub(crate) mod books; pub(crate) mod build; pub(crate) mod doc; pub(crate) mod test; pub(crate) mod validate; ================================================ FILE: xtask/src/commands/test.rs ================================================ use tracel_xtask::{ prelude::{clap::ValueEnum, *}, utils::{ process::{ExitSignal, ProcessExitError}, workspace::WorkspaceMember, }, }; use crate::NO_STD_CRATES; #[cfg(unix)] use std::os::unix::process::ExitStatusExt; #[macros::extend_command_args(TestCmdArgs, Target, TestSubCommand)] pub struct BurnTestCmdArgs { /// Test in CI mode which excludes unsupported crates. #[arg(long)] pub ci: CiTestType, } #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, ValueEnum, PartialEq)] pub enum CiTestType { GithubRunner, GithubMacRunner, GcpCudaRunner, GcpVulkanRunner, GcpWgpuRunner, } fn handle_backend_tests( mut args: TestCmdArgs, backend: &str, env: Environment, context: Context, ) -> anyhow::Result<()> { args.target = Target::AllPackages; args.only.push("burn-backend-tests".to_string()); args.no_default_features = true; let mut features = vec![String::from(backend)]; if !matches!(context, Context::NoStd) { features.push("std".into()) } args.features = Some(features); base_commands::test::handle_command(args, env, context) } fn handle_wgpu_test(member: &str, args: &TestCmdArgs) -> anyhow::Result<()> { #[cfg(unix)] let filter_err = |e: &&ProcessExitError| { e.status.signal() == Some(11) || matches!(e.signal, Some(ExitSignal { code: 11, .. })) }; #[cfg(not(unix))] let filter_err = |e: &&ProcessExitError| matches!(e.signal, Some(ExitSignal { code: 11, .. })); let workspace_member = WorkspaceMember { name: member.into(), path: "".into(), // unused }; if let Err(err) = base_commands::test::run_unit_test(&workspace_member, args) { let should_ignore = err .downcast_ref::() .filter(filter_err) // Failed to execute unit test for '{member}' .map(|e| e.message.contains(member)) .unwrap_or(false); if should_ignore { // Ignore intermittent successful failures // https://github.com/gfx-rs/wgpu/issues/2949 // https://github.com/KhronosGroup/Vulkan-ValidationLayers/issues/4391 eprintln!("⚠️ Ignored SIGSEGV in wgpu test"); } else { return Err(err); } } Ok(()) } pub(crate) fn handle_command( mut args: BurnTestCmdArgs, env: Environment, context: Context, ) -> anyhow::Result<()> { match context { Context::NoStd => { ["Default"].iter().try_for_each(|test_target| { let mut test_args = vec!["--no-default-features"]; if *test_target != "Default" { test_args.extend(vec!["--target", *test_target]); } helpers::custom_crates_tests( NO_STD_CRATES.to_vec(), handle_test_args(&test_args, args.release), None, None, "no-std", ) })?; handle_backend_tests(args.clone().try_into().unwrap(), "ndarray", env, context)?; Ok(()) } Context::Std => { // 1) Tests with default features // ------------------------------ match args.ci { CiTestType::GithubRunner => { // Exclude crates that are not supported on CI args.exclude.extend(vec![ "burn-cpu".to_string(), "burn-cuda".to_string(), "burn-rocm".to_string(), // "burn-router" uses "burn-wgpu" for the tests. "burn-router".to_string(), "burn-tch".to_string(), "burn-wgpu".to_string(), // dqn-agent example relies on gym-rs dependency which requires SDL2. // It would be good to remove the gym-rs dependency in the future. "dqn-agent".to_string(), // Requires wgpu runtime "burn-cubecl-fusion".to_string(), ]); // Burn remote tests don't work on windows for now #[cfg(target_os = "windows")] { args.exclude.extend(vec!["burn-remote".to_string()]); }; base_commands::test::handle_command( args.clone().try_into().unwrap(), env.clone(), context.clone(), )?; handle_backend_tests( args.clone().try_into().unwrap(), "ndarray", env, context, )?; } CiTestType::GithubMacRunner => { handle_backend_tests( args.clone().try_into().unwrap(), "metal", env.clone(), context.clone(), )?; args.target = Target::AllPackages; args.only.push("burn-wgpu".to_string()); args.features .get_or_insert_with(Vec::new) .push("metal".to_string()); base_commands::test::handle_command( args.clone().try_into().unwrap(), env, context, )?; } CiTestType::GcpCudaRunner => { handle_backend_tests(args.clone().try_into().unwrap(), "cuda", env, context)?; } CiTestType::GcpVulkanRunner => { handle_backend_tests(args.clone().try_into().unwrap(), "vulkan", env, context)?; args.target = Target::AllPackages; let mut args_vulkan: TestCmdArgs = args.clone().try_into().unwrap(); args_vulkan.features = Some(vec!["test-vulkan".into()]); handle_wgpu_test("burn-core", &args_vulkan)?; handle_wgpu_test("burn-optim", &args_vulkan)?; handle_wgpu_test("burn-nn", &args_vulkan)?; handle_wgpu_test("burn-vision", &args_vulkan)?; } CiTestType::GcpWgpuRunner => { handle_backend_tests(args.clone().try_into().unwrap(), "wgpu", env, context)?; // "burn-router" uses "burn-wgpu" for the tests. args.target = Target::AllPackages; let mut args_wgpu = args.clone().try_into().unwrap(); handle_wgpu_test("burn-wgpu", &args_wgpu)?; handle_wgpu_test("burn-router", &args_wgpu)?; handle_wgpu_test("burn-cubecl-fusion", &args_wgpu)?; args_wgpu.features = Some(vec!["test-wgpu".into()]); handle_wgpu_test("burn-core", &args_wgpu)?; handle_wgpu_test("burn-optim", &args_wgpu)?; handle_wgpu_test("burn-nn", &args_wgpu)?; handle_wgpu_test("burn-vision", &args_wgpu)?; } } // 2) Specific additional commands to test specific features // --------------------------------------------------------- match args.ci { CiTestType::GithubRunner => { // burn-dataset helpers::custom_crates_tests( vec!["burn-dataset"], handle_test_args(&["--all-features"], args.release), None, None, "std all features", )?; // burn-core helpers::custom_crates_tests( vec!["burn-core"], handle_test_args( &["--features", "test-tch,record-item-custom-serde"], args.release, ), None, None, "std with features: test-tch,record-item-custom-serde", )?; // burn-vision helpers::custom_crates_tests( vec!["burn-vision"], handle_test_args(&["--features", "test-cpu"], args.release), None, None, "std cpu", )?; // burn-train vision (LPIPS, DISTS metrics) helpers::custom_crates_tests( vec!["burn-train"], handle_test_args(&["--features", "vision"], args.release), None, None, "std vision", )?; // burn-nn (pretrained and local tests) let mut nn_features = "pretrained".to_string(); // If the "CI" environment variable is missing, we are running locally. if std::env::var("CI").is_err() { nn_features.push_str(",test-local"); } helpers::custom_crates_tests( vec!["burn-nn"], handle_test_args(&["--features", &nn_features], args.release), None, None, &format!("std burn-nn with features: {}", nn_features), )?; } CiTestType::GcpCudaRunner => (), CiTestType::GcpVulkanRunner | CiTestType::GcpWgpuRunner => (), // handled in tests above CiTestType::GithubMacRunner => { // burn-ndarray helpers::custom_crates_tests( vec!["burn-ndarray"], handle_test_args(&["--features", "blas-accelerate"], args.release), None, None, "std blas-accelerate", )?; // burn-train vision (LPIPS, DISTS metrics) helpers::custom_crates_tests( vec!["burn-train"], handle_test_args(&["--features", "vision"], args.release), None, None, "std vision", )?; helpers::custom_crates_tests( vec!["burn-core"], handle_test_args(&["--features", "test-metal"], args.release), None, None, "std metal", )?; helpers::custom_crates_tests( vec!["burn-vision"], handle_test_args(&["--features", "test-metal"], args.release), None, None, "std metal", )?; } } Ok(()) } Context::All => Context::value_variants() .iter() .filter(|ctx| **ctx != Context::All) .try_for_each(|ctx| { handle_command( BurnTestCmdArgs { command: args.command.clone(), target: args.target.clone(), exclude: args.exclude.clone(), only: args.only.clone(), threads: args.threads, jobs: args.jobs, ci: args.ci.clone(), features: args.features.clone(), no_default_features: args.no_default_features, release: args.release, test: args.test.clone(), force: args.force, no_capture: args.no_capture, }, env.clone(), ctx.clone(), ) }), } } fn handle_test_args<'a>(args: &'a [&'a str], release: bool) -> Vec<&'a str> { let mut args = args.to_vec(); if release { args.push("--release"); } args } ================================================ FILE: xtask/src/commands/validate.rs ================================================ use tracel_xtask::prelude::*; use crate::commands::{ build::BurnBuildCmdArgs, test::{BurnTestCmdArgs, CiTestType}, }; pub fn handle_command( args: &ValidateCmdArgs, env: Environment, context: Context, ) -> anyhow::Result<()> { let target = Target::Workspace; let exclude = vec![]; let only = vec![]; if context == Context::NoStd || context == Context::All { // ================= // no-std validation // ================= info!("Run validation for no-std execution environment..."); #[cfg(target_os = "linux")] { // build super::build::handle_command( BurnBuildCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), ci: true, release: args.release, features: args.features.clone(), no_default_features: args.no_default_features, }, env.clone(), Context::NoStd, )?; // tests super::test::handle_command( BurnTestCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), threads: None, jobs: None, command: Some(TestSubCommand::All), ci: CiTestType::GithubRunner, features: None, no_default_features: false, force: false, no_capture: false, release: args.release, test: None, }, env.clone(), Context::NoStd, )?; } } if context == Context::Std || context == Context::All { // ============== // std validation // ============== info!("Run validation for std execution environment..."); // checks [ CheckSubCommand::Audit, CheckSubCommand::Format, CheckSubCommand::Lint, CheckSubCommand::Typos, ] .iter() .try_for_each(|c| { base_commands::check::handle_command( CheckCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), command: Some(c.clone()), ignore_audit: args.ignore_audit, features: args.features.clone(), no_default_features: args.no_default_features, ignore_typos: args.ignore_typos, }, env.clone(), context.clone(), ) })?; // build super::build::handle_command( BurnBuildCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), ci: true, release: args.release, features: args.features.clone(), no_default_features: args.no_default_features, }, env.clone(), Context::Std, )?; // tests super::test::handle_command( BurnTestCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), threads: None, jobs: None, command: Some(TestSubCommand::All), ci: CiTestType::GithubRunner, features: None, no_default_features: false, release: args.release, test: None, force: false, no_capture: false, }, env.clone(), Context::Std, )?; // documentation [DocSubCommand::Build, DocSubCommand::Tests] .iter() .try_for_each(|c| { super::doc::handle_command( DocCmdArgs { target: target.clone(), exclude: exclude.clone(), only: only.clone(), command: Some(c.clone()), features: args.features.clone(), no_default_features: args.no_default_features, }, env.clone(), context.clone(), ) })?; } Ok(()) } ================================================ FILE: xtask/src/main.rs ================================================ mod commands; #[macro_use] extern crate log; use std::time::Instant; use tracel_xtask::prelude::*; // no-std const WASM32_TARGET: &str = "wasm32-unknown-unknown"; const ARM_TARGET: &str = "thumbv7m-none-eabi"; const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi"; const NO_STD_CRATES: &[&str] = &[ "burn", "burn-autodiff", "burn-core", "burn-std", "burn-backend", "burn-tensor", "burn-ndarray", "burn-no-std-tests", ]; #[macros::base_commands( Bump, Check, Compile, Coverage, Doc, Dependencies, Fix, Publish, Validate, Vulnerabilities )] pub enum Command { /// Run commands to manage Burn Books. Books(commands::books::BooksArgs), /// Build Burn in different modes. Build(commands::build::BurnBuildCmdArgs), /// Test Burn. Test(commands::test::BurnTestCmdArgs), } fn main() -> anyhow::Result<()> { let start = Instant::now(); let (args, environment) = init_xtask::(parse_args::()?)?; if args.context == Context::NoStd { // Install additional targets for no-std execution environments rustup_add_target(WASM32_TARGET)?; rustup_add_target(ARM_TARGET)?; rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET)?; } match args.command { Command::Books(cmd_args) => cmd_args.parse(), Command::Build(cmd_args) => { commands::build::handle_command(cmd_args, environment, args.context) } Command::Doc(cmd_args) => { commands::doc::handle_command(cmd_args, environment, args.context) } Command::Test(cmd_args) => { commands::test::handle_command(cmd_args, environment, args.context) } Command::Validate(cmd_args) => { commands::validate::handle_command(&cmd_args, environment, args.context) } _ => dispatch_base_commands(args, environment), }?; let duration = start.elapsed(); info!( "\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m", format_duration(&duration) ); Ok(()) }