gitextract_82tcaglc/ ├── .cargo/ │ ├── audit.toml │ └── config.toml ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── doc_request.md │ │ └── feature_request.md │ ├── PULL_REQUEST_TEMPLATE/ │ │ └── template.md │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ ├── combine-dependabot-prs.yml │ ├── dependencies.yml │ ├── publish.yml │ ├── stale-pr.yml │ ├── test-gpu.yml │ ├── test.yml │ ├── valgrind.yml │ └── vulnerabilities.yml ├── .gitignore ├── CITATION.cff ├── CODE-OF-CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── NOTICES.md ├── POEM.md ├── README.md ├── _typos.toml ├── benchmarks.toml ├── burn-book/ │ ├── .gitignore │ ├── .prettierrc.json │ ├── book.toml │ └── src/ │ ├── SUMMARY.md │ ├── advanced/ │ │ ├── README.md │ │ ├── backend-extension/ │ │ │ ├── README.md │ │ │ ├── custom-cubecl-kernel.md │ │ │ └── custom-wgpu-kernel.md │ │ ├── no-std.md │ │ └── web-assembly.md │ ├── basic-workflow/ │ │ ├── README.md │ │ ├── backend.md │ │ ├── data.md │ │ ├── inference.md │ │ ├── model.md │ │ └── training.md │ ├── building-blocks/ │ │ ├── README.md │ │ ├── autodiff.md │ │ ├── backend.md │ │ ├── config.md │ │ ├── dataset.md │ │ ├── learner.md │ │ ├── metric.md │ │ ├── module.md │ │ ├── record.md │ │ └── tensor.md │ ├── custom-training-loop.md │ ├── distributed-computing.md │ ├── examples.md │ ├── getting-started.md │ ├── models-and-pretrained-weights.md │ ├── motivation.md │ ├── onnx-import.md │ ├── overview.md │ ├── performance/ │ │ ├── README.md │ │ ├── distributed-computing.md │ │ ├── good-practices/ │ │ │ ├── README.md │ │ │ ├── asynchronous-execution.md │ │ │ ├── kernel-fusion.md │ │ │ └── kernel-selection.md │ │ └── quantization.md │ └── saving-and-loading.md ├── codecov.yml ├── contributor-book/ │ ├── .gitignore │ ├── .prettierrc.json │ ├── book.toml │ └── src/ │ ├── SUMMARY.md │ ├── frequently-encountered-issues/ │ │ ├── README.md │ │ └── issues-while-adding-ops.md │ ├── getting-started/ │ │ ├── README.md │ │ ├── configuring-your-editor.md │ │ ├── setting-up-the-environment.md │ │ └── testing.md │ ├── guides/ │ │ ├── README.md │ │ ├── adding-a-new-operation-to-burn.md │ │ └── submitting-examples.md │ ├── how-to-read-this-book.md │ ├── overview.md │ └── project-architecture/ │ ├── README.md │ ├── backend.md │ ├── module.md │ ├── serialization.md │ └── tensor.md ├── crates/ │ ├── burn/ │ │ ├── Cargo.toml │ │ └── src/ │ │ ├── backend.rs │ │ ├── collective.rs │ │ └── lib.rs │ ├── burn-autodiff/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── checkpoint/ │ │ │ ├── base.rs │ │ │ ├── builder.rs │ │ │ ├── mod.rs │ │ │ ├── retro_forward.rs │ │ │ ├── state.rs │ │ │ └── strategy.rs │ │ ├── grads.rs │ │ ├── graph/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── node.rs │ │ │ ├── requirement.rs │ │ │ └── traversal.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── backward.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── maxmin.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── sort.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── runtime/ │ │ │ ├── client.rs │ │ │ ├── graph.rs │ │ │ ├── memory_management.rs │ │ │ ├── mod.rs │ │ │ └── server.rs │ │ ├── tensor.rs │ │ └── utils.rs │ ├── burn-backend/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend/ │ │ │ ├── base.rs │ │ │ ├── device.rs │ │ │ ├── mod.rs │ │ │ ├── ops/ │ │ │ │ ├── activation.rs │ │ │ │ ├── argwhere.rs │ │ │ │ ├── bool_tensor.rs │ │ │ │ ├── cat.rs │ │ │ │ ├── int_tensor.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── modules/ │ │ │ │ │ ├── attention.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── conv.rs │ │ │ │ │ ├── grid_sample.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── pool.rs │ │ │ │ │ └── unfold.rs │ │ │ │ ├── qtensor.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── sort.rs │ │ │ │ ├── tensor.rs │ │ │ │ └── transaction.rs │ │ │ └── primitive.rs │ │ ├── data/ │ │ │ ├── compare.rs │ │ │ ├── mod.rs │ │ │ └── tensor.rs │ │ ├── distribution.rs │ │ ├── element/ │ │ │ ├── base.rs │ │ │ ├── cast.rs │ │ │ ├── mod.rs │ │ │ └── scalar.rs │ │ ├── lib.rs │ │ └── tensor/ │ │ ├── alias.rs │ │ ├── container.rs │ │ ├── kind.rs │ │ ├── mod.rs │ │ ├── ops/ │ │ │ ├── autodiff.rs │ │ │ ├── base.rs │ │ │ ├── bool.rs │ │ │ ├── float.rs │ │ │ ├── int.rs │ │ │ ├── mod.rs │ │ │ ├── numeric.rs │ │ │ └── ordered.rs │ │ └── quantization/ │ │ ├── calibration.rs │ │ ├── mod.rs │ │ ├── parameters.rs │ │ └── scheme.rs │ ├── burn-backend-tests/ │ │ ├── .cargo/ │ │ │ └── config.toml │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── src/ │ │ │ └── lib.rs │ │ └── tests/ │ │ ├── autodiff/ │ │ │ ├── abs.rs │ │ │ ├── adaptive_avgpool1d.rs │ │ │ ├── adaptive_avgpool2d.rs │ │ │ ├── add.rs │ │ │ ├── aggregation.rs │ │ │ ├── avgpool1d.rs │ │ │ ├── avgpool2d.rs │ │ │ ├── backward.rs │ │ │ ├── bridge.rs │ │ │ ├── broadcast.rs │ │ │ ├── cast.rs │ │ │ ├── cat.rs │ │ │ ├── ceil.rs │ │ │ ├── checkpoint.rs │ │ │ ├── complex.rs │ │ │ ├── conv1d.rs │ │ │ ├── conv2d.rs │ │ │ ├── conv3d.rs │ │ │ ├── conv_transpose1d.rs │ │ │ ├── conv_transpose2d.rs │ │ │ ├── conv_transpose3d.rs │ │ │ ├── cross.rs │ │ │ ├── cross_entropy.rs │ │ │ ├── cummax.rs │ │ │ ├── cummin.rs │ │ │ ├── cumprod.rs │ │ │ ├── cumsum.rs │ │ │ ├── deform_conv2d.rs │ │ │ ├── div.rs │ │ │ ├── erf.rs │ │ │ ├── exp.rs │ │ │ ├── expand.rs │ │ │ ├── flip.rs │ │ │ ├── floor.rs │ │ │ ├── gather_scatter.rs │ │ │ ├── gelu.rs │ │ │ ├── gradients.rs │ │ │ ├── log.rs │ │ │ ├── log1p.rs │ │ │ ├── log_sigmoid.rs │ │ │ ├── mask.rs │ │ │ ├── matmul.rs │ │ │ ├── maxmin.rs │ │ │ ├── maxpool1d.rs │ │ │ ├── maxpool2d.rs │ │ │ ├── memory_management.rs │ │ │ ├── mod.rs │ │ │ ├── mul.rs │ │ │ ├── multithread.rs │ │ │ ├── nearest_interpolate.rs │ │ │ ├── neg.rs │ │ │ ├── nonzero.rs │ │ │ ├── permute.rs │ │ │ ├── pow.rs │ │ │ ├── recip.rs │ │ │ ├── relu.rs │ │ │ ├── remainder.rs │ │ │ ├── repeat_dim.rs │ │ │ ├── reshape.rs │ │ │ ├── round.rs │ │ │ ├── select.rs │ │ │ ├── sigmoid.rs │ │ │ ├── sign.rs │ │ │ ├── slice.rs │ │ │ ├── slice_assign.rs │ │ │ ├── softmax.rs │ │ │ ├── sort.rs │ │ │ ├── sqrt.rs │ │ │ ├── sub.rs │ │ │ ├── transpose.rs │ │ │ ├── trig.rs │ │ │ └── unfold.rs │ │ ├── autodiff.rs │ │ ├── common/ │ │ │ ├── autodiff.rs │ │ │ ├── backend.rs │ │ │ └── tensor.rs │ │ ├── cubecl/ │ │ │ ├── avg_pool2d.rs │ │ │ ├── bernoulli.rs │ │ │ ├── cast.rs │ │ │ ├── cat.rs │ │ │ ├── clamp.rs │ │ │ ├── contiguous.rs │ │ │ ├── conv2d.rs │ │ │ ├── conv3d.rs │ │ │ ├── conv_transpose2d.rs │ │ │ ├── conv_transpose3d.rs │ │ │ ├── cross.rs │ │ │ ├── gather.rs │ │ │ ├── mask_fill.rs │ │ │ ├── mask_where.rs │ │ │ ├── max_pool2d.rs │ │ │ ├── max_pool2d_backward.rs │ │ │ ├── mod.rs │ │ │ ├── normal.rs │ │ │ ├── quantization.rs │ │ │ ├── reduce.rs │ │ │ ├── repeat_dim.rs │ │ │ ├── scatter.rs │ │ │ ├── select.rs │ │ │ ├── select_assign.rs │ │ │ ├── slice.rs │ │ │ ├── slice_assign.rs │ │ │ ├── unary.rs │ │ │ └── uniform.rs │ │ ├── cubecl.rs │ │ ├── fused_ops/ │ │ │ ├── mod.rs │ │ │ └── reduce_broadcasted.rs │ │ ├── fusion.rs │ │ ├── tensor/ │ │ │ ├── bool/ │ │ │ │ ├── mod.rs │ │ │ │ └── ops/ │ │ │ │ ├── all.rs │ │ │ │ ├── any.rs │ │ │ │ ├── argwhere_nonzero.rs │ │ │ │ ├── cat.rs │ │ │ │ ├── comparison.rs │ │ │ │ ├── create_like.rs │ │ │ │ ├── expand.rs │ │ │ │ ├── flip.rs │ │ │ │ ├── full.rs │ │ │ │ ├── gather_scatter.rs │ │ │ │ ├── init.rs │ │ │ │ ├── logical.rs │ │ │ │ ├── mask.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── movedim.rs │ │ │ │ ├── permute.rs │ │ │ │ ├── repeat.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── reshape.rs │ │ │ │ ├── select.rs │ │ │ │ ├── stack.rs │ │ │ │ ├── take.rs │ │ │ │ ├── transpose.rs │ │ │ │ ├── tri_mask.rs │ │ │ │ └── unfold.rs │ │ │ ├── clone_invariance.rs │ │ │ ├── float/ │ │ │ │ ├── activation/ │ │ │ │ │ ├── celu.rs │ │ │ │ │ ├── elu.rs │ │ │ │ │ ├── gelu.rs │ │ │ │ │ ├── glu.rs │ │ │ │ │ ├── hard_sigmoid.rs │ │ │ │ │ ├── leaky_relu.rs │ │ │ │ │ ├── log_sigmoid.rs │ │ │ │ │ ├── mish.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── prelu.rs │ │ │ │ │ ├── quiet_softmax.rs │ │ │ │ │ ├── relu.rs │ │ │ │ │ ├── selu.rs │ │ │ │ │ ├── sigmoid.rs │ │ │ │ │ ├── silu.rs │ │ │ │ │ ├── softmax.rs │ │ │ │ │ ├── softmin.rs │ │ │ │ │ ├── softplus.rs │ │ │ │ │ ├── softsign.rs │ │ │ │ │ ├── tanh_activation.rs │ │ │ │ │ └── thresholded_relu.rs │ │ │ │ ├── grid/ │ │ │ │ │ ├── affine_grid.rs │ │ │ │ │ ├── meshgrid.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── linalg/ │ │ │ │ │ ├── cosine_similarity.rs │ │ │ │ │ ├── diag.rs │ │ │ │ │ ├── lu_decomposition.rs │ │ │ │ │ ├── matvec.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── outer.rs │ │ │ │ │ ├── trace.rs │ │ │ │ │ └── vector_norm.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── module/ │ │ │ │ │ ├── adaptive_avgpool1d.rs │ │ │ │ │ ├── adaptive_avgpool2d.rs │ │ │ │ │ ├── attention.rs │ │ │ │ │ ├── avgpool1d.rs │ │ │ │ │ ├── avgpool2d.rs │ │ │ │ │ ├── bicubic_interpolate.rs │ │ │ │ │ ├── bilinear_interpolate.rs │ │ │ │ │ ├── conv1d.rs │ │ │ │ │ ├── conv2d.rs │ │ │ │ │ ├── conv3d.rs │ │ │ │ │ ├── conv_transpose1d.rs │ │ │ │ │ ├── conv_transpose2d.rs │ │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ │ ├── deform_conv2d.rs │ │ │ │ │ ├── forward.rs │ │ │ │ │ ├── lanczos3_interpolate.rs │ │ │ │ │ ├── linear.rs │ │ │ │ │ ├── maxpool1d.rs │ │ │ │ │ ├── maxpool2d.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── nearest_interpolate.rs │ │ │ │ │ └── unfold4d.rs │ │ │ │ ├── ops/ │ │ │ │ │ ├── abs.rs │ │ │ │ │ ├── add.rs │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ ├── all.rs │ │ │ │ │ ├── any.rs │ │ │ │ │ ├── arg.rs │ │ │ │ │ ├── cast.rs │ │ │ │ │ ├── cat.rs │ │ │ │ │ ├── ceil.rs │ │ │ │ │ ├── chunk.rs │ │ │ │ │ ├── clamp.rs │ │ │ │ │ ├── close.rs │ │ │ │ │ ├── comparison.rs │ │ │ │ │ ├── create_like.rs │ │ │ │ │ ├── cross.rs │ │ │ │ │ ├── cumulative.rs │ │ │ │ │ ├── div.rs │ │ │ │ │ ├── dot.rs │ │ │ │ │ ├── erf.rs │ │ │ │ │ ├── exp.rs │ │ │ │ │ ├── expand.rs │ │ │ │ │ ├── finite.rs │ │ │ │ │ ├── flatten.rs │ │ │ │ │ ├── flip.rs │ │ │ │ │ ├── floor.rs │ │ │ │ │ ├── fmod.rs │ │ │ │ │ ├── full.rs │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ ├── grid_sample.rs │ │ │ │ │ ├── inf.rs │ │ │ │ │ ├── init.rs │ │ │ │ │ ├── iter_dim.rs │ │ │ │ │ ├── log.rs │ │ │ │ │ ├── log1p.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── matmul.rs │ │ │ │ │ ├── maxmin.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── movedim.rs │ │ │ │ │ ├── mul.rs │ │ │ │ │ ├── nan.rs │ │ │ │ │ ├── narrow.rs │ │ │ │ │ ├── neg.rs │ │ │ │ │ ├── one_hot.rs │ │ │ │ │ ├── padding.rs │ │ │ │ │ ├── permute.rs │ │ │ │ │ ├── powf.rs │ │ │ │ │ ├── powf_scalar.rs │ │ │ │ │ ├── prod.rs │ │ │ │ │ ├── random.rs │ │ │ │ │ ├── recip.rs │ │ │ │ │ ├── remainder.rs │ │ │ │ │ ├── repeat.rs │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ ├── reshape.rs │ │ │ │ │ ├── round.rs │ │ │ │ │ ├── select.rs │ │ │ │ │ ├── sign.rs │ │ │ │ │ ├── slice.rs │ │ │ │ │ ├── slice_assign.rs │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ ├── split.rs │ │ │ │ │ ├── sqrt.rs │ │ │ │ │ ├── square.rs │ │ │ │ │ ├── squeeze.rs │ │ │ │ │ ├── stack.rs │ │ │ │ │ ├── sub.rs │ │ │ │ │ ├── take.rs │ │ │ │ │ ├── topk.rs │ │ │ │ │ ├── transaction.rs │ │ │ │ │ ├── transpose.rs │ │ │ │ │ ├── tri.rs │ │ │ │ │ ├── trig.rs │ │ │ │ │ ├── trunc.rs │ │ │ │ │ └── unfold.rs │ │ │ │ ├── primitive.rs │ │ │ │ ├── quantization/ │ │ │ │ │ ├── calibration.rs │ │ │ │ │ ├── data.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── ops/ │ │ │ │ │ │ ├── extended/ │ │ │ │ │ │ │ ├── abs.rs │ │ │ │ │ │ │ ├── add.rs │ │ │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ │ │ ├── all.rs │ │ │ │ │ │ │ ├── any.rs │ │ │ │ │ │ │ ├── arg.rs │ │ │ │ │ │ │ ├── cat.rs │ │ │ │ │ │ │ ├── ceil.rs │ │ │ │ │ │ │ ├── chunk.rs │ │ │ │ │ │ │ ├── clamp.rs │ │ │ │ │ │ │ ├── cos.rs │ │ │ │ │ │ │ ├── cosh.rs │ │ │ │ │ │ │ ├── div.rs │ │ │ │ │ │ │ ├── erf.rs │ │ │ │ │ │ │ ├── exp.rs │ │ │ │ │ │ │ ├── expand.rs │ │ │ │ │ │ │ ├── flip.rs │ │ │ │ │ │ │ ├── floor.rs │ │ │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ │ │ ├── log.rs │ │ │ │ │ │ │ ├── log1p.rs │ │ │ │ │ │ │ ├── map_comparison.rs │ │ │ │ │ │ │ ├── mask.rs │ │ │ │ │ │ │ ├── maxmin.rs │ │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ │ ├── mul.rs │ │ │ │ │ │ │ ├── narrow.rs │ │ │ │ │ │ │ ├── neg.rs │ │ │ │ │ │ │ ├── permute.rs │ │ │ │ │ │ │ ├── powf.rs │ │ │ │ │ │ │ ├── powf_scalar.rs │ │ │ │ │ │ │ ├── recip.rs │ │ │ │ │ │ │ ├── remainder.rs │ │ │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ │ │ ├── reshape.rs │ │ │ │ │ │ │ ├── round.rs │ │ │ │ │ │ │ ├── select.rs │ │ │ │ │ │ │ ├── sin.rs │ │ │ │ │ │ │ ├── sinh.rs │ │ │ │ │ │ │ ├── slice.rs │ │ │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ │ │ ├── split.rs │ │ │ │ │ │ │ ├── sqrt.rs │ │ │ │ │ │ │ ├── stack.rs │ │ │ │ │ │ │ ├── sub.rs │ │ │ │ │ │ │ ├── tan.rs │ │ │ │ │ │ │ ├── tanh.rs │ │ │ │ │ │ │ ├── topk.rs │ │ │ │ │ │ │ └── transpose.rs │ │ │ │ │ │ ├── matmul.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ └── quantize.rs │ │ │ │ │ └── scheme.rs │ │ │ │ └── stats/ │ │ │ │ ├── cov.rs │ │ │ │ ├── display.rs │ │ │ │ ├── eye.rs │ │ │ │ ├── median.rs │ │ │ │ ├── mod.rs │ │ │ │ └── var.rs │ │ │ ├── int/ │ │ │ │ ├── mod.rs │ │ │ │ ├── ops/ │ │ │ │ │ ├── abs.rs │ │ │ │ │ ├── add.rs │ │ │ │ │ ├── aggregation.rs │ │ │ │ │ ├── all.rs │ │ │ │ │ ├── any.rs │ │ │ │ │ ├── arange.rs │ │ │ │ │ ├── arange_step.rs │ │ │ │ │ ├── arg.rs │ │ │ │ │ ├── bitwise.rs │ │ │ │ │ ├── cartesian_grid.rs │ │ │ │ │ ├── cast.rs │ │ │ │ │ ├── cat.rs │ │ │ │ │ ├── chunk.rs │ │ │ │ │ ├── comparison.rs │ │ │ │ │ ├── create_like.rs │ │ │ │ │ ├── cumulative.rs │ │ │ │ │ ├── div.rs │ │ │ │ │ ├── expand.rs │ │ │ │ │ ├── flip.rs │ │ │ │ │ ├── full.rs │ │ │ │ │ ├── gather_scatter.rs │ │ │ │ │ ├── init.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── matmul.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── movedim.rs │ │ │ │ │ ├── mul.rs │ │ │ │ │ ├── one_hot.rs │ │ │ │ │ ├── permute.rs │ │ │ │ │ ├── random.rs │ │ │ │ │ ├── remainder.rs │ │ │ │ │ ├── repeat.rs │ │ │ │ │ ├── repeat_dim.rs │ │ │ │ │ ├── reshape.rs │ │ │ │ │ ├── roll.rs │ │ │ │ │ ├── select.rs │ │ │ │ │ ├── sign.rs │ │ │ │ │ ├── slice.rs │ │ │ │ │ ├── slice_assign.rs │ │ │ │ │ ├── sort_argsort.rs │ │ │ │ │ ├── stack.rs │ │ │ │ │ ├── sub.rs │ │ │ │ │ ├── take.rs │ │ │ │ │ ├── topk.rs │ │ │ │ │ ├── transpose.rs │ │ │ │ │ ├── tri.rs │ │ │ │ │ └── unfold.rs │ │ │ │ └── primitive.rs │ │ │ ├── mod.rs │ │ │ └── multi_threads.rs │ │ └── tensor.rs │ ├── burn-candle/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── candle_utils.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── utils.rs │ │ └── tensor.rs │ ├── burn-collective/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── multinode-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── README.md │ │ │ └── src/ │ │ │ ├── bin/ │ │ │ │ ├── global.rs │ │ │ │ ├── node.rs │ │ │ │ └── test_launcher.rs │ │ │ ├── lib.rs │ │ │ └── shared.rs │ │ └── src/ │ │ ├── api.rs │ │ ├── config.rs │ │ ├── global/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── node/ │ │ │ │ ├── base.rs │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ring.rs │ │ │ │ ├── sync.rs │ │ │ │ ├── tree.rs │ │ │ │ └── worker.rs │ │ │ ├── orchestrator/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── state.rs │ │ │ └── shared.rs │ │ ├── lib.rs │ │ ├── local/ │ │ │ ├── all_reduce/ │ │ │ │ ├── base.rs │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ ├── ring.rs │ │ │ │ └── tree.rs │ │ │ ├── broadcast/ │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ └── tree.rs │ │ │ ├── client.rs │ │ │ ├── mod.rs │ │ │ ├── reduce/ │ │ │ │ ├── centralized.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── op.rs │ │ │ │ └── tree.rs │ │ │ ├── server.rs │ │ │ └── tensor_map.rs │ │ └── tests/ │ │ ├── all_reduce.rs │ │ ├── broadcast.rs │ │ ├── mod.rs │ │ └── reduce.rs │ ├── burn-communication/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── base.rs │ │ ├── data_service.rs │ │ ├── lib.rs │ │ ├── util.rs │ │ └── websocket/ │ │ ├── base.rs │ │ ├── client.rs │ │ ├── mod.rs │ │ └── server.rs │ ├── burn-core/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── config.rs │ │ │ ├── data/ │ │ │ │ ├── dataloader/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── batch.rs │ │ │ │ │ ├── batcher.rs │ │ │ │ │ ├── builder.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── multithread.rs │ │ │ │ │ ├── split.rs │ │ │ │ │ └── strategy.rs │ │ │ │ └── mod.rs │ │ │ ├── lib.rs │ │ │ ├── module/ │ │ │ │ ├── base.rs │ │ │ │ ├── display.rs │ │ │ │ ├── initializer.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── param/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── constant.rs │ │ │ │ │ ├── id.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── primitive.rs │ │ │ │ │ ├── running.rs │ │ │ │ │ ├── tensor.rs │ │ │ │ │ └── visitor.rs │ │ │ │ ├── quantize.rs │ │ │ │ └── reinit.rs │ │ │ ├── record/ │ │ │ │ ├── base.rs │ │ │ │ ├── file.rs │ │ │ │ ├── memory.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── primitive.rs │ │ │ │ ├── recorder.rs │ │ │ │ ├── serde/ │ │ │ │ │ ├── adapter.rs │ │ │ │ │ ├── data.rs │ │ │ │ │ ├── de.rs │ │ │ │ │ ├── error.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── ser.rs │ │ │ │ ├── settings.rs │ │ │ │ └── tensor.rs │ │ │ ├── tensor.rs │ │ │ └── vision.rs │ │ └── tests/ │ │ ├── test_derive_config.rs │ │ ├── test_derive_module.rs │ │ ├── test_derive_record.rs │ │ └── test_record_resilience.rs │ ├── burn-cpu/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-cubecl/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── fusion.rs │ │ ├── kernel/ │ │ │ ├── attention/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune.rs │ │ │ ├── binary.rs │ │ │ ├── binary_float.rs │ │ │ ├── binary_int.rs │ │ │ ├── cast/ │ │ │ │ ├── base.rs │ │ │ │ ├── bool_cast.rs │ │ │ │ └── mod.rs │ │ │ ├── clamp.rs │ │ │ ├── comparison.rs │ │ │ ├── contiguous.rs │ │ │ ├── conv/ │ │ │ │ ├── backward_data/ │ │ │ │ │ ├── fallback.rs │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── backward_weight/ │ │ │ │ │ ├── fallback.rs │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── base.rs │ │ │ │ ├── conv_transpose2d/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── col2im.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── transpose_direct.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ ├── deform_conv2d.rs │ │ │ │ ├── deform_conv_transpose2d.rs │ │ │ │ ├── direct.rs │ │ │ │ ├── forward/ │ │ │ │ │ ├── implicit_gemm/ │ │ │ │ │ │ ├── launch.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── tune.rs │ │ │ │ ├── im2col.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune_key.rs │ │ │ ├── cross.rs │ │ │ ├── grid_sample/ │ │ │ │ ├── base.rs │ │ │ │ ├── bilinear.rs │ │ │ │ └── mod.rs │ │ │ ├── index/ │ │ │ │ ├── flip.rs │ │ │ │ ├── gather.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── repeat_dim.rs │ │ │ │ ├── scatter.rs │ │ │ │ ├── select.rs │ │ │ │ ├── select_assign.rs │ │ │ │ ├── slice.rs │ │ │ │ └── slice_assign.rs │ │ │ ├── interpolate/ │ │ │ │ ├── base.rs │ │ │ │ ├── bicubic.rs │ │ │ │ ├── bilinear.rs │ │ │ │ ├── lanczos3.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── nearest.rs │ │ │ │ └── nearest_backward.rs │ │ │ ├── mask/ │ │ │ │ ├── base.rs │ │ │ │ ├── mask_fill.rs │ │ │ │ ├── mask_where.rs │ │ │ │ └── mod.rs │ │ │ ├── matmul/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── tune/ │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ └── utils.rs │ │ │ ├── mod.rs │ │ │ ├── pool/ │ │ │ │ ├── adaptive_avg_pool2d.rs │ │ │ │ ├── adaptive_avg_pool2d_backward.rs │ │ │ │ ├── avg_pool2d.rs │ │ │ │ ├── avg_pool2d_backward.rs │ │ │ │ ├── max_pool2d.rs │ │ │ │ ├── max_pool2d_backward.rs │ │ │ │ ├── mod.rs │ │ │ │ └── pool2d.rs │ │ │ ├── prng/ │ │ │ │ ├── bernoulli.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── normal.rs │ │ │ │ └── uniform.rs │ │ │ ├── quantization/ │ │ │ │ ├── dequantize.rs │ │ │ │ ├── mod.rs │ │ │ │ └── quantize.rs │ │ │ ├── reduce/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tune.rs │ │ │ ├── unary_float.rs │ │ │ ├── unary_int.rs │ │ │ ├── unary_numeric.rs │ │ │ └── utils.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── numeric.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── template/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── source.rs │ │ ├── tensor/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── quantization.rs │ │ └── tune_key.rs │ ├── burn-cubecl-fusion/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── base.rs │ │ ├── engine/ │ │ │ ├── codegen/ │ │ │ │ ├── base.rs │ │ │ │ ├── io.rs │ │ │ │ ├── ir.rs │ │ │ │ ├── kernel.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── tensor.rs │ │ │ │ └── view.rs │ │ │ ├── fuser.rs │ │ │ ├── launch/ │ │ │ │ ├── base.rs │ │ │ │ ├── executor.rs │ │ │ │ ├── input.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── output.rs │ │ │ │ ├── plan.rs │ │ │ │ ├── runner.rs │ │ │ │ └── vectorization/ │ │ │ │ ├── base.rs │ │ │ │ ├── mod.rs │ │ │ │ └── planner.rs │ │ │ ├── mod.rs │ │ │ ├── scoring.rs │ │ │ ├── settings.rs │ │ │ └── trace/ │ │ │ ├── base.rs │ │ │ ├── block.rs │ │ │ ├── fuser.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── optim/ │ │ │ ├── base.rs │ │ │ ├── elemwise/ │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ └── optimization.rs │ │ │ ├── matmul/ │ │ │ │ ├── args.rs │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── optimization.rs │ │ │ │ └── tune.rs │ │ │ ├── mod.rs │ │ │ ├── reduce/ │ │ │ │ ├── args.rs │ │ │ │ ├── fuser.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── optimization.rs │ │ │ │ └── tune.rs │ │ │ └── reduce_broadcasted/ │ │ │ ├── fuser/ │ │ │ │ ├── base.rs │ │ │ │ ├── block.rs │ │ │ │ ├── full.rs │ │ │ │ ├── full_analyzer.rs │ │ │ │ └── mod.rs │ │ │ ├── launch.rs │ │ │ ├── mod.rs │ │ │ ├── optimization.rs │ │ │ ├── tune.rs │ │ │ └── unit.rs │ │ └── tune.rs │ ├── burn-cuda/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-dataset/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── hf_dataset.rs │ │ │ └── speech_commands.rs │ │ ├── src/ │ │ │ ├── audio/ │ │ │ │ ├── mod.rs │ │ │ │ └── speech_commands.rs │ │ │ ├── dataset/ │ │ │ │ ├── base.rs │ │ │ │ ├── dataframe.rs │ │ │ │ ├── fake.rs │ │ │ │ ├── in_memory.rs │ │ │ │ ├── iterator.rs │ │ │ │ ├── mod.rs │ │ │ │ └── sqlite.rs │ │ │ ├── lib.rs │ │ │ ├── nlp/ │ │ │ │ ├── ag_news.rs │ │ │ │ ├── mod.rs │ │ │ │ └── text_folder.rs │ │ │ ├── source/ │ │ │ │ ├── huggingface/ │ │ │ │ │ ├── downloader.rs │ │ │ │ │ ├── importer.py │ │ │ │ │ └── mod.rs │ │ │ │ └── mod.rs │ │ │ ├── transform/ │ │ │ │ ├── composed.rs │ │ │ │ ├── mapper.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── options.rs │ │ │ │ ├── partial.rs │ │ │ │ ├── sampler.rs │ │ │ │ ├── selection.rs │ │ │ │ ├── shuffle.rs │ │ │ │ └── window.rs │ │ │ └── vision/ │ │ │ ├── cifar.rs │ │ │ ├── image_folder.rs │ │ │ ├── mnist.rs │ │ │ └── mod.rs │ │ └── tests/ │ │ └── data/ │ │ ├── dataset-fmt.csv │ │ ├── dataset.csv │ │ ├── dataset.json │ │ ├── dataset_coco.json │ │ ├── segmask_folder/ │ │ │ └── annotations/ │ │ │ ├── mask_checkerboard.txt │ │ │ ├── mask_random_2colors.txt │ │ │ └── mask_random_3colors.txt │ │ └── text_folder/ │ │ ├── negative/ │ │ │ ├── sample1.txt │ │ │ └── sample2.txt │ │ └── positive/ │ │ ├── sample1.txt │ │ └── sample2.txt │ ├── burn-derive/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── config/ │ │ │ ├── analyzer.rs │ │ │ ├── analyzer_enum.rs │ │ │ ├── analyzer_struct.rs │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── module/ │ │ │ ├── base.rs │ │ │ ├── codegen.rs │ │ │ ├── codegen_enum.rs │ │ │ ├── codegen_struct.rs │ │ │ ├── display.rs │ │ │ ├── generics.rs │ │ │ ├── mod.rs │ │ │ ├── record.rs │ │ │ ├── record_enum.rs │ │ │ └── record_struct.rs │ │ ├── record/ │ │ │ ├── base.rs │ │ │ ├── codegen.rs │ │ │ ├── item/ │ │ │ │ ├── codegen.rs │ │ │ │ ├── codegen_enum.rs │ │ │ │ ├── codegen_struct.rs │ │ │ │ └── mod.rs │ │ │ └── mod.rs │ │ └── shared/ │ │ ├── attribute.rs │ │ ├── enum_variant.rs │ │ ├── field.rs │ │ ├── generics.rs │ │ └── mod.rs │ ├── burn-dispatch/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── device.rs │ │ ├── lib.rs │ │ ├── macros.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ └── tensor.rs │ ├── burn-fusion/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── client.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── binary.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── unary.rs │ │ ├── search/ │ │ │ ├── block.rs │ │ │ ├── merging.rs │ │ │ ├── mod.rs │ │ │ └── optimization/ │ │ │ ├── blocks.rs │ │ │ ├── mod.rs │ │ │ └── stream.rs │ │ ├── server.rs │ │ ├── stream/ │ │ │ ├── base.rs │ │ │ ├── context.rs │ │ │ ├── execution/ │ │ │ │ ├── base.rs │ │ │ │ ├── explorer.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ordering.rs │ │ │ │ ├── policy.rs │ │ │ │ ├── processor.rs │ │ │ │ ├── tests.rs │ │ │ │ └── validator.rs │ │ │ ├── memory_checks.rs │ │ │ ├── mod.rs │ │ │ ├── multi.rs │ │ │ ├── queue/ │ │ │ │ ├── base.rs │ │ │ │ ├── execution.rs │ │ │ │ └── mod.rs │ │ │ ├── shared_tensors.rs │ │ │ └── store/ │ │ │ ├── base.rs │ │ │ ├── index.rs │ │ │ └── mod.rs │ │ └── tensor.rs │ ├── burn-ir/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── builder.rs │ │ ├── handle.rs │ │ ├── lib.rs │ │ ├── operation.rs │ │ ├── scalar.rs │ │ └── tensor.rs │ ├── burn-ndarray/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── adaptive_avgpool.rs │ │ │ ├── avgpool.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── conv.rs │ │ │ ├── deform_conv.rs │ │ │ ├── grid_sample.rs │ │ │ ├── int_tensor.rs │ │ │ ├── interpolate.rs │ │ │ ├── macros.rs │ │ │ ├── matmul.rs │ │ │ ├── maxpool.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── padding.rs │ │ │ ├── qtensor.rs │ │ │ ├── quantization.rs │ │ │ ├── simd/ │ │ │ │ ├── avgpool.rs │ │ │ │ ├── base.rs │ │ │ │ ├── binary.rs │ │ │ │ ├── binary_elemwise.rs │ │ │ │ ├── cmp.rs │ │ │ │ ├── conv.rs │ │ │ │ ├── maxpool.rs │ │ │ │ ├── mod.rs │ │ │ │ └── unary.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ ├── parallel.rs │ │ ├── rand.rs │ │ ├── sharing.rs │ │ ├── storage.rs │ │ └── tensor.rs │ ├── burn-nn/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── activation/ │ │ │ │ ├── activation_wrapper.rs │ │ │ │ ├── celu.rs │ │ │ │ ├── elu.rs │ │ │ │ ├── gelu.rs │ │ │ │ ├── glu.rs │ │ │ │ ├── hard_shrink.rs │ │ │ │ ├── hard_sigmoid.rs │ │ │ │ ├── hard_swish.rs │ │ │ │ ├── leaky_relu.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── prelu.rs │ │ │ │ ├── relu.rs │ │ │ │ ├── selu.rs │ │ │ │ ├── shrink.rs │ │ │ │ ├── sigmoid.rs │ │ │ │ ├── soft_shrink.rs │ │ │ │ ├── softplus.rs │ │ │ │ ├── softsign.rs │ │ │ │ ├── swiglu.rs │ │ │ │ ├── tanh.rs │ │ │ │ └── thresholded_relu.rs │ │ │ ├── lib.rs │ │ │ ├── loss/ │ │ │ │ ├── binary_cross_entropy.rs │ │ │ │ ├── cosine_embedding.rs │ │ │ │ ├── cross_entropy.rs │ │ │ │ ├── ctc.rs │ │ │ │ ├── huber.rs │ │ │ │ ├── kldiv.rs │ │ │ │ ├── lp_loss.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── mse.rs │ │ │ │ ├── poisson.rs │ │ │ │ ├── pretrained/ │ │ │ │ │ ├── gram_matrix/ │ │ │ │ │ │ ├── gram_matrix_loss.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ ├── vgg19.rs │ │ │ │ │ │ └── weights.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── reduction.rs │ │ │ │ ├── rnnt.rs │ │ │ │ └── smooth_l1.rs │ │ │ ├── modules/ │ │ │ │ ├── attention/ │ │ │ │ │ ├── cross_attention.rs │ │ │ │ │ ├── mask.rs │ │ │ │ │ ├── mha.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── cache/ │ │ │ │ │ ├── autoregressive.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── conv/ │ │ │ │ │ ├── checks.rs │ │ │ │ │ ├── conv1d.rs │ │ │ │ │ ├── conv2d.rs │ │ │ │ │ ├── conv3d.rs │ │ │ │ │ ├── conv_transpose1d.rs │ │ │ │ │ ├── conv_transpose2d.rs │ │ │ │ │ ├── conv_transpose3d.rs │ │ │ │ │ ├── deform_conv2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── dropout.rs │ │ │ │ ├── embedding.rs │ │ │ │ ├── interpolate/ │ │ │ │ │ ├── interpolate1d.rs │ │ │ │ │ ├── interpolate2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── linear.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── noise.rs │ │ │ │ ├── norm/ │ │ │ │ │ ├── batch.rs │ │ │ │ │ ├── group.rs │ │ │ │ │ ├── instance.rs │ │ │ │ │ ├── layer.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── normalization_wrapper.rs │ │ │ │ │ └── rms.rs │ │ │ │ ├── pool/ │ │ │ │ │ ├── adaptive_avg_pool1d.rs │ │ │ │ │ ├── adaptive_avg_pool2d.rs │ │ │ │ │ ├── avg_pool1d.rs │ │ │ │ │ ├── avg_pool2d.rs │ │ │ │ │ ├── max_pool1d.rs │ │ │ │ │ ├── max_pool2d.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── pos_encoding.rs │ │ │ │ ├── rnn/ │ │ │ │ │ ├── basic.rs │ │ │ │ │ ├── gate_controller.rs │ │ │ │ │ ├── gru.rs │ │ │ │ │ ├── lstm.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── rope_encoding.rs │ │ │ │ ├── transformer/ │ │ │ │ │ ├── decoder.rs │ │ │ │ │ ├── encoder.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── pwff.rs │ │ │ │ └── unfold.rs │ │ │ └── padding.rs │ │ └── tests/ │ │ └── quantize.rs │ ├── burn-no-std-tests/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── burnpack.rs │ │ │ ├── conv.rs │ │ │ ├── lib.rs │ │ │ ├── mlp.rs │ │ │ ├── model.rs │ │ │ └── safetensors.rs │ │ └── tests/ │ │ ├── burnpack_tests.rs │ │ ├── safetensors_tests.rs │ │ └── test_integration.rs │ ├── burn-optim/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── grad_clipping/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── lr_scheduler/ │ │ │ ├── base.rs │ │ │ ├── composed.rs │ │ │ ├── constant.rs │ │ │ ├── cosine.rs │ │ │ ├── exponential.rs │ │ │ ├── linear.rs │ │ │ ├── mod.rs │ │ │ ├── noam.rs │ │ │ └── step.rs │ │ └── optim/ │ │ ├── adagrad.rs │ │ ├── adam.rs │ │ ├── adamw.rs │ │ ├── base.rs │ │ ├── decay.rs │ │ ├── grad_accum.rs │ │ ├── grads.rs │ │ ├── lbfgs.rs │ │ ├── mod.rs │ │ ├── momentum.rs │ │ ├── muon.rs │ │ ├── rmsprop.rs │ │ ├── sgd.rs │ │ ├── simple/ │ │ │ ├── adaptor.rs │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── record/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ └── v1.rs │ │ └── visitor.rs │ ├── burn-remote/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── client/ │ │ │ ├── base.rs │ │ │ ├── channel.rs │ │ │ ├── mod.rs │ │ │ ├── runner.rs │ │ │ └── worker.rs │ │ ├── lib.rs │ │ ├── server/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── processor.rs │ │ │ ├── session.rs │ │ │ └── stream.rs │ │ └── shared/ │ │ ├── mod.rs │ │ └── task.rs │ ├── burn-rl/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── environment/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── policy/ │ │ │ ├── async_policy.rs │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ └── transition_buffer/ │ │ ├── base.rs │ │ ├── mod.rs │ │ └── slice_access.rs │ ├── burn-rocm/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-router/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── backend.rs │ │ ├── bridge/ │ │ │ ├── base.rs │ │ │ ├── byte.rs │ │ │ └── mod.rs │ │ ├── channel/ │ │ │ ├── base.rs │ │ │ ├── direct.rs │ │ │ └── mod.rs │ │ ├── client/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── binary.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ ├── transaction.rs │ │ │ └── unary.rs │ │ ├── runner.rs │ │ ├── tensor.rs │ │ └── types.rs │ ├── burn-std/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── id.rs │ │ ├── lib.rs │ │ ├── network.rs │ │ └── tensor/ │ │ ├── dtype.rs │ │ ├── mod.rs │ │ ├── quantization.rs │ │ ├── shape.rs │ │ └── slice.rs │ ├── burn-store/ │ │ ├── Cargo.toml │ │ ├── MIGRATION.md │ │ ├── README.md │ │ ├── benches/ │ │ │ ├── download_resnet18.py │ │ │ ├── generate_unified_models.py │ │ │ ├── resnet18_loading.rs │ │ │ ├── unified_loading.rs │ │ │ ├── unified_saving.rs │ │ │ └── zero_copy_loading.rs │ │ ├── examples/ │ │ │ ├── burnpack_inspect.rs │ │ │ └── half_precision.rs │ │ ├── pytorch-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── src/ │ │ │ │ └── lib.rs │ │ │ └── tests/ │ │ │ ├── backend.rs │ │ │ ├── batch_norm/ │ │ │ │ ├── batch_norm2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── boolean/ │ │ │ │ ├── boolean.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── buffer/ │ │ │ │ ├── buffer.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── complex_nested/ │ │ │ │ ├── complex_nested.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── config/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── mod.rs │ │ │ │ └── weights_with_config.pt │ │ │ ├── conv1d/ │ │ │ │ ├── conv1d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv2d/ │ │ │ │ ├── conv2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv_transpose1d/ │ │ │ │ ├── conv_transpose1d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── conv_transpose2d/ │ │ │ │ ├── conv_transpose2d.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── debug_test.pt │ │ │ ├── embedding/ │ │ │ │ ├── embedding.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── enum_module/ │ │ │ │ ├── enum_depthwise_false.pt │ │ │ │ ├── enum_depthwise_true.pt │ │ │ │ ├── export_weights.py │ │ │ │ └── mod.rs │ │ │ ├── group_norm/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── group_norm.pt │ │ │ │ └── mod.rs │ │ │ ├── integer/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── integer.pt │ │ │ │ └── mod.rs │ │ │ ├── key_remap/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── key_remap.pt │ │ │ │ └── mod.rs │ │ │ ├── key_remap_chained/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── key_remap.pt │ │ │ │ └── mod.rs │ │ │ ├── layer_norm/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── layer_norm.pt │ │ │ │ └── mod.rs │ │ │ ├── linear/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── linear.pt │ │ │ │ ├── linear_with_bias.pt │ │ │ │ └── mod.rs │ │ │ ├── missing_module_field/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── missing_module_field.pt │ │ │ │ └── mod.rs │ │ │ ├── non_contiguous_indexes/ │ │ │ │ ├── export_weights.py │ │ │ │ ├── mod.rs │ │ │ │ └── non_contiguous_indexes.pt │ │ │ ├── test_int.pt │ │ │ ├── test_mod.rs │ │ │ └── top_level_key/ │ │ │ ├── export_weights.py │ │ │ ├── mod.rs │ │ │ └── top_level_key.pt │ │ ├── safetensors-tests/ │ │ │ ├── Cargo.toml │ │ │ ├── src/ │ │ │ │ └── lib.rs │ │ │ └── tests/ │ │ │ ├── backend.rs │ │ │ ├── multi_layer/ │ │ │ │ ├── mod.rs │ │ │ │ ├── multi_layer.py │ │ │ │ └── multi_layer.safetensors │ │ │ └── test_mod.rs │ │ └── src/ │ │ ├── adapter.rs │ │ ├── applier.rs │ │ ├── apply_result.rs │ │ ├── burnpack/ │ │ │ ├── base.rs │ │ │ ├── mod.rs │ │ │ ├── reader.rs │ │ │ ├── store.rs │ │ │ ├── tests/ │ │ │ │ ├── alignment.rs │ │ │ │ ├── edge_cases.rs │ │ │ │ ├── header.rs │ │ │ │ ├── helpers.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── reader.rs │ │ │ │ ├── round_trip.rs │ │ │ │ ├── store.rs │ │ │ │ ├── writer.rs │ │ │ │ └── zero_copy.rs │ │ │ └── writer.rs │ │ ├── collector.rs │ │ ├── filter.rs │ │ ├── keyremapper.rs │ │ ├── lib.rs │ │ ├── pytorch/ │ │ │ ├── lazy_data.rs │ │ │ ├── mod.rs │ │ │ ├── pickle_reader.rs │ │ │ ├── reader.rs │ │ │ ├── store.rs │ │ │ └── tests/ │ │ │ ├── mod.rs │ │ │ ├── reader/ │ │ │ │ ├── create_legacy_with_offsets.py │ │ │ │ ├── create_tar_format.py │ │ │ │ ├── mod.rs │ │ │ │ ├── simple_legacy.py │ │ │ │ ├── test_data/ │ │ │ │ │ ├── bfloat16.pt │ │ │ │ │ ├── bool.pt │ │ │ │ │ ├── broken.pt │ │ │ │ │ ├── buffers.pt │ │ │ │ │ ├── checkpoint.pt │ │ │ │ │ ├── complex_structure.pt │ │ │ │ │ ├── empty.pt │ │ │ │ │ ├── extreme_values.pt │ │ │ │ │ ├── float16.pt │ │ │ │ │ ├── float32.pt │ │ │ │ │ ├── float64.pt │ │ │ │ │ ├── int16.pt │ │ │ │ │ ├── int32.pt │ │ │ │ │ ├── int64.pt │ │ │ │ │ ├── int8.pt │ │ │ │ │ ├── large_shape.pt │ │ │ │ │ ├── legacy_shared_storage.pt │ │ │ │ │ ├── legacy_with_offsets.pt │ │ │ │ │ ├── mixed_types.pt │ │ │ │ │ ├── nested_dict.pt │ │ │ │ │ ├── parameter.pt │ │ │ │ │ ├── scalar.pt │ │ │ │ │ ├── simple_legacy.pt │ │ │ │ │ ├── special_values.pt │ │ │ │ │ ├── state_dict.pt │ │ │ │ │ ├── tensor_2d.pt │ │ │ │ │ ├── tensor_3d.pt │ │ │ │ │ ├── tensor_4d.pt │ │ │ │ │ └── uint8.pt │ │ │ │ └── test_data.py │ │ │ └── store/ │ │ │ ├── mod.rs │ │ │ └── test_data/ │ │ │ ├── generate_enum_test.py │ │ │ └── model_without_enum_variants.pt │ │ ├── safetensors/ │ │ │ ├── mod.rs │ │ │ ├── store.rs │ │ │ └── tests/ │ │ │ ├── adapter.rs │ │ │ ├── direct_access.rs │ │ │ ├── error_handling.rs │ │ │ ├── file_io.rs │ │ │ ├── filtering.rs │ │ │ ├── integration.rs │ │ │ ├── metadata.rs │ │ │ ├── mixed_datatypes.rs │ │ │ ├── mod.rs │ │ │ ├── multi_layer_verify.rs │ │ │ ├── pytorch_import.rs │ │ │ └── round_trip.rs │ │ ├── tensor_snapshot.rs │ │ └── traits.rs │ ├── burn-tch/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ └── src/ │ │ ├── backend.rs │ │ ├── bin/ │ │ │ ├── cpu.rs │ │ │ ├── cuda.rs │ │ │ └── mps.rs │ │ ├── cuda_hack/ │ │ │ ├── dummy_cuda_dependency.cpp │ │ │ └── fake_cuda_dependency.cpp │ │ ├── element.rs │ │ ├── lib.rs │ │ ├── ops/ │ │ │ ├── activation.rs │ │ │ ├── base.rs │ │ │ ├── bool_tensor.rs │ │ │ ├── int_tensor.rs │ │ │ ├── mod.rs │ │ │ ├── module.rs │ │ │ ├── qtensor.rs │ │ │ ├── tensor.rs │ │ │ └── transaction.rs │ │ └── tensor.rs │ ├── burn-tensor/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── device.rs │ │ ├── lib.rs │ │ └── tensor/ │ │ ├── activation/ │ │ │ ├── base.rs │ │ │ └── mod.rs │ │ ├── api/ │ │ │ ├── autodiff.rs │ │ │ ├── base.rs │ │ │ ├── bool.rs │ │ │ ├── cartesian_grid.rs │ │ │ ├── check.rs │ │ │ ├── float.rs │ │ │ ├── fmod.rs │ │ │ ├── int.rs │ │ │ ├── mod.rs │ │ │ ├── numeric.rs │ │ │ ├── options.rs │ │ │ ├── orderable.rs │ │ │ ├── pad.rs │ │ │ ├── take.rs │ │ │ ├── transaction.rs │ │ │ └── trunc.rs │ │ ├── grid/ │ │ │ ├── affine_grid.rs │ │ │ ├── meshgrid.rs │ │ │ └── mod.rs │ │ ├── linalg/ │ │ │ ├── cosine_similarity.rs │ │ │ ├── diag.rs │ │ │ ├── lu_decomposition.rs │ │ │ ├── matvec.rs │ │ │ ├── mod.rs │ │ │ ├── outer.rs │ │ │ ├── trace.rs │ │ │ └── vector_norm.rs │ │ ├── loss/ │ │ │ └── mod.rs │ │ ├── mod.rs │ │ ├── module.rs │ │ ├── quantization.rs │ │ ├── report.rs │ │ └── stats/ │ │ └── mod.rs │ ├── burn-tensor-testgen/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ └── lib.rs │ ├── burn-train/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── checkpoint/ │ │ │ ├── async_checkpoint.rs │ │ │ ├── base.rs │ │ │ ├── file.rs │ │ │ ├── mod.rs │ │ │ └── strategy/ │ │ │ ├── base.rs │ │ │ ├── composed.rs │ │ │ ├── lastn.rs │ │ │ ├── metric.rs │ │ │ └── mod.rs │ │ ├── components.rs │ │ ├── evaluator/ │ │ │ ├── base.rs │ │ │ ├── builder.rs │ │ │ ├── components.rs │ │ │ └── mod.rs │ │ ├── learner/ │ │ │ ├── application_logger.rs │ │ │ ├── base.rs │ │ │ ├── classification.rs │ │ │ ├── early_stopping.rs │ │ │ ├── mod.rs │ │ │ ├── regression.rs │ │ │ ├── rl/ │ │ │ │ ├── checkpointer.rs │ │ │ │ ├── components.rs │ │ │ │ ├── env_runner/ │ │ │ │ │ ├── async_runner.rs │ │ │ │ │ ├── base.rs │ │ │ │ │ └── mod.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── off_policy.rs │ │ │ │ ├── output.rs │ │ │ │ ├── paradigm.rs │ │ │ │ └── strategy.rs │ │ │ ├── sequence.rs │ │ │ ├── summary.rs │ │ │ ├── supervised/ │ │ │ │ ├── mod.rs │ │ │ │ ├── paradigm.rs │ │ │ │ ├── step/ │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── train.rs │ │ │ │ └── strategies/ │ │ │ │ ├── base.rs │ │ │ │ ├── ddp/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── epoch.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── strategy.rs │ │ │ │ │ └── worker.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── multi/ │ │ │ │ │ ├── epoch.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── strategy.rs │ │ │ │ └── single/ │ │ │ │ ├── epoch.rs │ │ │ │ ├── mod.rs │ │ │ │ └── strategy.rs │ │ │ └── train_val.rs │ │ ├── lib.rs │ │ ├── logger/ │ │ │ ├── async_logger.rs │ │ │ ├── base.rs │ │ │ ├── file.rs │ │ │ ├── in_memory.rs │ │ │ ├── metric.rs │ │ │ └── mod.rs │ │ ├── metric/ │ │ │ ├── acc.rs │ │ │ ├── auroc.rs │ │ │ ├── base.rs │ │ │ ├── cer.rs │ │ │ ├── classification.rs │ │ │ ├── confusion_stats.rs │ │ │ ├── cpu_temp.rs │ │ │ ├── cpu_use.rs │ │ │ ├── cuda.rs │ │ │ ├── fbetascore.rs │ │ │ ├── hamming.rs │ │ │ ├── iteration.rs │ │ │ ├── learning_rate.rs │ │ │ ├── loss.rs │ │ │ ├── memory_use.rs │ │ │ ├── mod.rs │ │ │ ├── perplexity.rs │ │ │ ├── precision.rs │ │ │ ├── processor/ │ │ │ │ ├── async_wrapper.rs │ │ │ │ ├── base.rs │ │ │ │ ├── full.rs │ │ │ │ ├── metrics.rs │ │ │ │ ├── minimal.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── rl_metrics.rs │ │ │ │ └── rl_processor.rs │ │ │ ├── recall.rs │ │ │ ├── rl/ │ │ │ │ ├── cum_reward.rs │ │ │ │ ├── ep_len.rs │ │ │ │ ├── exploration_rate.rs │ │ │ │ └── mod.rs │ │ │ ├── state.rs │ │ │ ├── store/ │ │ │ │ ├── aggregate.rs │ │ │ │ ├── base.rs │ │ │ │ ├── client.rs │ │ │ │ ├── log.rs │ │ │ │ └── mod.rs │ │ │ ├── top_k_acc.rs │ │ │ ├── vision/ │ │ │ │ ├── dice.rs │ │ │ │ ├── dists/ │ │ │ │ │ ├── l2pool.rs │ │ │ │ │ ├── metric.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── vgg16_l2pool.rs │ │ │ │ │ └── weights.rs │ │ │ │ ├── lpips/ │ │ │ │ │ ├── alexnet.rs │ │ │ │ │ ├── metric.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── squeezenet.rs │ │ │ │ │ ├── vgg.rs │ │ │ │ │ └── weights.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── ms_ssim.rs │ │ │ │ ├── psnr.rs │ │ │ │ └── ssim.rs │ │ │ └── wer.rs │ │ └── renderer/ │ │ ├── base.rs │ │ ├── cli.rs │ │ ├── mod.rs │ │ └── tui/ │ │ ├── base.rs │ │ ├── controls.rs │ │ ├── full_history.rs │ │ ├── metric_numeric.rs │ │ ├── metric_text.rs │ │ ├── mod.rs │ │ ├── plot_utils.rs │ │ ├── popup.rs │ │ ├── progress.rs │ │ ├── recent_history.rs │ │ ├── renderer.rs │ │ └── status.rs │ ├── burn-vision/ │ │ ├── Cargo.toml │ │ ├── src/ │ │ │ ├── backends/ │ │ │ │ ├── cpu/ │ │ │ │ │ ├── base.rs │ │ │ │ │ ├── connected_components/ │ │ │ │ │ │ ├── spaghetti/ │ │ │ │ │ │ │ ├── Spaghetti_center_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_first_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_forest_labels.rs │ │ │ │ │ │ │ ├── Spaghetti_last_line_forest_code.rs │ │ │ │ │ │ │ ├── Spaghetti_single_line_forest_code.rs │ │ │ │ │ │ │ └── mod.rs │ │ │ │ │ │ └── spaghetti_4c/ │ │ │ │ │ │ ├── Spaghetti4C_center_line_forest_code.rs │ │ │ │ │ │ ├── Spaghetti4C_first_line_forest_code.rs │ │ │ │ │ │ ├── Spaghetti4C_forest_labels.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── connected_components.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ ├── morphology/ │ │ │ │ │ │ ├── filter.rs │ │ │ │ │ │ ├── filter_engine.rs │ │ │ │ │ │ └── mod.rs │ │ │ │ │ ├── nms.rs │ │ │ │ │ └── ops.rs │ │ │ │ ├── cube/ │ │ │ │ │ ├── connected_components/ │ │ │ │ │ │ ├── hardware_accelerated.rs │ │ │ │ │ │ ├── mod.rs │ │ │ │ │ │ └── prefix_sum.rs │ │ │ │ │ ├── mod.rs │ │ │ │ │ └── ops.rs │ │ │ │ └── mod.rs │ │ │ ├── base.rs │ │ │ ├── lib.rs │ │ │ ├── ops/ │ │ │ │ ├── base.rs │ │ │ │ └── mod.rs │ │ │ ├── tensor.rs │ │ │ ├── tests/ │ │ │ │ └── mod.rs │ │ │ ├── transform/ │ │ │ │ ├── mod.rs │ │ │ │ └── transform2d.rs │ │ │ └── utils/ │ │ │ ├── mod.rs │ │ │ └── save.rs │ │ └── tests/ │ │ ├── common/ │ │ │ └── mod.rs │ │ ├── connected_components.rs │ │ ├── morphology.rs │ │ └── nms.rs │ └── burn-wgpu/ │ ├── Cargo.toml │ ├── README.md │ └── src/ │ └── lib.rs ├── deny.toml ├── docs/ │ └── katex-header.html ├── examples/ │ ├── custom-csv-dataset/ │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── custom-csv-dataset.rs │ │ │ └── dataframe-dataset.rs │ │ └── src/ │ │ ├── dataframe_dataset.rs │ │ ├── dataset.rs │ │ ├── diabetes_patient.rs │ │ ├── lib.rs │ │ └── utils.rs │ ├── custom-cubecl-kernel/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-cubecl-kernel.rs │ │ └── src/ │ │ ├── backward.rs │ │ ├── forward.rs │ │ ├── kernel.rs │ │ └── lib.rs │ ├── custom-image-dataset/ │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── custom-image-dataset.rs │ │ └── src/ │ │ ├── data.rs │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── custom-learning-strategy/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-learning-strategy.rs │ │ └── src/ │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── custom-renderer/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-renderer.rs │ │ └── src/ │ │ └── lib.rs │ ├── custom-training-loop/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-training-loop.rs │ │ └── src/ │ │ └── lib.rs │ ├── custom-wgpu-kernel/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── custom-wgpu-kernel.rs │ │ └── src/ │ │ ├── backward.rs │ │ ├── forward.rs │ │ ├── kernel.wgsl │ │ └── lib.rs │ ├── dop_timer/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── event_utils.rs │ │ ├── main.rs │ │ ├── parsers.rs │ │ └── workers.rs │ ├── dqn-agent/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── dqn-agent.rs │ │ └── src/ │ │ ├── agent.rs │ │ ├── env.rs │ │ ├── lib.rs │ │ ├── training.rs │ │ └── utils.rs │ ├── guide/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── guide.rs │ │ └── src/ │ │ ├── bin/ │ │ │ ├── infer.rs │ │ │ ├── print.rs │ │ │ └── train.rs │ │ ├── data.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── import-model-weights/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── src/ │ │ │ ├── bin/ │ │ │ │ ├── burnpack.rs │ │ │ │ ├── convert.rs │ │ │ │ ├── pytorch.rs │ │ │ │ └── safetensors.rs │ │ │ ├── inference.rs │ │ │ ├── lib.rs │ │ │ └── model.rs │ │ └── weights/ │ │ ├── mnist.pt │ │ ├── mnist.safetensors │ │ └── mnist_train_export.py │ ├── mnist/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ └── mnist.rs │ │ └── src/ │ │ ├── data.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── mnist-inference-web/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build-for-web.sh │ │ ├── index.html │ │ ├── index.js │ │ ├── run-server.sh │ │ └── src/ │ │ ├── lib.rs │ │ ├── model.rs │ │ ├── state.rs │ │ └── web.rs │ ├── modern-lstm/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ ├── lstm-infer.rs │ │ │ └── lstm-train.rs │ │ └── src/ │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── multi-gpus/ │ │ ├── Cargo.toml │ │ ├── examples/ │ │ │ └── multi-gpus.rs │ │ └── src/ │ │ └── lib.rs │ ├── notebook/ │ │ ├── README.md │ │ ├── autodiff.ipynb │ │ └── basic-tensor-op.ipynb │ ├── server/ │ │ ├── Cargo.toml │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ └── server.rs │ │ └── src/ │ │ └── lib.rs │ ├── simple-regression/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── regression.rs │ │ └── src/ │ │ ├── dataset.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── text-classification/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── cubecl.toml │ │ ├── examples/ │ │ │ ├── ag-news-infer.rs │ │ │ ├── ag-news-train.rs │ │ │ ├── db-pedia-infer.rs │ │ │ └── db-pedia-train.rs │ │ └── src/ │ │ ├── data/ │ │ │ ├── batcher.rs │ │ │ ├── dataset.rs │ │ │ ├── mod.rs │ │ │ └── tokenizer.rs │ │ ├── inference.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ ├── text-generation/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── examples/ │ │ │ └── text-generation.rs │ │ └── src/ │ │ ├── data/ │ │ │ ├── batcher.rs │ │ │ ├── dataset.rs │ │ │ ├── mod.rs │ │ │ └── tokenizer.rs │ │ ├── lib.rs │ │ ├── model.rs │ │ └── training.rs │ └── wgan/ │ ├── Cargo.toml │ ├── README.md │ ├── examples/ │ │ ├── wgan-generate.rs │ │ └── wgan-mnist.rs │ └── src/ │ ├── dataset.rs │ ├── infer.rs │ ├── lib.rs │ ├── model.rs │ └── training.rs ├── rustfmt.toml └── xtask/ ├── Cargo.toml └── src/ ├── commands/ │ ├── books.rs │ ├── build.rs │ ├── doc.rs │ ├── mod.rs │ ├── test.rs │ └── validate.rs └── main.rs