gitextract_c0zbkz84/ ├── .clang-format ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── bug_report.md │ ├── actions/ │ │ ├── build-cuda-release/ │ │ │ └── action.yml │ │ ├── build-docs/ │ │ │ └── action.yml │ │ ├── build-linux/ │ │ │ └── action.yml │ │ ├── build-linux-release/ │ │ │ └── action.yml │ │ ├── build-macos/ │ │ │ └── action.yml │ │ ├── build-macos-release/ │ │ │ └── action.yml │ │ ├── build-windows/ │ │ │ └── action.yml │ │ ├── setup-linux/ │ │ │ └── action.yml │ │ ├── setup-macos/ │ │ │ └── action.yml │ │ ├── setup-windows/ │ │ │ └── action.yml │ │ ├── test-linux/ │ │ │ └── action.yml │ │ └── test-windows/ │ │ └── action.yml │ ├── dependabot.yml │ ├── pull_request_template.md │ ├── scripts/ │ │ ├── build-sanitizer-tests.sh │ │ └── setup+build-cpp-linux-fedora-container.sh │ └── workflows/ │ ├── build_and_test.yml │ ├── documentation.yml │ ├── nightly.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGMENTS.md ├── CITATION.cff ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks/ │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── autograd.cpp │ │ ├── compare_devices.cpp │ │ ├── irregular_strides.cpp │ │ ├── single_ops.cpp │ │ └── time_utils.h │ ├── numpy/ │ │ ├── single_ops.py │ │ └── time_utils.py │ └── python/ │ ├── batch_matmul_bench.py │ ├── blas/ │ │ ├── bench_gemm.py │ │ └── bench_gemv.py │ ├── comparative/ │ │ ├── README.md │ │ ├── bench_mlx.py │ │ ├── bench_torch.py │ │ └── compare.py │ ├── compile_bench.py │ ├── conv1d_bench.py │ ├── conv2d_bench_cpu.py │ ├── conv2d_train_bench_cpu.py │ ├── conv2d_transpose_bench_cpu.py │ ├── conv3d_bench.py │ ├── conv3d_bench_cpu.py │ ├── conv3d_train_bench_cpu.py │ ├── conv3d_transpose_bench_cpu.py │ ├── conv_bench.py │ ├── conv_transpose_bench.py │ ├── conv_unaligned_bench.py │ ├── distributed_bench.py │ ├── einsum_bench.py │ ├── fft_bench.py │ ├── gather_bench.py │ ├── gather_mm_bench.py │ ├── gather_qmm_bench.py │ ├── hadamard_bench.py │ ├── large_gemm_bench.py │ ├── layer_norm_bench.py │ ├── masked_scatter.py │ ├── rms_norm_bench.py │ ├── rope_bench.py │ ├── scatter_bench.py │ ├── sdpa_bench.py │ ├── sdpa_vector_bench.py │ ├── segmented_mm_bench.py │ ├── single_ops.py │ ├── slice_update_bench.py │ ├── synchronize_bench.py │ └── time_utils.py ├── cmake/ │ ├── FindCUDNN.cmake │ ├── FindNCCL.cmake │ ├── Findnvpl.cmake │ └── extension.cmake ├── docs/ │ ├── .clang-format │ ├── .gitignore │ ├── .nojekyll │ ├── Doxyfile │ ├── Makefile │ ├── README.md │ ├── index.html │ ├── requirements.txt │ └── src/ │ ├── _templates/ │ │ ├── module-base-class.rst │ │ ├── nn-module-template.rst │ │ └── optimizers-template.rst │ ├── conf.py │ ├── cpp/ │ │ └── ops.rst │ ├── dev/ │ │ ├── custom_metal_kernels.rst │ │ ├── extensions.rst │ │ ├── metal_debugger.rst │ │ ├── metal_logging.rst │ │ └── mlx_in_cpp.rst │ ├── examples/ │ │ ├── data_parallelism.rst │ │ ├── linear_regression.rst │ │ ├── llama-inference.rst │ │ ├── mlp.rst │ │ └── tensor_parallelism.rst │ ├── index.rst │ ├── install.rst │ ├── python/ │ │ ├── array.rst │ │ ├── cuda.rst │ │ ├── data_types.rst │ │ ├── devices_and_streams.rst │ │ ├── distributed.rst │ │ ├── export.rst │ │ ├── fast.rst │ │ ├── fft.rst │ │ ├── linalg.rst │ │ ├── memory_management.rst │ │ ├── metal.rst │ │ ├── nn/ │ │ │ ├── distributed.rst │ │ │ ├── functions.rst │ │ │ ├── init.rst │ │ │ ├── layers.rst │ │ │ ├── losses.rst │ │ │ └── module.rst │ │ ├── nn.rst │ │ ├── ops.rst │ │ ├── optimizers/ │ │ │ ├── common_optimizers.rst │ │ │ ├── optimizer.rst │ │ │ └── schedulers.rst │ │ ├── optimizers.rst │ │ ├── random.rst │ │ ├── transforms.rst │ │ └── tree_utils.rst │ └── usage/ │ ├── compile.rst │ ├── distributed.rst │ ├── export.rst │ ├── function_transforms.rst │ ├── indexing.rst │ ├── launching_distributed.rst │ ├── lazy_evaluation.rst │ ├── numpy.rst │ ├── quick_start.rst │ ├── saving_and_loading.rst │ ├── unified_memory.rst │ └── using_streams.rst ├── examples/ │ ├── cmake_project/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ └── example.cpp │ ├── cpp/ │ │ ├── CMakeLists.txt │ │ ├── distributed.cpp │ │ ├── linear_regression.cpp │ │ ├── logistic_regression.cpp │ │ ├── metal_capture.cpp │ │ ├── timer.h │ │ └── tutorial.cpp │ ├── export/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── eval_mlp.cpp │ │ ├── eval_mlp.py │ │ ├── train_mlp.cpp │ │ └── train_mlp.py │ ├── extensions/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── axpby/ │ │ │ ├── axpby.cpp │ │ │ ├── axpby.h │ │ │ └── axpby.metal │ │ ├── bindings.cpp │ │ ├── mlx_sample_extensions/ │ │ │ └── __init__.py │ │ ├── pyproject.toml │ │ ├── requirements.txt │ │ ├── setup.py │ │ └── test.py │ └── python/ │ ├── linear_regression.py │ ├── logistic_regression.py │ └── qqmm.py ├── mlx/ │ ├── 3rdparty/ │ │ ├── .clang-format │ │ └── pocketfft.h │ ├── CMakeLists.txt │ ├── allocator.h │ ├── api.h │ ├── array.cpp │ ├── array.h │ ├── backend/ │ │ ├── common/ │ │ │ ├── CMakeLists.txt │ │ │ ├── binary.h │ │ │ ├── broadcasting.cpp │ │ │ ├── broadcasting.h │ │ │ ├── buffer_cache.h │ │ │ ├── common.cpp │ │ │ ├── compiled.cpp │ │ │ ├── compiled.h │ │ │ ├── copy.h │ │ │ ├── hadamard.h │ │ │ ├── load.cpp │ │ │ ├── matmul.h │ │ │ ├── quantized.h │ │ │ ├── reduce.cpp │ │ │ ├── reduce.h │ │ │ ├── slicing.cpp │ │ │ ├── slicing.h │ │ │ ├── ternary.h │ │ │ ├── unary.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── cpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── arange.h │ │ │ ├── arg_reduce.cpp │ │ │ ├── binary.cpp │ │ │ ├── binary.h │ │ │ ├── binary_ops.h │ │ │ ├── binary_two.h │ │ │ ├── cholesky.cpp │ │ │ ├── compiled.cpp │ │ │ ├── compiled_preamble.h │ │ │ ├── conv.cpp │ │ │ ├── copy.cpp │ │ │ ├── copy.h │ │ │ ├── device_info.cpp │ │ │ ├── device_info.h │ │ │ ├── distributed.cpp │ │ │ ├── eig.cpp │ │ │ ├── eigh.cpp │ │ │ ├── encoder.cpp │ │ │ ├── encoder.h │ │ │ ├── eval.cpp │ │ │ ├── eval.h │ │ │ ├── fft.cpp │ │ │ ├── gemm.h │ │ │ ├── gemms/ │ │ │ │ ├── bnns.cpp │ │ │ │ ├── cblas.cpp │ │ │ │ ├── simd_bf16.cpp │ │ │ │ ├── simd_fp16.cpp │ │ │ │ └── simd_gemm.h │ │ │ ├── hadamard.cpp │ │ │ ├── indexing.cpp │ │ │ ├── inverse.cpp │ │ │ ├── jit_compiler.cpp │ │ │ ├── jit_compiler.h │ │ │ ├── lapack.h │ │ │ ├── logsumexp.cpp │ │ │ ├── luf.cpp │ │ │ ├── make_compiled_preamble.ps1 │ │ │ ├── make_compiled_preamble.sh │ │ │ ├── masked_mm.cpp │ │ │ ├── matmul.cpp │ │ │ ├── primitives.cpp │ │ │ ├── qrf.cpp │ │ │ ├── quantized.cpp │ │ │ ├── reduce.cpp │ │ │ ├── scan.cpp │ │ │ ├── select.cpp │ │ │ ├── simd/ │ │ │ │ ├── accelerate_fp16_simd.h │ │ │ │ ├── accelerate_simd.h │ │ │ │ ├── base_simd.h │ │ │ │ ├── math.h │ │ │ │ ├── neon_fp16_simd.h │ │ │ │ ├── simd.h │ │ │ │ └── type.h │ │ │ ├── slicing.h │ │ │ ├── softmax.cpp │ │ │ ├── sort.cpp │ │ │ ├── svd.cpp │ │ │ ├── ternary.h │ │ │ ├── threefry.cpp │ │ │ ├── threefry.h │ │ │ ├── unary.cpp │ │ │ ├── unary.h │ │ │ └── unary_ops.h │ │ ├── cuda/ │ │ │ ├── CMakeLists.txt │ │ │ ├── allocator.cpp │ │ │ ├── allocator.h │ │ │ ├── arange.cu │ │ │ ├── arg_reduce.cu │ │ │ ├── bin2h.cmake │ │ │ ├── binary/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── add.cu │ │ │ │ ├── arctan2.cu │ │ │ │ ├── binary.cuh │ │ │ │ ├── bitwise_binary.cu │ │ │ │ ├── divide.cu │ │ │ │ ├── equal.cu │ │ │ │ ├── greater.cu │ │ │ │ ├── greater_equal.cu │ │ │ │ ├── less.cu │ │ │ │ ├── less_equal.cu │ │ │ │ ├── log_add_exp.cu │ │ │ │ ├── logical_and.cu │ │ │ │ ├── logical_or.cu │ │ │ │ ├── maximum.cu │ │ │ │ ├── minimum.cu │ │ │ │ ├── multiply.cu │ │ │ │ ├── not_equal.cu │ │ │ │ ├── power.cu │ │ │ │ ├── remainder.cu │ │ │ │ └── subtract.cu │ │ │ ├── binary_two.cu │ │ │ ├── compiled.cpp │ │ │ ├── conv/ │ │ │ │ ├── conv.h │ │ │ │ ├── gemm_conv.cu │ │ │ │ └── gemm_grouped_conv.cu │ │ │ ├── conv.cpp │ │ │ ├── copy/ │ │ │ │ ├── copy.cuh │ │ │ │ ├── copy_contiguous.cu │ │ │ │ ├── copy_general.cu │ │ │ │ ├── copy_general_dynamic.cu │ │ │ │ └── copy_general_input.cu │ │ │ ├── copy.cu │ │ │ ├── cublas_utils.cpp │ │ │ ├── cublas_utils.h │ │ │ ├── cuda.h │ │ │ ├── cuda_utils.h │ │ │ ├── cudnn_utils.cpp │ │ │ ├── cudnn_utils.h │ │ │ ├── custom_kernel.cpp │ │ │ ├── cutlass_utils.cuh │ │ │ ├── delayload.cpp │ │ │ ├── device/ │ │ │ │ ├── atomic_ops.cuh │ │ │ │ ├── binary_ops.cuh │ │ │ │ ├── cast_op.cuh │ │ │ │ ├── complex.cuh │ │ │ │ ├── config.h │ │ │ │ ├── fp16_math.cuh │ │ │ │ ├── gather.cuh │ │ │ │ ├── gather_axis.cuh │ │ │ │ ├── hadamard.cuh │ │ │ │ ├── indexing.cuh │ │ │ │ ├── scatter.cuh │ │ │ │ ├── scatter_axis.cuh │ │ │ │ ├── scatter_ops.cuh │ │ │ │ ├── slice_update.cuh │ │ │ │ ├── ternary_ops.cuh │ │ │ │ ├── unary_ops.cuh │ │ │ │ └── utils.cuh │ │ │ ├── device.cpp │ │ │ ├── device.h │ │ │ ├── device_info.cpp │ │ │ ├── distributed.cu │ │ │ ├── eval.cpp │ │ │ ├── event.cu │ │ │ ├── event.h │ │ │ ├── fence.cpp │ │ │ ├── fft.cu │ │ │ ├── gemms/ │ │ │ │ ├── cublas_gemm.cpp │ │ │ │ ├── cublas_gemm.h │ │ │ │ ├── cublas_gemm_batched_12_0.cpp │ │ │ │ ├── cublas_gemm_batched_12_9.cu │ │ │ │ ├── gemv.cu │ │ │ │ ├── gemv.h │ │ │ │ ├── grouped_gemm.h │ │ │ │ └── grouped_gemm_unaligned.cu │ │ │ ├── hadamard.cu │ │ │ ├── indexing.cpp │ │ │ ├── jit_module.cpp │ │ │ ├── jit_module.h │ │ │ ├── kernel_utils.cu │ │ │ ├── kernel_utils.cuh │ │ │ ├── layer_norm.cu │ │ │ ├── load.cpp │ │ │ ├── logsumexp.cu │ │ │ ├── lru_cache.h │ │ │ ├── matmul.cpp │ │ │ ├── no_cuda.cpp │ │ │ ├── primitives.cpp │ │ │ ├── quantized/ │ │ │ │ ├── affine_quantize.cu │ │ │ │ ├── convert_fp8.cu │ │ │ │ ├── cublas_qqmm.cpp │ │ │ │ ├── cublas_qqmm.h │ │ │ │ ├── fp_quantize.cu │ │ │ │ ├── mxfp8_quantize.cuh │ │ │ │ ├── no_qqmm_impl.cpp │ │ │ │ ├── nvfp4_quantize.cuh │ │ │ │ ├── qmm/ │ │ │ │ │ ├── CMakeLists.txt │ │ │ │ │ ├── fp_qmv.cu │ │ │ │ │ ├── qmm.cu │ │ │ │ │ ├── qmm.h │ │ │ │ │ ├── qmm_impl_sm80.cuh │ │ │ │ │ ├── qmm_impl_sm80_m16.cu │ │ │ │ │ ├── qmm_impl_sm80_m32.cu │ │ │ │ │ ├── qmm_impl_sm80_m64.cu │ │ │ │ │ ├── qmm_impl_sm90.cuh │ │ │ │ │ ├── qmm_impl_sm90_m128_n128_m2.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n16_m1.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n256_m2.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n32_m1.cu │ │ │ │ │ ├── qmm_impl_sm90_m128_n64_m2.cu │ │ │ │ │ └── qmv.cu │ │ │ │ ├── qqmm.cpp │ │ │ │ ├── qqmm_impl.cpp │ │ │ │ ├── qqmm_impl.h │ │ │ │ ├── qqmm_utils.cu │ │ │ │ ├── qqmm_utils.h │ │ │ │ ├── quantized.cpp │ │ │ │ ├── quantized.h │ │ │ │ └── quantized_utils.h │ │ │ ├── random.cu │ │ │ ├── reduce/ │ │ │ │ ├── all_reduce.cu │ │ │ │ ├── col_reduce.cu │ │ │ │ ├── init_reduce.cu │ │ │ │ ├── reduce.cuh │ │ │ │ ├── reduce_ops.cuh │ │ │ │ ├── reduce_utils.cuh │ │ │ │ └── row_reduce.cu │ │ │ ├── reduce.cu │ │ │ ├── rms_norm.cu │ │ │ ├── rope.cu │ │ │ ├── scaled_dot_product_attention.cpp │ │ │ ├── scaled_dot_product_attention.cu │ │ │ ├── scan.cu │ │ │ ├── slicing.cpp │ │ │ ├── softmax.cu │ │ │ ├── sort.cu │ │ │ ├── steel/ │ │ │ │ ├── defines.cuh │ │ │ │ ├── gemm.cuh │ │ │ │ ├── mma.cuh │ │ │ │ ├── tiles.cuh │ │ │ │ └── utils.cuh │ │ │ ├── ternary.cu │ │ │ ├── unary/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── abs.cu │ │ │ │ ├── arccos.cu │ │ │ │ ├── arccosh.cu │ │ │ │ ├── arcsin.cu │ │ │ │ ├── arcsinh.cu │ │ │ │ ├── arctan.cu │ │ │ │ ├── arctanh.cu │ │ │ │ ├── bitwise_invert.cu │ │ │ │ ├── ceil.cu │ │ │ │ ├── conjugate.cu │ │ │ │ ├── cos.cu │ │ │ │ ├── cosh.cu │ │ │ │ ├── erf.cu │ │ │ │ ├── erf_inv.cu │ │ │ │ ├── exp.cu │ │ │ │ ├── expm1.cu │ │ │ │ ├── floor.cu │ │ │ │ ├── imag.cu │ │ │ │ ├── log.cu │ │ │ │ ├── log1p.cu │ │ │ │ ├── logical_not.cu │ │ │ │ ├── negative.cu │ │ │ │ ├── real.cu │ │ │ │ ├── round.cu │ │ │ │ ├── sigmoid.cu │ │ │ │ ├── sign.cu │ │ │ │ ├── sin.cu │ │ │ │ ├── sinh.cu │ │ │ │ ├── sqrt.cu │ │ │ │ ├── square.cu │ │ │ │ ├── tan.cu │ │ │ │ ├── tanh.cu │ │ │ │ └── unary.cuh │ │ │ ├── utils.cpp │ │ │ ├── utils.h │ │ │ ├── vector_types.cuh │ │ │ ├── worker.cpp │ │ │ └── worker.h │ │ ├── gpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── copy.cpp │ │ │ ├── copy.h │ │ │ ├── device_info.h │ │ │ ├── eval.h │ │ │ ├── primitives.cpp │ │ │ ├── scan.h │ │ │ ├── slicing.cpp │ │ │ └── slicing.h │ │ ├── metal/ │ │ │ ├── CMakeLists.txt │ │ │ ├── allocator.cpp │ │ │ ├── allocator.h │ │ │ ├── binary.cpp │ │ │ ├── binary.h │ │ │ ├── compiled.cpp │ │ │ ├── conv.cpp │ │ │ ├── copy.cpp │ │ │ ├── custom_kernel.cpp │ │ │ ├── device.cpp │ │ │ ├── device.h │ │ │ ├── device_info.cpp │ │ │ ├── distributed.cpp │ │ │ ├── eval.cpp │ │ │ ├── event.cpp │ │ │ ├── fence.cpp │ │ │ ├── fft.cpp │ │ │ ├── hadamard.cpp │ │ │ ├── indexing.cpp │ │ │ ├── jit/ │ │ │ │ ├── includes.h │ │ │ │ └── indexing.h │ │ │ ├── jit_kernels.cpp │ │ │ ├── kernels/ │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── arange.h │ │ │ │ ├── arange.metal │ │ │ │ ├── arg_reduce.metal │ │ │ │ ├── atomic.h │ │ │ │ ├── bf16.h │ │ │ │ ├── bf16_math.h │ │ │ │ ├── binary.h │ │ │ │ ├── binary.metal │ │ │ │ ├── binary_ops.h │ │ │ │ ├── binary_two.h │ │ │ │ ├── binary_two.metal │ │ │ │ ├── cexpf.h │ │ │ │ ├── complex.h │ │ │ │ ├── conv.metal │ │ │ │ ├── copy.h │ │ │ │ ├── copy.metal │ │ │ │ ├── defines.h │ │ │ │ ├── erf.h │ │ │ │ ├── expm1f.h │ │ │ │ ├── fence.metal │ │ │ │ ├── fft/ │ │ │ │ │ ├── radix.h │ │ │ │ │ └── readwrite.h │ │ │ │ ├── fft.h │ │ │ │ ├── fft.metal │ │ │ │ ├── fp4.h │ │ │ │ ├── fp8.h │ │ │ │ ├── fp_quantized.h │ │ │ │ ├── fp_quantized.metal │ │ │ │ ├── fp_quantized_nax.h │ │ │ │ ├── fp_quantized_nax.metal │ │ │ │ ├── gemv.metal │ │ │ │ ├── gemv_masked.h │ │ │ │ ├── gemv_masked.metal │ │ │ │ ├── hadamard.h │ │ │ │ ├── indexing/ │ │ │ │ │ ├── gather.h │ │ │ │ │ ├── gather_axis.h │ │ │ │ │ ├── gather_front.h │ │ │ │ │ ├── indexing.h │ │ │ │ │ ├── masked_scatter.h │ │ │ │ │ ├── scatter.h │ │ │ │ │ └── scatter_axis.h │ │ │ │ ├── layer_norm.metal │ │ │ │ ├── logging.h │ │ │ │ ├── logsumexp.h │ │ │ │ ├── logsumexp.metal │ │ │ │ ├── quantized.h │ │ │ │ ├── quantized.metal │ │ │ │ ├── quantized_nax.h │ │ │ │ ├── quantized_nax.metal │ │ │ │ ├── quantized_utils.h │ │ │ │ ├── random.metal │ │ │ │ ├── reduce.h │ │ │ │ ├── reduce.metal │ │ │ │ ├── reduce_utils.h │ │ │ │ ├── reduction/ │ │ │ │ │ ├── ops.h │ │ │ │ │ ├── reduce_all.h │ │ │ │ │ ├── reduce_col.h │ │ │ │ │ ├── reduce_init.h │ │ │ │ │ └── reduce_row.h │ │ │ │ ├── rms_norm.metal │ │ │ │ ├── rope.metal │ │ │ │ ├── scaled_dot_product_attention.metal │ │ │ │ ├── scan.h │ │ │ │ ├── scan.metal │ │ │ │ ├── sdpa_vector.h │ │ │ │ ├── softmax.h │ │ │ │ ├── softmax.metal │ │ │ │ ├── sort.h │ │ │ │ ├── sort.metal │ │ │ │ ├── steel/ │ │ │ │ │ ├── attn/ │ │ │ │ │ │ ├── attn.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_attention.h │ │ │ │ │ │ │ ├── steel_attention.metal │ │ │ │ │ │ │ ├── steel_attention_nax.h │ │ │ │ │ │ │ └── steel_attention_nax.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── mma.h │ │ │ │ │ │ ├── nax.h │ │ │ │ │ │ ├── params.h │ │ │ │ │ │ └── transforms.h │ │ │ │ │ ├── conv/ │ │ │ │ │ │ ├── conv.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_conv.h │ │ │ │ │ │ │ ├── steel_conv.metal │ │ │ │ │ │ │ ├── steel_conv_3d.h │ │ │ │ │ │ │ ├── steel_conv_3d.metal │ │ │ │ │ │ │ ├── steel_conv_general.h │ │ │ │ │ │ │ └── steel_conv_general.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── loaders/ │ │ │ │ │ │ │ ├── loader_channel_l.h │ │ │ │ │ │ │ ├── loader_channel_n.h │ │ │ │ │ │ │ └── loader_general.h │ │ │ │ │ │ └── params.h │ │ │ │ │ ├── defines.h │ │ │ │ │ ├── gemm/ │ │ │ │ │ │ ├── gemm.h │ │ │ │ │ │ ├── gemm_nax.h │ │ │ │ │ │ ├── kernels/ │ │ │ │ │ │ │ ├── steel_gemm_fused.h │ │ │ │ │ │ │ ├── steel_gemm_fused.metal │ │ │ │ │ │ │ ├── steel_gemm_fused_nax.h │ │ │ │ │ │ │ ├── steel_gemm_fused_nax.metal │ │ │ │ │ │ │ ├── steel_gemm_gather.h │ │ │ │ │ │ │ ├── steel_gemm_gather.metal │ │ │ │ │ │ │ ├── steel_gemm_gather_nax.h │ │ │ │ │ │ │ ├── steel_gemm_gather_nax.metal │ │ │ │ │ │ │ ├── steel_gemm_masked.h │ │ │ │ │ │ │ ├── steel_gemm_masked.metal │ │ │ │ │ │ │ ├── steel_gemm_segmented.h │ │ │ │ │ │ │ ├── steel_gemm_segmented.metal │ │ │ │ │ │ │ ├── steel_gemm_splitk.h │ │ │ │ │ │ │ ├── steel_gemm_splitk.metal │ │ │ │ │ │ │ ├── steel_gemm_splitk_nax.h │ │ │ │ │ │ │ └── steel_gemm_splitk_nax.metal │ │ │ │ │ │ ├── loader.h │ │ │ │ │ │ ├── mma.h │ │ │ │ │ │ ├── nax.h │ │ │ │ │ │ ├── params.h │ │ │ │ │ │ └── transforms.h │ │ │ │ │ ├── utils/ │ │ │ │ │ │ ├── integral_constant.h │ │ │ │ │ │ └── type_traits.h │ │ │ │ │ └── utils.h │ │ │ │ ├── ternary.h │ │ │ │ ├── ternary.metal │ │ │ │ ├── ternary_ops.h │ │ │ │ ├── unary.h │ │ │ │ ├── unary.metal │ │ │ │ ├── unary_ops.h │ │ │ │ └── utils.h │ │ │ ├── kernels.h │ │ │ ├── logsumexp.cpp │ │ │ ├── make_compiled_preamble.sh │ │ │ ├── matmul.cpp │ │ │ ├── matmul.h │ │ │ ├── metal.cpp │ │ │ ├── metal.h │ │ │ ├── no_metal.cpp │ │ │ ├── nojit_kernels.cpp │ │ │ ├── normalization.cpp │ │ │ ├── primitives.cpp │ │ │ ├── quantized.cpp │ │ │ ├── reduce.cpp │ │ │ ├── reduce.h │ │ │ ├── resident.cpp │ │ │ ├── resident.h │ │ │ ├── rope.cpp │ │ │ ├── scaled_dot_product_attention.cpp │ │ │ ├── scan.cpp │ │ │ ├── slicing.cpp │ │ │ ├── softmax.cpp │ │ │ ├── sort.cpp │ │ │ ├── ternary.cpp │ │ │ ├── ternary.h │ │ │ ├── unary.cpp │ │ │ ├── unary.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── no_cpu/ │ │ │ ├── CMakeLists.txt │ │ │ ├── compiled.cpp │ │ │ ├── device_info.cpp │ │ │ └── primitives.cpp │ │ └── no_gpu/ │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── apple_memory.h │ │ ├── device_info.cpp │ │ ├── eval.cpp │ │ ├── event.cpp │ │ ├── fence.cpp │ │ ├── linux_memory.h │ │ └── primitives.cpp │ ├── compile.cpp │ ├── compile.h │ ├── compile_impl.h │ ├── device.cpp │ ├── device.h │ ├── distributed/ │ │ ├── CMakeLists.txt │ │ ├── distributed.cpp │ │ ├── distributed.h │ │ ├── distributed_impl.h │ │ ├── jaccl/ │ │ │ ├── CMakeLists.txt │ │ │ ├── jaccl.cpp │ │ │ ├── jaccl.h │ │ │ ├── mesh.cpp │ │ │ ├── mesh.h │ │ │ ├── mesh_impl.h │ │ │ ├── no_jaccl.cpp │ │ │ ├── ring.cpp │ │ │ ├── ring.h │ │ │ ├── ring_impl.h │ │ │ ├── utils.cpp │ │ │ └── utils.h │ │ ├── mpi/ │ │ │ ├── CMakeLists.txt │ │ │ ├── mpi.cpp │ │ │ ├── mpi.h │ │ │ ├── mpi_declarations.h │ │ │ └── no_mpi.cpp │ │ ├── nccl/ │ │ │ ├── CMakeLists.txt │ │ │ ├── nccl.cpp │ │ │ ├── nccl.h │ │ │ └── no_nccl.cpp │ │ ├── ops.cpp │ │ ├── ops.h │ │ ├── primitives.cpp │ │ ├── primitives.h │ │ ├── reduction_ops.h │ │ ├── ring/ │ │ │ ├── CMakeLists.txt │ │ │ ├── no_ring.cpp │ │ │ ├── ring.cpp │ │ │ └── ring.h │ │ ├── utils.cpp │ │ └── utils.h │ ├── dtype.cpp │ ├── dtype.h │ ├── dtype_utils.cpp │ ├── dtype_utils.h │ ├── einsum.cpp │ ├── einsum.h │ ├── event.h │ ├── export.cpp │ ├── export.h │ ├── export_impl.h │ ├── fast.cpp │ ├── fast.h │ ├── fast_primitives.h │ ├── fence.h │ ├── fft.cpp │ ├── fft.h │ ├── graph_utils.cpp │ ├── graph_utils.h │ ├── io/ │ │ ├── CMakeLists.txt │ │ ├── gguf.cpp │ │ ├── gguf.h │ │ ├── gguf_quants.cpp │ │ ├── load.cpp │ │ ├── load.h │ │ ├── no_gguf.cpp │ │ ├── no_safetensors.cpp │ │ └── safetensors.cpp │ ├── io.h │ ├── linalg.cpp │ ├── linalg.h │ ├── memory.h │ ├── mlx.h │ ├── ops.cpp │ ├── ops.h │ ├── primitives.cpp │ ├── primitives.h │ ├── random.cpp │ ├── random.h │ ├── scheduler.cpp │ ├── scheduler.h │ ├── small_vector.h │ ├── stream.h │ ├── threadpool.h │ ├── transforms.cpp │ ├── transforms.h │ ├── transforms_impl.h │ ├── types/ │ │ ├── bf16.h │ │ ├── complex.h │ │ ├── fp16.h │ │ ├── half_types.h │ │ └── limits.h │ ├── utils.cpp │ ├── utils.h │ ├── version.cpp │ └── version.h ├── mlx.pc.in ├── pyproject.toml ├── python/ │ ├── mlx/ │ │ ├── __main__.py │ │ ├── _distributed_utils/ │ │ │ ├── common.py │ │ │ ├── config.py │ │ │ └── launch.py │ │ ├── _reprlib_fix.py │ │ ├── _stub_patterns.txt │ │ ├── extension.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── init.py │ │ │ ├── layers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── base.py │ │ │ │ ├── containers.py │ │ │ │ ├── convolution.py │ │ │ │ ├── convolution_transpose.py │ │ │ │ ├── distributed.py │ │ │ │ ├── dropout.py │ │ │ │ ├── embedding.py │ │ │ │ ├── linear.py │ │ │ │ ├── normalization.py │ │ │ │ ├── pooling.py │ │ │ │ ├── positional_encoding.py │ │ │ │ ├── quantized.py │ │ │ │ ├── recurrent.py │ │ │ │ ├── transformer.py │ │ │ │ └── upsample.py │ │ │ ├── losses.py │ │ │ └── utils.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── optimizers.py │ │ │ └── schedulers.py │ │ ├── py.typed │ │ └── utils.py │ ├── src/ │ │ ├── CMakeLists.txt │ │ ├── array.cpp │ │ ├── buffer.h │ │ ├── constants.cpp │ │ ├── convert.cpp │ │ ├── convert.h │ │ ├── cuda.cpp │ │ ├── device.cpp │ │ ├── distributed.cpp │ │ ├── export.cpp │ │ ├── fast.cpp │ │ ├── fft.cpp │ │ ├── indexing.cpp │ │ ├── indexing.h │ │ ├── linalg.cpp │ │ ├── load.cpp │ │ ├── load.h │ │ ├── memory.cpp │ │ ├── metal.cpp │ │ ├── mlx.cpp │ │ ├── mlx_func.cpp │ │ ├── mlx_func.h │ │ ├── ops.cpp │ │ ├── random.cpp │ │ ├── small_vector.h │ │ ├── stream.cpp │ │ ├── transforms.cpp │ │ ├── trees.cpp │ │ ├── trees.h │ │ ├── utils.cpp │ │ └── utils.h │ └── tests/ │ ├── __main__.py │ ├── cuda_skip.py │ ├── mlx_distributed_tests.py │ ├── mlx_tests.py │ ├── mpi_test_distributed.py │ ├── nccl_test_distributed.py │ ├── ring_test_distributed.py │ ├── test_array.py │ ├── test_autograd.py │ ├── test_bf16.py │ ├── test_blas.py │ ├── test_compile.py │ ├── test_constants.py │ ├── test_conv.py │ ├── test_conv_transpose.py │ ├── test_device.py │ ├── test_double.py │ ├── test_einsum.py │ ├── test_eval.py │ ├── test_export_import.py │ ├── test_fast.py │ ├── test_fast_sdpa.py │ ├── test_fft.py │ ├── test_graph.py │ ├── test_init.py │ ├── test_linalg.py │ ├── test_load.py │ ├── test_losses.py │ ├── test_memory.py │ ├── test_nn.py │ ├── test_ops.py │ ├── test_optimizers.py │ ├── test_quantized.py │ ├── test_random.py │ ├── test_reduce.py │ ├── test_tree.py │ ├── test_upsample.py │ └── test_vmap.py ├── setup.py └── tests/ ├── CMakeLists.txt ├── allocator_tests.cpp ├── arg_reduce_tests.cpp ├── array_tests.cpp ├── autograd_tests.cpp ├── blas_tests.cpp ├── compile_tests.cpp ├── creations_tests.cpp ├── custom_vjp_tests.cpp ├── device_tests.cpp ├── einsum_tests.cpp ├── eval_tests.cpp ├── export_import_tests.cpp ├── fft_tests.cpp ├── gpu_tests.cpp ├── linalg_tests.cpp ├── load_tests.cpp ├── ops_tests.cpp ├── random_tests.cpp ├── scheduler_tests.cpp ├── test_teardown.cpp ├── tests.cpp ├── utils_tests.cpp └── vmap_tests.cpp