Showing preview only (6,725K chars total). Download the full file or copy to clipboard to get everything.
Repository: ml-explore/mlx
Branch: main
Commit: 70a0da6fca8a
Files: 879
Total size: 6.3 MB
Directory structure:
gitextract_c0zbkz84/
├── .clang-format
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ ├── actions/
│ │ ├── build-cuda-release/
│ │ │ └── action.yml
│ │ ├── build-docs/
│ │ │ └── action.yml
│ │ ├── build-linux/
│ │ │ └── action.yml
│ │ ├── build-linux-release/
│ │ │ └── action.yml
│ │ ├── build-macos/
│ │ │ └── action.yml
│ │ ├── build-macos-release/
│ │ │ └── action.yml
│ │ ├── build-windows/
│ │ │ └── action.yml
│ │ ├── setup-linux/
│ │ │ └── action.yml
│ │ ├── setup-macos/
│ │ │ └── action.yml
│ │ ├── setup-windows/
│ │ │ └── action.yml
│ │ ├── test-linux/
│ │ │ └── action.yml
│ │ └── test-windows/
│ │ └── action.yml
│ ├── dependabot.yml
│ ├── pull_request_template.md
│ ├── scripts/
│ │ ├── build-sanitizer-tests.sh
│ │ └── setup+build-cpp-linux-fedora-container.sh
│ └── workflows/
│ ├── build_and_test.yml
│ ├── documentation.yml
│ ├── nightly.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── ACKNOWLEDGMENTS.md
├── CITATION.cff
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── benchmarks/
│ ├── cpp/
│ │ ├── CMakeLists.txt
│ │ ├── autograd.cpp
│ │ ├── compare_devices.cpp
│ │ ├── irregular_strides.cpp
│ │ ├── single_ops.cpp
│ │ └── time_utils.h
│ ├── numpy/
│ │ ├── single_ops.py
│ │ └── time_utils.py
│ └── python/
│ ├── batch_matmul_bench.py
│ ├── blas/
│ │ ├── bench_gemm.py
│ │ └── bench_gemv.py
│ ├── comparative/
│ │ ├── README.md
│ │ ├── bench_mlx.py
│ │ ├── bench_torch.py
│ │ └── compare.py
│ ├── compile_bench.py
│ ├── conv1d_bench.py
│ ├── conv2d_bench_cpu.py
│ ├── conv2d_train_bench_cpu.py
│ ├── conv2d_transpose_bench_cpu.py
│ ├── conv3d_bench.py
│ ├── conv3d_bench_cpu.py
│ ├── conv3d_train_bench_cpu.py
│ ├── conv3d_transpose_bench_cpu.py
│ ├── conv_bench.py
│ ├── conv_transpose_bench.py
│ ├── conv_unaligned_bench.py
│ ├── distributed_bench.py
│ ├── einsum_bench.py
│ ├── fft_bench.py
│ ├── gather_bench.py
│ ├── gather_mm_bench.py
│ ├── gather_qmm_bench.py
│ ├── hadamard_bench.py
│ ├── large_gemm_bench.py
│ ├── layer_norm_bench.py
│ ├── masked_scatter.py
│ ├── rms_norm_bench.py
│ ├── rope_bench.py
│ ├── scatter_bench.py
│ ├── sdpa_bench.py
│ ├── sdpa_vector_bench.py
│ ├── segmented_mm_bench.py
│ ├── single_ops.py
│ ├── slice_update_bench.py
│ ├── synchronize_bench.py
│ └── time_utils.py
├── cmake/
│ ├── FindCUDNN.cmake
│ ├── FindNCCL.cmake
│ ├── Findnvpl.cmake
│ └── extension.cmake
├── docs/
│ ├── .clang-format
│ ├── .gitignore
│ ├── .nojekyll
│ ├── Doxyfile
│ ├── Makefile
│ ├── README.md
│ ├── index.html
│ ├── requirements.txt
│ └── src/
│ ├── _templates/
│ │ ├── module-base-class.rst
│ │ ├── nn-module-template.rst
│ │ └── optimizers-template.rst
│ ├── conf.py
│ ├── cpp/
│ │ └── ops.rst
│ ├── dev/
│ │ ├── custom_metal_kernels.rst
│ │ ├── extensions.rst
│ │ ├── metal_debugger.rst
│ │ ├── metal_logging.rst
│ │ └── mlx_in_cpp.rst
│ ├── examples/
│ │ ├── data_parallelism.rst
│ │ ├── linear_regression.rst
│ │ ├── llama-inference.rst
│ │ ├── mlp.rst
│ │ └── tensor_parallelism.rst
│ ├── index.rst
│ ├── install.rst
│ ├── python/
│ │ ├── array.rst
│ │ ├── cuda.rst
│ │ ├── data_types.rst
│ │ ├── devices_and_streams.rst
│ │ ├── distributed.rst
│ │ ├── export.rst
│ │ ├── fast.rst
│ │ ├── fft.rst
│ │ ├── linalg.rst
│ │ ├── memory_management.rst
│ │ ├── metal.rst
│ │ ├── nn/
│ │ │ ├── distributed.rst
│ │ │ ├── functions.rst
│ │ │ ├── init.rst
│ │ │ ├── layers.rst
│ │ │ ├── losses.rst
│ │ │ └── module.rst
│ │ ├── nn.rst
│ │ ├── ops.rst
│ │ ├── optimizers/
│ │ │ ├── common_optimizers.rst
│ │ │ ├── optimizer.rst
│ │ │ └── schedulers.rst
│ │ ├── optimizers.rst
│ │ ├── random.rst
│ │ ├── transforms.rst
│ │ └── tree_utils.rst
│ └── usage/
│ ├── compile.rst
│ ├── distributed.rst
│ ├── export.rst
│ ├── function_transforms.rst
│ ├── indexing.rst
│ ├── launching_distributed.rst
│ ├── lazy_evaluation.rst
│ ├── numpy.rst
│ ├── quick_start.rst
│ ├── saving_and_loading.rst
│ ├── unified_memory.rst
│ └── using_streams.rst
├── examples/
│ ├── cmake_project/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ └── example.cpp
│ ├── cpp/
│ │ ├── CMakeLists.txt
│ │ ├── distributed.cpp
│ │ ├── linear_regression.cpp
│ │ ├── logistic_regression.cpp
│ │ ├── metal_capture.cpp
│ │ ├── timer.h
│ │ └── tutorial.cpp
│ ├── export/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ ├── eval_mlp.cpp
│ │ ├── eval_mlp.py
│ │ ├── train_mlp.cpp
│ │ └── train_mlp.py
│ ├── extensions/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ ├── axpby/
│ │ │ ├── axpby.cpp
│ │ │ ├── axpby.h
│ │ │ └── axpby.metal
│ │ ├── bindings.cpp
│ │ ├── mlx_sample_extensions/
│ │ │ └── __init__.py
│ │ ├── pyproject.toml
│ │ ├── requirements.txt
│ │ ├── setup.py
│ │ └── test.py
│ └── python/
│ ├── linear_regression.py
│ ├── logistic_regression.py
│ └── qqmm.py
├── mlx/
│ ├── 3rdparty/
│ │ ├── .clang-format
│ │ └── pocketfft.h
│ ├── CMakeLists.txt
│ ├── allocator.h
│ ├── api.h
│ ├── array.cpp
│ ├── array.h
│ ├── backend/
│ │ ├── common/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── binary.h
│ │ │ ├── broadcasting.cpp
│ │ │ ├── broadcasting.h
│ │ │ ├── buffer_cache.h
│ │ │ ├── common.cpp
│ │ │ ├── compiled.cpp
│ │ │ ├── compiled.h
│ │ │ ├── copy.h
│ │ │ ├── hadamard.h
│ │ │ ├── load.cpp
│ │ │ ├── matmul.h
│ │ │ ├── quantized.h
│ │ │ ├── reduce.cpp
│ │ │ ├── reduce.h
│ │ │ ├── slicing.cpp
│ │ │ ├── slicing.h
│ │ │ ├── ternary.h
│ │ │ ├── unary.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── cpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── arange.h
│ │ │ ├── arg_reduce.cpp
│ │ │ ├── binary.cpp
│ │ │ ├── binary.h
│ │ │ ├── binary_ops.h
│ │ │ ├── binary_two.h
│ │ │ ├── cholesky.cpp
│ │ │ ├── compiled.cpp
│ │ │ ├── compiled_preamble.h
│ │ │ ├── conv.cpp
│ │ │ ├── copy.cpp
│ │ │ ├── copy.h
│ │ │ ├── device_info.cpp
│ │ │ ├── device_info.h
│ │ │ ├── distributed.cpp
│ │ │ ├── eig.cpp
│ │ │ ├── eigh.cpp
│ │ │ ├── encoder.cpp
│ │ │ ├── encoder.h
│ │ │ ├── eval.cpp
│ │ │ ├── eval.h
│ │ │ ├── fft.cpp
│ │ │ ├── gemm.h
│ │ │ ├── gemms/
│ │ │ │ ├── bnns.cpp
│ │ │ │ ├── cblas.cpp
│ │ │ │ ├── simd_bf16.cpp
│ │ │ │ ├── simd_fp16.cpp
│ │ │ │ └── simd_gemm.h
│ │ │ ├── hadamard.cpp
│ │ │ ├── indexing.cpp
│ │ │ ├── inverse.cpp
│ │ │ ├── jit_compiler.cpp
│ │ │ ├── jit_compiler.h
│ │ │ ├── lapack.h
│ │ │ ├── logsumexp.cpp
│ │ │ ├── luf.cpp
│ │ │ ├── make_compiled_preamble.ps1
│ │ │ ├── make_compiled_preamble.sh
│ │ │ ├── masked_mm.cpp
│ │ │ ├── matmul.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── qrf.cpp
│ │ │ ├── quantized.cpp
│ │ │ ├── reduce.cpp
│ │ │ ├── scan.cpp
│ │ │ ├── select.cpp
│ │ │ ├── simd/
│ │ │ │ ├── accelerate_fp16_simd.h
│ │ │ │ ├── accelerate_simd.h
│ │ │ │ ├── base_simd.h
│ │ │ │ ├── math.h
│ │ │ │ ├── neon_fp16_simd.h
│ │ │ │ ├── simd.h
│ │ │ │ └── type.h
│ │ │ ├── slicing.h
│ │ │ ├── softmax.cpp
│ │ │ ├── sort.cpp
│ │ │ ├── svd.cpp
│ │ │ ├── ternary.h
│ │ │ ├── threefry.cpp
│ │ │ ├── threefry.h
│ │ │ ├── unary.cpp
│ │ │ ├── unary.h
│ │ │ └── unary_ops.h
│ │ ├── cuda/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── allocator.cpp
│ │ │ ├── allocator.h
│ │ │ ├── arange.cu
│ │ │ ├── arg_reduce.cu
│ │ │ ├── bin2h.cmake
│ │ │ ├── binary/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── add.cu
│ │ │ │ ├── arctan2.cu
│ │ │ │ ├── binary.cuh
│ │ │ │ ├── bitwise_binary.cu
│ │ │ │ ├── divide.cu
│ │ │ │ ├── equal.cu
│ │ │ │ ├── greater.cu
│ │ │ │ ├── greater_equal.cu
│ │ │ │ ├── less.cu
│ │ │ │ ├── less_equal.cu
│ │ │ │ ├── log_add_exp.cu
│ │ │ │ ├── logical_and.cu
│ │ │ │ ├── logical_or.cu
│ │ │ │ ├── maximum.cu
│ │ │ │ ├── minimum.cu
│ │ │ │ ├── multiply.cu
│ │ │ │ ├── not_equal.cu
│ │ │ │ ├── power.cu
│ │ │ │ ├── remainder.cu
│ │ │ │ └── subtract.cu
│ │ │ ├── binary_two.cu
│ │ │ ├── compiled.cpp
│ │ │ ├── conv/
│ │ │ │ ├── conv.h
│ │ │ │ ├── gemm_conv.cu
│ │ │ │ └── gemm_grouped_conv.cu
│ │ │ ├── conv.cpp
│ │ │ ├── copy/
│ │ │ │ ├── copy.cuh
│ │ │ │ ├── copy_contiguous.cu
│ │ │ │ ├── copy_general.cu
│ │ │ │ ├── copy_general_dynamic.cu
│ │ │ │ └── copy_general_input.cu
│ │ │ ├── copy.cu
│ │ │ ├── cublas_utils.cpp
│ │ │ ├── cublas_utils.h
│ │ │ ├── cuda.h
│ │ │ ├── cuda_utils.h
│ │ │ ├── cudnn_utils.cpp
│ │ │ ├── cudnn_utils.h
│ │ │ ├── custom_kernel.cpp
│ │ │ ├── cutlass_utils.cuh
│ │ │ ├── delayload.cpp
│ │ │ ├── device/
│ │ │ │ ├── atomic_ops.cuh
│ │ │ │ ├── binary_ops.cuh
│ │ │ │ ├── cast_op.cuh
│ │ │ │ ├── complex.cuh
│ │ │ │ ├── config.h
│ │ │ │ ├── fp16_math.cuh
│ │ │ │ ├── gather.cuh
│ │ │ │ ├── gather_axis.cuh
│ │ │ │ ├── hadamard.cuh
│ │ │ │ ├── indexing.cuh
│ │ │ │ ├── scatter.cuh
│ │ │ │ ├── scatter_axis.cuh
│ │ │ │ ├── scatter_ops.cuh
│ │ │ │ ├── slice_update.cuh
│ │ │ │ ├── ternary_ops.cuh
│ │ │ │ ├── unary_ops.cuh
│ │ │ │ └── utils.cuh
│ │ │ ├── device.cpp
│ │ │ ├── device.h
│ │ │ ├── device_info.cpp
│ │ │ ├── distributed.cu
│ │ │ ├── eval.cpp
│ │ │ ├── event.cu
│ │ │ ├── event.h
│ │ │ ├── fence.cpp
│ │ │ ├── fft.cu
│ │ │ ├── gemms/
│ │ │ │ ├── cublas_gemm.cpp
│ │ │ │ ├── cublas_gemm.h
│ │ │ │ ├── cublas_gemm_batched_12_0.cpp
│ │ │ │ ├── cublas_gemm_batched_12_9.cu
│ │ │ │ ├── gemv.cu
│ │ │ │ ├── gemv.h
│ │ │ │ ├── grouped_gemm.h
│ │ │ │ └── grouped_gemm_unaligned.cu
│ │ │ ├── hadamard.cu
│ │ │ ├── indexing.cpp
│ │ │ ├── jit_module.cpp
│ │ │ ├── jit_module.h
│ │ │ ├── kernel_utils.cu
│ │ │ ├── kernel_utils.cuh
│ │ │ ├── layer_norm.cu
│ │ │ ├── load.cpp
│ │ │ ├── logsumexp.cu
│ │ │ ├── lru_cache.h
│ │ │ ├── matmul.cpp
│ │ │ ├── no_cuda.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── quantized/
│ │ │ │ ├── affine_quantize.cu
│ │ │ │ ├── convert_fp8.cu
│ │ │ │ ├── cublas_qqmm.cpp
│ │ │ │ ├── cublas_qqmm.h
│ │ │ │ ├── fp_quantize.cu
│ │ │ │ ├── mxfp8_quantize.cuh
│ │ │ │ ├── no_qqmm_impl.cpp
│ │ │ │ ├── nvfp4_quantize.cuh
│ │ │ │ ├── qmm/
│ │ │ │ │ ├── CMakeLists.txt
│ │ │ │ │ ├── fp_qmv.cu
│ │ │ │ │ ├── qmm.cu
│ │ │ │ │ ├── qmm.h
│ │ │ │ │ ├── qmm_impl_sm80.cuh
│ │ │ │ │ ├── qmm_impl_sm80_m16.cu
│ │ │ │ │ ├── qmm_impl_sm80_m32.cu
│ │ │ │ │ ├── qmm_impl_sm80_m64.cu
│ │ │ │ │ ├── qmm_impl_sm90.cuh
│ │ │ │ │ ├── qmm_impl_sm90_m128_n128_m2.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n16_m1.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n256_m2.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n32_m1.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n64_m2.cu
│ │ │ │ │ └── qmv.cu
│ │ │ │ ├── qqmm.cpp
│ │ │ │ ├── qqmm_impl.cpp
│ │ │ │ ├── qqmm_impl.h
│ │ │ │ ├── qqmm_utils.cu
│ │ │ │ ├── qqmm_utils.h
│ │ │ │ ├── quantized.cpp
│ │ │ │ ├── quantized.h
│ │ │ │ └── quantized_utils.h
│ │ │ ├── random.cu
│ │ │ ├── reduce/
│ │ │ │ ├── all_reduce.cu
│ │ │ │ ├── col_reduce.cu
│ │ │ │ ├── init_reduce.cu
│ │ │ │ ├── reduce.cuh
│ │ │ │ ├── reduce_ops.cuh
│ │ │ │ ├── reduce_utils.cuh
│ │ │ │ └── row_reduce.cu
│ │ │ ├── reduce.cu
│ │ │ ├── rms_norm.cu
│ │ │ ├── rope.cu
│ │ │ ├── scaled_dot_product_attention.cpp
│ │ │ ├── scaled_dot_product_attention.cu
│ │ │ ├── scan.cu
│ │ │ ├── slicing.cpp
│ │ │ ├── softmax.cu
│ │ │ ├── sort.cu
│ │ │ ├── steel/
│ │ │ │ ├── defines.cuh
│ │ │ │ ├── gemm.cuh
│ │ │ │ ├── mma.cuh
│ │ │ │ ├── tiles.cuh
│ │ │ │ └── utils.cuh
│ │ │ ├── ternary.cu
│ │ │ ├── unary/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── abs.cu
│ │ │ │ ├── arccos.cu
│ │ │ │ ├── arccosh.cu
│ │ │ │ ├── arcsin.cu
│ │ │ │ ├── arcsinh.cu
│ │ │ │ ├── arctan.cu
│ │ │ │ ├── arctanh.cu
│ │ │ │ ├── bitwise_invert.cu
│ │ │ │ ├── ceil.cu
│ │ │ │ ├── conjugate.cu
│ │ │ │ ├── cos.cu
│ │ │ │ ├── cosh.cu
│ │ │ │ ├── erf.cu
│ │ │ │ ├── erf_inv.cu
│ │ │ │ ├── exp.cu
│ │ │ │ ├── expm1.cu
│ │ │ │ ├── floor.cu
│ │ │ │ ├── imag.cu
│ │ │ │ ├── log.cu
│ │ │ │ ├── log1p.cu
│ │ │ │ ├── logical_not.cu
│ │ │ │ ├── negative.cu
│ │ │ │ ├── real.cu
│ │ │ │ ├── round.cu
│ │ │ │ ├── sigmoid.cu
│ │ │ │ ├── sign.cu
│ │ │ │ ├── sin.cu
│ │ │ │ ├── sinh.cu
│ │ │ │ ├── sqrt.cu
│ │ │ │ ├── square.cu
│ │ │ │ ├── tan.cu
│ │ │ │ ├── tanh.cu
│ │ │ │ └── unary.cuh
│ │ │ ├── utils.cpp
│ │ │ ├── utils.h
│ │ │ ├── vector_types.cuh
│ │ │ ├── worker.cpp
│ │ │ └── worker.h
│ │ ├── gpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── copy.cpp
│ │ │ ├── copy.h
│ │ │ ├── device_info.h
│ │ │ ├── eval.h
│ │ │ ├── primitives.cpp
│ │ │ ├── scan.h
│ │ │ ├── slicing.cpp
│ │ │ └── slicing.h
│ │ ├── metal/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── allocator.cpp
│ │ │ ├── allocator.h
│ │ │ ├── binary.cpp
│ │ │ ├── binary.h
│ │ │ ├── compiled.cpp
│ │ │ ├── conv.cpp
│ │ │ ├── copy.cpp
│ │ │ ├── custom_kernel.cpp
│ │ │ ├── device.cpp
│ │ │ ├── device.h
│ │ │ ├── device_info.cpp
│ │ │ ├── distributed.cpp
│ │ │ ├── eval.cpp
│ │ │ ├── event.cpp
│ │ │ ├── fence.cpp
│ │ │ ├── fft.cpp
│ │ │ ├── hadamard.cpp
│ │ │ ├── indexing.cpp
│ │ │ ├── jit/
│ │ │ │ ├── includes.h
│ │ │ │ └── indexing.h
│ │ │ ├── jit_kernels.cpp
│ │ │ ├── kernels/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── arange.h
│ │ │ │ ├── arange.metal
│ │ │ │ ├── arg_reduce.metal
│ │ │ │ ├── atomic.h
│ │ │ │ ├── bf16.h
│ │ │ │ ├── bf16_math.h
│ │ │ │ ├── binary.h
│ │ │ │ ├── binary.metal
│ │ │ │ ├── binary_ops.h
│ │ │ │ ├── binary_two.h
│ │ │ │ ├── binary_two.metal
│ │ │ │ ├── cexpf.h
│ │ │ │ ├── complex.h
│ │ │ │ ├── conv.metal
│ │ │ │ ├── copy.h
│ │ │ │ ├── copy.metal
│ │ │ │ ├── defines.h
│ │ │ │ ├── erf.h
│ │ │ │ ├── expm1f.h
│ │ │ │ ├── fence.metal
│ │ │ │ ├── fft/
│ │ │ │ │ ├── radix.h
│ │ │ │ │ └── readwrite.h
│ │ │ │ ├── fft.h
│ │ │ │ ├── fft.metal
│ │ │ │ ├── fp4.h
│ │ │ │ ├── fp8.h
│ │ │ │ ├── fp_quantized.h
│ │ │ │ ├── fp_quantized.metal
│ │ │ │ ├── fp_quantized_nax.h
│ │ │ │ ├── fp_quantized_nax.metal
│ │ │ │ ├── gemv.metal
│ │ │ │ ├── gemv_masked.h
│ │ │ │ ├── gemv_masked.metal
│ │ │ │ ├── hadamard.h
│ │ │ │ ├── indexing/
│ │ │ │ │ ├── gather.h
│ │ │ │ │ ├── gather_axis.h
│ │ │ │ │ ├── gather_front.h
│ │ │ │ │ ├── indexing.h
│ │ │ │ │ ├── masked_scatter.h
│ │ │ │ │ ├── scatter.h
│ │ │ │ │ └── scatter_axis.h
│ │ │ │ ├── layer_norm.metal
│ │ │ │ ├── logging.h
│ │ │ │ ├── logsumexp.h
│ │ │ │ ├── logsumexp.metal
│ │ │ │ ├── quantized.h
│ │ │ │ ├── quantized.metal
│ │ │ │ ├── quantized_nax.h
│ │ │ │ ├── quantized_nax.metal
│ │ │ │ ├── quantized_utils.h
│ │ │ │ ├── random.metal
│ │ │ │ ├── reduce.h
│ │ │ │ ├── reduce.metal
│ │ │ │ ├── reduce_utils.h
│ │ │ │ ├── reduction/
│ │ │ │ │ ├── ops.h
│ │ │ │ │ ├── reduce_all.h
│ │ │ │ │ ├── reduce_col.h
│ │ │ │ │ ├── reduce_init.h
│ │ │ │ │ └── reduce_row.h
│ │ │ │ ├── rms_norm.metal
│ │ │ │ ├── rope.metal
│ │ │ │ ├── scaled_dot_product_attention.metal
│ │ │ │ ├── scan.h
│ │ │ │ ├── scan.metal
│ │ │ │ ├── sdpa_vector.h
│ │ │ │ ├── softmax.h
│ │ │ │ ├── softmax.metal
│ │ │ │ ├── sort.h
│ │ │ │ ├── sort.metal
│ │ │ │ ├── steel/
│ │ │ │ │ ├── attn/
│ │ │ │ │ │ ├── attn.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_attention.h
│ │ │ │ │ │ │ ├── steel_attention.metal
│ │ │ │ │ │ │ ├── steel_attention_nax.h
│ │ │ │ │ │ │ └── steel_attention_nax.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── mma.h
│ │ │ │ │ │ ├── nax.h
│ │ │ │ │ │ ├── params.h
│ │ │ │ │ │ └── transforms.h
│ │ │ │ │ ├── conv/
│ │ │ │ │ │ ├── conv.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_conv.h
│ │ │ │ │ │ │ ├── steel_conv.metal
│ │ │ │ │ │ │ ├── steel_conv_3d.h
│ │ │ │ │ │ │ ├── steel_conv_3d.metal
│ │ │ │ │ │ │ ├── steel_conv_general.h
│ │ │ │ │ │ │ └── steel_conv_general.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── loaders/
│ │ │ │ │ │ │ ├── loader_channel_l.h
│ │ │ │ │ │ │ ├── loader_channel_n.h
│ │ │ │ │ │ │ └── loader_general.h
│ │ │ │ │ │ └── params.h
│ │ │ │ │ ├── defines.h
│ │ │ │ │ ├── gemm/
│ │ │ │ │ │ ├── gemm.h
│ │ │ │ │ │ ├── gemm_nax.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_gemm_fused.h
│ │ │ │ │ │ │ ├── steel_gemm_fused.metal
│ │ │ │ │ │ │ ├── steel_gemm_fused_nax.h
│ │ │ │ │ │ │ ├── steel_gemm_fused_nax.metal
│ │ │ │ │ │ │ ├── steel_gemm_gather.h
│ │ │ │ │ │ │ ├── steel_gemm_gather.metal
│ │ │ │ │ │ │ ├── steel_gemm_gather_nax.h
│ │ │ │ │ │ │ ├── steel_gemm_gather_nax.metal
│ │ │ │ │ │ │ ├── steel_gemm_masked.h
│ │ │ │ │ │ │ ├── steel_gemm_masked.metal
│ │ │ │ │ │ │ ├── steel_gemm_segmented.h
│ │ │ │ │ │ │ ├── steel_gemm_segmented.metal
│ │ │ │ │ │ │ ├── steel_gemm_splitk.h
│ │ │ │ │ │ │ ├── steel_gemm_splitk.metal
│ │ │ │ │ │ │ ├── steel_gemm_splitk_nax.h
│ │ │ │ │ │ │ └── steel_gemm_splitk_nax.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── mma.h
│ │ │ │ │ │ ├── nax.h
│ │ │ │ │ │ ├── params.h
│ │ │ │ │ │ └── transforms.h
│ │ │ │ │ ├── utils/
│ │ │ │ │ │ ├── integral_constant.h
│ │ │ │ │ │ └── type_traits.h
│ │ │ │ │ └── utils.h
│ │ │ │ ├── ternary.h
│ │ │ │ ├── ternary.metal
│ │ │ │ ├── ternary_ops.h
│ │ │ │ ├── unary.h
│ │ │ │ ├── unary.metal
│ │ │ │ ├── unary_ops.h
│ │ │ │ └── utils.h
│ │ │ ├── kernels.h
│ │ │ ├── logsumexp.cpp
│ │ │ ├── make_compiled_preamble.sh
│ │ │ ├── matmul.cpp
│ │ │ ├── matmul.h
│ │ │ ├── metal.cpp
│ │ │ ├── metal.h
│ │ │ ├── no_metal.cpp
│ │ │ ├── nojit_kernels.cpp
│ │ │ ├── normalization.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── quantized.cpp
│ │ │ ├── reduce.cpp
│ │ │ ├── reduce.h
│ │ │ ├── resident.cpp
│ │ │ ├── resident.h
│ │ │ ├── rope.cpp
│ │ │ ├── scaled_dot_product_attention.cpp
│ │ │ ├── scan.cpp
│ │ │ ├── slicing.cpp
│ │ │ ├── softmax.cpp
│ │ │ ├── sort.cpp
│ │ │ ├── ternary.cpp
│ │ │ ├── ternary.h
│ │ │ ├── unary.cpp
│ │ │ ├── unary.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── no_cpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── compiled.cpp
│ │ │ ├── device_info.cpp
│ │ │ └── primitives.cpp
│ │ └── no_gpu/
│ │ ├── CMakeLists.txt
│ │ ├── allocator.cpp
│ │ ├── apple_memory.h
│ │ ├── device_info.cpp
│ │ ├── eval.cpp
│ │ ├── event.cpp
│ │ ├── fence.cpp
│ │ ├── linux_memory.h
│ │ └── primitives.cpp
│ ├── compile.cpp
│ ├── compile.h
│ ├── compile_impl.h
│ ├── device.cpp
│ ├── device.h
│ ├── distributed/
│ │ ├── CMakeLists.txt
│ │ ├── distributed.cpp
│ │ ├── distributed.h
│ │ ├── distributed_impl.h
│ │ ├── jaccl/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── jaccl.cpp
│ │ │ ├── jaccl.h
│ │ │ ├── mesh.cpp
│ │ │ ├── mesh.h
│ │ │ ├── mesh_impl.h
│ │ │ ├── no_jaccl.cpp
│ │ │ ├── ring.cpp
│ │ │ ├── ring.h
│ │ │ ├── ring_impl.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── mpi/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── mpi.cpp
│ │ │ ├── mpi.h
│ │ │ ├── mpi_declarations.h
│ │ │ └── no_mpi.cpp
│ │ ├── nccl/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── nccl.cpp
│ │ │ ├── nccl.h
│ │ │ └── no_nccl.cpp
│ │ ├── ops.cpp
│ │ ├── ops.h
│ │ ├── primitives.cpp
│ │ ├── primitives.h
│ │ ├── reduction_ops.h
│ │ ├── ring/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── no_ring.cpp
│ │ │ ├── ring.cpp
│ │ │ └── ring.h
│ │ ├── utils.cpp
│ │ └── utils.h
│ ├── dtype.cpp
│ ├── dtype.h
│ ├── dtype_utils.cpp
│ ├── dtype_utils.h
│ ├── einsum.cpp
│ ├── einsum.h
│ ├── event.h
│ ├── export.cpp
│ ├── export.h
│ ├── export_impl.h
│ ├── fast.cpp
│ ├── fast.h
│ ├── fast_primitives.h
│ ├── fence.h
│ ├── fft.cpp
│ ├── fft.h
│ ├── graph_utils.cpp
│ ├── graph_utils.h
│ ├── io/
│ │ ├── CMakeLists.txt
│ │ ├── gguf.cpp
│ │ ├── gguf.h
│ │ ├── gguf_quants.cpp
│ │ ├── load.cpp
│ │ ├── load.h
│ │ ├── no_gguf.cpp
│ │ ├── no_safetensors.cpp
│ │ └── safetensors.cpp
│ ├── io.h
│ ├── linalg.cpp
│ ├── linalg.h
│ ├── memory.h
│ ├── mlx.h
│ ├── ops.cpp
│ ├── ops.h
│ ├── primitives.cpp
│ ├── primitives.h
│ ├── random.cpp
│ ├── random.h
│ ├── scheduler.cpp
│ ├── scheduler.h
│ ├── small_vector.h
│ ├── stream.h
│ ├── threadpool.h
│ ├── transforms.cpp
│ ├── transforms.h
│ ├── transforms_impl.h
│ ├── types/
│ │ ├── bf16.h
│ │ ├── complex.h
│ │ ├── fp16.h
│ │ ├── half_types.h
│ │ └── limits.h
│ ├── utils.cpp
│ ├── utils.h
│ ├── version.cpp
│ └── version.h
├── mlx.pc.in
├── pyproject.toml
├── python/
│ ├── mlx/
│ │ ├── __main__.py
│ │ ├── _distributed_utils/
│ │ │ ├── common.py
│ │ │ ├── config.py
│ │ │ └── launch.py
│ │ ├── _reprlib_fix.py
│ │ ├── _stub_patterns.txt
│ │ ├── extension.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── init.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activations.py
│ │ │ │ ├── base.py
│ │ │ │ ├── containers.py
│ │ │ │ ├── convolution.py
│ │ │ │ ├── convolution_transpose.py
│ │ │ │ ├── distributed.py
│ │ │ │ ├── dropout.py
│ │ │ │ ├── embedding.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── normalization.py
│ │ │ │ ├── pooling.py
│ │ │ │ ├── positional_encoding.py
│ │ │ │ ├── quantized.py
│ │ │ │ ├── recurrent.py
│ │ │ │ ├── transformer.py
│ │ │ │ └── upsample.py
│ │ │ ├── losses.py
│ │ │ └── utils.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── optimizers.py
│ │ │ └── schedulers.py
│ │ ├── py.typed
│ │ └── utils.py
│ ├── src/
│ │ ├── CMakeLists.txt
│ │ ├── array.cpp
│ │ ├── buffer.h
│ │ ├── constants.cpp
│ │ ├── convert.cpp
│ │ ├── convert.h
│ │ ├── cuda.cpp
│ │ ├── device.cpp
│ │ ├── distributed.cpp
│ │ ├── export.cpp
│ │ ├── fast.cpp
│ │ ├── fft.cpp
│ │ ├── indexing.cpp
│ │ ├── indexing.h
│ │ ├── linalg.cpp
│ │ ├── load.cpp
│ │ ├── load.h
│ │ ├── memory.cpp
│ │ ├── metal.cpp
│ │ ├── mlx.cpp
│ │ ├── mlx_func.cpp
│ │ ├── mlx_func.h
│ │ ├── ops.cpp
│ │ ├── random.cpp
│ │ ├── small_vector.h
│ │ ├── stream.cpp
│ │ ├── transforms.cpp
│ │ ├── trees.cpp
│ │ ├── trees.h
│ │ ├── utils.cpp
│ │ └── utils.h
│ └── tests/
│ ├── __main__.py
│ ├── cuda_skip.py
│ ├── mlx_distributed_tests.py
│ ├── mlx_tests.py
│ ├── mpi_test_distributed.py
│ ├── nccl_test_distributed.py
│ ├── ring_test_distributed.py
│ ├── test_array.py
│ ├── test_autograd.py
│ ├── test_bf16.py
│ ├── test_blas.py
│ ├── test_compile.py
│ ├── test_constants.py
│ ├── test_conv.py
│ ├── test_conv_transpose.py
│ ├── test_device.py
│ ├── test_double.py
│ ├── test_einsum.py
│ ├── test_eval.py
│ ├── test_export_import.py
│ ├── test_fast.py
│ ├── test_fast_sdpa.py
│ ├── test_fft.py
│ ├── test_graph.py
│ ├── test_init.py
│ ├── test_linalg.py
│ ├── test_load.py
│ ├── test_losses.py
│ ├── test_memory.py
│ ├── test_nn.py
│ ├── test_ops.py
│ ├── test_optimizers.py
│ ├── test_quantized.py
│ ├── test_random.py
│ ├── test_reduce.py
│ ├── test_tree.py
│ ├── test_upsample.py
│ └── test_vmap.py
├── setup.py
└── tests/
├── CMakeLists.txt
├── allocator_tests.cpp
├── arg_reduce_tests.cpp
├── array_tests.cpp
├── autograd_tests.cpp
├── blas_tests.cpp
├── compile_tests.cpp
├── creations_tests.cpp
├── custom_vjp_tests.cpp
├── device_tests.cpp
├── einsum_tests.cpp
├── eval_tests.cpp
├── export_import_tests.cpp
├── fft_tests.cpp
├── gpu_tests.cpp
├── linalg_tests.cpp
├── load_tests.cpp
├── ops_tests.cpp
├── random_tests.cpp
├── scheduler_tests.cpp
├── test_teardown.cpp
├── tests.cpp
├── utils_tests.cpp
└── vmap_tests.cpp
================================================
FILE CONTENTS
================================================
================================================
FILE: .clang-format
================================================
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report about an issue you've encountered
title: "[BUG] "
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Include code snippet
```python
```
**Expected behavior**
A clear and concise description of what you expected to happen.
**Desktop (please complete the following information):**
- OS Version: [e.g. MacOS 14.1.2]
- Version [e.g. 0.7.0]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/actions/build-cuda-release/action.yml
================================================
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cuda*.whl \
--plat manylinux_2_35_${{ inputs.arch }} \
--exclude libcublas* \
--exclude libcuda* \
--exclude libcudnn* \
--exclude libnccl* \
--exclude libnvrtc*
================================================
FILE: .github/actions/build-docs/action.yml
================================================
name: 'Build Documentation'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: bash
run: |
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error
================================================
FILE: .github/actions/build-linux/action.yml
================================================
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=80"
# Can not build tests and stubs when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF -DMLX_BUILD_PYTHON_STUBS=OFF"
fi
# Install cpu-only torch to save space
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)
================================================
FILE: .github/actions/build-linux-release/action.yml
================================================
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Build MLX
shell: bash
run: pip install -e . -v
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
auditwheel repair dist/mlx-*.whl \
--plat manylinux_2_35_${{ inputs.arch }} \
--exclude libmlx.so* \
--only-plat
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
================================================
FILE: .github/actions/build-macos/action.yml
================================================
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools typing_extensions
pip install -e ".[dev]" -v
- name: Install tests dependencies
shell: bash -l {0}
run: |
pip install tensorflow
- name: Run Python tests
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu python -m unittest discover -v python/tests
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m unittest discover -v python/tests
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash -l {0}
run: |
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
python -m unittest discover -v python/tests
================================================
FILE: .github/actions/build-macos-release/action.yml
================================================
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash -l {0}
env:
DEVELOPER_DIR: /Applications/Xcode-latest.app
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
DEVELOPER_DIR: /Applications/Xcode-latest.app
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
================================================
FILE: .github/actions/build-windows/action.yml
================================================
name: 'Build on Windows'
runs:
using: 'composite'
steps:
- name: Install Python package
id: python-build
shell: cmd
env:
# For MSVC, Ninja/Release is the only config supported by ccache.
CMAKE_ARGS: >-
-G Ninja
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=cl
-DCMAKE_CXX_COMPILER=cl
-DCMAKE_RC_COMPILER=rc
run: |
uv pip install ".[dev]" -v
:: Pass the CMAKE_ARGS to following steps.
>>%GITHUB_OUTPUT% ECHO CMAKE_ARGS=%CMAKE_ARGS%
- name: Build CPP only
shell: cmd
run: |
cmake . -B build ${{ steps.python-build.outputs.CMAKE_ARGS }}
cmake --build build -j %NUMBER_OF_PROCESSORS%
================================================
FILE: .github/actions/setup-linux/action.yml
================================================
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
toolkit:
description: 'Which toolkit to install'
required: false
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
default: '3.14'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Install common dependencies
shell: bash
run: |
echo "::group::Install common dependencies"
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
zip \
libblas-dev liblapack-dev liblapacke-dev \
openmpi-bin openmpi-common libopenmpi-dev
echo "::endgroup::"
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
- name: Setup Python venv
shell: bash
run: |
echo "::group::Setup Python venv"
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake typing_extensions
echo PATH=$PATH >> $GITHUB_ENV
# Search python packages in .venv
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
echo "::endgroup::"
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-compiler-12-6 cuda-libraries-dev-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-compiler-12-9 cuda-libraries-dev-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-compiler-13-0 cuda-libraries-dev-13-0"
}
run: |
echo "::group::Install CUDA toolkit"
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
echo "::endgroup::"
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
echo "::group::Installed NVIDIA and CUDA packages"
dpkg -l | egrep "cuda|nvidia" -i
echo "::endgroup::"
echo "::group::NVIDIA-SMI Status"
nvidia-smi || true
echo "::endgroup::"
================================================
FILE: .github/actions/setup-macos/action.yml
================================================
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}
================================================
FILE: .github/actions/setup-windows/action.yml
================================================
name: 'Setup Windows environment'
inputs:
python-version:
description: 'Version of python to set up'
required: false
default: '3.14'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs:
using: 'composite'
steps:
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-cpu
max-size: 1GB
- name: Setup Visual Studio cmd
shell: cmd
run: |
:: Find out path to VS.
pushd "C:\Program Files (x86)\Microsoft Visual Studio\Installer\"
for /f "delims=" %%x in ('.\vswhere.exe -latest -property InstallationPath') do set VSPATH=%%x
popd
:: Import VS vars.
call "%VSPATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
:: Export to all steps.
>>%GITHUB_ENV% set
- uses: astral-sh/setup-uv@v7
- name: Setup Python venv
shell: cmd
run: |
uv venv --python ${{ inputs.python-version }}
call ".venv/Scripts/activate.bat"
>>%GITHUB_ENV% set
================================================
FILE: .github/actions/test-linux/action.yml
================================================
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*linalg_tests.cpp"
echo "::endgroup::"
================================================
FILE: .github/actions/test-windows/action.yml
================================================
name: 'Run tests on Windows'
runs:
using: 'composite'
steps:
- name: Run Python tests - CPU
shell: bash
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests.exe -tce="*gguf*,test random uniform"
echo "::endgroup::"
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
================================================
FILE: .github/pull_request_template.md
================================================
## Proposed changes
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
## Checklist
Put an `x` in the boxes that apply.
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
================================================
FILE: .github/scripts/build-sanitizer-tests.sh
================================================
#!/bin/bash
set -ex
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
BASE_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=DEBUG -DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
if [[ "$(uname -s)" != "Darwin" ]]; then
BASE_CMAKE_ARGS+=" -DMLX_BUILD_METAL=OFF"
fi
run_test() {
local sanitizer_name=$1
local cmake_sanitizer_flag="-DUSE_${sanitizer_name}=ON"
echo " Running tests with: ${sanitizer_name}"
case "$sanitizer_name" in
ASAN)
export ASAN_OPTIONS="detect_leaks=0"
;;
UBSAN)
export UBSAN_OPTIONS="halt_on_error=0:print_stacktrace=1"
;;
TSAN)
export TSAN_OPTIONS=""
;;
esac
rm -rf build
mkdir -p build
pushd build > /dev/null
cmake .. ${BASE_CMAKE_ARGS} ${cmake_sanitizer_flag}
make -j $(nproc)
./tests/tests
popd > /dev/null
unset ${sanitizer_name}_OPTIONS
}
sanitizer_arg=$(echo "$1" | tr '[:lower:]' '[:upper:]')
if [[ "$sanitizer_arg" == "ASAN" || "$sanitizer_arg" == "UBSAN" || "$sanitizer_arg" == "TSAN" ]]; then
run_test "$sanitizer_arg"
echo " ${sanitizer_arg} test run completed successfully."
else
echo "Error: Invalid sanitizer '$1'. Please use one of: ASAN, UBSAN, TSAN."
exit 1
fi
================================================
FILE: .github/scripts/setup+build-cpp-linux-fedora-container.sh
================================================
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd
================================================
FILE: .github/workflows/build_and_test.yml
================================================
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
- run: df -h
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0", "26.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
windows_build_and_test:
name: Windows (cpu, x86_64)
needs: check_lint
runs-on: windows-2025
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-windows
- uses: ./.github/actions/build-windows
- uses: ./.github/actions/test-windows
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_sanitizer_build_and_test:
name: Linux Sanitizer Tests (${{ matrix.sanitizer }})
needs: check_lint
strategy:
fail-fast: false
matrix:
sanitizer: [ASAN, UBSAN]
# todo 12/16/2025: enable TSAN later + consider enabling ASAN for GPU backend tests.
# sanitizer: [ASAN, UBSAN, TSAN]
runs-on: ubuntu-22.04-arm
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: Install Dependencies
run: |
export DEBIAN_FRONTEND=noninteractive
sudo apt-get update -y
sudo apt-get install -y \
build-essential \
libblas-dev \
liblapacke-dev \
libopenblas-dev \
cmake \
clang \
git
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
- name: Linux Build and Test with ${{ matrix.sanitizer }}
run: |
bash .github/scripts/build-sanitizer-tests.sh ${{ matrix.sanitizer }}
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
================================================
FILE: .github/workflows/documentation.yml
================================================
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
================================================
FILE: .github/workflows/nightly.yml
================================================
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v7
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v7
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
- run: df -h
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
- run: df -h
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 26 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 26.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
arch: 'x86_64'
- name: Upload artifacts
uses: actions/upload-artifact@v7
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda_*.whl
retention-days: 7
================================================
FILE: .github/workflows/release.yml
================================================
name: PyPI Release
on:
push:
tags:
- 'v*'
branches:
- 'test-publish/*'
workflow_dispatch:
inputs:
dry_run:
description: 'Dry run (do not publish to PyPi)'
required: false
type: boolean
dev_release:
description: 'Development release (DEV_RELEASE=1)'
required: false
type: boolean
permissions:
contents: read
jobs:
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
if: ${{ !inputs.dry_run }}
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python_version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v7
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
if-no-files-found: error
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v7
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
if-no-files-found: error
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools typing_extensions
pip install -e . -v
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 26 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 26.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v7
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
if-no-files-found: error
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v7
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
if-no-files-found: error
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ inputs.dev_release && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v7
with:
overwrite: true
name: mlx-${{ matrix.toolkit }}-${{ matrix.arch }}
path: wheelhouse/mlx_cuda_*.whl
if-no-files-found: error
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v8
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v8
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: du -ah dist
- name: Publish package distributions to PyPI
if: ${{ !inputs.dry_run }}
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [build_cuda_release]
permissions:
id-token: write
environment:
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v8
with:
pattern: mlx-cuda-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: du -ah dist
- name: Publish package distributions to PyPI
if: ${{ !inputs.dry_run }}
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [build_linux_release]
permissions:
id-token: write
environment:
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v8
with:
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: du -ah dist
- name: Publish package distributions to PyPI
if: ${{ !inputs.dry_run }}
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [build_mac_release]
permissions:
id-token: write
environment:
name: ${{ inputs.dry_run && 'dry-run' || 'pypi' }}
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v8
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: du -ah dist
- name: Publish package distributions to PyPI
if: ${{ !inputs.dry_run }}
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# tensor files
*.safe
*.safetensors
# Metal libraries
*.metallib
# Distribution / packaging
python/mlx/core
python/mlx/share
python/mlx/include
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
venv/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
uv.lock
.DS_Store
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
*.ilk
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# Debug symbols
*.pdb
# VSCode
.vscode/
# Jetbrains
.cache/
# vim
*.swp
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v21.1.8
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 7.0.0
hooks:
- id: isort
args:
- --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format
================================================
FILE: ACKNOWLEDGMENTS.md
================================================
# Individual Contributors
If you wish to be acknowledged for your contributions, please list your name
with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with
their license copied verbatim.
## PocketFFT
Copyright (C) 2010-2018 Max-Planck-Society
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## metal-cpp
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright © 2023 Apple Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT
================================================
FILE: CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.25)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF)
option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF)
option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF)
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(
FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
if(USE_ASAN AND USE_TSAN)
message(
FATAL_ERROR
"AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time."
)
endif()
set(SANITIZER_COMPILE_FLAGS "")
set(SANITIZER_LINK_FLAGS "")
if(USE_ASAN)
if(WIN32 AND MSVC)
list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)
list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)
else()
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
endif()
endif()
endif()
if(USE_UBSAN)
if(WIN32 AND MSVC)
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
else()
message(
WARNING
"UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC."
)
endif()
else()
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
endif()
endif()
if(USE_TSAN)
if(WIN32 AND MSVC)
message(
FATAL_ERROR
"ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC."
)
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.")
else()
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
endif()
endif()
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
cmake_policy(SET CMP0135 NEW)
add_library(mlx)
target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})
target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})
if(MLX_BUILD_CUDA)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
find_package(CUDNN REQUIRED)
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "13.1" AND CUDAToolkit_VERSION
VERSION_LESS "13.2")
message(FATAL_ERROR "CUDA Toolkit 13.1 is not supported.")
endif()
endif()
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
endif()
# Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.
# This is only done when MLX is built as the top project.
if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.2
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
else()
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(WIN32)
# Download and link prebuilt binaries of OpenBLAS. Note that we can only
# link with the dynamic library, the prebuilt binaries were built with MinGW
# so static-linking would require linking with MinGW's runtime.
FetchContent_Declare(
openblas
URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip"
)
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx
PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib")
target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include")
# Make sure the DLL file is placed in the same dir with executables.
set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll")
add_custom_command(
TARGET mlx
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}
${CMAKE_BINARY_DIR})
else()
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 12.1.0
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
Python 3.10
COMPONENTS Interpreter Development.Module
REQUIRED)
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v2.10.2
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nanobind)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if(MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
if(WIN32)
# Install DLLs to the same dir with extension file (core.pyd) on Windows.
set(CMAKE_INSTALL_BINDIR ".")
if(MLX_BUILD_CPU)
# Install OpenBLAS.
install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)
endif()
endif()
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING
PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies
if(MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source)
endif()
# Install cmake config
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION})
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
MLX_CMAKE_INSTALL_MODULE_DIR)
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to MLX
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
## License
By contributing to MLX, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: LICENSE
================================================
MIT License
Copyright © 2023 Apple Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561
================================================
FILE: README.md
================================================
# MLX
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples)
[](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html),
[PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
[ArrayFire](https://arrayfire.org/).
## Examples
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
variety of examples, including:
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
- Large-scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
## Quickstart
See the [quick start
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
```bash
pip install mlx
```
To install the CUDA backend on Linux, run:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.
## Citing MLX
The MLX software suite was initially developed with equal contribution by Awni
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```text
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
```
================================================
FILE: benchmarks/cpp/CMakeLists.txt
================================================
function(build_benchmark SRCFILE)
get_filename_component(src_name ${SRCFILE} NAME_WE)
set(target "${src_name}")
add_executable(${target} ${SRCFILE})
target_link_libraries(${target} PRIVATE mlx)
endfunction(build_benchmark)
build_benchmark(single_ops.cpp)
build_benchmark(irregular_strides.cpp)
build_benchmark(compare_devices.cpp)
build_benchmark(autograd.cpp)
================================================
FILE: benchmarks/cpp/autograd.cpp
================================================
// Copyright © 2023 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
void time_value_and_grad() {
auto x = mx::ones({200, 1000});
mx::eval(x);
auto fn = [](mx::array x) {
for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x));
}
return mx::sum(x);
};
auto grad_fn = mx::grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_value_and_grad();
}
================================================
FILE: benchmarks/cpp/compare_devices.cpp
================================================
// Copyright © 2023 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(mx::Device::cpu);
for (auto size : sizes) {
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
}
}
int main() {
time_add_op();
}
================================================
FILE: benchmarks/cpp/irregular_strides.cpp
================================================
// Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream>
#include <sstream>
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
void time_irregular_binary_ops_1D() {
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
auto b = mx::random::uniform({size});
mx::eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device);
}
void time_irregular_binary_ops_2D() {
auto device = mx::default_device();
int size = 2048;
auto a = mx::random::uniform({size, size});
auto b = mx::random::uniform({size, size});
mx::eval(a, b);
TIMEM("2D regular", mx::add, a, b, device);
b = mx::transpose(b);
mx::eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device);
b = mx::random::uniform({size});
mx::eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = mx::reshape(b, {size, 1});
mx::eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
}
void time_irregular_binary_ops_3D() {
auto device = mx::default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device);
b = mx::transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device);
b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
}
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device);
std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = mx::random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device);
shape[k] = a.shape(k);
}
}
shape[i] = a.shape(i);
}
}
void time_irregular_reshape() {
auto device = mx::default_device();
mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
int size = 64;
int d = 2 * size;
auto a = mx::random::uniform({d, d, d});
shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a);
shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}
void time_irregular_astype_1D() {
auto device = mx::default_device();
int size = 1000000;
int step = 2;
auto a = mx::random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device);
}
void time_irregular_astype_2D() {
auto device = mx::default_device();
int size = 2048;
mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);
a = mx::transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
}
int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
}
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();
time_irregular_binary_ops_4D();
time_irregular_reshape();
time_irregular_astype_1D();
time_irregular_astype_2D();
}
================================================
FILE: benchmarks/cpp/single_ops.cpp
================================================
// Copyright © 2023 Apple Inc.
#include "mlx/mlx.h"
#include "time_utils.h"
namespace mx = mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = mx::default_device();
auto a = mx::zeros(shape, mx::float32);
mx::eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
a = mx::zeros(shape, mx::int32);
mx::eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = mx::zeros(shape, mx::bool_);
mx::eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = mx::default_device();
auto a = mx::random::normal({M, N});
mx::eval(a);
TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device);
TIME(mx::sign, a, device);
TIME(mx::square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::add, a, b, device);
TIME(mx::subtract, a, b, device);
TIME(mx::multiply, a, b, device);
TIME(mx::divide, a, b, device);
TIME(mx::maximum, a, b, device);
TIME(mx::minimum, a, b, device);
TIME(mx::where, condition, a, b, device);
condition = mx::array({true});
b = mx::random::uniform({1});
mx::eval(b);
TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
mx::eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P});
auto device = mx::default_device();
mx::eval(a, b);
TIMEM("non-strided", mx::add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1});
mx::eval(a, b);
TIMEM("strided", mx::add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::equal, a, b, device);
TIME(mx::greater, a, b, device);
TIME(mx::greater_equal, a, b, device);
TIME(mx::less, a, b, device);
TIME(mx::less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = mx::random::uniform({M, N});
auto b = mx::random::uniform({N});
auto c = mx::random::uniform({M});
mx::eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K});
auto b = mx::random::uniform({K, N});
auto device = mx::default_device();
mx::eval(a, b);
TIME(mx::matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = mx::random::normal({10000, 1000});
mx::eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {
auto a = mx::random::normal({1000, 768});
mx::eval(a);
auto indices = mx::random::randint(0, 1000, {256});
mx::eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices);
auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768});
mx::eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(embedding_update);
auto embedding_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(embedding_add);
a = mx::reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1});
mx::eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(single_element_update);
auto single_element_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(single_element_add);
}
void time_divmod() {
auto a = mx::random::normal({1000});
auto b = mx::random::normal({1000});
mx::eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();
time_binary_ops();
time_strided_ops();
time_random_generation();
time_comparisons();
time_matvec();
time_matmul();
time_reductions();
time_gather_scatter();
time_divmod();
}
================================================
FILE: benchmarks/cpp/time_utils.h
================================================
// Copyright © 2023 Apple Inc.
#pragma once
#include <chrono>
#include <iomanip>
#include <iostream>
#include "mlx/mlx.h"
#define milliseconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
#define time_now() std::chrono::high_resolution_clock::now()
#define TIME(FUNC, ...) \
std::cout << "Timing " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));
}
int num_iters = 100;
auto start = time_now();
for (int i = 0; i < num_iters; i++) {
eval(fn(std::forward<Args>(args)...));
}
auto end = time_now();
return milliseconds(end - start) / static_cast<double>(num_iters);
}
================================================
FILE: benchmarks/numpy/single_ops.py
================================================
# Copyright © 2023 Apple Inc.
import numpy as np
from time_utils import time_fn
def time_add():
a = np.ones((100, 100, 10), dtype=np.float32)
b = np.ones((100, 100, 10), dtype=np.float32)
time_fn(np.add, a, b)
def time_matmul():
a = np.random.rand(1000, 500).astype(np.float32)
b = np.random.rand(500, 1000).astype(np.float32)
time_fn(np.matmul, a, b)
def time_exp():
a = np.random.randn(1000, 100).astype(np.float32)
time_fn(np.exp, a)
def time_take():
a = np.random.rand(10000, 500)
ids = np.random.randint(0, 10000, (20, 10))
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
def random_take():
return [np.take(a, idx, 0) for idx in ids]
time_fn(random_take)
if __name__ == "__main__":
time_add()
time_matmul()
time_exp()
time_take()
================================================
FILE: benchmarks/numpy/time_utils.py
================================================
# Copyright © 2023 Apple Inc.
import time
def time_fn(fn, *args):
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
fn(*args)
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = fn(*args)
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
================================================
FILE: benchmarks/python/batch_matmul_bench.py
================================================
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
from time_utils import time_fn
B = 8
T = 1024
D = 512
def time_batch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B, T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B, T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def batch_vjp_first():
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
time_fn(batch_vjp_first)
def batch_vjp_second():
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
time_fn(batch_vjp_second)
def time_unbatch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B * T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def unbatch_vjp_first():
return mx.matmul(c, mx.transpose(b))
time_fn(unbatch_vjp_first)
def unbatch_vjp_second():
return mx.matmul(mx.transpose(a), c)
time_fn(unbatch_vjp_second)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_batch_matmul()
time_unbatch_matmul()
================================================
FILE: benchmarks/python/blas/bench_gemm.py
================================================
# Copyright © 2023 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 8
N_iter_bench = 80
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemm_nn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_nt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
def gemm_tn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_tt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemm_nn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_nt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np).to("mps")
b_pt = torch.from_numpy(b_np).to("mps")
torch.mps.synchronize()
f_mx = {
"nn": gemm_nn_mlx,
"nt": gemm_nt_mlx,
"tn": gemm_tn_mlx,
"tt": gemm_tt_mlx,
}[transpose]
f_pt = {
"nn": gemm_nn_torch,
"nt": gemm_nt_torch,
"tn": gemm_tn_torch,
"tt": gemm_tt_torch,
}[transpose]
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
print(
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
)
for dtype in dtypes:
for transpose in transposes:
for B, M, N, K in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
gflop_count = get_gflop_count(B, M, N, K)
gflops_mx = gflop_count / (time_mlx)
gflops_pt = gflop_count / (time_torch)
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/blas/bench_gemv.py
================================================
# Copyright © 2023 Apple Inc.
import os
import subprocess
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
results_dir = "./results"
if not os.path.isdir(results_dir):
os.mkdir(results_dir)
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 50
N_iter_func = 20
out_vec_sizes = [128, 512, 2048, 4096]
in_vec_sizes = [128, 512, 2048, 4096]
benchmark_vector_lens = []
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
benchmark_vector_lens.sort()
def bench(f, m, v):
for i in range(N_warmup):
f(m, v)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(m, v)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemv_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
mx.eval(ys)
return ys
def gemv_t_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemv_torch(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemv_t_torch(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
torch.mps.synchronize()
return ys
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
mat_mlx = mx.array(mat_npy)
vec_mlx = mx.array(vec_npy)
mat_trc = torch.from_numpy(mat_npy).to("mps")
vec_trc = torch.from_numpy(vec_npy).to("mps")
torch.mps.synchronize()
time_torch = (
bench(gemv_t_torch, mat_trc, vec_trc)
if transpose
else bench(gemv_torch, mat_trc, vec_trc)
)
time_mlx = (
bench(gemv_t_mlx, mat_mlx, vec_mlx)
if transpose
else bench(gemv_mlx, mat_mlx, vec_mlx)
)
c_mlx = (
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
)
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
if not np.allclose(c_mlx, c_npy, atol=2e-5):
print(
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(in_vec_len, out_vec_len):
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
1024**3
)
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
item_size = 4 if np_dtype == np.float32 else 2
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for out_vec_len in out_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
else:
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for in_vec_len in in_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
else:
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
ax.legend()
for transpose in (False, True):
for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
for i, in_vec_len in enumerate(in_vec_sizes):
bench_with_in_len(
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
)
for i, out_vec_len in enumerate(out_vec_sizes):
bench_with_out_len(
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
)
op_name = "gemv_t" if transpose else "gemv"
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
)
)
plt.close(fig)
================================================
FILE: benchmarks/python/comparative/README.md
================================================
Microbenchmarks comparing MLX to PyTorch
========================================
Implement the same microbenchmarks in MLX and PyTorch to compare and make a
list of the biggest possible performance improvements and/or regressions.
Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
instance to measure the times it takes to sum across the 3rd axis of the above
tensor on the cpu.
`compare.py` runs several benchmarks and compares the speed-up or lack thereof
in comparison to PyTorch.
Each bench script can be run with `--print-pid` to print the PID and wait for a
key in order to ease attaching a debugger.
================================================
FILE: benchmarks/python/comparative/bench_mlx.py
================================================
# Copyright © 2023 Apple Inc.
import argparse
import math
import os
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
def int_or_list(x):
try:
return int(x)
except ValueError:
return [int(xi) for xi in x.split(",")]
def none_or_list(x):
if x == "":
return None
else:
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
s = time.perf_counter()
for i in range(100):
f(*args)
e = time.perf_counter()
return e - s
def matmul_square(x):
y = x
for i in range(10):
y = y @ x
mx.eval(y)
return y
def matmul(x, y):
ys = []
for i in range(10):
ys.append(x @ y)
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
}
def conv1d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv1d(x, y))
mx.eval(ys)
def conv2d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv2d(x, y))
mx.eval(ys)
def binary(op, x, y):
for i in range(100):
y = getattr(mx, op)(x, y)
mx.eval(y)
def reduction(op, axis, x):
ys = []
for i in range(100):
ys.append(getattr(mx, op)(x, axis=axis))
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
y = ex / mx.sum(ex, axis=axis, keepdims=True)
ys.append(y)
mx.eval(ys)
def softmax_fused(axis, x):
ys = []
for i in range(100):
y = mx.softmax(x, axis=axis)
ys.append(y)
mx.eval(ys)
def relu(x):
y = x
for i in range(100):
y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
mx.eval(y)
def leaky_relu(x):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def elu(x):
y = x
for i in range(100):
y = nn.elu(y)
mx.eval(y)
def relu6(x):
y = x
for i in range(100):
y = nn.relu6(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def celu(x):
y = x
for i in range(100):
y = nn.celu(y)
mx.eval(y)
def log_sigmoid(x):
y = x
for i in range(100):
y = nn.log_sigmoid(y)
mx.eval(y)
def scalar_mult(x):
y = x
for i in range(100):
y = y * (1.0 / (1 + i))
mx.eval(y)
def cross_entropy(targets, x):
ys = []
for i in range(100):
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
x, mx.reshape(targets, (-1, 1)), axis=-1
)
ys.append(mx.mean(y))
mx.eval(ys)
def logsumexp(axis, x):
ys = []
for i in range(100):
ys.append(mx.logsumexp(x, axis=axis))
mx.eval(ys)
def linear(w, b, x):
ys = []
for i in range(10):
ys.append(x @ mx.transpose(w, (1, 0)) + b)
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
for i in range(10):
shape = x.shape
x = mx.reshape(x, (-1, N, D))
positions = mx.arange(N)
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
y = mx.reshape(y, (-1, N, D))
ys.append(y)
mx.eval(ys)
def concatenate(axis, x, y):
ys = []
for i in range(10):
ys.append(mx.concatenate([x, y], axis=axis))
mx.eval(ys)
def cumsum(axis, x):
ys = []
for i in range(10):
ys.append(mx.cumsum(x, axis))
mx.eval(ys)
def sort(axis, x):
ys = []
for i in range(10):
ys.append(mx.sort(x, axis))
mx.eval(ys)
def topk(axis, x):
k = x.shape[axis] // 3
ys = []
for i in range(10):
ys.append(mx.topk(x, k, axis))
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
parser.add_argument(
"--size",
default=[(1024, 1024)],
type=lambda x: list(map(int, x.split("x"))),
help="Set the matrix size",
action="append",
)
parser.add_argument(
"--axis",
default=[1],
type=int_or_list,
help="Set a reduction axis",
action="append",
)
parser.add_argument(
"--transpose",
type=none_or_list,
default=[],
help="Permute the matrix",
action="append",
)
parser.add_argument(
"--print-pid", action="store_true", help="Print the PID and pause"
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
if len(args.size) > 1:
args.size.pop(0)
if len(args.axis) > 1:
args.axis.pop(0)
if args.cpu:
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size, dtype in zip(args.size, types):
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
continue
xs[i] = mx.transpose(xs[i], t)
mx.eval(xs)
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
elif args.benchmark == "sum_all":
print(bench(reduction, "sum", None, x))
elif args.benchmark == "argmax":
print(bench(reduction, "argmax", axis, x))
elif args.benchmark == "add":
print(bench(binary, "add", *xs))
elif args.benchmark == "mul":
print(bench(binary, "multiply", *xs))
elif args.benchmark == "softmax":
if args.fused:
print(bench(softmax_fused, axis, x))
else:
print(bench(softmax, axis, x))
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
elif args.benchmark == "cross_entropy":
if len(size) != 2:
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
targets = mx.zeros((len(x),), dtype=mx.uint32)
print(bench(cross_entropy, targets, x))
elif args.benchmark == "logsumexp":
print(bench(logsumexp, axis, x))
elif args.benchmark == "rope":
print(bench(rope, x))
elif args.benchmark == "concatenate":
print(bench(concatenate, axis, *xs))
elif args.benchmark == "cumsum":
print(bench(cumsum, axis, *xs))
elif args.benchmark == "conv1d":
print(bench(conv1d, *xs))
elif args.benchmark == "conv2d":
print(bench(conv2d, *xs))
elif args.benchmark == "sort":
print(bench(sort, axis, x))
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")
================================================
FILE: benchmarks/python/comparative/bench_torch.py
================================================
# Copyright © 2023 Apple Inc.
import argparse
import os
import time
import torch
import torch.cuda
import torch.mps
def int_or_list(x):
try:
return int(x)
except ValueError:
return [int(xi) for xi in x.split(",")]
def none_or_list(x):
if x == "":
return None
else:
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return torch.float32
else:
dt = getattr(torch, x)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{x} is not a torch dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
s = time.perf_counter()
for i in range(100):
f(*args)
e = time.perf_counter()
return e - s
def sync_if_needed(x):
if x.device == torch.device("mps"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
def matmul_square(x):
y = x
for i in range(10):
y = y @ x
sync_if_needed(x)
@torch.no_grad()
def matmul(x, y):
ys = []
for i in range(10):
ys.append(x @ y)
sync_if_needed(x)
@torch.no_grad()
def conv1d(x, y):
x = torch.transpose(x, -1, -2)
y = torch.transpose(y, -1, -2)
ys = []
for i in range(10):
ys.append(torch.nn.functional.conv1d(x, y))
sync_if_needed(x)
@torch.no_grad()
def conv2d(x, y):
x = torch.permute(x, (0, 3, 1, 2))
y = torch.permute(y, (0, 3, 1, 2))
ys = []
for i in range(10):
ys.append(torch.nn.functional.conv2d(x, y))
sync_if_needed(x)
@torch.no_grad()
def binary(op, x, y):
for i in range(100):
y = getattr(torch, op)(x, y)
sync_if_needed(x)
@torch.no_grad()
def reduction(op, axis, x):
ys = []
for i in range(100):
ys.append(getattr(x, op)(axis))
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
for i in range(100):
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
y = ex / torch.sum(ex, dim=axis, keepdims=True)
ys.append(y)
sync_if_needed(x)
@torch.no_grad()
def softmax_fused(axis, x):
ys = []
for i in range(100):
ys.append(torch.nn.functional.softmax(x, dim=axis))
sync_if_needed(x)
@torch.no_grad()
def relu(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
sync_if_needed(x)
@torch.no_grad()
def leaky_relu(x):
y = x
for i in range(100):
y = torch.nn.functional.leaky_relu(y)
sync_if_needed(x)
@torch.no_grad()
def elu(x):
y = x
for i in range(100):
y = torch.nn.functional.elu(y)
sync_if_needed(x)
@torch.no_grad()
def celu(x):
y = x
for i in range(100):
y = torch.nn.functional.celu(y)
sync_if_needed(x)
@torch.no_grad()
def relu6(x):
y = x
for i in range(100):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
for i in range(100):
y = torch.nn.functional.softplus(y)
sync_if_needed(x)
@torch.no_grad()
def log_sigmoid(x):
y = x
for i in range(100):
y = torch.nn.functional.logsigmoid(y)
sync_if_needed(x)
@torch.no_grad()
def prelu(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
sync_if_needed(x)
@torch.no_grad()
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.mish(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
for i in range(100):
y = y * (1.0 / (1 + i))
sync_if_needed(x)
@torch.no_grad()
def cross_entropy(targets, x):
ys = []
for i in range(100):
ys.append(torch.nn.functional.cross_entropy(x, targets))
sync_if_needed(x)
@torch.no_grad()
def logsumexp(axis, x):
ys = []
for i in range(100):
ys.append(torch.logsumexp(x, dim=axis))
sync_if_needed(x)
@torch.no_grad()
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(torch.nn.functional.linear(x, w, b))
sync_if_needed(x)
@torch.no_grad()
def linear(w, b, x):
ys = []
for i in range(10):
ys.append((x @ torch.transpose(w, -2, -1)) + b)
sync_if_needed(x)
@torch.no_grad()
def rope(x):
*_, N, D = x.shape
ys = []
for i in range(10):
x = x.view(-1, N, D)
positions = torch.arange(N, device=x.device)
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
theta = positions[:, None] * freqs[None]
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
y = y.reshape(-1, N, D)
ys.append(y)
sync_if_needed(x)
@torch.no_grad()
def concatenate(axis, x, y):
ys = []
for i in range(10):
ys.append(torch.cat([x, y], dim=axis))
sync_if_needed(x)
@torch.no_grad()
def cumsum(axis, x):
ys = []
for i in range(10):
ys.append(x.cumsum(axis))
sync_if_needed(x)
@torch.no_grad()
def sort(axis, x):
ys = []
for i in range(10):
ys.append(torch.sort(x, dim=axis)[0])
sync_if_needed(x)
@torch.no_grad()
def topk(axis, x):
k = x.shape[axis] // 3
ys = []
for i in range(10):
ys.append(torch.topk(x, k, dim=axis)[0])
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
parser.add_argument(
"--size",
default=[(1024, 1024)],
type=lambda x: list(map(int, x.split("x"))),
help="Set the matrix size",
action="append",
)
parser.add_argument(
"--axis",
default=[1],
type=int_or_list,
help="Set a reduction axis",
action="append",
)
parser.add_argument(
"--transpose",
type=none_or_list,
default=[],
help="Permute the matrix",
action="append",
)
parser.add_argument(
"--print-pid", action="store_true", help="Print the PID and pause"
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
if len(args.size) > 1:
args.size.pop(0)
if len(args.axis) > 1:
args.axis.pop(0)
torch.set_num_threads(1)
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype
if not types:
types = [torch.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size, dtype in zip(args.size, types):
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None:
continue
xs[i] = xs[i].permute(*t)
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
elif args.benchmark == "sum_all":
print(bench(reduction, "sum", None, x))
elif args.benchmark == "argmax":
print(bench(reduction, "argmax", axis, x))
elif args.benchmark == "add":
print(bench(binary, "add", *xs))
elif args.benchmark == "mul":
print(bench(binary, "mul", *xs))
elif args.benchmark == "softmax":
if args.fused:
print(bench(softmax_fused, axis, x))
else:
print(bench(softmax, axis, x))
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
elif args.benchmark == "cross_entropy":
if len(size) != 2:
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
print(bench(cross_entropy, targets, x))
elif args.benchmark == "logsumexp":
print(bench(logsumexp, axis, x))
elif args.benchmark == "rope":
print(bench(rope, x))
elif args.benchmark == "concatenate":
print(bench(concatenate, axis, *xs))
elif args.benchmark == "cumsum":
print(bench(cumsum, axis, *xs))
elif args.benchmark == "conv1d":
print(bench(conv1d, *xs))
elif args.benchmark == "conv2d":
print(bench(conv2d, *xs))
elif args.benchmark == "sort":
print(bench(sort, axis, x))
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
================================================
FILE: benchmarks/python/comparative/compare.py
================================================
# Copyright © 2023 Apple Inc.
#!/usr/bin/env python
import argparse
import re
from pathlib import Path
from subprocess import run
BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
def run_or_raise(*args, **kwargs):
try:
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args):
t_mlx = run_or_raise(["python", BENCH_MLX] + args)
t_torch = run_or_raise(["python", BENCH_TORCH] + args)
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
def compare_mlx_dtypes(args, dt1, dt2):
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
def make_regex_search(regexes):
compiled_regexes = list(map(re.compile, regexes))
def search(x):
return (c.search(x) is not None for c in compiled_regexes)
return search
def make_predicate(positive_filter, negative_filter):
if positive_filter is not None:
positive_filter_search = make_regex_search(positive_filter)
positive_filter = lambda x: all(positive_filter_search(x))
else:
positive_filter = lambda x: True
if negative_filter is not None:
negative_filter_search = make_regex_search(negative_filter)
negative_filter = lambda x: not any(negative_filter_search(x))
else:
negative_filter = lambda x: True
def predicate(x):
return positive_filter(x) and negative_filter(x)
return predicate
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
parser.add_argument(
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
)
parser.add_argument(
"--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
)
parser.add_argument(
"--mlx_dtypes",
"-d",
help="Compare mlx benchmarks between the 2 provided data types",
nargs=2,
)
args, rest = parser.parse_known_args()
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
if _filter(x)
else None
)
else:
compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
# Binary ops
compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
compare_filtered("add --size 10x1024x128 --size 1x1024x128")
compare_filtered("add --size 1024x128 --size 1x128 --cpu")
compare_filtered("add --size 1024x128 --size 1x128")
compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
compare_filtered("add --size 1024x4096 --size 1x4096")
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
compare_filtered("add --size 1024x1024 --size 1024x1024")
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
compare_filtered(
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
)
compare_filtered(
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
)
# Reduction ops
compare_filtered("sum_all --size 10x1024x128 --cpu")
compare_filtered("sum_all --size 10x1024x128")
compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
compare_filtered("sum_axis --size 16x1024x128 --axis 2")
compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 2")
compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
compare_filtered("sum_axis --size 1024x1024 --axis 1")
compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 1024x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 2")
compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
compare_filtered("argmax --size 1024x1024 --axis 1")
# Matmul ops
compare_filtered("matmul_square --size 1024x1024")
compare_filtered("matmul_square --size 1024x1024 --cpu")
compare_filtered("matmul_square --size 16x1024x1024")
compare_filtered("matmul_square --size 16x1024x1024 --cpu")
compare_filtered(
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
)
compare_filtered(
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered(
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
)
compare_filtered(
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered("matmul --size 512x8192 --size 8192x512")
compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
# compare_filtered("matmul --size 512x131072 --size 131072x512")
# compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
compare_filtered("matmul --size 8192x512 --size 512x8192")
compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
# compare_filtered("matmul --size 131072x512 --size 512x512")
# compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
compare_filtered(
"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
)
# Matvec ops
compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
compare_filtered(
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
)
compare_filtered(
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
)
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
compare_filtered(
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
)
compare_filtered(
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
)
# Various ops
compare_filtered("softmax --size 32x16x1024 --axis 2")
compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
compare_filtered("softmax --size 2x1024x1024 --axis 1")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
compare_filtered("relu --size 32x16x1024")
compare_filtered("relu --size 32x16x1024 --cpu")
compare_filtered("leaky_relu --size 32x16x1024")
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
compare_filtered("elu --size 32x16x1024")
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")
compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
compare_filtered("mish --size 32x16x1024 --cpu")
compare_filtered("prelu --size 32x16x1024")
compare_filtered("prelu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")
compare_filtered("cross_entropy --size 256x1024 --cpu")
compare_filtered("logsumexp --size 1024x1024 --axis 1")
compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
compare_filtered("logsumexp --size 1024x1024 --axis 0")
compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
compare_filtered(
"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu"
)
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80")
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu")
compare_filtered("conv1d --size 16x1000x80 --size 128x11x80")
compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu")
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3")
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu")
compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3")
compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu")
compare_filtered("cumsum --size 1024x1024 --axis 1")
compare_filtered("cumsum --size 1024x1024 --axis 0")
compare_filtered("cumsum --size 128x1024 --axis 1")
compare_filtered("cumsum --size 128x1024 --axis 0")
compare_filtered("cumsum --size 1024x4096 --axis 1")
compare_filtered("cumsum --size 1024x4096 --axis 0")
compare_filtered("cumsum --size 128x4096 --axis 1")
compare_filtered("cumsum --size 128x4096 --axis 0")
compare_filtered("cumsum --size 1024x7777 --axis 1")
compare_filtered("cumsum --size 1024x7777 --axis 0")
compare_filtered("cumsum --size 128x7777 --axis 1")
compare_filtered("cumsum --size 128x7777 --axis 0")
compare_filtered("cumsum --size 32768x128 --axis 1")
compare_filtered("cumsum --size 32768x128 --axis 0")
compare_filtered("sort --size 1024x1024 --axis 0")
compare_filtered("sort --size 1024x1024 --axis 1")
compare_filtered("sort --size 32768x128 --axis 0")
compare_filtered("sort --size 32768x128 --axis 1")
compare_filtered("sort --size 128x128 --axis 0 --cpu")
compare_filtered("sort --size 128x128 --axis 1 --cpu")
compare_filtered("topk --size 1024x1024 --axis 0")
compare_filtered("topk --size 1024x1024 --axis 1")
compare_filtered("topk --size 32768x128 --axis 0")
compare_filtered("topk --size 32768x128 --axis 1")
compare_filtered("topk --size 128x128 --axis 0 --cpu")
compare_filtered("topk --size 128x128 --axis 1 --cpu")
================================================
FILE: benchmarks/python/compile_bench.py
================================================
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()
================================================
FILE: benchmarks/python/conv1d_bench.py
================================================
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv2d_bench_cpu.py
================================================
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv2d_train_bench_cpu.py
================================================
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/python/conv2d_transpose_bench_cpu.py
================================================
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv3d_bench.py
================================================
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 2
N_iter_bench = 10
N_iter_func = 10
def bench(f, a, b, b_prime):
for i in range(N_warmup):
f(a, b, b_prime)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b, b_prime)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b, b_prime):
y = a
for i in range(N_iter_func):
y = mx.conv3d(y, b, stride=strides, padding=padding, groups=groups)
y = mx.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)
mx.eval(y)
return y
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b, b_prime):
y = a
for i in range(N_iter_func):
y = torch.conv3d(y, b, stride=strides, padding=padding, groups=groups)
y = torch.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)
torch.mps.synchronize()
return y
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C))
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups)))
b_prime_np = np.random.uniform(-scale, scale, (C, kD, kH, kW, int(O / groups)))
a_np, b_np, b_prime_np = map(lambda x: x.astype(np_dtype), (a_np, b_np, b_prime_np))
a_mx, b_mx, b_prime_mx = map(lambda x: mx.array(x), (a_np, b_np, b_prime_np))
a_pt, b_pt, b_prime_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("mps"),
(a_np, b_np, b_prime_np),
)
torch.mps.synchronize()
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt, b_prime_pt)
time_mlx = bench(f_mx, a_mx, b_mx, b_prime_mx)
# Measure MLX memory
mx.clear_cache()
mx.reset_peak_memory()
y = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
mx.eval(y)
mlx_peak_mb = mx.get_peak_memory() / 1024**2
mlx_active_mb = mx.get_active_memory() / 1024**2
del y
# Measure PyTorch MPS memory
torch.mps.synchronize()
torch.mps.empty_cache()
y = torch.conv3d(a_pt, b_pt, stride=strides, padding=padding, groups=groups)
torch.mps.synchronize()
pt_current_mb = torch.mps.current_allocated_memory() / 1024**2
pt_driver_mb = torch.mps.driver_allocated_memory() / 1024**2
del y
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 5e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} "
f"[strides = {strides}, padding = {padding}, groups = {groups}] "
f"with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch, mlx_peak_mb, mlx_active_mb, pt_current_mb, pt_driver_mb
if __name__ == "__main__":
dtypes = ("float16", "float32")
shapes = (
# (C % 16 == 0)
(4, 16, 16, 16, 32, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),
(4, 16, 16, 16, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
(4, 16, 16, 16, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
(4, 32, 32, 32, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
(4, 32, 32, 32, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
# Larger spatial dims
(2, 64, 64, 64, 32, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
(1, 64, 64, 64, 64, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
# Strided
(4, 32, 32, 32, 64, 3, 3, 3, 128, (2, 2, 2), (1, 1, 1), 1),
# Asymmetric kernels
(4, 32, 32, 32, 64, 3, 1, 1, 128, (1, 1, 1), (1, 0, 0), 1),
(4, 32, 32, 32, 64, 1, 3, 3, 128, (1, 1, 1), (0, 1, 1), 1),
# (C % 16 != 0)
(4, 16, 16, 16, 21, 3, 3, 3, 21, (1, 1, 1), (1, 1, 1), 1),
(4, 16, 16, 16, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),
(4, 32, 32, 32, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),
(4, 16, 16, 16, 3, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),
)
for dtype in dtypes:
print(f"\n{'=' * 120}" f"\n dtype: {dtype}" f"\n{'=' * 120}")
print(
f"{'(N, D, H, W, C)':<26s} {'( O, kD, kH, kW, C)':<24s} "
f"{'stride':<12s} {'pads':<12s} {'groups':>6s} "
f"{'diff%':>7s} "
f"{'MLX peak':>9s} {'MLX act':>8s} {'PT cur':>8s} {'PT drv':>8s}"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch, mlx_peak, mlx_act, pt_cur, pt_drv = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), "
f"{strides}, {padding}, {groups:6d}, "
f"{100. * diff:+6.1f}% "
f"{mlx_peak:8.1f} {mlx_act:7.1f} {pt_cur:7.1f} {pt_drv:7.1f}"
)
================================================
FILE: benchmarks/python/conv3d_bench_cpu.py
================================================
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv3d_train_bench_cpu.py
================================================
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/python/conv3d_transpose_bench_cpu.py
================================================
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv_bench.py
================================================
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")
================================================
FILE: benchmarks/python/conv_transpose_bench.py
================================================
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128,
gitextract_c0zbkz84/
├── .clang-format
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ ├── actions/
│ │ ├── build-cuda-release/
│ │ │ └── action.yml
│ │ ├── build-docs/
│ │ │ └── action.yml
│ │ ├── build-linux/
│ │ │ └── action.yml
│ │ ├── build-linux-release/
│ │ │ └── action.yml
│ │ ├── build-macos/
│ │ │ └── action.yml
│ │ ├── build-macos-release/
│ │ │ └── action.yml
│ │ ├── build-windows/
│ │ │ └── action.yml
│ │ ├── setup-linux/
│ │ │ └── action.yml
│ │ ├── setup-macos/
│ │ │ └── action.yml
│ │ ├── setup-windows/
│ │ │ └── action.yml
│ │ ├── test-linux/
│ │ │ └── action.yml
│ │ └── test-windows/
│ │ └── action.yml
│ ├── dependabot.yml
│ ├── pull_request_template.md
│ ├── scripts/
│ │ ├── build-sanitizer-tests.sh
│ │ └── setup+build-cpp-linux-fedora-container.sh
│ └── workflows/
│ ├── build_and_test.yml
│ ├── documentation.yml
│ ├── nightly.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── ACKNOWLEDGMENTS.md
├── CITATION.cff
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── benchmarks/
│ ├── cpp/
│ │ ├── CMakeLists.txt
│ │ ├── autograd.cpp
│ │ ├── compare_devices.cpp
│ │ ├── irregular_strides.cpp
│ │ ├── single_ops.cpp
│ │ └── time_utils.h
│ ├── numpy/
│ │ ├── single_ops.py
│ │ └── time_utils.py
│ └── python/
│ ├── batch_matmul_bench.py
│ ├── blas/
│ │ ├── bench_gemm.py
│ │ └── bench_gemv.py
│ ├── comparative/
│ │ ├── README.md
│ │ ├── bench_mlx.py
│ │ ├── bench_torch.py
│ │ └── compare.py
│ ├── compile_bench.py
│ ├── conv1d_bench.py
│ ├── conv2d_bench_cpu.py
│ ├── conv2d_train_bench_cpu.py
│ ├── conv2d_transpose_bench_cpu.py
│ ├── conv3d_bench.py
│ ├── conv3d_bench_cpu.py
│ ├── conv3d_train_bench_cpu.py
│ ├── conv3d_transpose_bench_cpu.py
│ ├── conv_bench.py
│ ├── conv_transpose_bench.py
│ ├── conv_unaligned_bench.py
│ ├── distributed_bench.py
│ ├── einsum_bench.py
│ ├── fft_bench.py
│ ├── gather_bench.py
│ ├── gather_mm_bench.py
│ ├── gather_qmm_bench.py
│ ├── hadamard_bench.py
│ ├── large_gemm_bench.py
│ ├── layer_norm_bench.py
│ ├── masked_scatter.py
│ ├── rms_norm_bench.py
│ ├── rope_bench.py
│ ├── scatter_bench.py
│ ├── sdpa_bench.py
│ ├── sdpa_vector_bench.py
│ ├── segmented_mm_bench.py
│ ├── single_ops.py
│ ├── slice_update_bench.py
│ ├── synchronize_bench.py
│ └── time_utils.py
├── cmake/
│ ├── FindCUDNN.cmake
│ ├── FindNCCL.cmake
│ ├── Findnvpl.cmake
│ └── extension.cmake
├── docs/
│ ├── .clang-format
│ ├── .gitignore
│ ├── .nojekyll
│ ├── Doxyfile
│ ├── Makefile
│ ├── README.md
│ ├── index.html
│ ├── requirements.txt
│ └── src/
│ ├── _templates/
│ │ ├── module-base-class.rst
│ │ ├── nn-module-template.rst
│ │ └── optimizers-template.rst
│ ├── conf.py
│ ├── cpp/
│ │ └── ops.rst
│ ├── dev/
│ │ ├── custom_metal_kernels.rst
│ │ ├── extensions.rst
│ │ ├── metal_debugger.rst
│ │ ├── metal_logging.rst
│ │ └── mlx_in_cpp.rst
│ ├── examples/
│ │ ├── data_parallelism.rst
│ │ ├── linear_regression.rst
│ │ ├── llama-inference.rst
│ │ ├── mlp.rst
│ │ └── tensor_parallelism.rst
│ ├── index.rst
│ ├── install.rst
│ ├── python/
│ │ ├── array.rst
│ │ ├── cuda.rst
│ │ ├── data_types.rst
│ │ ├── devices_and_streams.rst
│ │ ├── distributed.rst
│ │ ├── export.rst
│ │ ├── fast.rst
│ │ ├── fft.rst
│ │ ├── linalg.rst
│ │ ├── memory_management.rst
│ │ ├── metal.rst
│ │ ├── nn/
│ │ │ ├── distributed.rst
│ │ │ ├── functions.rst
│ │ │ ├── init.rst
│ │ │ ├── layers.rst
│ │ │ ├── losses.rst
│ │ │ └── module.rst
│ │ ├── nn.rst
│ │ ├── ops.rst
│ │ ├── optimizers/
│ │ │ ├── common_optimizers.rst
│ │ │ ├── optimizer.rst
│ │ │ └── schedulers.rst
│ │ ├── optimizers.rst
│ │ ├── random.rst
│ │ ├── transforms.rst
│ │ └── tree_utils.rst
│ └── usage/
│ ├── compile.rst
│ ├── distributed.rst
│ ├── export.rst
│ ├── function_transforms.rst
│ ├── indexing.rst
│ ├── launching_distributed.rst
│ ├── lazy_evaluation.rst
│ ├── numpy.rst
│ ├── quick_start.rst
│ ├── saving_and_loading.rst
│ ├── unified_memory.rst
│ └── using_streams.rst
├── examples/
│ ├── cmake_project/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ └── example.cpp
│ ├── cpp/
│ │ ├── CMakeLists.txt
│ │ ├── distributed.cpp
│ │ ├── linear_regression.cpp
│ │ ├── logistic_regression.cpp
│ │ ├── metal_capture.cpp
│ │ ├── timer.h
│ │ └── tutorial.cpp
│ ├── export/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ ├── eval_mlp.cpp
│ │ ├── eval_mlp.py
│ │ ├── train_mlp.cpp
│ │ └── train_mlp.py
│ ├── extensions/
│ │ ├── CMakeLists.txt
│ │ ├── README.md
│ │ ├── axpby/
│ │ │ ├── axpby.cpp
│ │ │ ├── axpby.h
│ │ │ └── axpby.metal
│ │ ├── bindings.cpp
│ │ ├── mlx_sample_extensions/
│ │ │ └── __init__.py
│ │ ├── pyproject.toml
│ │ ├── requirements.txt
│ │ ├── setup.py
│ │ └── test.py
│ └── python/
│ ├── linear_regression.py
│ ├── logistic_regression.py
│ └── qqmm.py
├── mlx/
│ ├── 3rdparty/
│ │ ├── .clang-format
│ │ └── pocketfft.h
│ ├── CMakeLists.txt
│ ├── allocator.h
│ ├── api.h
│ ├── array.cpp
│ ├── array.h
│ ├── backend/
│ │ ├── common/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── binary.h
│ │ │ ├── broadcasting.cpp
│ │ │ ├── broadcasting.h
│ │ │ ├── buffer_cache.h
│ │ │ ├── common.cpp
│ │ │ ├── compiled.cpp
│ │ │ ├── compiled.h
│ │ │ ├── copy.h
│ │ │ ├── hadamard.h
│ │ │ ├── load.cpp
│ │ │ ├── matmul.h
│ │ │ ├── quantized.h
│ │ │ ├── reduce.cpp
│ │ │ ├── reduce.h
│ │ │ ├── slicing.cpp
│ │ │ ├── slicing.h
│ │ │ ├── ternary.h
│ │ │ ├── unary.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── cpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── arange.h
│ │ │ ├── arg_reduce.cpp
│ │ │ ├── binary.cpp
│ │ │ ├── binary.h
│ │ │ ├── binary_ops.h
│ │ │ ├── binary_two.h
│ │ │ ├── cholesky.cpp
│ │ │ ├── compiled.cpp
│ │ │ ├── compiled_preamble.h
│ │ │ ├── conv.cpp
│ │ │ ├── copy.cpp
│ │ │ ├── copy.h
│ │ │ ├── device_info.cpp
│ │ │ ├── device_info.h
│ │ │ ├── distributed.cpp
│ │ │ ├── eig.cpp
│ │ │ ├── eigh.cpp
│ │ │ ├── encoder.cpp
│ │ │ ├── encoder.h
│ │ │ ├── eval.cpp
│ │ │ ├── eval.h
│ │ │ ├── fft.cpp
│ │ │ ├── gemm.h
│ │ │ ├── gemms/
│ │ │ │ ├── bnns.cpp
│ │ │ │ ├── cblas.cpp
│ │ │ │ ├── simd_bf16.cpp
│ │ │ │ ├── simd_fp16.cpp
│ │ │ │ └── simd_gemm.h
│ │ │ ├── hadamard.cpp
│ │ │ ├── indexing.cpp
│ │ │ ├── inverse.cpp
│ │ │ ├── jit_compiler.cpp
│ │ │ ├── jit_compiler.h
│ │ │ ├── lapack.h
│ │ │ ├── logsumexp.cpp
│ │ │ ├── luf.cpp
│ │ │ ├── make_compiled_preamble.ps1
│ │ │ ├── make_compiled_preamble.sh
│ │ │ ├── masked_mm.cpp
│ │ │ ├── matmul.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── qrf.cpp
│ │ │ ├── quantized.cpp
│ │ │ ├── reduce.cpp
│ │ │ ├── scan.cpp
│ │ │ ├── select.cpp
│ │ │ ├── simd/
│ │ │ │ ├── accelerate_fp16_simd.h
│ │ │ │ ├── accelerate_simd.h
│ │ │ │ ├── base_simd.h
│ │ │ │ ├── math.h
│ │ │ │ ├── neon_fp16_simd.h
│ │ │ │ ├── simd.h
│ │ │ │ └── type.h
│ │ │ ├── slicing.h
│ │ │ ├── softmax.cpp
│ │ │ ├── sort.cpp
│ │ │ ├── svd.cpp
│ │ │ ├── ternary.h
│ │ │ ├── threefry.cpp
│ │ │ ├── threefry.h
│ │ │ ├── unary.cpp
│ │ │ ├── unary.h
│ │ │ └── unary_ops.h
│ │ ├── cuda/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── allocator.cpp
│ │ │ ├── allocator.h
│ │ │ ├── arange.cu
│ │ │ ├── arg_reduce.cu
│ │ │ ├── bin2h.cmake
│ │ │ ├── binary/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── add.cu
│ │ │ │ ├── arctan2.cu
│ │ │ │ ├── binary.cuh
│ │ │ │ ├── bitwise_binary.cu
│ │ │ │ ├── divide.cu
│ │ │ │ ├── equal.cu
│ │ │ │ ├── greater.cu
│ │ │ │ ├── greater_equal.cu
│ │ │ │ ├── less.cu
│ │ │ │ ├── less_equal.cu
│ │ │ │ ├── log_add_exp.cu
│ │ │ │ ├── logical_and.cu
│ │ │ │ ├── logical_or.cu
│ │ │ │ ├── maximum.cu
│ │ │ │ ├── minimum.cu
│ │ │ │ ├── multiply.cu
│ │ │ │ ├── not_equal.cu
│ │ │ │ ├── power.cu
│ │ │ │ ├── remainder.cu
│ │ │ │ └── subtract.cu
│ │ │ ├── binary_two.cu
│ │ │ ├── compiled.cpp
│ │ │ ├── conv/
│ │ │ │ ├── conv.h
│ │ │ │ ├── gemm_conv.cu
│ │ │ │ └── gemm_grouped_conv.cu
│ │ │ ├── conv.cpp
│ │ │ ├── copy/
│ │ │ │ ├── copy.cuh
│ │ │ │ ├── copy_contiguous.cu
│ │ │ │ ├── copy_general.cu
│ │ │ │ ├── copy_general_dynamic.cu
│ │ │ │ └── copy_general_input.cu
│ │ │ ├── copy.cu
│ │ │ ├── cublas_utils.cpp
│ │ │ ├── cublas_utils.h
│ │ │ ├── cuda.h
│ │ │ ├── cuda_utils.h
│ │ │ ├── cudnn_utils.cpp
│ │ │ ├── cudnn_utils.h
│ │ │ ├── custom_kernel.cpp
│ │ │ ├── cutlass_utils.cuh
│ │ │ ├── delayload.cpp
│ │ │ ├── device/
│ │ │ │ ├── atomic_ops.cuh
│ │ │ │ ├── binary_ops.cuh
│ │ │ │ ├── cast_op.cuh
│ │ │ │ ├── complex.cuh
│ │ │ │ ├── config.h
│ │ │ │ ├── fp16_math.cuh
│ │ │ │ ├── gather.cuh
│ │ │ │ ├── gather_axis.cuh
│ │ │ │ ├── hadamard.cuh
│ │ │ │ ├── indexing.cuh
│ │ │ │ ├── scatter.cuh
│ │ │ │ ├── scatter_axis.cuh
│ │ │ │ ├── scatter_ops.cuh
│ │ │ │ ├── slice_update.cuh
│ │ │ │ ├── ternary_ops.cuh
│ │ │ │ ├── unary_ops.cuh
│ │ │ │ └── utils.cuh
│ │ │ ├── device.cpp
│ │ │ ├── device.h
│ │ │ ├── device_info.cpp
│ │ │ ├── distributed.cu
│ │ │ ├── eval.cpp
│ │ │ ├── event.cu
│ │ │ ├── event.h
│ │ │ ├── fence.cpp
│ │ │ ├── fft.cu
│ │ │ ├── gemms/
│ │ │ │ ├── cublas_gemm.cpp
│ │ │ │ ├── cublas_gemm.h
│ │ │ │ ├── cublas_gemm_batched_12_0.cpp
│ │ │ │ ├── cublas_gemm_batched_12_9.cu
│ │ │ │ ├── gemv.cu
│ │ │ │ ├── gemv.h
│ │ │ │ ├── grouped_gemm.h
│ │ │ │ └── grouped_gemm_unaligned.cu
│ │ │ ├── hadamard.cu
│ │ │ ├── indexing.cpp
│ │ │ ├── jit_module.cpp
│ │ │ ├── jit_module.h
│ │ │ ├── kernel_utils.cu
│ │ │ ├── kernel_utils.cuh
│ │ │ ├── layer_norm.cu
│ │ │ ├── load.cpp
│ │ │ ├── logsumexp.cu
│ │ │ ├── lru_cache.h
│ │ │ ├── matmul.cpp
│ │ │ ├── no_cuda.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── quantized/
│ │ │ │ ├── affine_quantize.cu
│ │ │ │ ├── convert_fp8.cu
│ │ │ │ ├── cublas_qqmm.cpp
│ │ │ │ ├── cublas_qqmm.h
│ │ │ │ ├── fp_quantize.cu
│ │ │ │ ├── mxfp8_quantize.cuh
│ │ │ │ ├── no_qqmm_impl.cpp
│ │ │ │ ├── nvfp4_quantize.cuh
│ │ │ │ ├── qmm/
│ │ │ │ │ ├── CMakeLists.txt
│ │ │ │ │ ├── fp_qmv.cu
│ │ │ │ │ ├── qmm.cu
│ │ │ │ │ ├── qmm.h
│ │ │ │ │ ├── qmm_impl_sm80.cuh
│ │ │ │ │ ├── qmm_impl_sm80_m16.cu
│ │ │ │ │ ├── qmm_impl_sm80_m32.cu
│ │ │ │ │ ├── qmm_impl_sm80_m64.cu
│ │ │ │ │ ├── qmm_impl_sm90.cuh
│ │ │ │ │ ├── qmm_impl_sm90_m128_n128_m2.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n16_m1.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n256_m2.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n32_m1.cu
│ │ │ │ │ ├── qmm_impl_sm90_m128_n64_m2.cu
│ │ │ │ │ └── qmv.cu
│ │ │ │ ├── qqmm.cpp
│ │ │ │ ├── qqmm_impl.cpp
│ │ │ │ ├── qqmm_impl.h
│ │ │ │ ├── qqmm_utils.cu
│ │ │ │ ├── qqmm_utils.h
│ │ │ │ ├── quantized.cpp
│ │ │ │ ├── quantized.h
│ │ │ │ └── quantized_utils.h
│ │ │ ├── random.cu
│ │ │ ├── reduce/
│ │ │ │ ├── all_reduce.cu
│ │ │ │ ├── col_reduce.cu
│ │ │ │ ├── init_reduce.cu
│ │ │ │ ├── reduce.cuh
│ │ │ │ ├── reduce_ops.cuh
│ │ │ │ ├── reduce_utils.cuh
│ │ │ │ └── row_reduce.cu
│ │ │ ├── reduce.cu
│ │ │ ├── rms_norm.cu
│ │ │ ├── rope.cu
│ │ │ ├── scaled_dot_product_attention.cpp
│ │ │ ├── scaled_dot_product_attention.cu
│ │ │ ├── scan.cu
│ │ │ ├── slicing.cpp
│ │ │ ├── softmax.cu
│ │ │ ├── sort.cu
│ │ │ ├── steel/
│ │ │ │ ├── defines.cuh
│ │ │ │ ├── gemm.cuh
│ │ │ │ ├── mma.cuh
│ │ │ │ ├── tiles.cuh
│ │ │ │ └── utils.cuh
│ │ │ ├── ternary.cu
│ │ │ ├── unary/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── abs.cu
│ │ │ │ ├── arccos.cu
│ │ │ │ ├── arccosh.cu
│ │ │ │ ├── arcsin.cu
│ │ │ │ ├── arcsinh.cu
│ │ │ │ ├── arctan.cu
│ │ │ │ ├── arctanh.cu
│ │ │ │ ├── bitwise_invert.cu
│ │ │ │ ├── ceil.cu
│ │ │ │ ├── conjugate.cu
│ │ │ │ ├── cos.cu
│ │ │ │ ├── cosh.cu
│ │ │ │ ├── erf.cu
│ │ │ │ ├── erf_inv.cu
│ │ │ │ ├── exp.cu
│ │ │ │ ├── expm1.cu
│ │ │ │ ├── floor.cu
│ │ │ │ ├── imag.cu
│ │ │ │ ├── log.cu
│ │ │ │ ├── log1p.cu
│ │ │ │ ├── logical_not.cu
│ │ │ │ ├── negative.cu
│ │ │ │ ├── real.cu
│ │ │ │ ├── round.cu
│ │ │ │ ├── sigmoid.cu
│ │ │ │ ├── sign.cu
│ │ │ │ ├── sin.cu
│ │ │ │ ├── sinh.cu
│ │ │ │ ├── sqrt.cu
│ │ │ │ ├── square.cu
│ │ │ │ ├── tan.cu
│ │ │ │ ├── tanh.cu
│ │ │ │ └── unary.cuh
│ │ │ ├── utils.cpp
│ │ │ ├── utils.h
│ │ │ ├── vector_types.cuh
│ │ │ ├── worker.cpp
│ │ │ └── worker.h
│ │ ├── gpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── copy.cpp
│ │ │ ├── copy.h
│ │ │ ├── device_info.h
│ │ │ ├── eval.h
│ │ │ ├── primitives.cpp
│ │ │ ├── scan.h
│ │ │ ├── slicing.cpp
│ │ │ └── slicing.h
│ │ ├── metal/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── allocator.cpp
│ │ │ ├── allocator.h
│ │ │ ├── binary.cpp
│ │ │ ├── binary.h
│ │ │ ├── compiled.cpp
│ │ │ ├── conv.cpp
│ │ │ ├── copy.cpp
│ │ │ ├── custom_kernel.cpp
│ │ │ ├── device.cpp
│ │ │ ├── device.h
│ │ │ ├── device_info.cpp
│ │ │ ├── distributed.cpp
│ │ │ ├── eval.cpp
│ │ │ ├── event.cpp
│ │ │ ├── fence.cpp
│ │ │ ├── fft.cpp
│ │ │ ├── hadamard.cpp
│ │ │ ├── indexing.cpp
│ │ │ ├── jit/
│ │ │ │ ├── includes.h
│ │ │ │ └── indexing.h
│ │ │ ├── jit_kernels.cpp
│ │ │ ├── kernels/
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── arange.h
│ │ │ │ ├── arange.metal
│ │ │ │ ├── arg_reduce.metal
│ │ │ │ ├── atomic.h
│ │ │ │ ├── bf16.h
│ │ │ │ ├── bf16_math.h
│ │ │ │ ├── binary.h
│ │ │ │ ├── binary.metal
│ │ │ │ ├── binary_ops.h
│ │ │ │ ├── binary_two.h
│ │ │ │ ├── binary_two.metal
│ │ │ │ ├── cexpf.h
│ │ │ │ ├── complex.h
│ │ │ │ ├── conv.metal
│ │ │ │ ├── copy.h
│ │ │ │ ├── copy.metal
│ │ │ │ ├── defines.h
│ │ │ │ ├── erf.h
│ │ │ │ ├── expm1f.h
│ │ │ │ ├── fence.metal
│ │ │ │ ├── fft/
│ │ │ │ │ ├── radix.h
│ │ │ │ │ └── readwrite.h
│ │ │ │ ├── fft.h
│ │ │ │ ├── fft.metal
│ │ │ │ ├── fp4.h
│ │ │ │ ├── fp8.h
│ │ │ │ ├── fp_quantized.h
│ │ │ │ ├── fp_quantized.metal
│ │ │ │ ├── fp_quantized_nax.h
│ │ │ │ ├── fp_quantized_nax.metal
│ │ │ │ ├── gemv.metal
│ │ │ │ ├── gemv_masked.h
│ │ │ │ ├── gemv_masked.metal
│ │ │ │ ├── hadamard.h
│ │ │ │ ├── indexing/
│ │ │ │ │ ├── gather.h
│ │ │ │ │ ├── gather_axis.h
│ │ │ │ │ ├── gather_front.h
│ │ │ │ │ ├── indexing.h
│ │ │ │ │ ├── masked_scatter.h
│ │ │ │ │ ├── scatter.h
│ │ │ │ │ └── scatter_axis.h
│ │ │ │ ├── layer_norm.metal
│ │ │ │ ├── logging.h
│ │ │ │ ├── logsumexp.h
│ │ │ │ ├── logsumexp.metal
│ │ │ │ ├── quantized.h
│ │ │ │ ├── quantized.metal
│ │ │ │ ├── quantized_nax.h
│ │ │ │ ├── quantized_nax.metal
│ │ │ │ ├── quantized_utils.h
│ │ │ │ ├── random.metal
│ │ │ │ ├── reduce.h
│ │ │ │ ├── reduce.metal
│ │ │ │ ├── reduce_utils.h
│ │ │ │ ├── reduction/
│ │ │ │ │ ├── ops.h
│ │ │ │ │ ├── reduce_all.h
│ │ │ │ │ ├── reduce_col.h
│ │ │ │ │ ├── reduce_init.h
│ │ │ │ │ └── reduce_row.h
│ │ │ │ ├── rms_norm.metal
│ │ │ │ ├── rope.metal
│ │ │ │ ├── scaled_dot_product_attention.metal
│ │ │ │ ├── scan.h
│ │ │ │ ├── scan.metal
│ │ │ │ ├── sdpa_vector.h
│ │ │ │ ├── softmax.h
│ │ │ │ ├── softmax.metal
│ │ │ │ ├── sort.h
│ │ │ │ ├── sort.metal
│ │ │ │ ├── steel/
│ │ │ │ │ ├── attn/
│ │ │ │ │ │ ├── attn.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_attention.h
│ │ │ │ │ │ │ ├── steel_attention.metal
│ │ │ │ │ │ │ ├── steel_attention_nax.h
│ │ │ │ │ │ │ └── steel_attention_nax.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── mma.h
│ │ │ │ │ │ ├── nax.h
│ │ │ │ │ │ ├── params.h
│ │ │ │ │ │ └── transforms.h
│ │ │ │ │ ├── conv/
│ │ │ │ │ │ ├── conv.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_conv.h
│ │ │ │ │ │ │ ├── steel_conv.metal
│ │ │ │ │ │ │ ├── steel_conv_3d.h
│ │ │ │ │ │ │ ├── steel_conv_3d.metal
│ │ │ │ │ │ │ ├── steel_conv_general.h
│ │ │ │ │ │ │ └── steel_conv_general.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── loaders/
│ │ │ │ │ │ │ ├── loader_channel_l.h
│ │ │ │ │ │ │ ├── loader_channel_n.h
│ │ │ │ │ │ │ └── loader_general.h
│ │ │ │ │ │ └── params.h
│ │ │ │ │ ├── defines.h
│ │ │ │ │ ├── gemm/
│ │ │ │ │ │ ├── gemm.h
│ │ │ │ │ │ ├── gemm_nax.h
│ │ │ │ │ │ ├── kernels/
│ │ │ │ │ │ │ ├── steel_gemm_fused.h
│ │ │ │ │ │ │ ├── steel_gemm_fused.metal
│ │ │ │ │ │ │ ├── steel_gemm_fused_nax.h
│ │ │ │ │ │ │ ├── steel_gemm_fused_nax.metal
│ │ │ │ │ │ │ ├── steel_gemm_gather.h
│ │ │ │ │ │ │ ├── steel_gemm_gather.metal
│ │ │ │ │ │ │ ├── steel_gemm_gather_nax.h
│ │ │ │ │ │ │ ├── steel_gemm_gather_nax.metal
│ │ │ │ │ │ │ ├── steel_gemm_masked.h
│ │ │ │ │ │ │ ├── steel_gemm_masked.metal
│ │ │ │ │ │ │ ├── steel_gemm_segmented.h
│ │ │ │ │ │ │ ├── steel_gemm_segmented.metal
│ │ │ │ │ │ │ ├── steel_gemm_splitk.h
│ │ │ │ │ │ │ ├── steel_gemm_splitk.metal
│ │ │ │ │ │ │ ├── steel_gemm_splitk_nax.h
│ │ │ │ │ │ │ └── steel_gemm_splitk_nax.metal
│ │ │ │ │ │ ├── loader.h
│ │ │ │ │ │ ├── mma.h
│ │ │ │ │ │ ├── nax.h
│ │ │ │ │ │ ├── params.h
│ │ │ │ │ │ └── transforms.h
│ │ │ │ │ ├── utils/
│ │ │ │ │ │ ├── integral_constant.h
│ │ │ │ │ │ └── type_traits.h
│ │ │ │ │ └── utils.h
│ │ │ │ ├── ternary.h
│ │ │ │ ├── ternary.metal
│ │ │ │ ├── ternary_ops.h
│ │ │ │ ├── unary.h
│ │ │ │ ├── unary.metal
│ │ │ │ ├── unary_ops.h
│ │ │ │ └── utils.h
│ │ │ ├── kernels.h
│ │ │ ├── logsumexp.cpp
│ │ │ ├── make_compiled_preamble.sh
│ │ │ ├── matmul.cpp
│ │ │ ├── matmul.h
│ │ │ ├── metal.cpp
│ │ │ ├── metal.h
│ │ │ ├── no_metal.cpp
│ │ │ ├── nojit_kernels.cpp
│ │ │ ├── normalization.cpp
│ │ │ ├── primitives.cpp
│ │ │ ├── quantized.cpp
│ │ │ ├── reduce.cpp
│ │ │ ├── reduce.h
│ │ │ ├── resident.cpp
│ │ │ ├── resident.h
│ │ │ ├── rope.cpp
│ │ │ ├── scaled_dot_product_attention.cpp
│ │ │ ├── scan.cpp
│ │ │ ├── slicing.cpp
│ │ │ ├── softmax.cpp
│ │ │ ├── sort.cpp
│ │ │ ├── ternary.cpp
│ │ │ ├── ternary.h
│ │ │ ├── unary.cpp
│ │ │ ├── unary.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── no_cpu/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── compiled.cpp
│ │ │ ├── device_info.cpp
│ │ │ └── primitives.cpp
│ │ └── no_gpu/
│ │ ├── CMakeLists.txt
│ │ ├── allocator.cpp
│ │ ├── apple_memory.h
│ │ ├── device_info.cpp
│ │ ├── eval.cpp
│ │ ├── event.cpp
│ │ ├── fence.cpp
│ │ ├── linux_memory.h
│ │ └── primitives.cpp
│ ├── compile.cpp
│ ├── compile.h
│ ├── compile_impl.h
│ ├── device.cpp
│ ├── device.h
│ ├── distributed/
│ │ ├── CMakeLists.txt
│ │ ├── distributed.cpp
│ │ ├── distributed.h
│ │ ├── distributed_impl.h
│ │ ├── jaccl/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── jaccl.cpp
│ │ │ ├── jaccl.h
│ │ │ ├── mesh.cpp
│ │ │ ├── mesh.h
│ │ │ ├── mesh_impl.h
│ │ │ ├── no_jaccl.cpp
│ │ │ ├── ring.cpp
│ │ │ ├── ring.h
│ │ │ ├── ring_impl.h
│ │ │ ├── utils.cpp
│ │ │ └── utils.h
│ │ ├── mpi/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── mpi.cpp
│ │ │ ├── mpi.h
│ │ │ ├── mpi_declarations.h
│ │ │ └── no_mpi.cpp
│ │ ├── nccl/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── nccl.cpp
│ │ │ ├── nccl.h
│ │ │ └── no_nccl.cpp
│ │ ├── ops.cpp
│ │ ├── ops.h
│ │ ├── primitives.cpp
│ │ ├── primitives.h
│ │ ├── reduction_ops.h
│ │ ├── ring/
│ │ │ ├── CMakeLists.txt
│ │ │ ├── no_ring.cpp
│ │ │ ├── ring.cpp
│ │ │ └── ring.h
│ │ ├── utils.cpp
│ │ └── utils.h
│ ├── dtype.cpp
│ ├── dtype.h
│ ├── dtype_utils.cpp
│ ├── dtype_utils.h
│ ├── einsum.cpp
│ ├── einsum.h
│ ├── event.h
│ ├── export.cpp
│ ├── export.h
│ ├── export_impl.h
│ ├── fast.cpp
│ ├── fast.h
│ ├── fast_primitives.h
│ ├── fence.h
│ ├── fft.cpp
│ ├── fft.h
│ ├── graph_utils.cpp
│ ├── graph_utils.h
│ ├── io/
│ │ ├── CMakeLists.txt
│ │ ├── gguf.cpp
│ │ ├── gguf.h
│ │ ├── gguf_quants.cpp
│ │ ├── load.cpp
│ │ ├── load.h
│ │ ├── no_gguf.cpp
│ │ ├── no_safetensors.cpp
│ │ └── safetensors.cpp
│ ├── io.h
│ ├── linalg.cpp
│ ├── linalg.h
│ ├── memory.h
│ ├── mlx.h
│ ├── ops.cpp
│ ├── ops.h
│ ├── primitives.cpp
│ ├── primitives.h
│ ├── random.cpp
│ ├── random.h
│ ├── scheduler.cpp
│ ├── scheduler.h
│ ├── small_vector.h
│ ├── stream.h
│ ├── threadpool.h
│ ├── transforms.cpp
│ ├── transforms.h
│ ├── transforms_impl.h
│ ├── types/
│ │ ├── bf16.h
│ │ ├── complex.h
│ │ ├── fp16.h
│ │ ├── half_types.h
│ │ └── limits.h
│ ├── utils.cpp
│ ├── utils.h
│ ├── version.cpp
│ └── version.h
├── mlx.pc.in
├── pyproject.toml
├── python/
│ ├── mlx/
│ │ ├── __main__.py
│ │ ├── _distributed_utils/
│ │ │ ├── common.py
│ │ │ ├── config.py
│ │ │ └── launch.py
│ │ ├── _reprlib_fix.py
│ │ ├── _stub_patterns.txt
│ │ ├── extension.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── init.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activations.py
│ │ │ │ ├── base.py
│ │ │ │ ├── containers.py
│ │ │ │ ├── convolution.py
│ │ │ │ ├── convolution_transpose.py
│ │ │ │ ├── distributed.py
│ │ │ │ ├── dropout.py
│ │ │ │ ├── embedding.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── normalization.py
│ │ │ │ ├── pooling.py
│ │ │ │ ├── positional_encoding.py
│ │ │ │ ├── quantized.py
│ │ │ │ ├── recurrent.py
│ │ │ │ ├── transformer.py
│ │ │ │ └── upsample.py
│ │ │ ├── losses.py
│ │ │ └── utils.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── optimizers.py
│ │ │ └── schedulers.py
│ │ ├── py.typed
│ │ └── utils.py
│ ├── src/
│ │ ├── CMakeLists.txt
│ │ ├── array.cpp
│ │ ├── buffer.h
│ │ ├── constants.cpp
│ │ ├── convert.cpp
│ │ ├── convert.h
│ │ ├── cuda.cpp
│ │ ├── device.cpp
│ │ ├── distributed.cpp
│ │ ├── export.cpp
│ │ ├── fast.cpp
│ │ ├── fft.cpp
│ │ ├── indexing.cpp
│ │ ├── indexing.h
│ │ ├── linalg.cpp
│ │ ├── load.cpp
│ │ ├── load.h
│ │ ├── memory.cpp
│ │ ├── metal.cpp
│ │ ├── mlx.cpp
│ │ ├── mlx_func.cpp
│ │ ├── mlx_func.h
│ │ ├── ops.cpp
│ │ ├── random.cpp
│ │ ├── small_vector.h
│ │ ├── stream.cpp
│ │ ├── transforms.cpp
│ │ ├── trees.cpp
│ │ ├── trees.h
│ │ ├── utils.cpp
│ │ └── utils.h
│ └── tests/
│ ├── __main__.py
│ ├── cuda_skip.py
│ ├── mlx_distributed_tests.py
│ ├── mlx_tests.py
│ ├── mpi_test_distributed.py
│ ├── nccl_test_distributed.py
│ ├── ring_test_distributed.py
│ ├── test_array.py
│ ├── test_autograd.py
│ ├── test_bf16.py
│ ├── test_blas.py
│ ├── test_compile.py
│ ├── test_constants.py
│ ├── test_conv.py
│ ├── test_conv_transpose.py
│ ├── test_device.py
│ ├── test_double.py
│ ├── test_einsum.py
│ ├── test_eval.py
│ ├── test_export_import.py
│ ├── test_fast.py
│ ├── test_fast_sdpa.py
│ ├── test_fft.py
│ ├── test_graph.py
│ ├── test_init.py
│ ├── test_linalg.py
│ ├── test_load.py
│ ├── test_losses.py
│ ├── test_memory.py
│ ├── test_nn.py
│ ├── test_ops.py
│ ├── test_optimizers.py
│ ├── test_quantized.py
│ ├── test_random.py
│ ├── test_reduce.py
│ ├── test_tree.py
│ ├── test_upsample.py
│ └── test_vmap.py
├── setup.py
└── tests/
├── CMakeLists.txt
├── allocator_tests.cpp
├── arg_reduce_tests.cpp
├── array_tests.cpp
├── autograd_tests.cpp
├── blas_tests.cpp
├── compile_tests.cpp
├── creations_tests.cpp
├── custom_vjp_tests.cpp
├── device_tests.cpp
├── einsum_tests.cpp
├── eval_tests.cpp
├── export_import_tests.cpp
├── fft_tests.cpp
├── gpu_tests.cpp
├── linalg_tests.cpp
├── load_tests.cpp
├── ops_tests.cpp
├── random_tests.cpp
├── scheduler_tests.cpp
├── test_teardown.cpp
├── tests.cpp
├── utils_tests.cpp
└── vmap_tests.cpp
Showing preview only (317K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (4174 symbols across 517 files)
FILE: benchmarks/cpp/autograd.cpp
function time_value_and_grad (line 10) | void time_value_and_grad() {
function main (line 36) | int main() {
FILE: benchmarks/cpp/compare_devices.cpp
function time_add_op (line 9) | void time_add_op() {
function main (line 25) | int main() {
FILE: benchmarks/cpp/irregular_strides.cpp
function time_irregular_binary_ops_1D (line 12) | void time_irregular_binary_ops_1D() {
function time_irregular_binary_ops_2D (line 24) | void time_irregular_binary_ops_2D() {
function time_irregular_binary_ops_3D (line 45) | void time_irregular_binary_ops_3D() {
function time_irregular_binary_ops_4D (line 76) | void time_irregular_binary_ops_4D() {
function time_irregular_reshape (line 116) | void time_irregular_reshape() {
function time_irregular_astype_1D (line 161) | void time_irregular_astype_1D() {
function time_irregular_astype_2D (line 170) | void time_irregular_astype_2D() {
function main (line 188) | int main(int argc, char** argv) {
FILE: benchmarks/cpp/single_ops.cpp
function time_creation_ops (line 8) | void time_creation_ops() {
function time_type_conversions (line 23) | void time_type_conversions() {
function time_random_generation (line 45) | void time_random_generation() {
function time_unary_ops (line 55) | void time_unary_ops() {
function time_binary_ops (line 74) | void time_binary_ops() {
function time_strided_ops (line 112) | void time_strided_ops() {
function time_comparisons (line 125) | void time_comparisons() {
function time_matvec (line 138) | void time_matvec() {
function time_matmul (line 151) | void time_matmul() {
function time_reductions (line 163) | void time_reductions() {
function time_gather_scatter (line 213) | void time_gather_scatter() {
function time_divmod (line 260) | void time_divmod() {
function main (line 274) | int main() {
FILE: benchmarks/numpy/single_ops.py
function time_add (line 7) | def time_add():
function time_matmul (line 13) | def time_matmul():
function time_exp (line 19) | def time_exp():
function time_take (line 24) | def time_take():
FILE: benchmarks/numpy/time_utils.py
function time_fn (line 6) | def time_fn(fn, *args):
FILE: benchmarks/python/batch_matmul_bench.py
function time_batch_matmul (line 13) | def time_batch_matmul():
function time_unbatch_matmul (line 33) | def time_unbatch_matmul():
FILE: benchmarks/python/blas/bench_gemm.py
function bench (line 21) | def bench(f, a, b):
function gemm_nn_mlx (line 33) | def gemm_nn_mlx(a, b):
function gemm_nt_mlx (line 42) | def gemm_nt_mlx(a, b):
function gemm_tn_mlx (line 51) | def gemm_tn_mlx(a, b):
function gemm_tt_mlx (line 60) | def gemm_tt_mlx(a, b):
function gemm_nn_torch (line 70) | def gemm_nn_torch(a, b):
function gemm_nt_torch (line 80) | def gemm_nt_torch(a, b):
function gemm_tn_torch (line 90) | def gemm_tn_torch(a, b):
function gemm_tt_torch (line 100) | def gemm_tt_torch(a, b):
function bench_shape (line 109) | def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
function get_gflop_count (line 157) | def get_gflop_count(B, M, N, K):
FILE: benchmarks/python/blas/bench_gemv.py
function bench (line 36) | def bench(f, m, v):
function gemv_mlx (line 48) | def gemv_mlx(m, v):
function gemv_t_mlx (line 57) | def gemv_t_mlx(m, v):
function gemv_torch (line 67) | def gemv_torch(m, v):
function gemv_t_torch (line 77) | def gemv_t_torch(m, v):
function bench_lens (line 86) | def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
function get_gflop_count (line 123) | def get_gflop_count(in_vec_len, out_vec_len):
function get_gbyte_size (line 129) | def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
function bench_with_in_len (line 135) | def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
function bench_with_out_len (line 166) | def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
FILE: benchmarks/python/comparative/bench_mlx.py
function int_or_list (line 13) | def int_or_list(x):
function none_or_list (line 20) | def none_or_list(x):
function dtype_from_str (line 27) | def dtype_from_str(x):
function bench (line 37) | def bench(f, *args):
function matmul_square (line 48) | def matmul_square(x):
function matmul (line 56) | def matmul(x, y):
function _quant_matmul (line 63) | def _quant_matmul(x, w, s, b, transpose, group_size, bits):
function conv1d (line 120) | def conv1d(x, y):
function conv2d (line 127) | def conv2d(x, y):
function binary (line 134) | def binary(op, x, y):
function reduction (line 140) | def reduction(op, axis, x):
function sum_and_add (line 147) | def sum_and_add(axis, x, y):
function softmax (line 154) | def softmax(axis, x):
function softmax_fused (line 163) | def softmax_fused(axis, x):
function relu (line 171) | def relu(x):
function leaky_relu (line 178) | def leaky_relu(x: mx.array):
function prelu (line 185) | def prelu(x: mx.array):
function softplus (line 192) | def softplus(x: mx.array):
function mish (line 199) | def mish(x: mx.array):
function leaky_relu (line 206) | def leaky_relu(x):
function elu (line 213) | def elu(x):
function relu6 (line 220) | def relu6(x):
function softplus (line 227) | def softplus(x):
function celu (line 234) | def celu(x):
function log_sigmoid (line 241) | def log_sigmoid(x):
function scalar_mult (line 248) | def scalar_mult(x):
function cross_entropy (line 255) | def cross_entropy(targets, x):
function logsumexp (line 265) | def logsumexp(axis, x):
function linear (line 272) | def linear(w, b, x):
function linear_fused (line 279) | def linear_fused(w, b, x):
function rope (line 286) | def rope(x):
function concatenate (line 307) | def concatenate(axis, x, y):
function cumsum (line 314) | def cumsum(axis, x):
function sort (line 321) | def sort(axis, x):
function topk (line 328) | def topk(axis, x):
function step_function (line 336) | def step_function(x):
function selu (line 343) | def selu(x):
FILE: benchmarks/python/comparative/bench_torch.py
function int_or_list (line 12) | def int_or_list(x):
function none_or_list (line 19) | def none_or_list(x):
function dtype_from_str (line 26) | def dtype_from_str(x):
function bench (line 36) | def bench(f, *args):
function sync_if_needed (line 47) | def sync_if_needed(x):
function matmul_square (line 55) | def matmul_square(x):
function matmul (line 63) | def matmul(x, y):
function conv1d (line 71) | def conv1d(x, y):
function conv2d (line 81) | def conv2d(x, y):
function binary (line 91) | def binary(op, x, y):
function reduction (line 98) | def reduction(op, axis, x):
function sum_and_add (line 106) | def sum_and_add(axis, x, y):
function softmax (line 114) | def softmax(axis, x):
function softmax_fused (line 124) | def softmax_fused(axis, x):
function relu (line 132) | def relu(x):
function leaky_relu (line 140) | def leaky_relu(x):
function elu (line 148) | def elu(x):
function celu (line 156) | def celu(x):
function relu6 (line 164) | def relu6(x):
function softplus (line 172) | def softplus(x):
function log_sigmoid (line 180) | def log_sigmoid(x):
function prelu (line 188) | def prelu(x: torch.Tensor) -> torch.Tensor:
function mish (line 196) | def mish(x: torch.Tensor) -> torch.Tensor:
function scalar_mult (line 204) | def scalar_mult(x):
function cross_entropy (line 212) | def cross_entropy(targets, x):
function logsumexp (line 220) | def logsumexp(axis, x):
function linear_fused (line 228) | def linear_fused(w, b, x):
function linear (line 236) | def linear(w, b, x):
function rope (line 244) | def rope(x):
function concatenate (line 265) | def concatenate(axis, x, y):
function cumsum (line 273) | def cumsum(axis, x):
function sort (line 281) | def sort(axis, x):
function topk (line 289) | def topk(axis, x):
function step_function (line 298) | def step_function(x):
function selu (line 306) | def selu(x):
FILE: benchmarks/python/comparative/compare.py
function run_or_raise (line 14) | def run_or_raise(*args, **kwargs):
function compare (line 24) | def compare(args):
function compare_mlx_dtypes (line 31) | def compare_mlx_dtypes(args, dt1, dt2):
function make_regex_search (line 38) | def make_regex_search(regexes):
function make_predicate (line 47) | def make_predicate(positive_filter, negative_filter):
FILE: benchmarks/python/compile_bench.py
function bench_gelu (line 11) | def bench_gelu():
function bench_layernorm (line 52) | def bench_layernorm():
FILE: benchmarks/python/conv1d_bench.py
function bench (line 19) | def bench(f, a, b):
function make_mx_conv_1D (line 31) | def make_mx_conv_1D(strides=1, padding=0, groups=1):
function make_pt_conv_1D (line 43) | def make_pt_conv_1D(strides=1, padding=0, groups=1):
function bench_shape (line 56) | def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
FILE: benchmarks/python/conv2d_bench_cpu.py
function bench (line 15) | def bench(f, a, b):
function make_mx_conv_2D (line 26) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_2D (line 38) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 50) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
FILE: benchmarks/python/conv2d_train_bench_cpu.py
function bench_mlx (line 9) | def bench_mlx(steps: int = 20) -> float:
function bench_torch (line 73) | def bench_torch(steps: int = 20) -> float:
function main (line 128) | def main():
FILE: benchmarks/python/conv2d_transpose_bench_cpu.py
function bench (line 14) | def bench(f, a, b):
function make_mx_conv_transpose_2D (line 25) | def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_transpose_2D (line 39) | def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 53) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
FILE: benchmarks/python/conv3d_bench.py
function bench (line 13) | def bench(f, a, b, b_prime):
function make_mx_conv_3D (line 25) | def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
function make_pt_conv_3D (line 37) | def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
function bench_shape (line 50) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ...
FILE: benchmarks/python/conv3d_bench_cpu.py
function bench (line 15) | def bench(f, a, b):
function make_mx_conv_3D (line 26) | def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_3D (line 38) | def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 50) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ...
FILE: benchmarks/python/conv3d_train_bench_cpu.py
function bench_mlx (line 9) | def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
function bench_torch (line 73) | def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
function main (line 128) | def main():
FILE: benchmarks/python/conv3d_transpose_bench_cpu.py
function bench (line 15) | def bench(f, a, b):
function make_mx_conv_3D (line 26) | def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
function make_pt_conv_3D (line 40) | def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
function bench_shape (line 54) | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, ...
FILE: benchmarks/python/conv_bench.py
function bench (line 19) | def bench(f, a, b):
function make_mx_conv_2D (line 31) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_2D (line 43) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 56) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
FILE: benchmarks/python/conv_transpose_bench.py
function bench (line 16) | def bench(f, a, b):
function make_mx_conv_transpose_2D (line 28) | def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_transpose_2D (line 42) | def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 57) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
FILE: benchmarks/python/conv_unaligned_bench.py
function bench (line 13) | def bench(f, a, b):
function make_mx_conv_2D (line 25) | def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function make_pt_conv_2D (line 37) | def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
function bench_shape (line 50) | def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
FILE: benchmarks/python/distributed_bench.py
function time_fn (line 13) | def time_fn(fn, *args, **kwargs):
function time_all_sum (line 37) | def time_all_sum():
FILE: benchmarks/python/einsum_bench.py
function timeit (line 9) | def timeit(fn, its=100, args=[]):
function time_little_einsum_path (line 19) | def time_little_einsum_path():
function time_big_einsum_path (line 33) | def time_big_einsum_path():
function time_attention (line 55) | def time_attention():
FILE: benchmarks/python/fft_bench.py
function bandwidth_gb (line 14) | def bandwidth_gb(runtime_ms, system_size):
function run_bench (line 21) | def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
function time_fft (line 62) | def time_fft():
FILE: benchmarks/python/gather_bench.py
function benchmark_gather_mlx (line 10) | def benchmark_gather_mlx(x_shape, idx_shape):
function benchmark_gather_torch (line 21) | def benchmark_gather_torch(x_shape, idx_shape, device):
FILE: benchmarks/python/gather_mm_bench.py
function gather_sort (line 13) | def gather_sort(x, indices):
function scatter_unsort (line 21) | def scatter_unsort(x, inv_order, shape=None):
function gather_mm_simulate (line 28) | def gather_mm_simulate(x, w, indices):
function time_gather_mm (line 37) | def time_gather_mm():
FILE: benchmarks/python/gather_qmm_bench.py
function gather_sort (line 13) | def gather_sort(x, indices):
function scatter_unsort (line 21) | def scatter_unsort(x, inv_order, shape=None):
function gather_mm_simulate (line 28) | def gather_mm_simulate(x, w, indices):
function time_gather_qmm (line 43) | def time_gather_qmm():
FILE: benchmarks/python/hadamard_bench.py
function had (line 12) | def had(x):
function copy (line 17) | def copy(x):
function run (line 22) | def run(dtype):
FILE: benchmarks/python/large_gemm_bench.py
function bench_mlx (line 14) | def bench_mlx(a, b):
function bench_torch (line 29) | def bench_torch(a, b):
function check_correctness (line 45) | def check_correctness(out_mx, out_pt, rtol, M, N, K):
function bench_gemm (line 56) | def bench_gemm(M, N, K, dtype, rtol):
FILE: benchmarks/python/layer_norm_bench.py
function layer_norm (line 10) | def layer_norm(x, w, b, eps):
function time_layer_norm (line 23) | def time_layer_norm(N, dt):
FILE: benchmarks/python/masked_scatter.py
function get_device_name (line 28) | def get_device_name():
function _power_of_two_formatter (line 62) | def _power_of_two_formatter(value, _position):
function torch_sync (line 71) | def torch_sync():
function masked_scatter_mlx (line 78) | def masked_scatter_mlx(self_arr, mask_arr, src_arr):
function masked_scatter_torch (line 89) | def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
function measure (line 99) | def measure(fn):
function bytes_touched (line 109) | def bytes_touched(length, true_count, item_size):
function build_case (line 116) | def build_case(length, density, np_dtype, torch_dtype):
function bench_case (line 148) | def bench_case(length, density, dtype):
function plot_density (line 174) | def plot_density(ax_perf, ax_speedup, density, dtype):
function main (line 208) | def main():
FILE: benchmarks/python/rms_norm_bench.py
function rms_norm (line 8) | def rms_norm(x, w, eps):
function time_rms_norm (line 18) | def time_rms_norm():
FILE: benchmarks/python/rope_bench.py
function time_rope (line 8) | def time_rope():
FILE: benchmarks/python/scatter_bench.py
function benchmark_scatter_mlx (line 10) | def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
function benchmark_scatter_torch (line 25) | def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
FILE: benchmarks/python/sdpa_bench.py
function bench (line 20) | def bench(f, *args):
function prepare_inputs (line 31) | def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
function mlx_ref_attn (line 58) | def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
function mlx_fused_attn (line 104) | def mlx_fused_attn(q, k, v, scale, mask):
function do_attention (line 108) | def do_attention(f, q, k, v, scale, mask=None, transpose=False):
function do_attention_bench (line 119) | def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
function bench_shape (line 129) | def bench_shape(
function get_gflop_count (line 158) | def get_gflop_count(B, M, N, K):
FILE: benchmarks/python/sdpa_vector_bench.py
function upproject (line 16) | def upproject(x, w):
function attention (line 23) | def attention(q, k, v, mask=None, w=None):
function sdpa (line 45) | def sdpa(q, k, v, mask=None, w=None):
function time_self_attention_primitives (line 52) | def time_self_attention_primitives():
function time_self_attention_sdpa (line 62) | def time_self_attention_sdpa():
function time_self_attention_sdpa_with_mask (line 72) | def time_self_attention_sdpa_with_mask():
FILE: benchmarks/python/segmented_mm_bench.py
function parse_cases (line 16) | def parse_cases(cases):
function make_segments (line 24) | def make_segments(k, num_segments, pattern, seed):
function numpy_segmented_mm_ref (line 35) | def numpy_segmented_mm_ref(a, b, segments):
function mlx_segmented_mm_loop (line 43) | def mlx_segmented_mm_loop(a, b, segments):
function bench_mlx (line 52) | def bench_mlx(a, b, segments, warmup, iters):
function bench_mlx_loop (line 67) | def bench_mlx_loop(a, b, segments, warmup, iters):
function print_table (line 82) | def print_table(headers, rows):
function main (line 102) | def main():
FILE: benchmarks/python/single_ops.py
function time_add (line 9) | def time_add():
function time_matmul (line 40) | def time_matmul():
function time_maximum (line 47) | def time_maximum():
function time_max (line 54) | def time_max():
function time_min (line 61) | def time_min():
function time_negative (line 68) | def time_negative():
function time_exp (line 80) | def time_exp():
function time_logsumexp (line 86) | def time_logsumexp():
function time_take (line 92) | def time_take():
function time_reshape_transposed (line 104) | def time_reshape_transposed():
FILE: benchmarks/python/slice_update_bench.py
function benchmark_slice_update_mlx (line 10) | def benchmark_slice_update_mlx(dst_shape, slice_shape, slice_range, dtyp...
function benchmark_slice_update_torch (line 32) | def benchmark_slice_update_torch(
FILE: benchmarks/python/synchronize_bench.py
function timeit (line 8) | def timeit(fn, a):
function all_reduce_benchmark (line 23) | def all_reduce_benchmark():
function all_gather_benchmark (line 39) | def all_gather_benchmark():
FILE: benchmarks/python/time_utils.py
function time_fn (line 8) | def time_fn(fn, *args, **kwargs):
function measure_runtime (line 29) | def measure_runtime(fn, **kwargs):
FILE: docs/src/conf.py
function setup (line 71) | def setup(app):
FILE: examples/cmake_project/example.cpp
function main (line 9) | int main() {
FILE: examples/cpp/distributed.cpp
function main (line 9) | int main() {
FILE: examples/cpp/linear_regression.cpp
function main (line 15) | int main() {
FILE: examples/cpp/logistic_regression.cpp
function main (line 15) | int main() {
FILE: examples/cpp/metal_capture.cpp
function main (line 10) | int main() {
FILE: examples/cpp/timer.h
function namespace (line 7) | namespace timer {
FILE: examples/cpp/tutorial.cpp
function array_basics (line 10) | void array_basics() {
function automatic_differentiation (line 81) | void automatic_differentiation() {
function main (line 96) | int main() {
FILE: examples/export/eval_mlp.cpp
function main (line 8) | int main() {
FILE: examples/export/eval_mlp.py
class MLP (line 8) | class MLP(nn.Module):
method __init__ (line 11) | def __init__(
method __call__ (line 21) | def __call__(self, x):
function forward (line 39) | def forward(x):
FILE: examples/export/train_mlp.cpp
function main (line 8) | int main() {
FILE: examples/export/train_mlp.py
class MLP (line 9) | class MLP(nn.Module):
method __init__ (line 12) | def __init__(
method __call__ (line 22) | def __call__(self, x):
function init (line 34) | def init():
function loss_fn (line 51) | def loss_fn(params, X, y):
function step (line 55) | def step(*inputs):
FILE: examples/extensions/axpby/axpby.cpp
type my_ext (line 18) | namespace my_ext {
function current_binary_dir (line 22) | std::string current_binary_dir() {
function axpby (line 44) | mx::array axpby(
function axpby_impl (line 82) | void axpby_impl(
FILE: examples/extensions/axpby/axpby.h
function namespace (line 10) | namespace my_ext {
FILE: examples/extensions/bindings.cpp
function NB_MODULE (line 11) | NB_MODULE(_ext, m) {
FILE: examples/python/linear_regression.py
function loss_fn (line 26) | def loss_fn(w):
FILE: examples/python/logistic_regression.py
function loss_fn (line 26) | def loss_fn(w):
FILE: examples/python/qqmm.py
function ulp_bf16_at (line 14) | def ulp_bf16_at(x):
function test_qqmm (line 22) | def test_qqmm():
function test_qqmm_vjp (line 80) | def test_qqmm_vjp():
FILE: mlx/3rdparty/pocketfft.h
function namespace (line 91) | namespace pocketfft {
function cmplx (line 302) | static cmplx<Thigh> calc(size_t x, size_t n, Thigh ang)
type util (line 369) | struct util // hack to avoid duplicate symbols
function POCKETFFT_NOINLINE (line 383) | static POCKETFFT_NOINLINE double cost_guess (size_t n)
function POCKETFFT_NOINLINE (line 401) | static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n)
function POCKETFFT_NOINLINE (line 430) | static POCKETFFT_NOINLINE size_t good_size_real(size_t n)
function prod (line 456) | static size_t prod(const shape_t &shape)
function shutdown (line 753) | void shutdown()
function restart (line 759) | void restart()
function thread_pool (line 766) | inline thread_pool & get_pool()
type fctdata (line 829) | struct fctdata
function add_factor (line 839) | void add_factor(size_t factor)
function twsize (line 1498) | size_t twsize() const
function comp_twiddle (line 1512) | void comp_twiddle()
type fctdata (line 1555) | struct fctdata
function add_factor (line 1565) | void add_factor(size_t factor)
function twsize (line 2297) | size_t twsize() const
function comp_twiddle (line 2310) | void comp_twiddle()
function exec (line 2527) | void exec(T c[], T0 fct, bool fwd) const
function exec (line 2547) | void exec(T c[], T0 fct, bool ortho,
function exec (line 2578) | void exec(T c[], T0 fct,
function exec (line 2609) | void exec(T c[], T0 fct, bool ortho,
function exec (line 2687) | void exec(T c[], T0 fct,
function find_in_cache (line 2780) | auto find_in_cache = [&]() -> std::shared_ptr<T>
function class (line 2822) | class arr_info
function advance_i (line 2870) | void advance_i()
function advance (line 2916) | void advance(size_t n)
function iofs (line 2927) | ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*str_i; }
function iofs (line 2928) | ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*...
function oofs (line 2929) | ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*str_o; }
function oofs (line 2930) | ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*...
function class (line 2938) | class simple_iter
function class (line 2966) | class rev_iter
type ExecC2C (line 3166) | struct ExecC2C
type ExecHartley (line 3211) | struct ExecHartley
type ExecDcst (line 3223) | struct ExecDcst
type ExecR2R (line 3366) | struct ExecR2R
function ExecDcst (line 3407) | const ExecDcst exec{ortho, type, true};
function ExecDcst (line 3425) | const ExecDcst exec{ortho, type, false};
function newaxes (line 3461) | auto newaxes = shape_t{axes.begin(), --axes.end()};
function newaxes (line 3499) | auto newaxes = shape_t{axes.begin(), --axes.end()};
FILE: mlx/allocator.h
function namespace (line 9) | namespace mlx::core::allocator {
FILE: mlx/array.cpp
type mlx::core (line 11) | namespace mlx::core {
function array (line 60) | array array::unsafe_weak_copy(const array& other) {
FILE: mlx/array.h
function namespace (line 16) | namespace mlx::core {
function ArrayIterator (line 157) | struct MLX_API ArrayIterator {
function Data (line 231) | struct Data {
type Flags (line 248) | struct Flags {
function set_siblings (line 313) | void set_siblings(std::vector<array> siblings, uint16_t position) {
function buffer_size (line 359) | size_t buffer_size() const {
type Status (line 387) | enum Status {
function is_available (line 404) | bool is_available() const;
function set_status (line 414) | void set_status(Status s) const {
function attach_event (line 424) | void attach_event(Event e) const {
function set_tracer (line 433) | void set_tracer(bool is_tracer) {
function is_tracer (line 437) | bool is_tracer() const;
type MLX_API (line 468) | struct MLX_API
function offset (line 489) | int64_t offset{0};
FILE: mlx/backend/common/binary.h
function namespace (line 9) | namespace mlx::core {
FILE: mlx/backend/common/broadcasting.cpp
type mlx::core (line 5) | namespace mlx::core {
function broadcast (line 7) | void broadcast(const array& in, array& out) {
FILE: mlx/backend/common/broadcasting.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/common/buffer_cache.h
function T (line 30) | T* reuse_from_cache(size_t size) {
function recycle_to_cache (line 48) | void recycle_to_cache(T* buf) {
function release_cached_buffers (line 58) | int release_cached_buffers(size_t min_bytes_to_free) {
function clear (line 87) | int clear() {
function BufferHolder (line 150) | BufferHolder* tail_{nullptr};
FILE: mlx/backend/common/common.cpp
type mlx::core (line 8) | namespace mlx::core {
function prepare_reshape (line 147) | std::pair<bool, Strides> prepare_reshape(const array& in, const array&...
function shared_buffer_reshape (line 184) | void shared_buffer_reshape(
FILE: mlx/backend/common/compiled.cpp
type mlx::core (line 7) | namespace mlx::core {
function print_constant (line 9) | void print_constant(std::ostream& os, const array& x) {
function get_type_string (line 47) | std::string get_type_string(Dtype d) {
function compiled_check_contiguity (line 85) | bool compiled_check_contiguity(
function compiled_allocate_outputs (line 113) | void compiled_allocate_outputs(
function compiled_collapse_contiguous_dims (line 173) | std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contig...
function compiled_use_large_index (line 224) | bool compiled_use_large_index(
FILE: mlx/backend/common/compiled.h
function namespace (line 10) | namespace mlx::core {
function is_scalar (line 47) | inline bool is_scalar(const array& x) {
FILE: mlx/backend/common/copy.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/common/hadamard.h
function namespace (line 9) | namespace mlx::core {
FILE: mlx/backend/common/load.cpp
function swap_endianness (line 12) | void swap_endianness(uint8_t* data_bytes, size_t N) {
type mlx::core (line 28) | namespace mlx::core {
FILE: mlx/backend/common/matmul.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/common/quantized.h
function namespace (line 3) | namespace mlx::core {
FILE: mlx/backend/common/reduce.cpp
type mlx::core (line 5) | namespace mlx::core {
function shapes_without_reduction_axes (line 7) | std::pair<Shape, Strides> shapes_without_reduction_axes(
function shapes_without_reduction_axes (line 20) | std::pair<Shape, Strides> shapes_without_reduction_axes(
function ReductionPlan (line 29) | ReductionPlan get_reduction_plan(const array& x, const std::vector<int...
FILE: mlx/backend/common/reduce.h
type ReductionOpType (line 9) | enum ReductionOpType {
type ReductionPlan (line 39) | struct ReductionPlan {
FILE: mlx/backend/common/slicing.cpp
type mlx::core (line 5) | namespace mlx::core {
function prepare_slice (line 7) | std::tuple<int64_t, Strides> prepare_slice(
function shared_buffer_slice (line 20) | void shared_buffer_slice(
function slice (line 38) | void slice(
FILE: mlx/backend/common/slicing.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/common/ternary.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/backend/common/unary.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/backend/common/utils.cpp
type mlx::core (line 7) | namespace mlx::core {
function current_binary_dir (line 9) | std::filesystem::path current_binary_dir() {
function collapse_contiguous_dims (line 20) | std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
function collapse_contiguous_dims (line 83) | std::pair<Shape, Strides> collapse_contiguous_dims(
function collapse_contiguous_dims (line 111) | std::pair<Shape, Strides> collapse_contiguous_dims(
function Dims (line 117) | Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* =...
function Dims (line 148) | Dims get_2d_grid_dims_common(const Shape& shape, const Strides& stride...
function Dims (line 173) | Dims get_2d_grid_dims_common(
function get_grid_and_block_common (line 221) | std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, in...
FILE: mlx/backend/common/utils.h
function namespace (line 11) | namespace mlx::core {
function loc (line 189) | int64_t loc{0};
function is_donatable (line 218) | inline bool is_donatable(const array& in, const array& out) {
FILE: mlx/backend/cpu/arange.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/backend/cpu/arg_reduce.cpp
type mlx::core (line 9) | namespace mlx::core {
function arg_reduce (line 14) | void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
function arg_reduce_dispatch (line 35) | void arg_reduce_dispatch(
FILE: mlx/backend/cpu/binary.cpp
type mlx::core (line 15) | namespace mlx::core {
FILE: mlx/backend/cpu/binary.h
function namespace (line 13) | namespace mlx::core {
function binary_op_dispatch_dims (line 109) | void binary_op_dispatch_dims(
FILE: mlx/backend/cpu/binary_two.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/backend/cpu/cholesky.cpp
type mlx::core (line 10) | namespace mlx::core {
function cholesky_impl (line 13) | void cholesky_impl(const array& a, array& factor, bool upper, Stream s...
FILE: mlx/backend/cpu/compiled.cpp
type mlx::core (line 20) | namespace mlx::core {
type CompilerCache (line 22) | struct CompilerCache {
type DLib (line 23) | struct DLib {
method DLib (line 24) | DLib(const std::string& libname) {
function CompilerCache (line 44) | static CompilerCache& cache() {
type DLib (line 23) | struct DLib {
method DLib (line 24) | DLib(const std::string& libname) {
type detail (line 51) | namespace detail {
function compile_available_for_device (line 52) | bool compile_available_for_device(const Device& device) {
function build_kernel (line 150) | inline void build_kernel(
FILE: mlx/backend/cpu/conv.cpp
type mlx::core (line 12) | namespace mlx::core {
function slow_conv_1D (line 21) | void slow_conv_1D(
function slow_conv_2D (line 110) | void slow_conv_2D(
function slow_conv_3D (line 358) | void slow_conv_3D(
function dispatch_slow_conv_1D (line 673) | void dispatch_slow_conv_1D(
function dispatch_slow_conv_2D (line 726) | void dispatch_slow_conv_2D(
function dispatch_slow_conv_3D (line 779) | void dispatch_slow_conv_3D(
function flip_spatial_dims_inplace (line 837) | void flip_spatial_dims_inplace(
function explicit_gemm_conv_1D_cpu (line 856) | void explicit_gemm_conv_1D_cpu(
function explicit_gemm_conv_ND_cpu (line 999) | void explicit_gemm_conv_ND_cpu(
function conv_1D_cpu (line 1171) | void conv_1D_cpu(
function conv_2D_cpu (line 1213) | void conv_2D_cpu(
function conv_3D_cpu (line 1251) | void conv_3D_cpu(
FILE: mlx/backend/cpu/copy.cpp
type mlx::core (line 11) | namespace mlx::core {
function copy_single (line 16) | void copy_single(const array& src, array& dst) {
function copy_vector (line 25) | void copy_vector(const array& src, array& dst) {
function copy_dims (line 33) | inline void copy_dims(
function copy_general_general (line 57) | void copy_general_general(
function copy_general_general (line 128) | inline void copy_general_general(const array& src, array& dst) {
function copy_general (line 142) | void copy_general(
function copy_general (line 165) | inline void copy_general(const array& src, array& dst) {
function copy (line 179) | void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
function copy (line 197) | void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
function copy_inplace_dispatch (line 245) | inline void copy_inplace_dispatch(
function copy_cpu_inplace (line 298) | void copy_cpu_inplace(
function copy_cpu (line 312) | void copy_cpu(const array& src, array& dst, CopyType ctype, Stream str...
function copy_cpu_inplace (line 325) | void copy_cpu_inplace(
function array (line 380) | array contiguous_copy_cpu(const array& arr, Stream stream) {
FILE: mlx/backend/cpu/copy.h
function namespace (line 11) | namespace mlx::core {
FILE: mlx/backend/cpu/device_info.cpp
type mlx::core::cpu (line 15) | namespace mlx::core::cpu {
function get_cpu_architecture (line 20) | std::string get_cpu_architecture() {
function get_cpu_name (line 51) | std::string get_cpu_name() {
function is_available (line 96) | bool is_available() {
function device_count (line 100) | int device_count() {
FILE: mlx/backend/cpu/device_info.h
function namespace (line 9) | namespace mlx::core::cpu {
FILE: mlx/backend/cpu/distributed.cpp
type mlx::core::distributed (line 10) | namespace mlx::core::distributed {
function ensure_row_contiguous (line 12) | std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream ...
FILE: mlx/backend/cpu/eig.cpp
type mlx::core (line 11) | namespace mlx::core {
function complex64_t (line 16) | complex64_t to_complex(T r, T i) {
type EigWork (line 21) | struct EigWork {}
type EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> (line 24) | struct EigWork<
method EigWork (line 36) | EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
method run (line 65) | void run(T* a, O* values, O* vectors) {
type EigWork<std::complex<float>> (line 116) | struct EigWork<std::complex<float>> {
method EigWork (line 129) | EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
method run (line 155) | void run(T* a, T* values, T* vectors) {
function eig_impl (line 177) | void eig_impl(
FILE: mlx/backend/cpu/eigh.cpp
type mlx::core (line 11) | namespace mlx::core {
type EighWork (line 16) | struct EighWork {}
type EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> (line 19) | struct EighWork<
method EighWork (line 32) | EighWork(char jobz_, char uplo_, int N_)
method run (line 54) | void run(T* vectors, T* values) {
type EighWork<std::complex<float>> (line 71) | struct EighWork<std::complex<float>> {
method EighWork (line 84) | EighWork(char jobz_, char uplo_, int N_)
method run (line 111) | void run(T* vectors, R* values) {
function eigh_impl (line 143) | void eigh_impl(
FILE: mlx/backend/cpu/encoder.cpp
type mlx::core::cpu (line 5) | namespace mlx::core::cpu {
function CommandEncoder (line 7) | CommandEncoder& get_command_encoder(Stream stream) {
FILE: mlx/backend/cpu/encoder.h
function namespace (line 10) | namespace mlx::core::cpu {
function num_ops_ (line 62) | int num_ops_{0};
FILE: mlx/backend/cpu/eval.cpp
type mlx::core::cpu (line 8) | namespace mlx::core::cpu {
function eval (line 10) | void eval(array& arr) {
FILE: mlx/backend/cpu/eval.h
function namespace (line 8) | namespace mlx::core::cpu {
FILE: mlx/backend/cpu/fft.cpp
type mlx::core (line 10) | namespace mlx::core {
FILE: mlx/backend/cpu/gemm.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/cpu/gemms/bnns.cpp
type mlx::core (line 9) | namespace mlx::core {
function BNNSDataType (line 15) | constexpr BNNSDataType to_bnns_dtype<float>() {
function BNNSDataType (line 19) | constexpr BNNSDataType to_bnns_dtype<float16_t>() {
function BNNSDataType (line 24) | constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
function matmul_bnns (line 29) | void matmul_bnns(
FILE: mlx/backend/cpu/gemms/cblas.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/cpu/gemms/simd_bf16.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/cpu/gemms/simd_fp16.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/cpu/gemms/simd_gemm.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/cpu/hadamard.cpp
type mlx::core (line 10) | namespace mlx::core {
function hadamard_n (line 14) | void hadamard_n(T* out, int n, int m, float scale, size_t size) {
function hadamard_m (line 40) | void hadamard_m(T* out, int n, int m, float scale, size_t size) {
function hadamard (line 78) | void hadamard(array& out, int n, int m, float scale, Stream stream) {
FILE: mlx/backend/cpu/indexing.cpp
type mlx::core (line 16) | namespace mlx::core {
function offset_neg_idx (line 19) | inline size_t offset_neg_idx(IdxT idx, size_t size) {
function offset_neg_idx (line 24) | inline size_t offset_neg_idx(uint32_t idx, size_t) {
type None (line 28) | struct None {
type Sum (line 34) | struct Sum {
type Prod (line 41) | struct Prod {
type Max (line 48) | struct Max {
type Min (line 55) | struct Min {
function gather (line 63) | void gather(
function dispatch_gather (line 150) | void dispatch_gather(
function gather_axis (line 258) | void gather_axis(
function dispatch_gather_axis (line 304) | void dispatch_gather_axis(
function scatter (line 402) | void scatter(
function dispatch_scatter_inds (line 446) | void dispatch_scatter_inds(
function dispatch_scatter (line 472) | void dispatch_scatter(
function scatter_axis (line 586) | void scatter_axis(array& out, const array idx, const array& upd, int a...
function dispatch_scatter_axis_op (line 628) | void dispatch_scatter_axis_op(
function dispatch_scatter_axis (line 645) | void dispatch_scatter_axis(
function masked_scatter_impl (line 754) | void masked_scatter_impl(const array& mask, const array& src, array& o...
function slice_update_impl (line 858) | void slice_update_impl(
FILE: mlx/backend/cpu/inverse.cpp
type mlx::core (line 9) | namespace mlx::core {
function general_inv (line 12) | void general_inv(T* inv, int N) {
function tri_inv (line 72) | void tri_inv(T* inv, int N, bool upper) {
function inverse_impl (line 106) | void inverse_impl(
FILE: mlx/backend/cpu/jit_compiler.cpp
type mlx::core (line 11) | namespace mlx::core {
function str_split (line 18) | std::vector<std::string> str_split(const std::string& str, char delimi...
type VisualStudioInfo (line 29) | struct VisualStudioInfo {
method VisualStudioInfo (line 30) | VisualStudioInfo() {
function VisualStudioInfo (line 80) | const VisualStudioInfo& GetVisualStudioInfo() {
method VisualStudioInfo (line 30) | VisualStudioInfo() {
FILE: mlx/backend/cpu/jit_compiler.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/cpu/logsumexp.cpp
type mlx::core (line 12) | namespace mlx::core {
function logsumexp (line 19) | void logsumexp(const array& in, array& out, Stream stream) {
FILE: mlx/backend/cpu/luf.cpp
type mlx::core (line 11) | namespace mlx::core {
function luf_impl (line 14) | void luf_impl(
FILE: mlx/backend/cpu/masked_mm.cpp
type mlx::core (line 13) | namespace mlx::core {
function mask_matrix (line 18) | inline void mask_matrix(
function segmented_mm (line 57) | inline void segmented_mm(
FILE: mlx/backend/cpu/matmul.cpp
type mlx::core (line 12) | namespace mlx::core {
function matmul_dispatch (line 15) | void matmul_dispatch(
function matmul_general (line 69) | void matmul_general(
FILE: mlx/backend/cpu/primitives.cpp
type mlx::core (line 19) | namespace mlx::core {
function reshape (line 21) | void reshape(const array& in, array& out) {
function compute_dynamic_offset (line 31) | static std::pair<array, bool> compute_dynamic_offset(
FILE: mlx/backend/cpu/qrf.cpp
type mlx::core (line 9) | namespace mlx::core {
function qrf_impl (line 12) | void qrf_impl(const array& a, array& q, array& r, Stream stream) {
FILE: mlx/backend/cpu/quantized.cpp
type mlx::core (line 14) | namespace mlx::core {
function array (line 18) | array ensure_row_contiguous(
function T (line 50) | static inline T dequantize_scale(uint8_t s) {
function extract_bits (line 65) | void extract_bits(const uint8_t* w_in, T* w_out) {
function _qmm (line 97) | void _qmm(
function _qmm_t (line 155) | void _qmm_t(
function extract_bits_simd (line 215) | simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
function _qmm_t_simd (line 240) | void _qmm_t_simd(
function _qmm_dispatch_transpose (line 287) | void _qmm_dispatch_transpose(
function _qmm_dispatch_group (line 310) | void _qmm_dispatch_group(
function _qmm_dispatch_typed (line 341) | void _qmm_dispatch_typed(
function _qmm_dispatch_typed (line 384) | void _qmm_dispatch_typed(
function _qmm_dispatch (line 421) | void _qmm_dispatch(
function fp_qmm (line 450) | void fp_qmm(
function fp_qmm_t (line 492) | void fp_qmm_t(
function fp_extract_bits_simd (line 536) | simd::Simd<float, S> fp_extract_bits_simd(const uint32_t* w) {
function fp_qmm_t_simd (line 558) | void fp_qmm_t_simd(
function fp_qmm_dispatch_transpose (line 603) | void fp_qmm_dispatch_transpose(
function fp_qmm_dispatch_mode (line 625) | void fp_qmm_dispatch_mode(
function fp_qmm_dispatch_typed (line 656) | void fp_qmm_dispatch_typed(
function fp_qmm_dispatch (line 673) | void fp_qmm_dispatch(
function _bs_qmm_dispatch_typed (line 701) | void _bs_qmm_dispatch_typed(
function _bs_qmm_dispatch (line 749) | void _bs_qmm_dispatch(
function fp_bs_qmm_dispatch_mode (line 806) | void fp_bs_qmm_dispatch_mode(
function fp_bs_qmm_dispatch_typed (line 847) | void fp_bs_qmm_dispatch_typed(
function fp_bs_qmm_dispatch (line 869) | void fp_bs_qmm_dispatch(
function to_fp8_e8m0 (line 1049) | uint8_t to_fp8_e8m0(float x) {
function to_fp4_e2m1 (line 1064) | uint8_t to_fp4_e2m1(float x) {
function fp_quantize_dequantize (line 1094) | void fp_quantize_dequantize(
function dispatch_quantize_dequantize (line 1131) | void dispatch_quantize_dequantize(
function quantize (line 1149) | void quantize(
function dispatch_quantize (line 1214) | void dispatch_quantize(
FILE: mlx/backend/cpu/reduce.cpp
type mlx::core (line 12) | namespace mlx::core {
type Limits (line 15) | struct Limits {
type Limits<bool> (line 50) | struct Limits<bool> {
function strided_reduce (line 72) | void strided_reduce(
function contiguous_reduce (line 99) | void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U ...
function nd_loop (line 115) | void nd_loop(
function reduction_op (line 139) | void reduction_op(
type AndReduce (line 264) | struct AndReduce {
type OrReduce (line 290) | struct OrReduce {
type MaxReduce (line 316) | struct MaxReduce {
method T (line 318) | T operator()(T y, T x) {
type MinReduce (line 341) | struct MinReduce {
method T (line 343) | T operator()(T y, T x) {
type SumReduce (line 366) | struct SumReduce {
method U (line 368) | U operator()(U y, T x) {
method T (line 378) | T operator()(simd::Simd<T, N> x) {
type ProdReduce (line 383) | struct ProdReduce {
method U (line 385) | U operator()(U y, T x) {
method T (line 395) | T operator()(simd::Simd<T, N> x) {
function reduce_dispatch_and_or (line 401) | void reduce_dispatch_and_or(
function reduce_dispatch_sum_prod (line 414) | void reduce_dispatch_sum_prod(
function reduce_dispatch_min_max (line 435) | void reduce_dispatch_min_max(
FILE: mlx/backend/cpu/scan.cpp
type mlx::core (line 12) | namespace mlx::core {
function contiguous_scan (line 17) | void contiguous_scan(
function strided_scan (line 82) | void strided_scan(
function scan_op (line 157) | void scan_op(
function scan_dispatch (line 194) | void scan_dispatch(
FILE: mlx/backend/cpu/select.cpp
type mlx::core (line 9) | namespace mlx::core {
function select_op (line 14) | void select_op(
FILE: mlx/backend/cpu/simd/accelerate_fp16_simd.h
function namespace (line 9) | namespace mlx::core::simd {
FILE: mlx/backend/cpu/simd/accelerate_simd.h
function value (line 66) | value(v){}
function T (line 73) | T operator[](int idx) const {
FILE: mlx/backend/cpu/simd/base_simd.h
function namespace (line 16) | namespace mlx::core::simd {
function DEFAULT_UNARY (line 91) | DEFAULT_UNARY(operator!, std::logical_not{})
FILE: mlx/backend/cpu/simd/math.h
function namespace (line 7) | namespace mlx::core::simd {
function lhs (line 155) | auto lhs = [](auto t) {
function rhs (line 167) | auto rhs = [](auto t) {
FILE: mlx/backend/cpu/simd/neon_fp16_simd.h
function namespace (line 7) | namespace mlx::core::simd {
function Simd (line 160) | inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
function float16_t (line 182) | inline float16_t max(Simd<float16_t, N> x) {
function float16_t (line 189) | inline float16_t min(Simd<float16_t, N> x) {
function float16_t (line 196) | inline float16_t sum(Simd<float16_t, N> x) {
function float16_t (line 203) | inline float16_t prod(Simd<float16_t, N> x) {
FILE: mlx/backend/cpu/slicing.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/cpu/softmax.cpp
type mlx::core (line 12) | namespace mlx::core {
function softmax (line 19) | void softmax(const array& in, array& out, Stream stream) {
FILE: mlx/backend/cpu/sort.cpp
type mlx::core (line 14) | namespace mlx::core {
function nan_aware_less (line 24) | bool nan_aware_less(T a, T b) {
type StridedIterator (line 35) | struct StridedIterator {
method StridedIterator (line 43) | StridedIterator() = default;
method StridedIterator (line 45) | explicit StridedIterator(T* ptr, int64_t stride, difference_type off...
method StridedIterator (line 48) | explicit StridedIterator(array& arr, int axis, difference_type offse...
method reference (line 52) | reference operator*() const {
method reference (line 56) | reference operator[](difference_type idx) const {
method difference_type (line 85) | difference_type operator-(const StridedIterator& other) const {
method StridedIterator (line 90) | StridedIterator& operator++() {
method StridedIterator (line 95) | StridedIterator& operator--() {
method StridedIterator (line 100) | StridedIterator& operator+=(difference_type diff) {
method StridedIterator (line 105) | StridedIterator& operator-=(difference_type diff) {
method StridedIterator (line 110) | StridedIterator operator+(difference_type diff) {
method StridedIterator (line 114) | StridedIterator operator-(difference_type diff) {
function sort (line 124) | void sort(array& out, int axis) {
function argsort (line 155) | void argsort(const array& in, array& out, int axis) {
function partition (line 218) | void partition(array& out, int axis, int kth) {
function argpartition (line 252) | void argpartition(const array& in, array& out, int axis, int kth) {
FILE: mlx/backend/cpu/svd.cpp
type mlx::core (line 9) | namespace mlx::core {
type SVDWork (line 12) | struct SVDWork {}
type SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> (line 15) | struct SVDWork<
method SVDWork (line 30) | SVDWork(int N, int M, int K, char jobz)
method run (line 69) | void run(T* a, R* s, T* u, T* vt) {
type SVDWork<std::complex<float>> (line 99) | struct SVDWork<std::complex<float>> {
method SVDWork (line 113) | SVDWork(int N, int M, int K, char jobz)
method run (line 158) | void run(T* a, R* s, T* u, T* vt) {
function svd_impl (line 189) | void svd_impl(
FILE: mlx/backend/cpu/ternary.h
function namespace (line 9) | namespace mlx::core {
function ContiguousIterator (line 100) | ContiguousIterator a_it(shape, a_strides, ndim - 2);
function else (line 138) | else if (topt == TernaryOpType::VectorVectorVector) {
FILE: mlx/backend/cpu/threefry.cpp
type mlx::core::random (line 5) | namespace mlx::core::random {
function threefry2x32_hash (line 7) | std::pair<uint32_t, uint32_t> threefry2x32_hash(
FILE: mlx/backend/cpu/threefry.h
function namespace (line 8) | namespace mlx::core::random {
FILE: mlx/backend/cpu/unary.cpp
type mlx::core (line 12) | namespace mlx::core {
FILE: mlx/backend/cpu/unary.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/cpu/unary_ops.h
function namespace (line 11) | namespace mlx::core::detail {
type Square (line 103) | struct Square {
type ToFP8 (line 120) | struct ToFP8 {
function else (line 154) | struct FromFP8 {
FILE: mlx/backend/cuda/allocator.cpp
type mlx::core (line 18) | namespace mlx::core {
type cu (line 20) | namespace cu {
function is_windows (line 32) | bool is_windows() {
function supports_managed_memory (line 54) | bool supports_managed_memory() {
function unified_free (line 83) | inline void unified_free(void* data) {
function cudaMemLocation (line 92) | inline cudaMemLocation cuda_mem_loc(int i) {
function cuda_mem_loc (line 99) | inline int cuda_mem_loc(int i) {
function CudaBuffer (line 134) | CudaBuffer* SmallSizePool::malloc() {
function Buffer (line 183) | Buffer
function Buffer (line 270) | Buffer CudaAllocator::malloc(size_t size) {
function CudaAllocator (line 385) | CudaAllocator& allocator() {
function Buffer (line 397) | Buffer malloc_async(size_t size, CommandEncoder& encoder) {
type allocator (line 404) | namespace allocator {
function Allocator (line 406) | Allocator& allocator() {
function get_active_memory (line 421) | size_t get_active_memory() {
function get_peak_memory (line 424) | size_t get_peak_memory() {
function reset_peak_memory (line 427) | void reset_peak_memory() {
function set_memory_limit (line 430) | size_t set_memory_limit(size_t limit) {
function get_memory_limit (line 433) | size_t get_memory_limit() {
function get_cache_memory (line 436) | size_t get_cache_memory() {
function set_cache_limit (line 439) | size_t set_cache_limit(size_t limit) {
function clear_cache (line 442) | void clear_cache() {
function set_wired_limit (line 447) | size_t set_wired_limit(size_t) {
FILE: mlx/backend/cuda/allocator.h
type CudaBuffer (line 21) | struct CudaBuffer {
function Block (line 36) | Block* next_free_{nullptr};
function class (line 50) | class CudaAllocator : public allocator::Allocator {
FILE: mlx/backend/cuda/compiled.cpp
type mlx::core (line 13) | namespace mlx::core {
type cu (line 15) | namespace cu {
type FusedKernelBuilder (line 17) | struct FusedKernelBuilder {
method build (line 25) | void build(const char* name, bool contiguous) {
FILE: mlx/backend/cuda/conv.cpp
type mlx::core (line 14) | namespace mlx::core {
type ConvBackendType (line 18) | enum ConvBackendType {
type ConvCacheKey (line 25) | struct ConvCacheKey {
function get_conv_settings (line 49) | auto get_conv_settings(
function build_conv_graph (line 93) | std::optional<DnnGraph> build_conv_graph(
function array (line 147) | array group_transpose(
function prepare_args (line 185) | std::tuple<array, array, array> prepare_args(
function register_args (line 229) | void register_args(
FILE: mlx/backend/cuda/conv/conv.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/backend/cuda/cublas_utils.cpp
type mlx::core (line 7) | namespace mlx::core {
type cublas_utils (line 8) | namespace cublas_utils {
type CublasPreference (line 12) | struct CublasPreference {
method CublasPreference (line 13) | CublasPreference(cu::Device& device) {
function cublasLtMatmulPreference_t (line 38) | cublasLtMatmulPreference_t get_preference(cu::Device& device) {
function cublasLtMatrixLayout_t (line 43) | cublasLtMatrixLayout_t create_matrix_layout(
FILE: mlx/backend/cuda/cublas_utils.h
function namespace (line 10) | namespace cublas_utils {
function cublasLtMatrixLayout_t (line 63) | cublasLtMatrixLayout_t c_desc_{nullptr};
FILE: mlx/backend/cuda/cuda.h
function namespace (line 11) | namespace mlx::core::cu {
FILE: mlx/backend/cuda/cuda_utils.h
function namespace (line 10) | namespace mlx::core {
function namespace (line 67) | namespace cu {
FILE: mlx/backend/cuda/cudnn_utils.cpp
type mlx::core (line 6) | namespace mlx::core {
function normalized_strides (line 20) | std::vector<int64_t> normalized_strides(const array& x) {
function nhwc_to_nchw (line 39) | inline auto nhwc_to_nchw(const array& x) {
FILE: mlx/backend/cuda/cudnn_utils.h
function namespace (line 16) | namespace mlx::core {
function cached_is_updatable_ (line 191) | bool cached_is_updatable_{true};
FILE: mlx/backend/cuda/custom_kernel.cpp
type mlx::core::fast (line 15) | namespace mlx::core::fast {
function template_arguments_hash (line 28) | std::string template_arguments_hash(
function build_kernel (line 51) | std::string build_kernel(
function CustomKernelFunction (line 144) | CustomKernelFunction cuda_kernel(
function precompiled_cuda_kernel (line 244) | std::vector<array> precompiled_cuda_kernel(
FILE: mlx/backend/cuda/delayload.cpp
type mlx::core (line 10) | namespace mlx::core {
function relative_to_current_binary (line 14) | inline fs::path relative_to_current_binary(const char* relative) {
function cublas_bin_dir (line 18) | inline fs::path cublas_bin_dir() {
function load_nvrtc (line 26) | fs::path load_nvrtc() {
function load_cudnn (line 38) | fs::path load_cudnn() {
function FARPROC (line 61) | FARPROC WINAPI delayload_helper(unsigned dliNotify, PDelayLoadInfo pdl...
FILE: mlx/backend/cuda/device.cpp
type mlx::core::cu (line 14) | namespace mlx::core::cu {
function use_cuda_graphs (line 18) | bool use_cuda_graphs() {
function is_empty_dim (line 34) | inline bool is_empty_dim(dim3 dim) {
function CommandEncoder (line 79) | CommandEncoder& Device::get_command_encoder(Stream s) {
function cublasLtHandle_t (line 87) | cublasLtHandle_t Device::get_cublaslt_handle() {
function cudnnHandle_t (line 95) | cudnnHandle_t Device::get_cudnn_handle() {
function get_graph_limits (line 213) | std::pair<int, int> get_graph_limits(Device& d) {
function cudaGraphNode_t (line 371) | cudaGraphNode_t CommandEncoder::add_kernel_node_raw(
function CUgraphNode (line 379) | CUgraphNode CommandEncoder::add_kernel_node_raw(
function subgraph_to_key (line 387) | std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
function Device (line 573) | Device& device(int cuda_device) {
function Device (line 588) | Device& device(mlx::core::Device d) {
function CommandEncoder (line 592) | CommandEncoder& get_command_encoder(Stream s) {
FILE: mlx/backend/cuda/device.h
function class (line 22) | class CommandEncoder {
function add_temporary (line 103) | void add_temporary(const array& arr) {
type GraphNode (line 126) | struct GraphNode {
function node_count_ (line 143) | int node_count_{0}
function bytes_in_graph_ (line 155) | size_t bytes_in_graph_{0}
function is_graph_updatable_ (line 156) | bool is_graph_updatable_{true};
function cublasLtHandle_t (line 208) | cublasLtHandle_t cublaslt_handle_{nullptr};
FILE: mlx/backend/cuda/device_info.cpp
type mlx::core (line 14) | namespace mlx::core {
type nvmlDevice_st (line 22) | struct nvmlDevice_st
type nvmlMemory_t (line 23) | struct nvmlMemory_t {
type NVMLState (line 29) | struct NVMLState {
function nvml_init (line 38) | bool nvml_init(NVMLState& nvml) {
function nvml_get_memory (line 66) | bool nvml_get_memory(
function format_uuid (line 84) | std::string format_uuid(const cudaUUID_t& uuid) {
type DeviceInfo (line 118) | struct DeviceInfo {
type gpu (line 205) | namespace gpu {
function is_available (line 207) | bool is_available() {
function device_count (line 211) | int device_count() {
type cu (line 224) | namespace cu {
function is_available (line 226) | bool is_available() {
FILE: mlx/backend/cuda/eval.cpp
type mlx::core::gpu (line 11) | namespace mlx::core::gpu {
function new_stream (line 13) | void new_stream(Stream s) {
function eval (line 22) | void eval(array& arr) {
function finalize (line 60) | void finalize(Stream s) {
function synchronize (line 65) | void synchronize(Stream s) {
FILE: mlx/backend/cuda/event.h
function namespace (line 14) | namespace mlx::core::cu {
function class (line 27) | class CudaEvent {
FILE: mlx/backend/cuda/fence.cpp
type mlx::core (line 8) | namespace mlx::core {
type FenceImpl (line 10) | struct FenceImpl {
FILE: mlx/backend/cuda/gemms/cublas_gemm.cpp
type mlx::core (line 11) | namespace mlx::core {
function cublasComputeType_t (line 15) | cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
FILE: mlx/backend/cuda/gemms/cublas_gemm.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/cuda/gemms/gemv.h
function namespace (line 7) | namespace mlx::core::cu {
FILE: mlx/backend/cuda/gemms/grouped_gemm.h
function namespace (line 5) | namespace mlx::core {
FILE: mlx/backend/cuda/indexing.cpp
type mlx::core (line 24) | namespace mlx::core {
function append_indices_arg (line 32) | void append_indices_arg(
FILE: mlx/backend/cuda/jit_module.cpp
type mlx::core::cu (line 16) | namespace mlx::core::cu {
function check_nvrtc_error (line 22) | void check_nvrtc_error(const char* name, nvrtcResult err) {
function get_ptx_path (line 123) | std::filesystem::path get_ptx_path(
function read_cached_ptx (line 143) | bool read_cached_ptx(
function write_cached_ptx (line 177) | void write_cached_ptx(
function version_lower_equal (line 211) | inline bool version_lower_equal(Device& device, int major, int minor) {
function compiler_supports_device_sass (line 222) | bool compiler_supports_device_sass(Device& device) {
function compile (line 274) | void compile(
function load_module (line 345) | void load_module(
function CUfunction (line 436) | CUfunction JitModule::get_kernel(
function JitModule (line 447) | JitModule& get_jit_module(
FILE: mlx/backend/cuda/jit_module.h
function namespace (line 18) | namespace mlx::core::cu {
function append_ptr (line 65) | void append_ptr(const void* v) {
function class (line 88) | class JitModule {
FILE: mlx/backend/cuda/load.cpp
function swap_endianness (line 13) | void swap_endianness(uint8_t* data_bytes, size_t N) {
type mlx::core (line 29) | namespace mlx::core {
FILE: mlx/backend/cuda/lru_cache.h
function namespace (line 14) | namespace mlx::core {
FILE: mlx/backend/cuda/matmul.cpp
type mlx::core (line 14) | namespace mlx::core {
function check_transpose (line 18) | std::tuple<bool, int64_t, array>
function ensure_batch_contiguous (line 33) | std::tuple<bool, int64_t, array>
function array (line 52) | array ensure_row_contiguous(
function gemm_and_bias (line 65) | void gemm_and_bias(
function gather_mm_rhs (line 139) | void gather_mm_rhs(
FILE: mlx/backend/cuda/no_cuda.cpp
type mlx::core (line 6) | namespace mlx::core {
type cu (line 8) | namespace cu {
function is_available (line 10) | bool is_available() {
type fast (line 16) | namespace fast {
function CustomKernelFunction (line 18) | CustomKernelFunction cuda_kernel(
function precompiled_cuda_kernel (line 29) | std::vector<array> precompiled_cuda_kernel(
FILE: mlx/backend/cuda/primitives.cpp
type mlx::core (line 8) | namespace mlx::core {
type distributed (line 37) | namespace distributed {
FILE: mlx/backend/cuda/quantized/cublas_qqmm.cpp
type mlx::core (line 12) | namespace mlx::core {
type QuantModeConfig (line 16) | struct QuantModeConfig {
function QuantModeConfig (line 22) | QuantModeConfig get_quant_mode_config(const std::string& mode) {
FILE: mlx/backend/cuda/quantized/cublas_qqmm.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/cuda/quantized/no_qqmm_impl.cpp
type mlx::core (line 5) | namespace mlx::core {
function qqmm_impl (line 6) | void qqmm_impl(
FILE: mlx/backend/cuda/quantized/qmm/qmm.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/cuda/quantized/qqmm.cpp
type mlx::core (line 13) | namespace mlx::core {
function quantize_input (line 17) | std::tuple<array, array> quantize_input(
function GemmScalars (line 55) | GemmScalars create_nvfp4_scalars(
FILE: mlx/backend/cuda/quantized/qqmm_impl.cpp
type mlx::core (line 6) | namespace mlx::core {
function qqmm_impl (line 8) | void qqmm_impl(
FILE: mlx/backend/cuda/quantized/qqmm_impl.h
function namespace (line 9) | namespace mlx::core {
FILE: mlx/backend/cuda/quantized/qqmm_utils.h
function namespace (line 8) | namespace mlx::core {
function array (line 30) | inline array pad_and_swizzle_scales(
FILE: mlx/backend/cuda/quantized/quantized.cpp
type mlx::core (line 13) | namespace mlx::core {
FILE: mlx/backend/cuda/quantized/quantized.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/cuda/quantized/quantized_utils.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/cuda/scaled_dot_product_attention.cpp
type mlx::core (line 11) | namespace mlx::core {
function array (line 15) | array prepare_sdpa_input(const array& x, Stream s) {
function array (line 28) | array prepare_sdpa_sinks(const array& sinks, Stream s) {
function malloc_with_same_layout (line 40) | void malloc_with_same_layout(
function use_cudnn_for_decoding (line 72) | bool use_cudnn_for_decoding(
function array (line 111) | array unslice_kv(const array& kv) {
type SDPACacheKey (line 126) | struct SDPACacheKey {
function build_sdpa_cache_key (line 142) | inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
type UIDS (line 190) | enum UIDS {
function DnnGraph (line 208) | DnnGraph build_sdpa_graph(
function DnnGraph (line 259) | DnnGraph build_sdpa_backward_graph(
function supports_sdpa_cudnn (line 310) | bool supports_sdpa_cudnn(
function sdpa_cudnn (line 347) | void sdpa_cudnn(
function sdpa_backward_cudnn (line 446) | void sdpa_backward_cudnn(
type fast (line 545) | namespace fast {
FILE: mlx/backend/cuda/slicing.cpp
type mlx::core (line 12) | namespace mlx::core {
function concatenate_gpu (line 14) | void concatenate_gpu(
function array (line 44) | array compute_dynamic_offset(
FILE: mlx/backend/cuda/utils.cpp
type mlx::core (line 11) | namespace mlx::core {
function check_cublas_error (line 13) | void check_cublas_error(const char* name, cublasStatus_t err) {
function check_cuda_error (line 21) | void check_cuda_error(const char* name, cudaError_t err) {
function check_cuda_error (line 28) | void check_cuda_error(const char* name, CUresult err) {
function check_cudnn_error (line 36) | void check_cudnn_error(const char* name, cudnnStatus_t err) {
FILE: mlx/backend/cuda/utils.h
function namespace (line 11) | namespace mlx::core {
type Dtype (line 41) | struct Dtype
FILE: mlx/backend/cuda/worker.cpp
type mlx::core::cu (line 6) | namespace mlx::core::cu {
FILE: mlx/backend/cuda/worker.h
function namespace (line 13) | namespace mlx::core::cu {
FILE: mlx/backend/gpu/copy.cpp
type mlx::core (line 9) | namespace mlx::core {
function copy_gpu (line 11) | void copy_gpu(const array& in, array& out, CopyType ctype) {
function copy_gpu_inplace (line 15) | void copy_gpu_inplace(
function copy_gpu_inplace (line 25) | void copy_gpu_inplace(
function array (line 37) | array contiguous_copy_gpu(const array& arr, const Stream& s) {
function array (line 43) | array flatten_in_eval(const array& x, int start_axis, int end_axis, St...
function array (line 57) | array reshape_in_eval(const array& x, Shape shape, Stream s) {
function array (line 63) | array transpose_in_eval(const array& x, const std::vector<int>& axes) {
function array (line 84) | array swapaxes_in_eval(const array& x, int axis1, int axis2) {
FILE: mlx/backend/gpu/copy.h
function namespace (line 11) | namespace mlx::core {
FILE: mlx/backend/gpu/device_info.h
function namespace (line 11) | namespace mlx::core::gpu {
FILE: mlx/backend/gpu/eval.h
function namespace (line 11) | namespace mlx::core::gpu {
FILE: mlx/backend/gpu/primitives.cpp
type mlx::core (line 21) | namespace mlx::core {
FILE: mlx/backend/gpu/scan.h
function namespace (line 6) | namespace mlx::core {
FILE: mlx/backend/gpu/slicing.cpp
type mlx::core (line 7) | namespace mlx::core {
function slice_gpu (line 9) | void slice_gpu(
function pad_gpu (line 18) | void pad_gpu(
FILE: mlx/backend/gpu/slicing.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/allocator.cpp
type mlx::core (line 12) | namespace mlx::core {
type allocator (line 17) | namespace allocator {
function Allocator (line 19) | Allocator& allocator() {
type metal (line 32) | namespace metal {
function Buffer (line 104) | Buffer MetalAllocator::malloc(size_t size) {
function Buffer (line 208) | Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
function MetalAllocator (line 235) | MetalAllocator& allocator() {
function set_cache_limit (line 245) | size_t set_cache_limit(size_t limit) {
function set_memory_limit (line 248) | size_t set_memory_limit(size_t limit) {
function get_memory_limit (line 251) | size_t get_memory_limit() {
function set_wired_limit (line 254) | size_t set_wired_limit(size_t limit) {
function get_active_memory (line 263) | size_t get_active_memory() {
function get_peak_memory (line 266) | size_t get_peak_memory() {
function reset_peak_memory (line 269) | void reset_peak_memory() {
function get_cache_memory (line 272) | size_t get_cache_memory() {
function clear_cache (line 275) | void clear_cache() {
FILE: mlx/backend/metal/allocator.h
function namespace (line 14) | namespace mlx::core::metal {
FILE: mlx/backend/metal/binary.cpp
type mlx::core (line 19) | namespace mlx::core {
function get_kernel_name (line 21) | std::string get_kernel_name(
function binary_op_gpu_inplace (line 65) | void binary_op_gpu_inplace(
function binary_op_gpu (line 165) | void binary_op_gpu(
function binary_op_gpu (line 179) | void binary_op_gpu(
function binary_op_gpu_inplace (line 187) | void binary_op_gpu_inplace(
function binary_op_gpu (line 196) | void binary_op_gpu(
function binary_op_gpu (line 209) | void binary_op_gpu(
FILE: mlx/backend/metal/binary.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/compiled.cpp
type mlx::core (line 14) | namespace mlx::core {
function build_kernel (line 16) | inline void build_kernel(
FILE: mlx/backend/metal/conv.cpp
type mlx::core (line 19) | namespace mlx::core {
function array (line 23) | inline array
function explicit_gemm_conv_ND_gpu (line 34) | void explicit_gemm_conv_ND_gpu(
function explicit_gemm_conv_group_ND_gpu (line 105) | void explicit_gemm_conv_group_ND_gpu(
function implicit_gemm_conv_2D_gpu (line 191) | void implicit_gemm_conv_2D_gpu(
function implicit_gemm_conv_2D_general_gpu (line 324) | void implicit_gemm_conv_2D_general_gpu(
function implicit_gemm_conv_3D_gpu (line 503) | void implicit_gemm_conv_3D_gpu(
function pad_and_slice_conv_3D_gpu (line 624) | void pad_and_slice_conv_3D_gpu(
function dispatch_conv_3D_gpu (line 671) | void dispatch_conv_3D_gpu(
function winograd_conv_2D_gpu (line 714) | void winograd_conv_2D_gpu(
function depthwise_conv_2D_gpu (line 908) | void depthwise_conv_2D_gpu(
function dispatch_conv_2D_gpu (line 970) | void dispatch_conv_2D_gpu(
function depthwise_conv_1D_gpu (line 1032) | void depthwise_conv_1D_gpu(
function conv_1D_gpu (line 1078) | void conv_1D_gpu(
function conv_2D_gpu (line 1165) | void conv_2D_gpu(
function conv_3D_gpu (line 1210) | void conv_3D_gpu(
FILE: mlx/backend/metal/copy.cpp
type mlx::core (line 9) | namespace mlx::core {
function copy_gpu (line 13) | void copy_gpu(const array& in, array& out, CopyType ctype, const Strea...
function copy_gpu_inplace (line 26) | void copy_gpu_inplace(
function fill_gpu (line 182) | void fill_gpu(const array& val, array& out, const Stream& s) {
function reshape_gpu (line 216) | void reshape_gpu(const array& in, array& out, Stream s) {
FILE: mlx/backend/metal/custom_kernel.cpp
type mlx::core::fast (line 14) | namespace mlx::core::fast {
type CustomKernelCache (line 16) | struct CustomKernelCache {
function CustomKernelCache (line 20) | static CustomKernelCache& cache() {
function write_signature (line 25) | std::string write_signature(
function write_template (line 153) | std::string write_template(
function CustomKernelFunction (line 175) | CustomKernelFunction metal_kernel(
FILE: mlx/backend/metal/device.cpp
type std (line 16) | namespace std {
type hash<NS::SharedPtr<T>> (line 20) | struct hash<NS::SharedPtr<T>> {
type mlx::core::metal (line 28) | namespace mlx::core::metal {
function get_metal_version (line 34) | auto get_metal_version() {
function load_device (line 48) | auto load_device() {
function load_library_from_path (line 57) | std::pair<MTL::Library*, NS::Error*> load_library_from_path(
function load_colocated_library (line 103) | std::pair<MTL::Library*, NS::Error*> load_colocated_library(
function load_swiftpm_library (line 114) | std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
function CommandEncoder (line 523) | CommandEncoder& Device::get_command_encoder(int index) {
function Device (line 832) | Device& device(mlx::core::Device) {
function new_scoped_memory_pool (line 840) | std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_po...
FILE: mlx/backend/metal/device.h
function namespace (line 16) | namespace mlx::core::metal {
function set_threadgroup_memory_length (line 80) | void set_threadgroup_memory_length(size_t length, int idx) {
function ConcurrentContext (line 84) | ConcurrentContext start_concurrent() {
function needs_commit (line 90) | bool needs_commit() const;
function buffer_ops_ (line 108) | int buffer_ops_{0}
function buffer_sizes_ (line 109) | size_t buffer_sizes_{0}
function needs_barrier_ (line 116) | bool needs_barrier_{false};
FILE: mlx/backend/metal/device_info.cpp
type mlx::core::gpu (line 9) | namespace mlx::core::gpu {
function is_available (line 11) | bool is_available() {
function device_count (line 15) | int device_count() {
FILE: mlx/backend/metal/distributed.cpp
type mlx::core::distributed (line 15) | namespace mlx::core::distributed {
FILE: mlx/backend/metal/eval.cpp
type mlx::core::gpu (line 10) | namespace mlx::core::gpu {
function new_stream (line 12) | void new_stream(Stream stream) {
function check_error (line 18) | inline void check_error(MTL::CommandBuffer* cbuf) {
function eval (line 27) | void eval(array& arr) {
function finalize (line 74) | void finalize(Stream s) {
function synchronize (line 83) | void synchronize(Stream s) {
FILE: mlx/backend/metal/event.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/fence.cpp
type mlx::core (line 7) | namespace mlx::core {
type FenceImpl (line 9) | struct FenceImpl {
method FenceImpl (line 10) | FenceImpl() {
FILE: mlx/backend/metal/fft.cpp
type mlx::core (line 18) | namespace mlx::core {
function supported_radices (line 30) | inline const std::vector<int> supported_radices() {
function prime_factors (line 35) | std::vector<int> prime_factors(int n) {
type FourStepParams (line 52) | struct FourStepParams {
type FFTPlan (line 70) | struct FFTPlan {
function next_fast_n (line 85) | int next_fast_n(int n) {
function plan_stockham_fft (line 89) | std::vector<int> plan_stockham_fft(int n) {
function FFTPlan (line 113) | FFTPlan plan_fft(int n) {
function compute_elems_per_thread (line 174) | int compute_elems_per_thread(FFTPlan plan) {
function mod_exp (line 231) | int mod_exp(int x, int y, int n) {
function primitive_root (line 243) | int primitive_root(int n) {
function compute_raders_constants (line 261) | std::tuple<array, array, array> compute_raders_constants(
function compute_bluestein_constants (line 303) | std::pair<array, array> compute_bluestein_constants(int n, int blueste...
function multi_upload_bluestein_fft (line 349) | void multi_upload_bluestein_fft(
function four_step_fft (line 474) | void four_step_fft(
function fft_op (line 508) | void fft_op(
function fft_op (line 751) | void fft_op(
function nd_fft_op (line 762) | void nd_fft_op(
FILE: mlx/backend/metal/hadamard.cpp
type mlx::core (line 13) | namespace mlx::core {
function gen_hadamard_codelet (line 17) | std::string gen_hadamard_codelet(int m) {
function hadamard_mn_contiguous (line 60) | void hadamard_mn_contiguous(
FILE: mlx/backend/metal/indexing.cpp
type mlx::core (line 20) | namespace mlx::core {
function make_index_args (line 24) | std::pair<std::string, std::string> make_index_args(
function make_op (line 42) | inline std::string make_op(typename T::ReduceType r, const std::string...
FILE: mlx/backend/metal/jit/includes.h
function namespace (line 5) | namespace mlx::core::metal {
FILE: mlx/backend/metal/jit/indexing.h
function std (line 3) | constexpr std::string_view gather_kernels = R"(
function std (line 36) | constexpr std::string_view scatter_kernels = R"(
FILE: mlx/backend/metal/jit_kernels.cpp
type mlx::core (line 9) | namespace mlx::core {
function append_binary_kernels (line 54) | void append_binary_kernels(
FILE: mlx/backend/metal/kernels.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/backend/metal/kernels/arange.h
function arange (line 3) | [[kernel]] void arange(
FILE: mlx/backend/metal/kernels/atomic.h
function uint (line 162) | uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
function mlx_atomic_update_and_store (line 170) | void mlx_atomic_update_and_store(
function condition (line 193) | static bool condition(T a, T b) {
function T (line 199) | T operator()(T a, T b) {
function condition (line 207) | static bool condition(T a, T b) {
function T (line 213) | T operator()(T a, T b) {
function condition (line 220) | static bool condition(T a, T b) {
function T (line 225) | T operator()(T a, T b) {
function condition (line 232) | static bool condition(T a, T b) {
function T (line 236) | T operator()(T a, T b) {
function condition (line 243) | static bool condition(T a, T b) {
function T (line 247) | T operator()(T a, T b) {
FILE: mlx/backend/metal/kernels/bf16.h
type bfloat (line 9) | typedef bfloat bfloat16_t;
function bfloat16_to_uint16 (line 10) | inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
function bfloat16_t (line 14) | inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
FILE: mlx/backend/metal/kernels/bf16_math.h
function namespace (line 226) | namespace metal {
function namespace (line 370) | namespace metal {
FILE: mlx/backend/metal/kernels/binary.h
function binary_ss (line 4) | [[kernel]] void binary_ss(
function binary_sv (line 13) | void binary_sv(
function binary_vs (line 32) | void binary_vs(
function binary_vv (line 51) | void binary_vv(
function binary_sv2 (line 70) | void binary_sv2(
function binary_vs2 (line 90) | void binary_vs2(
function binary_vv2 (line 110) | void binary_vv2(
function binary_g_nd1 (line 130) | void binary_g_nd1(
function binary_g_nd2 (line 143) | void binary_g_nd2(
function binary_g_nd3 (line 158) | void binary_g_nd3(
FILE: mlx/backend/metal/kernels/binary_ops.h
type Add (line 10) | struct Add {
type FloorDivide (line 17) | struct FloorDivide {
type Divide (line 36) | struct Divide {
type Remainder (line 43) | struct Remainder {
type Equal (line 72) | struct Equal {
type NaNEqual (line 79) | struct NaNEqual {
type Greater (line 94) | struct Greater {
type GreaterEqual (line 101) | struct GreaterEqual {
type Less (line 108) | struct Less {
type LessEqual (line 115) | struct LessEqual {
type LogAddExp (line 122) | struct LogAddExp {
function complex64_t (line 136) | complex64_t operator()(complex64_t x, complex64_t y) {
type Maximum (line 155) | struct Maximum {
type Minimum (line 178) | struct Minimum {
type Multiply (line 201) | struct Multiply {
type NotEqual (line 208) | struct NotEqual {
type Power (line 219) | struct Power {
type Subtract (line 262) | struct Subtract {
type LogicalAnd (line 269) | struct LogicalAnd {
type LogicalOr (line 276) | struct LogicalOr {
type BitwiseAnd (line 283) | struct BitwiseAnd {
type BitwiseOr (line 290) | struct BitwiseOr {
type BitwiseXor (line 297) | struct BitwiseXor {
type LeftShift (line 304) | struct LeftShift {
type RightShift (line 311) | struct RightShift {
type ArcTan2 (line 318) | struct ArcTan2 {
type DivMod (line 325) | struct DivMod {
FILE: mlx/backend/metal/kernels/binary_two.h
function binary_ss (line 4) | [[kernel]] void binary_ss(
function binary_sv (line 16) | void binary_sv(
function binary_vs (line 40) | void binary_vs(
function binary_vv (line 64) | void binary_vv(
function binary_sv2 (line 88) | void binary_sv2(
function binary_vs2 (line 113) | void binary_vs2(
function binary_vv2 (line 138) | void binary_vv2(
function binary_g_nd1 (line 163) | void binary_g_nd1(
function binary_g_nd2 (line 179) | void binary_g_nd2(
function binary_g_nd3 (line 197) | void binary_g_nd3(
FILE: mlx/backend/metal/kernels/cexpf.h
function get_float_word (line 32) | inline void get_float_word(thread uint32_t& i, float d) {
function get_float_word (line 38) | inline void get_float_word(thread int32_t& i, float d) {
function set_float_word (line 44) | inline void set_float_word(thread float& d, uint32_t i) {
function frexp_expf (line 50) | inline float frexp_expf(float x, thread int* expt) {
FILE: mlx/backend/metal/kernels/complex.h
type complex64_t (line 9) | struct complex64_t
function complex64_t (line 145) | constexpr complex64_t operator*(complex64_t a, complex64_t b) {
FILE: mlx/backend/metal/kernels/copy.h
function copy_s (line 4) | void copy_s(
function copy_v (line 22) | void copy_v(
function copy_s2 (line 40) | void copy_s2(
function copy_v2 (line 59) | void copy_v2(
function copy_g_nd1 (line 78) | void copy_g_nd1(
function copy_g_nd2 (line 88) | void copy_g_nd2(
function copy_g_nd3 (line 100) | void copy_g_nd3(
function copy_g (line 113) | void copy_g(
function copy_gg_nd2 (line 151) | void copy_gg_nd2(
function copy_gg_nd3 (line 163) | void copy_gg_nd3(
function copy_gg (line 175) | void copy_gg(
function copy_gg_dynamic_nd2 (line 218) | void copy_gg_dynamic_nd2(
function copy_gg_dynamic_nd3 (line 232) | void copy_gg_dynamic_nd3(
FILE: mlx/backend/metal/kernels/erf.h
function erf (line 12) | float erf(float a) {
function erfinv (line 42) | float erfinv(float a) {
FILE: mlx/backend/metal/kernels/expm1f.h
function expm1f_scaled_unchecked (line 43) | float expm1f_scaled_unchecked(float a, float b) {
function expm1f (line 80) | float expm1f(float a) {
FILE: mlx/backend/metal/kernels/fft.h
function fft (line 180) | [[kernel]] void fft(
function rader_fft (line 219) | [[kernel]] void rader_fft(
function bluestein_fft (line 374) | [[kernel]] void bluestein_fft(
function four_step_fft (line 443) | void four_step_fft(
FILE: mlx/backend/metal/kernels/fft/radix.h
function METAL_FUNC (line 19) | METAL_FUNC float2 complex_mul(float2 a, float2 b) {
function METAL_FUNC (line 24) | METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
function METAL_FUNC (line 29) | METAL_FUNC float2 get_twiddle(int k, int p) {
function METAL_FUNC (line 36) | METAL_FUNC void radix2(thread float2* x, thread float2* y) {
function METAL_FUNC (line 41) | METAL_FUNC void radix3(thread float2* x, thread float2* y) {
function METAL_FUNC (line 56) | METAL_FUNC void radix4(thread float2* x, thread float2* y) {
function METAL_FUNC (line 69) | METAL_FUNC void radix5(thread float2* x, thread float2* y) {
function METAL_FUNC (line 96) | METAL_FUNC void radix6(thread float2* x, thread float2* y) {
function METAL_FUNC (line 122) | METAL_FUNC void radix7(thread float2* x, thread float2* y) {
function METAL_FUNC (line 151) | METAL_FUNC void radix8(thread float2* x, thread float2* y) {
function METAL_FUNC (line 201) | METAL_FUNC void radix11(thread float2* x, thread float2* y) {
function METAL_FUNC (line 290) | METAL_FUNC void radix13(thread float2* x, thread float2* y) {
FILE: mlx/backend/metal/kernels/fft/readwrite.h
function METAL_FUNC (line 77) | METAL_FUNC float2 post_in(float2 elem) const {
function METAL_FUNC (line 82) | METAL_FUNC float2 post_in(float elem) const {
function METAL_FUNC (line 86) | METAL_FUNC float2 pre_out(float2 elem) const {
function METAL_FUNC (line 90) | METAL_FUNC float2 pre_out(float2 elem, int length) const {
function METAL_FUNC (line 94) | METAL_FUNC bool out_of_bounds() const {
function METAL_FUNC (line 123) | METAL_FUNC void write() const {
function METAL_FUNC (line 146) | METAL_FUNC void load_padded(int length, const device float2* w_k) const {
function METAL_FUNC (line 163) | METAL_FUNC void write_padded(int length, const device float2* w_k) const {
function METAL_FUNC (line 180) | METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
function METAL_FUNC (line 202) | METAL_FUNC void load_strided(int stride, int overall_n) {
function METAL_FUNC (line 210) | METAL_FUNC void write_strided(int stride, int overall_n) {
function write_padded (line 505) | float>::write_padded(
FILE: mlx/backend/metal/kernels/fp4.h
type fp4_e2m1 (line 3) | struct fp4_e2m1 {
function operator (line 33) | operator float16_t() {
function operator (line 39) | operator float() {
function operator (line 43) | operator bfloat16_t() {
FILE: mlx/backend/metal/kernels/fp8.h
function else (line 3) | struct fp8_e4m3 {
function operator (line 32) | operator float16_t() {
function operator (line 40) | operator bfloat16_t() {
function operator (line 44) | operator float() {
function operator (line 69) | operator bfloat16_t() {
function operator (line 74) | operator float() {
FILE: mlx/backend/metal/kernels/fp_quantized.h
function get_pack_factor (line 21) | short get_pack_factor() {
function get_bytes_per_pack (line 26) | short get_bytes_per_pack() {
function U (line 53) | U operator()(uint8_t x) {
function load_vector (line 63) | inline void load_vector(const device T* x, thread U* x_thread) {
function load_vector_safe (line 71) | inline void load_vector_safe(const device T* x, thread U* x_thread, int ...
function U (line 82) | inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
function U (line 103) | inline U
function qouter (line 125) | inline void qouter(const thread uint8_t* w, U x, U scale, thread U* resu...
function load_safe (line 218) | void load_safe(short2 src_tile_dim) const {
function next (line 244) | void next() {
type U (line 283) | typedef float U;
type U (line 346) | typedef float U;
type U (line 408) | typedef float U;
type U (line 549) | typedef float U;
type vec_w (line 550) | typedef struct {
function fp_qmm_t_impl (line 635) | void fp_qmm_t_impl(
function fp_qmm_n_impl (line 759) | void fp_qmm_n_impl(
function fp_qmv_quad (line 972) | void fp_qmv_quad(
function fp_qmv_fast (line 1011) | void fp_qmv_fast(
function fp_qmv (line 1050) | void fp_qmv(
function fp_qvm (line 1089) | void fp_qvm(
function fp_qmm_t (line 1179) | void fp_qmm_t(
function fp_qmm_n (line 1233) | void fp_qmm_n(
function fp_gather_qmv_fast (line 1282) | void fp_gather_qmv_fast(
function fp_gather_qmv (line 1331) | void fp_gather_qmv(
function fp_gather_qvm (line 1380) | void fp_gather_qvm(
function fp_gather_qmm_t (line 1436) | void fp_gather_qmm_t(
function fp_gather_qmm_n (line 1499) | void fp_gather_qmm_n(
function fp_gather_qmm_rhs (line 1566) | void fp_gather_qmm_rhs(
function fp_quantize (line 1750) | void fp_quantize(
function fp_dequantize (line 1792) | void fp_dequantize(
function fp_quantize_dequantize (line 1824) | void fp_quantize_dequantize(
FILE: mlx/backend/metal/kernels/fp_quantized_nax.h
function get_pack_factor (line 21) | short get_pack_factor() {
function get_bytes_per_pack (line 26) | short get_bytes_per_pack() {
function U (line 53) | U operator()(uint8_t x) {
function load_safe (line 146) | void load_safe(short2 src_tile_dim) const {
function next (line 176) | void next() {
function fp_qmm_t_impl (line 199) | void fp_qmm_t_impl(
function fp_qmm_n_impl (line 343) | void fp_qmm_n_impl(
function fp_qmm_t_nax (line 550) | void fp_qmm_t_nax(
function fp_qmm_n_nax (line 606) | void fp_qmm_n_nax(
function fp_gather_qmm_t_nax (line 665) | void fp_gather_qmm_t_nax(
function fp_gather_qmm_n_nax (line 730) | void fp_gather_qmm_n_nax(
function fp_gather_qmm_rhs_nax (line 798) | void fp_gather_qmm_rhs_nax(
FILE: mlx/backend/metal/kernels/gemv_masked.h
type _NoMask (line 10) | struct _NoMask {
type nomask_t (line 27) | typedef struct _NoMask nomask_t;
function METAL_FUNC (line 33) | METAL_FUNC OutT apply(InT x) const {
function METAL_FUNC (line 125) | static METAL_FUNC void run(
function METAL_FUNC (line 404) | static METAL_FUNC void run(
FILE: mlx/backend/metal/kernels/hadamard.h
function hadamard_m (line 142) | void hadamard_m(
FILE: mlx/backend/metal/kernels/indexing/gather_axis.h
function gather_axis (line 6) | void gather_axis(
FILE: mlx/backend/metal/kernels/indexing/gather_front.h
function gather_front (line 8) | void gather_front(
FILE: mlx/backend/metal/kernels/indexing/masked_scatter.h
function masked_assign_impl (line 8) | [[kernel]] void masked_assign_impl(
FILE: mlx/backend/metal/kernels/indexing/scatter.h
function slice_update_op_impl (line 70) | void slice_update_op_impl(
FILE: mlx/backend/metal/kernels/indexing/scatter_axis.h
function scatter_axis (line 12) | [[kernel]] void scatter_axis(
FILE: mlx/backend/metal/kernels/logging.h
function namespace (line 8) | namespace mlx {
type os_log (line 15) | struct os_log {
FILE: mlx/backend/metal/kernels/logsumexp.h
function logsumexp (line 4) | void logsumexp(
function logsumexp_looped (line 76) | void logsumexp_looped(
FILE: mlx/backend/metal/kernels/quantized.h
function get_pack_factor (line 18) | short get_pack_factor() {
function get_bytes_per_pack (line 23) | short get_bytes_per_pack() {
function U (line 29) | inline U load_vector(const device T* x, thread U* x_thread) {
function U (line 108) | inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
function U (line 192) | inline U qdot(
function U (line 293) | inline U qdot_safe(
function qouter (line 395) | inline void
function dequantize (line 484) | inline void
function load_safe (line 641) | void load_safe(short2 src_tile_dim) const {
function next (line 671) | void next() {
type U (line 711) | typedef float U;
type U (line 772) | typedef float U;
type U (line 840) | typedef float U;
type U (line 1001) | typedef float U;
type vec_w (line 1002) | typedef struct {
function qmm_t_impl (line 1094) | void qmm_t_impl(
function qmm_n_impl (line 1220) | void qmm_n_impl(
function affine_qmv_quad (line 1443) | void affine_qmv_quad(
function affine_qmv_fast (line 1495) | void affine_qmv_fast(
function affine_qmv (line 1547) | void affine_qmv(
function affine_qvm (line 1599) | void affine_qvm(
function affine_qmm_t (line 1715) | void affine_qmm_t(
function affine_qmm_n (line 1773) | void affine_qmm_n(
function affine_gather_qmv_fast (line 1826) | void affine_gather_qmv_fast(
function affine_gather_qmv (line 1888) | void affine_gather_qmv(
function affine_gather_qvm (line 1950) | void affine_gather_qvm(
function affine_gather_qmm_t (line 2019) | void affine_gather_qmm_t(
function affine_gather_qmm_n (line 2086) | void affine_gather_qmm_n(
function affine_gather_qmm_rhs (line 2157) | void affine_gather_qmm_rhs(
function affine_quantize (line 2344) | void affine_quantize(
function affine_dequantize (line 2449) | void affine_dequantize(
FILE: mlx/backend/metal/kernels/quantized_nax.h
function get_pack_factor (line 21) | short get_pack_factor() {
function get_bytes_per_pack (line 26) | short get_bytes_per_pack() {
function U (line 32) | inline U load_vector(const device T* x, thread U* x_thread) {
function U (line 111) | inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
function U (line 195) | inline U qdot(
function U (line 296) | inline U qdot_safe(
function qouter (line 398) | inline void
function dequantize (line 487) | inline void
function load_safe (line 644) | void load_safe(short2 src_tile_dim) const {
function next (line 674) | void next() {
function load_safe (line 784) | void load_safe(short2 src_tile_dim) const {
function next (line 814) | void next() {
function qmm_t_nax_tgp_impl (line 938) | void qmm_t_nax_tgp_impl(
function qmm_n_nax_tgp_impl (line 1082) | void qmm_n_nax_tgp_impl(
function affine_qmm_t_nax (line 1205) | void affine_qmm_t_nax(
function affine_qmm_n_nax (line 1264) | void affine_qmm_n_nax(
function affine_gather_qmm_t_nax (line 1324) | void affine_gather_qmm_t_nax(
function affine_gather_qmm_n_nax (line 1392) | void affine_gather_qmm_n_nax(
function affine_gather_qmm_rhs_nax (line 1461) | void affine_gather_qmm_rhs_nax(
FILE: mlx/backend/metal/kernels/quantized_utils.h
function typename (line 77) | typename loader_b_t>
FILE: mlx/backend/metal/kernels/reduction/ops.h
type None (line 29) | struct None {
function simd_reduce_impl (line 40) | bool simd_reduce_impl(bool val) {
function update (line 67) | void update(device bool* out, bool val) {
function simd_reduce_impl (line 81) | bool simd_reduce_impl(bool val) {
function update (line 108) | void update(device bool* out, bool val) {
function U (line 135) | U operator()(U a, U b) {
function U (line 157) | U operator()(U a, U b) {
FILE: mlx/backend/metal/kernels/reduction/reduce_all.h
function all_reduce (line 9) | void all_reduce(
FILE: mlx/backend/metal/kernels/reduction/reduce_col.h
function col_reduce_small (line 4) | void col_reduce_small(
function col_reduce_longcolumn (line 97) | void col_reduce_longcolumn(
function col_reduce_looped (line 163) | void col_reduce_looped(
function col_reduce_2pass (line 302) | void col_reduce_2pass(
FILE: mlx/backend/metal/kernels/reduction/reduce_init.h
function init_reduce (line 4) | [[kernel]] void init_reduce(
FILE: mlx/backend/metal/kernels/reduction/reduce_row.h
function per_thread_row_reduce (line 19) | void per_thread_row_reduce(
function per_thread_row_reduce (line 70) | void per_thread_row_reduce(
function per_thread_row_reduce (line 98) | void per_thread_row_reduce(
function threadgroup_reduce (line 129) | void threadgroup_reduce(
function thread_reduce (line 165) | void
function row_reduce_small (line 199) | void row_reduce_small(
FILE: mlx/backend/metal/kernels/scan.h
function U (line 45) | U simd_scan_impl(U x) {
function U (line 49) | U simd_exclusive_scan_impl(U x) {
function U (line 66) | U simd_scan_impl(U x) {
function U (line 70) | U simd_exclusive_scan_impl(U x) {
function bool (line 76) | struct CumProd<bool> {
function U (line 107) | U simd_scan(U x) {
function U (line 115) | U simd_exclusive_scan(U x) {
function U (line 130) | U simd_scan(U x) {
function U (line 138) | U simd_exclusive_scan(U x) {
function U (line 153) | U simd_scan(U x) {
function U (line 161) | U simd_exclusive_scan(U x) {
function load_unsafe (line 168) | inline void load_unsafe(U values[N_READS], const device T* input) {
function load_safe (line 181) | inline void load_safe(
function write_unsafe (line 200) | inline void write_unsafe(U values[N_READS], device U* out) {
function write_safe (line 213) | inline void write_safe(U values[N_READS], device U* out, int start, int ...
function contiguous_scan (line 236) | void contiguous_scan(
function strided_scan (line 392) | void strided_scan(
FILE: mlx/backend/metal/kernels/sdpa_vector.h
function sdpa_vector (line 16) | void sdpa_vector(
function sdpa_vector_2pass_1 (line 180) | void sdpa_vector_2pass_1(
function sdpa_vector_2pass_2 (line 321) | [[kernel]] void sdpa_vector_2pass_2(
FILE: mlx/backend/metal/kernels/softmax.h
function softmax_single_row (line 11) | void softmax_single_row(
function softmax_looped (line 101) | void softmax_looped(
FILE: mlx/backend/metal/kernels/sort.h
function METAL_FUNC (line 35) | METAL_FUNC bool operator()(T a, T b) const {
type ThreadSort (line 54) | struct ThreadSort {
function METAL_FUNC (line 88) | static METAL_FUNC int merge_partition(
function METAL_FUNC (line 114) | static METAL_FUNC void merge_step(
function METAL_FUNC (line 146) | static METAL_FUNC void sort(
function METAL_FUNC (line 258) | static METAL_FUNC void block_sort(
function block_sort (line 307) | void block_sort(
function block_sort_nc (line 362) | void block_sort_nc(
function METAL_FUNC (line 434) | static METAL_FUNC void block_sort(
function METAL_FUNC (line 472) | static METAL_FUNC int merge_partition(
function mb_block_sort (line 505) | void mb_block_sort(
function mb_block_partition (line 549) | void mb_block_partition(
FILE: mlx/backend/metal/kernels/steel/attn/attn.h
function namespace (line 18) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
type MaxOp (line 18) | struct MaxOp {
type SumOp (line 25) | struct SumOp {
type MulOp (line 32) | struct MulOp {
type SubOp (line 39) | struct SubOp {
type ExpSubOp (line 46) | struct ExpSubOp {
type DivOp (line 53) | struct DivOp {
function ulong3 (line 88) | ulong3 tidl{tid.x, tid.y, tid.z};
FILE: mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h
function METAL_FUNC (line 24) | METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
function METAL_FUNC (line 26) | METAL_FUNC T apply(T x) const {
type MaxOp (line 31) | struct MaxOp {
type SumOp (line 38) | struct SumOp {
type MulOp (line 45) | struct MulOp {
type SubOp (line 52) | struct SubOp {
type ExpSubOp (line 59) | struct ExpSubOp {
type DivOp (line 66) | struct DivOp {
function ulong3 (line 102) | ulong3 tidl{tid.x, tid.y, tid.z};
FILE: mlx/backend/metal/kernels/steel/attn/loader.h
function namespace (line 11) | namespace mlx {
function METAL_FUNC (line 199) | METAL_FUNC void load_unsafe() const {
function METAL_FUNC (line 210) | METAL_FUNC void load_safe(short2 src_tile_dim) const {
function METAL_FUNC (line 258) | METAL_FUNC void next() {
FILE: mlx/backend/metal/kernels/steel/attn/mma.h
function namespace (line 19) | namespace mlx {
function METAL_FUNC (line 513) | METAL_FUNC BlockMMA(
function METAL_FUNC (line 533) | METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
function METAL_FUNC (line 560) | METAL_FUNC void store_result(device U* D, const int ldd) {
function METAL_FUNC (line 573) | METAL_FUNC void
function METAL_FUNC (line 675) | METAL_FUNC void store_result(
function METAL_FUNC (line 707) | METAL_FUNC void store_result_safe(
FILE: mlx/backend/metal/kernels/steel/attn/nax.h
function namespace (line 20) | namespace mlx {
function thread (line 356) | thread T* reduced_vals) {
function thread (line 376) | thread T* row_vals) {
type typename (line 560) | typedef typename NAXFrag_t::template dtype_frag_t<T> frag_type;
function METAL_FUNC (line 564) | METAL_FUNC NAXTile() thread {}
function METAL_FUNC (line 566) | METAL_FUNC constexpr void clear() {
function ta (line 844) | constexpr auto ta = metal::bool_constant<transpose_a>{}
function tb (line 845) | constexpr auto tb = metal::bool_constant<transpose_b>{}
FILE: mlx/backend/metal/kernels/steel/attn/params.h
function namespace (line 9) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/attn/transforms.h
function namespace (line 11) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h
function implicit_gemm_conv_2d_general (line 16) | void
FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h
function namespace (line 13) | namespace mlx {
function METAL_FUNC (line 290) | METAL_FUNC void load_unsafe() const {
function METAL_FUNC (line 315) | METAL_FUNC void next() {
function METAL_FUNC (line 438) | METAL_FUNC void next() {
function METAL_FUNC (line 562) | METAL_FUNC void load_unsafe() const {
function METAL_FUNC (line 591) | METAL_FUNC void next() {
function METAL_FUNC (line 780) | METAL_FUNC void load_unsafe() const {
function METAL_FUNC (line 807) | METAL_FUNC void next() {
function METAL_FUNC (line 941) | METAL_FUNC void next() {
FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h
function namespace (line 13) | namespace mlx {
function METAL_FUNC (line 190) | METAL_FUNC void next() {
function METAL_FUNC (line 313) | METAL_FUNC void next() {
FILE: mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h
function namespace (line 11) | namespace mlx {
function METAL_FUNC (line 311) | METAL_FUNC void load_safe(const short remaining_k) const {
function METAL_FUNC (line 361) | METAL_FUNC void next() {
FILE: mlx/backend/metal/kernels/steel/conv/params.h
function MLXConvParams (line 23) | static MLXConvParams<NDIM>
function namespace (line 48) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/gemm/gemm.h
function namespace (line 17) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
function namespace (line 12) | namespace mlx::steel {
FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h
type _NoMask (line 11) | struct _NoMask {
function METAL_FUNC (line 32) | METAL_FUNC OutT apply(InT x) const {
type nomask_t (line 37) | typedef struct _NoMask nomask_t;
function block_masked_gemm (line 52) | void
FILE: mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h
function gemm_splitk (line 21) | void gemm_splitk(
function gemm_splitk_accum (line 172) | [[kernel]] void gemm_splitk_accum(
function gemm_splitk_accum_axpby (line 199) | [[kernel]] void gemm_splitk_accum_axpby(
FILE: mlx/backend/metal/kernels/steel/gemm/loader.h
function namespace (line 11) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/gemm/mma.h
function namespace (line 19) | namespace mlx {
function METAL_FUNC (line 181) | METAL_FUNC static constexpr void mma(
function METAL_FUNC (line 200) | METAL_FUNC static constexpr void mma(
type typename (line 230) | typedef typename MMAFrag_t::mat_type mat_type;
type typename (line 231) | typedef typename MMAFrag_t::frag_type frag_type;
function METAL_FUNC (line 235) | METAL_FUNC MMATile() thread {}
function METAL_FUNC (line 237) | METAL_FUNC constexpr void clear() {
function METAL_FUNC (line 254) | METAL_FUNC mat_type mat_at(const short i, const short j) {
function elem_type (line 263) | elem_type* elems() {
function elem_type (line 267) | elem_type* elems() const {
function METAL_FUNC (line 426) | static METAL_FUNC complex64_t apply(complex64_t x) {
function METAL_FUNC (line 429) | static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
function METAL_FUNC (line 488) | METAL_FUNC BlockMMA(
function METAL_FUNC (line 508) | METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
function METAL_FUNC (line 535) | METAL_FUNC void store_result(device U* D, const int ldd) {
function METAL_FUNC (line 548) | METAL_FUNC void
function METAL_FUNC (line 568) | METAL_FUNC void
function METAL_FUNC (line 670) | METAL_FUNC void store_result(
function METAL_FUNC (line 702) | METAL_FUNC void store_result_safe(
function METAL_FUNC (line 820) | METAL_FUNC BlockMMA(
function METAL_FUNC (line 840) | METAL_FUNC void mma(
function METAL_FUNC (line 900) | METAL_FUNC void store_result(device U* D, const int ldd) {
function METAL_FUNC (line 919) | METAL_FUNC void
function METAL_FUNC (line 950) | METAL_FUNC void
function METAL_FUNC (line 1068) | METAL_FUNC void store_result(
function METAL_FUNC (line 1102) | METAL_FUNC void store_result_safe(
FILE: mlx/backend/metal/kernels/steel/gemm/nax.h
function namespace (line 20) | namespace mlx {
function thread (line 356) | thread T* reduced_vals) {
function thread (line 376) | thread T* row_vals) {
type typename (line 560) | typedef typename NAXFrag_t::template dtype_frag_t<T> frag_type;
function METAL_FUNC (line 564) | METAL_FUNC NAXTile() thread {}
function METAL_FUNC (line 566) | METAL_FUNC constexpr void clear() {
function ta (line 844) | constexpr auto ta = metal::bool_constant<transpose_a>{}
function tb (line 845) | constexpr auto tb = metal::bool_constant<transpose_b>{}
FILE: mlx/backend/metal/kernels/steel/gemm/params.h
function namespace (line 9) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/gemm/transforms.h
function namespace (line 11) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/utils.h
function METAL_FUNC (line 7) | METAL_FUNC ulong2 elem_to_loc_broadcast(
function METAL_FUNC (line 24) | METAL_FUNC ulong3 elem_to_loc_broadcast(
FILE: mlx/backend/metal/kernels/steel/utils/integral_constant.h
function namespace (line 10) | namespace mlx {
FILE: mlx/backend/metal/kernels/steel/utils/type_traits.h
function namespace (line 9) | namespace metal {
FILE: mlx/backend/metal/kernels/ternary.h
function ternary_v (line 9) | void ternary_v(
function ternary_v2 (line 38) | void ternary_v2(
function ternary_g_nd1 (line 63) | void ternary_g_nd1(
function ternary_g_nd2 (line 79) | void ternary_g_nd2(
function ternary_g_nd3 (line 97) | void ternary_g_nd3(
FILE: mlx/backend/metal/kernels/ternary_ops.h
type Select (line 5) | struct Select {
FILE: mlx/backend/metal/kernels/unary.h
function unary_v (line 4) | void unary_v(
function unary_v2 (line 22) | void unary_v2(
FILE: mlx/backend/metal/kernels/unary_ops.h
type Abs (line 17) | struct Abs {
function complex64_t (line 37) | complex64_t operator()(complex64_t x) {
type ArcCos (line 42) | struct ArcCos {
type ArcCosh (line 51) | struct ArcCosh {
type ArcSin (line 58) | struct ArcSin {
type ArcSinh (line 67) | struct ArcSinh {
type ArcTan (line 74) | struct ArcTan {
type ArcTanh (line 83) | struct ArcTanh {
type BitwiseInvert (line 90) | struct BitwiseInvert {
type Ceil (line 97) | struct Ceil {
type Cos (line 131) | struct Cos {
function complex64_t (line 137) | complex64_t operator()(complex64_t x) {
type Cosh (line 144) | struct Cosh {
function complex64_t (line 150) | complex64_t operator()(complex64_t x) {
type Conjugate (line 157) | struct Conjugate {
type Erf (line 163) | struct Erf {
type ErfInv (line 170) | struct ErfInv {
type Exp (line 177) | struct Exp {
function complex64_t (line 182) | complex64_t operator()(complex64_t x) {
type Expm1 (line 187) | struct Expm1 {
type Floor (line 194) | struct Floor {
type Imag (line 228) | struct Imag {
type Log (line 234) | struct Log {
function complex64_t (line 240) | complex64_t operator()(complex64_t x) {
type Log2 (line 247) | struct Log2 {
function complex64_t (line 253) | complex64_t operator()(complex64_t x) {
type Log10 (line 259) | struct Log10 {
function complex64_t (line 265) | complex64_t operator()(complex64_t x) {
type Log1p (line 271) | struct Log1p {
type LogicalNot (line 278) | struct LogicalNot {
type Negative (line 285) | struct Negative {
type Real (line 292) | struct Real {
type Round (line 298) | struct Round {
function complex64_t (line 303) | complex64_t operator()(complex64_t x) {
type Sigmoid (line 308) | struct Sigmoid {
type Sign (line 316) | struct Sign {
function complex64_t (line 324) | complex64_t operator()(complex64_t x) {
type Sin (line 333) | struct Sin {
function complex64_t (line 339) | complex64_t operator()(complex64_t x) {
type Sinh (line 346) | struct Sinh {
function complex64_t (line 352) | complex64_t operator()(complex64_t x) {
type Square (line 359) | struct Square {
type Sqrt (line 366) | struct Sqrt {
function complex64_t (line 372) | complex64_t operator()(complex64_t x) {
type Rsqrt (line 384) | struct Rsqrt {
function complex64_t (line 390) | complex64_t operator()(complex64_t x) {
type Tan (line 395) | struct Tan {
function complex64_t (line 401) | complex64_t operator()(complex64_t x) {
type Tanh (line 410) | struct Tanh {
function complex64_t (line 416) | complex64_t operator()(complex64_t x) {
function i (line 426) | auto i = complex64_t{0.0, 1.0};
function i (line 432) | auto i = complex64_t{0.0, 1.0};
function i (line 438) | auto i = complex64_t{0.0, 1.0};
type ToFP8 (line 443) | struct ToFP8 {
type FromFP8 (line 450) | struct FromFP8 {
FILE: mlx/backend/metal/kernels/utils.h
type half (line 13) | typedef half float16_t;
function bool (line 73) | struct Limits<bool> {
function complex64_t (line 79) | struct Limits<complex64_t> {
function OffsetT (line 205) | OffsetT offset{0}
function index (line 206) | int index{0}
function next (line 210) | void next(const constant int* shape, const constant int64_t* strides) {
function next (line 223) | void next(int n, const constant int* shape, const constant int64_t* stri...
function OffsetT (line 246) | OffsetT location() {
function OffsetT (line 254) | OffsetT offset{0};
function T (line 307) | T ceildiv(T N, U M) {
function log1p (line 312) | inline float log1p(float x) {
function bfloat16_t (line 324) | inline bfloat16_t log1p(bfloat16_t x) {
function complex64_t (line 336) | inline complex64_t log1p(complex64_t in) {
FILE: mlx/backend/metal/logsumexp.cpp
type mlx::core (line 10) | namespace mlx::core {
FILE: mlx/backend/metal/matmul.cpp
type mlx::core (line 21) | namespace mlx::core {
function check_transpose (line 25) | std::tuple<bool, int64_t, array> check_transpose(
function array (line 43) | inline array
function ensure_batch_contiguous (line 54) | inline std::tuple<bool, int64_t, array>
function steel_matmul_regular_axpby_nax (line 176) | void steel_matmul_regular_axpby_nax(
function steel_matmul_regular_axpby (line 341) | void steel_matmul_regular_axpby(
function steel_gemm_splitk_axpby (line 530) | void steel_gemm_splitk_axpby(
function steel_gemm_splitk_axpby_nax (line 687) | void steel_gemm_splitk_axpby_nax(
function steel_matmul_axpby (line 859) | void steel_matmul_axpby(
function gemv_axbpy (line 1033) | void gemv_axbpy(
function gemv (line 1172) | inline void gemv(
function gather_mm_rhs (line 1844) | void gather_mm_rhs(
function gather_mm_rhs_nax (line 1977) | void gather_mm_rhs_nax(
function gather_mv (line 2120) | void gather_mv(
function gather_mm (line 2237) | void gather_mm(
function segmented_mm (line 2424) | void segmented_mm(
FILE: mlx/backend/metal/matmul.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/metal.cpp
type mlx::core::metal (line 8) | namespace mlx::core::metal {
function is_available (line 10) | bool is_available() {
function start_capture (line 14) | void start_capture(std::string path, NS::Object* object) {
function start_capture (line 39) | void start_capture(std::string path) {
function stop_capture (line 44) | void stop_capture() {
FILE: mlx/backend/metal/metal.h
function namespace (line 11) | namespace mlx::core::metal {
FILE: mlx/backend/metal/no_metal.cpp
type mlx::core (line 8) | namespace mlx::core {
type metal (line 10) | namespace metal {
function is_available (line 12) | bool is_available() {
function start_capture (line 16) | void start_capture(std::string) {}
function stop_capture (line 17) | void stop_capture() {}
type fast (line 27) | namespace fast {
function CustomKernelFunction (line 29) | CustomKernelFunction metal_kernel(
FILE: mlx/backend/metal/nojit_kernels.cpp
type mlx::core (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/normalization.cpp
type mlx::core::fast (line 11) | namespace mlx::core::fast {
FILE: mlx/backend/metal/primitives.cpp
type mlx::core (line 18) | namespace mlx::core {
function arange_set_scalars (line 21) | void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
FILE: mlx/backend/metal/quantized.cpp
type mlx::core (line 15) | namespace mlx::core {
function get_quantized_kernel_wrapped (line 20) | auto get_quantized_kernel_wrapped(
function get_qmm_nax_kernel_wrapped (line 37) | auto get_qmm_nax_kernel_wrapped(
function array (line 53) | inline array
function array (line 64) | inline array ensure_row_contiguous_matrix(
function get_qmv_batch_limit (line 84) | inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
function add_strides_and_shapes (line 128) | inline int add_strides_and_shapes(
function add_gather_strides_and_shapes (line 158) | inline int add_gather_strides_and_shapes(
function qmv_quad (line 177) | void qmv_quad(
function qmv (line 235) | void qmv(
function qvm_split_k (line 298) | void qvm_split_k(
function qvm (line 418) | void qvm(
function qmm_nax (line 472) | void qmm_nax(
function gather_qmm_nax (line 575) | void gather_qmm_nax(
function qmm (line 679) | void qmm(
function gather_qmm (line 773) | void gather_qmm(
function gather_qmv (line 864) | void gather_qmv(
function gather_qvm (line 930) | void gather_qvm(
function gather_qmm_rhs_nax (line 988) | void gather_qmm_rhs_nax(
function gather_qmm_rhs (line 1119) | void gather_qmm_rhs(
function dispatch_qmv (line 1269) | void dispatch_qmv(
function quantize_dequantize (line 1463) | void quantize_dequantize(
FILE: mlx/backend/metal/reduce.cpp
type mlx::core (line 15) | namespace mlx::core {
type RowReduceArgs (line 19) | struct RowReduceArgs {
method RowReduceArgs (line 36) | RowReduceArgs(
method encode (line 58) | void encode(CommandEncoder& compute_encoder) {
type ColReduceArgs (line 89) | struct ColReduceArgs {
method ColReduceArgs (line 107) | ColReduceArgs(
method ColReduceArgs (line 144) | ColReduceArgs(const array& intermediate) {
method encode (line 154) | void encode(CommandEncoder& compute_encoder) {
function safe_div (line 188) | inline auto safe_div(size_t n, size_t m) {
function safe_divup (line 192) | inline auto safe_divup(size_t n, size_t m) {
function is_64b_int (line 196) | inline bool is_64b_int(Dtype dtype) {
function is_64b_dtype (line 200) | inline bool is_64b_dtype(Dtype dtype) {
function get_kernel_reduce_ndim (line 204) | inline int get_kernel_reduce_ndim(int reduce_ndim) {
function threadgroup_size_from_row_size (line 214) | inline int threadgroup_size_from_row_size(int row_size) {
function output_grid_for_col_reduce (line 233) | inline auto output_grid_for_col_reduce(
function remap_reduce_types (line 245) | std::pair<Dtype, Dtype> remap_reduce_types(
function init_reduce (line 289) | void init_reduce(
function all_reduce_dispatch (line 312) | void all_reduce_dispatch(
function row_reduce_small (line 393) | void row_reduce_small(
function row_reduce_simple (line 449) | void row_reduce_simple(
function row_reduce_looped (line 489) | void row_reduce_looped(
function row_reduce_general_dispatch (line 539) | void row_reduce_general_dispatch(
function strided_reduce_small (line 566) | void strided_reduce_small(
function strided_reduce_longcolumn (line 632) | void strided_reduce_longcolumn(
function strided_reduce_looped (line 743) | void strided_reduce_looped(
function strided_reduce_2pass (line 808) | void strided_reduce_2pass(
function strided_reduce_general_dispatch (line 919) | void strided_reduce_general_dispatch(
FILE: mlx/backend/metal/reduce.h
function namespace (line 9) | namespace mlx::core {
FILE: mlx/backend/metal/resident.cpp
type mlx::core::metal (line 5) | namespace mlx::core::metal {
FILE: mlx/backend/metal/resident.h
function namespace (line 7) | namespace mlx::core::metal {
FILE: mlx/backend/metal/rope.cpp
type mlx::core::fast (line 6) | namespace mlx::core::fast {
FILE: mlx/backend/metal/scaled_dot_product_attention.cpp
type mlx::core::fast (line 14) | namespace mlx::core::fast {
function sdpa_full_self_attention_nax (line 18) | void sdpa_full_self_attention_nax(
function sdpa_full_self_attention_metal (line 166) | void sdpa_full_self_attention_metal(
function sdpa_vector (line 329) | void sdpa_vector(
function sdpa_vector_2pass (line 418) | void sdpa_vector_2pass(
FILE: mlx/backend/metal/scan.cpp
type mlx::core (line 13) | namespace mlx::core {
function scan_gpu_inplace (line 15) | void scan_gpu_inplace(
FILE: mlx/backend/metal/slicing.cpp
type mlx::core (line 12) | namespace mlx::core {
function concatenate_gpu (line 14) | void concatenate_gpu(
function array (line 45) | array compute_dynamic_offset(
FILE: mlx/backend/metal/softmax.cpp
type mlx::core (line 11) | namespace mlx::core {
FILE: mlx/backend/metal/sort.cpp
type mlx::core (line 11) | namespace mlx::core {
function single_block_sort (line 15) | void single_block_sort(
function multi_block_sort (line 114) | void multi_block_sort(
function gpu_merge_sort (line 274) | void gpu_merge_sort(
FILE: mlx/backend/metal/ternary.cpp
type mlx::core (line 9) | namespace mlx::core {
function ternary_op_gpu_inplace (line 11) | void ternary_op_gpu_inplace(
function ternary_op_gpu (line 135) | void ternary_op_gpu(
function ternary_op_gpu (line 148) | void ternary_op_gpu(
FILE: mlx/backend/metal/ternary.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/unary.cpp
type mlx::core (line 14) | namespace mlx::core {
function unary_op_gpu_inplace (line 16) | void unary_op_gpu_inplace(
function unary_op_gpu (line 98) | void unary_op_gpu(
function unary_op_gpu (line 107) | void unary_op_gpu(
FILE: mlx/backend/metal/unary.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/backend/metal/utils.cpp
type mlx::core (line 6) | namespace mlx::core {
function type_to_name (line 8) | std::string type_to_name(const Dtype& t) {
function type_to_name (line 57) | std::string type_to_name(const array& a) {
function get_block_dims (line 61) | MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) {
function get_2d_grid_dims (line 66) | MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
function get_2d_grid_dims (line 71) | MTL::Size
FILE: mlx/backend/metal/utils.h
function namespace (line 11) | namespace mlx::core {
FILE: mlx/backend/no_cpu/compiled.cpp
type mlx::core (line 6) | namespace mlx::core {
type detail (line 11) | namespace detail {
function compile_available_for_device (line 12) | bool compile_available_for_device(const Device& device) {
FILE: mlx/backend/no_cpu/device_info.cpp
type mlx::core::cpu (line 5) | namespace mlx::core::cpu {
function is_available (line 7) | bool is_available() {
function device_count (line 11) | int device_count() {
FILE: mlx/backend/no_cpu/primitives.cpp
type mlx::core (line 18) | namespace mlx::core {
type fast (line 133) | namespace fast {
type distributed (line 138) | namespace distributed {
FILE: mlx/backend/no_gpu/allocator.cpp
function get_memory_size (line 14) | size_t get_memory_size() {
type mlx::core (line 19) | namespace mlx::core {
type allocator (line 21) | namespace allocator {
class CommonAllocator (line 23) | class CommonAllocator : public Allocator {
method get_active_memory (line 30) | size_t get_active_memory() const {
method get_peak_memory (line 33) | size_t get_peak_memory() const {
method reset_peak_memory (line 36) | void reset_peak_memory() {
method get_memory_limit (line 40) | size_t get_memory_limit() {
method set_memory_limit (line 43) | size_t set_memory_limit(size_t limit) {
method CommonAllocator (line 54) | CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
function CommonAllocator (line 63) | CommonAllocator& common_allocator() {
method get_active_memory (line 30) | size_t get_active_memory() const {
method get_peak_memory (line 33) | size_t get_peak_memory() const {
method reset_peak_memory (line 36) | void reset_peak_memory() {
method get_memory_limit (line 40) | size_t get_memory_limit() {
method set_memory_limit (line 43) | size_t set_memory_limit(size_t limit) {
method CommonAllocator (line 54) | CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
function Allocator (line 68) | Allocator& allocator() {
function Buffer (line 79) | Buffer CommonAllocator::malloc(size_t size) {
function get_active_memory (line 106) | size_t get_active_memory() {
function get_peak_memory (line 109) | size_t get_peak_memory() {
function reset_peak_memory (line 112) | void reset_peak_memory() {
function set_memory_limit (line 115) | size_t set_memory_limit(size_t limit) {
function get_memory_limit (line 118) | size_t get_memory_limit() {
function get_cache_memory (line 123) | size_t get_cache_memory() {
function set_cache_limit (line 126) | size_t set_cache_limit(size_t) {
function set_wired_limit (line 129) | size_t set_wired_limit(size_t) {
function clear_cache (line 132) | void clear_cache() {}
FILE: mlx/backend/no_gpu/apple_memory.h
function get_memory_size (line 9) | size_t get_memory_size() {
FILE: mlx/backend/no_gpu/device_info.cpp
type mlx::core::gpu (line 5) | namespace mlx::core::gpu {
function is_available (line 7) | bool is_available() {
function device_count (line 11) | int device_count() {
FILE: mlx/backend/no_gpu/eval.cpp
type mlx::core::gpu (line 8) | namespace mlx::core::gpu {
function new_stream (line 10) | void new_stream(Stream) {}
function eval (line 12) | void eval(array&) {
function finalize (line 16) | void finalize(Stream) {
function synchronize (line 20) | void synchronize(Stream) {
FILE: mlx/backend/no_gpu/event.cpp
type mlx::core (line 9) | namespace mlx::core {
type EventCounter (line 11) | struct EventCounter {
FILE: mlx/backend/no_gpu/fence.cpp
type mlx::core (line 9) | namespace mlx::core {
type FenceImpl (line 11) | struct FenceImpl {
FILE: mlx/backend/no_gpu/linux_memory.h
function get_memory_size (line 9) | size_t get_memory_size() {
FILE: mlx/backend/no_gpu/primitives.cpp
type mlx::core (line 24) | namespace mlx::core {
type fast (line 164) | namespace fast {
type distributed (line 177) | namespace distributed {
FILE: mlx/compile.cpp
type mlx::core (line 21) | namespace mlx::core {
function is_unary (line 26) | bool is_unary(const Primitive& p) {
function is_binary (line 47) | bool is_binary(const Primitive& p) {
function is_ternary (line 61) | bool is_ternary(const Primitive& p) {
function is_broadcast (line 65) | bool is_broadcast(const Primitive& p) {
function is_noop (line 69) | bool is_noop(const Primitive& p) {
function is_reduction (line 73) | bool is_reduction(const Primitive& p) {
function is_fusable (line 77) | bool is_fusable(const Primitive& p) {
type detail (line 214) | namespace detail {
function merge_one (line 230) | void merge_one(array& dst, array& src, ParentsMap& parents_map) {
function merge (line 256) | void merge(array& dst, array& src, ParentsMap& parents_map) {
function array (line 268) | array split_one(
function get_function_address (line 291) | std::uintptr_t get_function_address(const std::function<T(U...)>& fu...
class CompilerCache (line 300) | class CompilerCache {
type CacheEntry (line 302) | struct CacheEntry {
method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless)
method CacheEntry (line 317) | CacheEntry& find(
method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless)
method erase (line 369) | void erase(std::uintptr_t fun_id) {
method clear (line 373) | void clear() {
method CompilerCache (line 378) | CompilerCache() {
function CompilerCache (line 388) | CompilerCache& compiler_cache() {
type CacheEntry (line 302) | struct CacheEntry {
method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless)
method CacheEntry (line 317) | CacheEntry& find(
method CacheEntry (line 303) | CacheEntry(Stream stream, bool shapeless)
method erase (line 369) | void erase(std::uintptr_t fun_id) {
method clear (line 373) | void clear() {
method CompilerCache (line 378) | CompilerCache() {
function compile_trace (line 393) | std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<v...
function compile_dfs (line 415) | std::pair<std::vector<array>, ParentsMap> compile_dfs(
function splitmix64 (line 544) | static inline uint64_t splitmix64(uint64_t x) noexcept {
type VecU64Hash (line 551) | struct VecU64Hash {
function compile_simplify (line 564) | void compile_simplify(
function compile_fuse (line 779) | void compile_fuse(
function compile_replace (line 1021) | std::vector<array> compile_replace(
function skip_compile (line 1088) | bool skip_compile() {
function ArrayFnWithExtra (line 1093) | ArrayFnWithExtra compile(
function compile (line 1160) | std::function<std::vector<array>(const std::vector<array>&)> compile(
function compile_erase (line 1187) | void compile_erase(std::uintptr_t fun_id) {
function compile_clear_cache (line 1191) | void compile_clear_cache() {
function compile (line 1197) | std::function<std::vector<array>(const std::vector<array>&)> compile(
function compile (line 1226) | std::function<std::vector<array>(const std::vector<array>&)> compile(
function disable_compile (line 1235) | void disable_compile() {
function enable_compile (line 1239) | void enable_compile() {
function set_compile_mode (line 1243) | void set_compile_mode(CompileMode mode) {
FILE: mlx/compile.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/compile_impl.h
function namespace (line 10) | namespace mlx::core::detail {
FILE: mlx/device.cpp
type mlx::core (line 9) | namespace mlx::core {
function Device (line 11) | Device& mutable_default_device() {
function Device (line 16) | const Device& default_device() {
function set_default_device (line 20) | void set_default_device(const Device& d) {
function is_available (line 36) | bool is_available(const Device& d) {
function device_count (line 47) | int device_count(Device::DeviceType type) {
FILE: mlx/device.h
function namespace (line 11) | namespace mlx::core {
FILE: mlx/distributed/distributed.cpp
type mlx::core::distributed (line 13) | namespace mlx::core::distributed {
type detail (line 15) | namespace detail {
function Stream (line 17) | Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
function all_sum (line 21) | void all_sum(Group group, const array& input, array& output, Stream ...
function all_max (line 25) | void all_max(Group group, const array& input, array& output, Stream ...
function all_min (line 29) | void all_min(Group group, const array& input, array& output, Stream ...
function all_gather (line 33) | void all_gather(Group group, const array& input, array& output, Stre...
function send (line 37) | void send(Group group, const array& input, int dst, Stream stream) {
function recv (line 41) | void recv(Group group, array& out, int src, Stream stream) {
function sum_scatter (line 45) | void sum_scatter(
class EmptyGroup (line 53) | class EmptyGroup : public GroupImpl {
method Stream (line 55) | Stream communication_stream(StreamOrDevice s) override {
method rank (line 59) | int rank() override {
method size (line 63) | int size() override {
method split (line 67) | std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
method all_sum (line 71) | void all_sum(const array&, array&, Stream) override {
method all_gather (line 75) | void all_gather(const array&, array&, Stream) override {
method send (line 79) | void send(const array&, int, Stream) override {
method recv (line 83) | void recv(array&, int, Stream) override {
method all_max (line 88) | void all_max(const array&, array&, Stream) override {
method all_min (line 93) | void all_min(const array&, array&, Stream) override {
method sum_scatter (line 97) | void sum_scatter(const array&, array&, Stream) override {
function is_available (line 105) | bool is_available() {
function is_available (line 110) | bool is_available(const std::string& bk) {
function Group (line 137) | Group Group::split(int color, int key /* = -1 */) const {
function Group (line 141) | Group init(bool strict /* = false */, const std::string& bk /* = "any"...
FILE: mlx/distributed/distributed.h
function namespace (line 11) | namespace mlx::core::distributed {
FILE: mlx/distributed/distributed_impl.h
function namespace (line 7) | namespace mlx::core::distributed::detail {
FILE: mlx/distributed/jaccl/jaccl.cpp
type DeviceFile (line 18) | struct DeviceFile {
method DeviceFile (line 19) | DeviceFile(const char* dev_file) {
method size (line 60) | int size() {
method is_valid_mesh (line 64) | bool is_valid_mesh() {
method is_valid_ring (line 76) | bool is_valid_ring() {
method extract_mesh_connectivity (line 101) | std::vector<std::string> extract_mesh_connectivity(int rank) {
method extract_ring_connectivity (line 111) | std::pair<std::vector<std::string>, std::vector<std::string>>
type mlx::core::distributed::jaccl (line 124) | namespace mlx::core::distributed::jaccl {
function is_available (line 126) | bool is_available() {
function init (line 130) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/jaccl/jaccl.h
function namespace (line 5) | namespace mlx::core::distributed::jaccl {
FILE: mlx/distributed/jaccl/mesh.cpp
type mlx::core::distributed::jaccl (line 8) | namespace mlx::core::distributed::jaccl {
FILE: mlx/distributed/jaccl/mesh.h
function namespace (line 12) | namespace mlx::core::distributed::jaccl {
FILE: mlx/distributed/jaccl/mesh_impl.h
function namespace (line 11) | namespace mlx::core::distributed::jaccl {
function all_gather (line 134) | void all_gather(const char* in_ptr, char* out_ptr, int64_t n_bytes) {
function send (line 214) | void send(const char* in_ptr, int64_t n_bytes, int dst) {
function recv (line 264) | void recv(char* out_ptr, int64_t n_bytes, int src) {
function recv_from (line 317) | void recv_from(int sz, int rank, int buff) {
function post_send_all (line 330) | void post_send_all(int sz, int buff) {
function post_recv_all (line 341) | void post_recv_all(int sz, int buff) {
FILE: mlx/distributed/jaccl/no_jaccl.cpp
type mlx::core::distributed::jaccl (line 5) | namespace mlx::core::distributed::jaccl {
function is_available (line 9) | bool is_available() {
function init (line 13) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/jaccl/ring.cpp
type mlx::core::distributed::jaccl (line 8) | namespace mlx::core::distributed::jaccl {
FILE: mlx/distributed/jaccl/ring.h
function namespace (line 11) | namespace mlx::core::distributed::jaccl {
FILE: mlx/distributed/jaccl/ring_impl.h
function namespace (line 11) | namespace mlx::core::distributed::jaccl {
function all_gather (line 299) | void
function send (line 405) | void send(const char* in_ptr, int64_t n_bytes, int dst, int n_wires) {
function recv (line 473) | void recv(char* out_ptr, int64_t n_bytes, int src, int n_wires) {
function recv_from (line 551) | void recv_from(int sz, int buff, int left_right, int wire) {
function post_recv_all (line 605) | void post_recv_all(int sz, int buff) {
function post_send_all (line 618) | void post_send_all(int sz, int buff) {
FILE: mlx/distributed/jaccl/utils.cpp
type mlx::core::distributed::jaccl (line 34) | namespace mlx::core::distributed::jaccl {
function IBVWrapper (line 69) | IBVWrapper& ibv() {
function Destination (line 177) | const Destination& Connection::info() {
function create_connections (line 257) | std::vector<Connection> create_connections(
FILE: mlx/distributed/jaccl/utils.h
function namespace (line 48) | namespace mlx::core::distributed::jaccl {
type Destination (line 90) | struct Destination {
function class (line 100) | class SharedBuffer {
function post_recv (line 198) | void post_recv(const SharedBuffer& buff, uint64_t work_request_id) {
function poll (line 216) | int poll(int num_completions, ibv_wc* work_completions) {
function poll (line 224) | inline int poll(
function poll (line 247) | inline int poll(
function class (line 268) | class SideChannel {
FILE: mlx/distributed/mpi/mpi.cpp
type mlx::core::distributed::mpi (line 35) | namespace mlx::core::distributed::mpi {
function simple_sum (line 42) | void simple_sum(
function simple_max (line 61) | void simple_max(
function simple_min (line 81) | void simple_min(
type MPIWrapper (line 100) | struct MPIWrapper {
method MPIWrapper (line 101) | MPIWrapper() {
method is_available (line 161) | bool is_available() {
method init_safe (line 165) | bool init_safe() {
method finalize_safe (line 195) | void finalize_safe() {
method MPI_Comm (line 201) | MPI_Comm world() {
method MPI_Datatype (line 205) | MPI_Datatype datatype(const array& arr) {
method MPI_Op (line 240) | MPI_Op op_sum(const array& arr) {
method MPI_Op (line 251) | MPI_Op op_max(const array& arr) {
method MPI_Op (line 264) | MPI_Op op_min(const array& arr) {
function MPIWrapper (line 339) | MPIWrapper& mpi() {
method MPIWrapper (line 101) | MPIWrapper() {
method is_available (line 161) | bool is_available() {
method init_safe (line 165) | bool init_safe() {
method finalize_safe (line 195) | void finalize_safe() {
method MPI_Comm (line 201) | MPI_Comm world() {
method MPI_Datatype (line 205) | MPI_Datatype datatype(const array& arr) {
method MPI_Op (line 240) | MPI_Op op_sum(const array& arr) {
method MPI_Op (line 251) | MPI_Op op_max(const array& arr) {
method MPI_Op (line 264) | MPI_Op op_min(const array& arr) {
class MPIGroup (line 346) | class MPIGroup : public GroupImpl {
method MPIGroup (line 348) | MPIGroup(MPI_Comm comm, bool global)
method Stream (line 359) | Stream communication_stream(StreamOrDevice s) override {
method rank (line 363) | int rank() override {
method size (line 370) | int size() override {
method split (line 377) | std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
method all_sum (line 389) | void all_sum(const array& input, array& output, Stream stream) overr...
method all_max (line 404) | void all_max(const array& input, array& output, Stream stream) overr...
method all_min (line 419) | void all_min(const array& input, array& output, Stream stream) overr...
method all_gather (line 434) | void all_gather(const array& input, array& output, Stream stream) ov...
method send (line 449) | void send(const array& input, int dst, Stream stream) override {
method recv (line 462) | void recv(array& out, int src, Stream stream) override {
method sum_scatter (line 475) | void sum_scatter(const array& input, array& output, Stream stream) o...
function is_available (line 486) | bool is_available() {
function init (line 490) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/mpi/mpi.h
function namespace (line 5) | namespace mlx::core::distributed::mpi {
FILE: mlx/distributed/mpi/mpi_declarations.h
type MPI_Status (line 22) | typedef struct ompi_status_public_t {
FILE: mlx/distributed/mpi/no_mpi.cpp
type mlx::core::distributed::mpi (line 5) | namespace mlx::core::distributed::mpi {
function is_available (line 9) | bool is_available() {
function init (line 13) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/nccl/nccl.cpp
type mlx::core::distributed::nccl (line 27) | namespace mlx::core::distributed::nccl {
type nccl_map (line 73) | struct nccl_map {
type detail (line 87) | namespace detail {
function dispatch_dtype (line 90) | void dispatch_dtype(const array& arr, F&& f) {
function sendAll (line 102) | inline void sendAll(int sock, const void* buf, size_t len) {
function recvAll (line 115) | inline void recvAll(int sock, void* buf, size_t len) {
function bootstrap_unique_id (line 130) | inline void bootstrap_unique_id(
function bootstrap_unique_id (line 258) | inline void bootstrap_unique_id(
function get_env_var_or_throw (line 439) | std::string get_env_var_or_throw(const char* env_var_name, bool stri...
type NCCLComm (line 271) | struct NCCLComm {
method NCCLComm (line 276) | NCCLComm(ncclComm_t c, int rank, int size)
method create (line 279) | static std::shared_ptr<NCCLComm>
method split (line 286) | static std::shared_ptr<NCCLComm> split(NCCLComm* source, int color, ...
method NCCLComm (line 297) | NCCLComm(const NCCLComm&) = delete;
method NCCLComm (line 298) | NCCLComm& operator=(const NCCLComm&) = delete;
class NCCLGroup (line 302) | class NCCLGroup : public GroupImpl {
method NCCLGroup (line 304) | NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
method NCCLGroup (line 316) | NCCLGroup(std::shared_ptr<NCCLComm> comm, int rank, int size)
method Stream (line 319) | Stream communication_stream(StreamOrDevice s) override {
method rank (line 323) | int rank() override {
method size (line 327) | int size() override {
method all_sum (line 331) | void all_sum(const array& input, array& output, Stream stream) overr...
method split (line 338) | std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
method all_gather (line 345) | void all_gather(const array& input, array& output, Stream stream) ov...
method send (line 359) | void send(const array& input, int dst, Stream stream) override {
method recv (line 363) | void recv(array& output, int src, Stream stream) override {
method all_max (line 367) | void all_max(const array& input, array& output, Stream stream) overr...
method all_min (line 374) | void all_min(const array& input, array& output, Stream stream) overr...
method sum_scatter (line 381) | void sum_scatter(const array& input, array& output, Stream stream) o...
method all_reduce_impl (line 389) | void all_reduce_impl(
method reduce_scatter_impl (line 408) | void reduce_scatter_impl(
function is_available (line 434) | bool is_available() {
type detail (line 438) | namespace detail {
function dispatch_dtype (line 90) | void dispatch_dtype(const array& arr, F&& f) {
function sendAll (line 102) | inline void sendAll(int sock, const void* buf, size_t len) {
function recvAll (line 115) | inline void recvAll(int sock, void* buf, size_t len) {
function bootstrap_unique_id (line 130) | inline void bootstrap_unique_id(
function bootstrap_unique_id (line 258) | inline void bootstrap_unique_id(
function get_env_var_or_throw (line 439) | std::string get_env_var_or_throw(const char* env_var_name, bool stri...
function init (line 455) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/nccl/nccl.h
function namespace (line 5) | namespace mlx::core::distributed::nccl {
FILE: mlx/distributed/nccl/no_nccl.cpp
type mlx::core::distributed::nccl (line 5) | namespace mlx::core::distributed::nccl {
function is_available (line 9) | bool is_available() {
function init (line 13) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/ops.cpp
type mlx::core::distributed (line 11) | namespace mlx::core::distributed {
function Group (line 15) | Group to_group(std::optional<Group> group) {
function array (line 25) | array all_sum(
function array (line 43) | array all_max(
function array (line 61) | array all_min(
function array (line 79) | array all_gather(
function array (line 103) | array send(
function array (line 126) | array recv(
function array (line 152) | array recv_like(
function array (line 160) | array sum_scatter(
FILE: mlx/distributed/ops.h
function namespace (line 11) | namespace mlx::core::distributed {
FILE: mlx/distributed/primitives.cpp
type mlx::core::distributed (line 10) | namespace mlx::core::distributed {
FILE: mlx/distributed/primitives.h
function namespace (line 9) | namespace mlx::core::distributed {
function class (line 24) | class AllReduce : public DistPrimitive {
function class (line 70) | class AllGather : public DistPrimitive {
function class (line 95) | class Send : public DistPrimitive {
function class (line 114) | class Recv : public DistPrimitive {
function class (line 130) | class ReduceScatter : public DistPrimitive {
FILE: mlx/distributed/reduction_ops.h
function namespace (line 3) | namespace mlx::core::distributed::detail {
FILE: mlx/distributed/ring/no_ring.cpp
type mlx::core::distributed::ring (line 5) | namespace mlx::core::distributed::ring {
function is_available (line 9) | bool is_available() {
function init (line 13) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/ring/ring.cpp
type mlx::core::distributed::ring (line 90) | namespace mlx::core::distributed::ring {
function log (line 105) | void log(std::ostream& os, T first) {
function log (line 110) | void log(std::ostream& os, T first, Args... args) {
function log_info (line 115) | void log_info(bool verbose, Args... args) {
function ceildiv (line 124) | decltype(T() * U()) ceildiv(T a, U b) {
class SocketThread (line 128) | class SocketThread {
method SocketThread (line 130) | SocketThread(int fd) : fd_(fd), stop_(false) {
method send (line 144) | std::future<void> send(const T* buffer, size_t size) {
method recv (line 149) | std::future<void> recv(T* buffer, size_t size) {
type SocketTask (line 154) | struct SocketTask {
method SocketTask (line 155) | SocketTask(void* b, size_t s, std::promise<void>&& p)
method SocketTask (line 157) | SocketTask(SocketTask&& t)
method send_impl (line 164) | std::future<void> send_impl(const char* buffer, size_t size) {
method recv_impl (line 181) | std::future<void> recv_impl(char* buffer, size_t size) {
method have_tasks (line 198) | bool have_tasks() {
method worker (line 202) | void worker() {
class CommunicationThreads (line 277) | class CommunicationThreads {
method add (line 279) | void add(const std::vector<int>& sockets) {
method send (line 286) | std::future<void> send(int socket, T* buffer, size_t size) {
method recv (line 291) | std::future<void> recv(int socket, T* buffer, size_t size) {
function load_nodes (line 311) | std::vector<std::vector<detail::address_t>> load_nodes(const char* hos...
function accept_connections (line 331) | std::vector<int> accept_connections(
function make_connections (line 349) | std::vector<int> make_connections(
class RingGroup (line 381) | class RingGroup : public GroupImpl {
method RingGroup (line 383) | RingGroup(
method Stream (line 465) | Stream communication_stream(StreamOrDevice s) override {
method rank (line 469) | int rank() override {
method size (line 473) | int size() override {
method all_sum (line 477) | void all_sum(const array& input, array& output, Stream stream) overr...
method all_max (line 482) | void all_max(const array& input, array& output, Stream stream) overr...
method all_min (line 487) | void all_min(const array& input, array& output, Stream stream) overr...
method split (line 492) | std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
method all_gather (line 496) | void all_gather(const array& input, array& output, Stream stream) ov...
method send (line 533) | void send(const array& input, int dst, Stream stream) override {
method recv (line 554) | void recv(array& out, int src, Stream stream) override {
method sum_scatter (line 578) | void sum_scatter(const array& input, array& output, Stream stream) o...
method all_reduce (line 584) | void all_reduce(
method all_reduce_impl (line 659) | void all_reduce_impl(
method all_gather_impl (line 756) | void all_gather_impl(
method send (line 792) | void
method recv (line 811) | void recv(const std::vector<int>& sockets, char* data, size_t data_s...
function is_available (line 843) | bool is_available() {
function init (line 847) | std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
FILE: mlx/distributed/ring/ring.h
function namespace (line 5) | namespace mlx::core::distributed::ring {
FILE: mlx/distributed/utils.cpp
type mlx::core::distributed::detail (line 11) | namespace mlx::core::distributed::detail {
function address_t (line 16) | address_t parse_address(const std::string& ip, const std::string& port) {
function address_t (line 40) | address_t parse_address(const std::string& ip_port) {
function TCPSocket (line 67) | TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
function TCPSocket (line 126) | TCPSocket TCPSocket::accept(const char* tag) {
function TCPSocket (line 163) | TCPSocket TCPSocket::connect(
FILE: mlx/distributed/utils.h
function namespace (line 9) | namespace mlx::core::distributed::detail {
FILE: mlx/dtype.cpp
type mlx::core (line 7) | namespace mlx::core {
function Dtype (line 86) | Dtype promote_types(const Dtype& t1, const Dtype& t2) {
function kindof (line 91) | Dtype::Kind kindof(const Dtype& t) {
class MLX_API (line 95) | class MLX_API
class MLX_API (line 96) | class MLX_API
class MLX_API (line 97) | class MLX_API
class MLX_API (line 98) | class MLX_API
class MLX_API (line 99) | class MLX_API
class MLX_API (line 100) | class MLX_API
class MLX_API (line 101) | class MLX_API
class MLX_API (line 102) | class MLX_API
class MLX_API (line 103) | class MLX_API
class MLX_API (line 104) | class MLX_API
class MLX_API (line 105) | class MLX_API
class MLX_API (line 106) | class MLX_API
class MLX_API (line 107) | class MLX_API
class MLX_API (line 108) | class MLX_API
function issubdtype (line 180) | bool issubdtype(const Dtype& a, const Dtype& b) {
function issubdtype (line 184) | bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
function issubdtype (line 188) | bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
function issubdtype (line 192) | bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
FILE: mlx/dtype.h
type Dtype (line 14) | struct Dtype {
function Dtype (line 85) | inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64...
FILE: mlx/dtype_utils.cpp
type mlx::core (line 5) | namespace mlx::core {
FILE: mlx/dtype_utils.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/einsum.cpp
type mlx::core (line 10) | namespace mlx::core {
type Subscript (line 24) | struct Subscript {
method Subscript (line 25) | Subscript(std::string str, CharSet set)
type PathInfo (line 31) | struct PathInfo {
type PathNode (line 39) | struct PathNode {
method PathNode (line 40) | PathNode(
function parse (line 60) | std::pair<std::vector<std::string>, std::string> parse(std::string sub...
function disjoint (line 107) | bool disjoint(const CharSet& x, const CharSet& y) {
function term_size (line 117) | size_t term_size(const T& term, std::unordered_map<char, ShapeElem> di...
function flop_count (line 125) | size_t flop_count(
function compute_cost_and_scaling (line 141) | std::pair<size_t, int> compute_cost_and_scaling(
function greedy_path (line 161) | std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
function can_dot (line 335) | bool can_dot(const std::vector<Subscript>& inputs, const Subscript& ou...
function array (line 349) | array batch_tensordot(
function array (line 424) | array collapse_repeats(array in, Subscript& subscript, StreamOrDevice ...
function preprocess_einsum_inputs (line 490) | void preprocess_einsum_inputs(
function array (line 537) | array einsum_naive(
function einsum_path_helper (line 630) | std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
function einsum_path (line 829) | std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
function array (line 851) | array einsum(
FILE: mlx/einsum.h
function namespace (line 12) | namespace mlx::core {
FILE: mlx/event.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/export.cpp
function is_big_endian (line 23) | bool is_big_endian() {
type mlx::core (line 28) | namespace mlx::core {
type PrimitiveSerializer (line 35) | struct PrimitiveSerializer {
method PrimitiveSerializer (line 41) | PrimitiveSerializer(
type NotSerializable (line 90) | struct NotSerializable {
type NotDeserializable (line 95) | struct NotDeserializable {
function reverse_bytes (line 100) | void reverse_bytes(T& data) {
function serialize (line 114) | void serialize(Writer& os, T v) {
function T (line 146) | T deserialize(Reader& is) {
type VariantType (line 183) | enum class VariantType { Int = 0, Float = 1, Bool = 2 }
function serialize_variant (line 186) | void serialize_variant(Writer& os, T v) {
function T (line 206) | T deserialize_variant(Reader& is) {
function deserialize_tuple (line 222) | decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
function serialize (line 226) | void serialize(Writer& os, const Stream& s) {
function Stream (line 232) | Stream deserialize(Reader& is) {
function serialize (line 239) | void serialize(Writer& os, const Dtype& t) {
function Dtype (line 245) | Dtype deserialize(Reader& is) {
function serialize (line 251) | void serialize(Writer& os, const array& arr) {
function array (line 256) | array deserialize(Reader& is) {
function serialize_primitive (line 270) | void serialize_primitive(Writer& os, const Primitive& p) {
function extract_state (line 277) | void extract_state(const T state, std::vector<StateT>& unpacked_state) {
function primitive_state (line 296) | std::vector<StateT> primitive_state(const Primitive& p) {
function deserialize_primitive (line 305) | std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
type PrimitiveFactory (line 321) | struct PrimitiveFactory {
method PrimitiveFactory (line 453) | PrimitiveFactory() {
method save (line 461) | void save(Writer& os, const std::shared_ptr<Primitive>& p) {
method Stream (line 477) | Stream resolve_stream(const Stream& stream) {
method load (line 494) | std::shared_ptr<Primitive> load(Reader& is) {
method extract_state (line 505) | std::pair<std::string, std::vector<StateT>> extract_state(
function write_header (line 523) | void write_header(Writer& os, int count, bool shapeless) {
type FunctionTable (line 530) | struct FunctionTable {
method FunctionTable (line 531) | FunctionTable(bool shapeless = false) : shapeless(shapeless) {}
type Function (line 532) | struct Function {
method Function (line 533) | Function(
method Function (line 547) | Function(const Function&) = delete;
method Function (line 548) | Function& operator=(const Function&) = delete;
method Function (line 549) | Function(Function&&) = default;
method Function (line 550) | Function() = default;
method insert (line 558) | void insert(
method print_functions (line 571) | void print_functions(std::ostream& os) {
function FunctionExporter (line 891) | FunctionExporter exporter(
function FunctionExporter (line 901) | FunctionExporter exporter(
function FunctionExporter (line 911) | FunctionExporter exporter(
function export_function (line 918) | void export_function(
function export_function (line 926) | void export_function(
function export_function (line 934) | void export_function(
function FunctionExporter (line 943) | FunctionExporter exporter(
function FunctionExporter (line 953) | FunctionExporter exporter(
function FunctionExporter (line 963) | FunctionExporter exporter(
function export_function (line 970) | void export_function(
function export_function (line 978) | void export_function(
function export_function (line 986) | void export_function(
function ImportedFunction (line 1034) | ImportedFunction import_function(const std::string& file) {
FILE: mlx/export.h
function namespace (line 12) | namespace mlx::core {
FILE: mlx/export_impl.h
function namespace (line 8) | namespace mlx::core {
FILE: mlx/fast.cpp
type mlx::core::fast (line 11) | namespace mlx::core::fast {
function array (line 53) | array rms_norm(
function array (line 190) | array layer_norm(
function array (line 367) | array rope(
function array (line 531) | array rope(
function array (line 560) | array rope(
function array (line 613) | array scaled_dot_product_attention(
FILE: mlx/fast.h
function namespace (line 11) | namespace mlx::core::fast {
FILE: mlx/fast_primitives.h
function class (line 13) | class Custom : public Primitive {
function eval_cpu (line 49) | void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outp...
function DEFINE_INPUT_OUTPUT_SHAPE (line 62) | DEFINE_NAME(RMSNorm)
function DEFINE_NAME (line 86) | void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outp...
FILE: mlx/fence.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/fft.cpp
type mlx::core::fft (line 10) | namespace mlx::core::fft {
function array (line 12) | array fft_impl(
function array (line 101) | array fft_impl(
function array (line 117) | array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice...
function array (line 123) | array fftn(
function array (line 130) | array fftn(
function array (line 136) | array fftn(const array& a, StreamOrDevice s /* = {} */) {
function array (line 140) | array ifftn(
function array (line 147) | array ifftn(
function array (line 153) | array ifftn(const array& a, StreamOrDevice s /* = {} */) {
function array (line 157) | array rfftn(
function array (line 164) | array rfftn(
function array (line 170) | array rfftn(const array& a, StreamOrDevice s /* = {} */) {
function array (line 174) | array irfftn(
function array (line 181) | array irfftn(
function array (line 188) | array irfftn(const array& a, StreamOrDevice s /* = {} */) {
function array (line 192) | array fftshift(
function array (line 217) | array ifftshift(
function array (line 244) | array fftshift(const array& a, StreamOrDevice s /* = {} */) {
function array (line 253) | array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
FILE: mlx/fft.h
function namespace (line 12) | namespace mlx::core::fft {
FILE: mlx/graph_utils.cpp
type mlx::core (line 13) | namespace mlx::core {
function depth_first_traversal (line 37) | void depth_first_traversal(
function print_graph (line 62) | void print_graph(
function export_to_dot (line 105) | void export_to_dot(
FILE: mlx/graph_utils.h
function NodeNamer (line 12) | struct MLX_API NodeNamer {
function print_graph (line 24) | inline void print_graph(std::ostream& os, const std::vector<array>& outp...
function export_to_dot (line 48) | inline void export_to_dot(std::ostream& os, const std::vector<array>& ou...
FILE: mlx/io.h
function namespace (line 14) | namespace mlx::core {
FILE: mlx/io/gguf.cpp
type mlx::core (line 11) | namespace mlx::core {
function dtype_to_gguf_tensor_type (line 16) | std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {
function gguf_type_to_dtype (line 33) | std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
function Shape (line 50) | Shape get_shape(const gguf_tensor& tensor) {
function extract_tensor_data (line 59) | std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* ...
function set_mx_value_from_gguf (line 90) | void set_mx_value_from_gguf(
function load_metadata (line 203) | std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ...
function load_arrays (line 214) | std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
function GGUFLoad (line 241) | GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
function append_kv_array (line 261) | void append_kv_array(
function save_gguf (line 294) | void save_gguf(
FILE: mlx/io/gguf.h
function namespace (line 13) | namespace mlx::core {
FILE: mlx/io/gguf_quants.cpp
type mlx::core (line 9) | namespace mlx::core {
function unpack_32_4 (line 11) | void unpack_32_4(uint8_t* data, int8_t* dst) {
function extract_q4_0_data (line 32) | void extract_q4_0_data(
function extract_q4_1_data (line 53) | void extract_q4_1_data(
function extract_q8_0_data (line 75) | void extract_q8_0_data(
function gguf_load_quantized (line 100) | void gguf_load_quantized(
FILE: mlx/io/load.cpp
type mlx::core (line 23) | namespace mlx::core {
function is_big_endian (line 36) | inline bool is_big_endian() {
function dtype_to_array_protocol (line 47) | std::string dtype_to_array_protocol(const Dtype& t) {
function Dtype (line 59) | Dtype dtype_from_array_protocol(std::string_view t) {
function pread (line 120) | int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
function save (line 144) | void save(std::shared_ptr<io::Writer> out_stream, array a) {
function save (line 219) | void save(std::string file, array a) {
function array (line 229) | array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
function array (line 331) | array load(std::string file, StreamOrDevice s) {
type io (line 335) | namespace io {
function ThreadPool (line 337) | ThreadPool& thread_pool() {
function ThreadPool (line 342) | ThreadPool& ParallelFileReader::thread_pool() {
FILE: mlx/io/load.h
function class (line 32) | class Reader {
FILE: mlx/io/no_gguf.cpp
type mlx::core (line 5) | namespace mlx::core {
function GGUFLoad (line 7) | GGUFLoad load_gguf(const std::string&, StreamOrDevice s) {
function save_gguf (line 12) | void save_gguf(
FILE: mlx/io/no_safetensors.cpp
type mlx::core (line 5) | namespace mlx::core {
function SafetensorsLoad (line 7) | SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader>, StreamOr...
function SafetensorsLoad (line 13) | SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {
function save_safetensors (line 19) | void save_safetensors(
function save_safetensors (line 28) | void save_safetensors(
FILE: mlx/io/safetensors.cpp
type mlx::core (line 35) | namespace mlx::core {
function dtype_to_safetensor_str (line 37) | std::string dtype_to_safetensor_str(Dtype t) {
function Dtype (line 70) | Dtype dtype_from_safetensor_str(std::string_view str) {
function SafetensorsLoad (line 106) | SafetensorsLoad load_safetensors(
function SafetensorsLoad (line 162) | SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevi...
function save_safetensors (line 166) | void save_safetensors(
function save_safetensors (line 220) | void save_safetensors(
FILE: mlx/linalg.cpp
type mlx::core::linalg (line 11) | namespace mlx::core::linalg {
function check_cpu_stream (line 13) | void check_cpu_stream(const StreamOrDevice& s, const std::string& pref...
function check_float (line 21) | void check_float(Dtype dtype, const std::string& prefix) {
function check_float_or_complex (line 30) | void check_float_or_complex(Dtype dtype, const std::string& prefix) {
function Dtype (line 39) | Dtype at_least_float(const Dtype& d) {
function array (line 43) | inline array l2_norm(
function array (line 55) | inline array vector_norm(
function array (line 80) | inline array matrix_norm(
function array (line 136) | inline array matrix_norm(
function array (line 165) | array norm(
function array (line 181) | array norm(
function array (line 204) | array norm(
function qr (line 226) | std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
function svd (line 250) | std::vector<array>
function array (line 296) | array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
function array (line 319) | array inv(const array& a, StreamOrDevice s /* = {} */) {
function array (line 323) | array tri_inv(
function array (line 330) | array cholesky(
function array (line 356) | array pinv(const array& a, StreamOrDevice s /* = {} */) {
function array (line 401) | array cholesky_inv(
function array (line 430) | array cross(
function validate_eig (line 502) | void validate_eig(
function array (line 521) | array eigvalsh(
function eigh (line 535) | std::pair<array, array> eigh(
function array (line 549) | array eigvals(const array& a, StreamOrDevice s /* = {} */) {
function eig (line 559) | std::pair<array, array> eig(const array& a, StreamOrDevice s /* = {} *...
function validate_lu (line 569) | void validate_lu(
function lu_helper (line 586) | std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} ...
function lu (line 602) | std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
function lu_factor (line 629) | std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* ...
function validate_solve (line 635) | void validate_solve(
function array (line 682) | array solve(const array& a, const array& b, StreamOrDevice s /* = {} *...
function array (line 698) | array solve_triangular(
FILE: mlx/linalg.h
function namespace (line 13) | namespace mlx::core::linalg {
FILE: mlx/memory.h
function namespace (line 9) | namespace mlx::core {
FILE: mlx/ops.cpp
type mlx::core (line 21) | namespace mlx::core {
function compute_reduce_shape (line 25) | std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
function Dtype (line 57) | Dtype at_least_float(const Dtype& d) {
function array (line 61) | array indices_or_default(
function validate_quantized_input (line 75) | void validate_quantized_input(
function extract_quantized_matmul_dims (line 117) | std::pair<int, int> extract_quantized_matmul_dims(
function array (line 150) | array arange(
function array (line 189) | array arange(
function array (line 196) | array arange(
function array (line 203) | array arange(double start, double stop, StreamOrDevice s /* = {} */) {
function array (line 206) | array arange(double stop, Dtype dtype, StreamOrDevice s /* = {} */) {
function array (line 209) | array arange(double stop, StreamOrDevice s /* = {} */) {
function array (line 212) | array arange(int start, int stop, int step, StreamOrDevice s /* = {} *...
function array (line 220) | array arange(int start, int stop, StreamOrDevice s /* = {} */) {
function array (line 228) | array arange(int stop, StreamOrDevice s /* = {} */) {
function array (line 232) | array linspace(
function array (line 258) | array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
function array (line 270) | array as_strided(
function array (line 287) | array copy(array a, StreamOrDevice s /* = {} */) {
function array (line 297) | array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
function array (line 305) | array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* =...
function array (line 312) | array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
function array (line 317) | array full_like(
function array (line 326) | array full_like(const array& a, array vals, StreamOrDevice s /* = {} *...
function array (line 330) | array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} ...
function array (line 334) | array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
function array (line 338) | array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} *...
function array (line 342) | array ones_like(const array& a, StreamOrDevice s /* = {} */) {
function array (line 346) | array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} *...
function array (line 366) | array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
function array (line 370) | array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
function array (line 376) | array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
function array (line 384) | array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
function array (line 392) | array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
function array (line 404) | array unflatten(
function array (line 457) | array flatten(
function array (line 496) | array flatten(const array& a, StreamOrDevice s /* = {} */) {
function array (line 500) | array hadamard_transform(
function array (line 529) | array squeeze_impl(
function array (line 557) | array squeeze(
function array (line 575) | array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
function array (line 579) | array squeeze(const array& a, StreamOrDevice s /* = {} */) {
function array (line 589) | array expand_dims_impl(
function array (line 612) | array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} *...
function array (line 616) | array expand_dims(
function normalize_slice (line 645) | inline auto
function normalize_dynamic_slice_inputs (line 701) | void normalize_dynamic_slice_inputs(
function array (line 751) | array slice(
function array (line 780) | array slice(
function array (line 789) | array slice(
function array (line 822) | array slice_update(
function array (line 862) | array slice_update(
function array (line 873) | array slice_update(
function array (line 902) | array slice_update(
function array (line 950) | array slice_update_add(
function array (line 967) | array slice_update_add(
function array (line 977) | array slice_update_prod(
function array (line 994) | array slice_update_prod(
function array (line 1004) | array slice_update_max(
function array (line 1021) | array slice_update_max(
function array (line 1031) | array slice_update_min(
function array (line 1048) | array slice_update_min(
function split (line 1058) | std::vector<array> split(
function split (line 1104) | std::vector<array>
function split (line 1109) | std::vector<array>
function split (line 1140) | std::vector<array>
function meshgrid (line 1145) | std::vector<array> meshgrid(
function array (line 1180) | array clip(
function array (line 1198) | array concatenate(
function array (line 1259) | array concatenate(std::vector<array> arrays, StreamOrDevice s /* = {} ...
function array (line 1267) | array stack(
function array (line 1289) | array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {}...
function array (line 1294) | array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
function array (line 1324) | array repeat(const array& arr, int repeats, StreamOrDevice s) {
function array (line 1328) | array tile(
function array (line 1358) | array edge_pad(
function array (line 1404) | array pad(
function array (line 1458) | array pad(
function array (line 1478) | array pad(
function array (line 1492) | array pad(
function array (line 1506) | array moveaxis(
function array (line 1533) | array swapaxes(
function array (line 1556) | array transpose(
function array (line 1595) | array transpose(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1601) | array broadcast_to(
function broadcast_arrays (line 1630) | std::vector<array> broadcast_arrays(
function broadcast_arrays (line 1700) | std::vector<array> broadcast_arrays(
function broadcast_arrays (line 1751) | std::pair<array, array>
function broadcast_arrays (line 1757) | std::pair<array, array> broadcast_arrays(
function array (line 1766) | array equal(const array& a, const array& b, StreamOrDevice s /* = {} *...
function array (line 1774) | array not_equal(const array& a, const array& b, StreamOrDevice s /* = ...
function array (line 1785) | array greater(const array& a, const array& b, StreamOrDevice s /* = {}...
function array (line 1793) | array greater_equal(
function array (line 1807) | array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
function array (line 1815) | array less_equal(const array& a, const array& b, StreamOrDevice s /* =...
function array (line 1826) | array array_equal(
function array (line 1847) | array isnan(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1854) | array isinf(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1861) | array isfinite(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1868) | array isposinf(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1875) | array isneginf(const array& a, StreamOrDevice s /* = {} */) {
function array (line 1882) | array where(
function array (line 1899) | array nan_to_num(
function array (line 1933) | array allclose(
function array (line 1943) | array isclose(
function array (line 1984) | array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 1990) | array all(
function array (line 2010) | array all(
function array (line 2018) | array any(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2024) | array any(
function array (line 2044) | array any(
function array (line 2052) | array sum(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2058) | array sum(
function array (line 2089) | array sum(
function array (line 2097) | array mean(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2103) | array mean(
function array (line 2122) | array mean(
function array (line 2130) | array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2136) | array median(
function array (line 2203) | array median(
function array (line 2211) | array var(
function array (line 2221) | array var(
function array (line 2248) | array var(
function array (line 2257) | array std(
function array (line 2267) | array std(
function array (line 2276) | array std(
function array (line 2285) | array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2291) | array prod(
function array (line 2322) | array prod(
function array (line 2330) | array max(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2336) | array max(
function array (line 2359) | array max(
function array (line 2367) | array min(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
function array (line 2373) | array min(
function array (line 2399) | array min(
function array (line 2407) | array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} *...
function array (line 2419) | array argmin(
function array (line 2444) | array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} *...
function array (line 2456) | array argmax(
function array (line 2481) | array bartlett(int M, StreamOrDevice s /* = {} */) {
function array (line 2496) | array hanning(int M, StreamOrDevice s /* = {} */) {
function array (line 2509) | array hamming(int M, StreamOrDevice s /* = {} */) {
function array (line 2530) | array blackman(int M, StreamOrDevice s /* = {} */) {
function array (line 2558) | array sort(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2564) | array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
function array (line 2579) | array argsort(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2585) | array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
function array (line 2603) | array partition(const array& a, int kth, StreamOrDevice s /* = {} */) {
function array (line 2612) | array partition(
function array (line 2644) | array argpartition(const array& a, int kth, StreamOrDevice s /* = {} *...
function array (line 2653) | array argpartition(
function array (line 2682) | array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
function array (line 2688) | array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
function array (line 2716) | array logsumexp(const array& a, bool keepdims, StreamOrDevice s /* = {...
function array (line 2722) | array logsumexp(
function array (line 2770) | array logsumexp(
function array (line 2778) | array abs(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2787) | array negative(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2795) | array operator-(const array& a) {
function array (line 2799) | array sign(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2803) | array logical_not(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2811) | array logical_and(const array& a, const array& b, StreamOrDevice s /* ...
function array (line 2821) | array operator&&(const array& a, const array& b) {
function array (line 2825) | array logical_or(const array& a, const array& b, StreamOrDevice s /* =...
function array (line 2835) | array operator||(const array& a, const array& b) {
function array (line 2839) | array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2844) | array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
function array (line 2853) | array operator+(const array& a, const array& b) {
function array (line 2857) | array subtract(const array& a, const array& b, StreamOrDevice s /* = {...
function array (line 2869) | array operator-(const array& a, const array& b) {
function array (line 2873) | array multiply(const array& a, const array& b, StreamOrDevice s /* = {...
function array (line 2885) | array operator*(const array& a, const array& b) {
function array (line 2889) | array divide(const array& a, const array& b, StreamOrDevice s /* = {} ...
function array (line 2897) | array operator/(const array& a, const array& b) {
function array (line 2900) | array operator/(double a, const array& b) {
function array (line 2903) | array operator/(const array& a, double b) {
function array (line 2907) | array floor_divide(
function array (line 2922) | array remainder(const array& a, const array& b, StreamOrDevice s /* = ...
function array (line 2933) | array operator%(const array& a, const array& b) {
function divmod (line 2937) | std::vector<array>
function array (line 2952) | array maximum(const array& a, const array& b, StreamOrDevice s /* = {}...
function array (line 2964) | array minimum(const array& a, const array& b, StreamOrDevice s /* = {}...
function array (line 2976) | array floor(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2984) | array ceil(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2991) | array square(const array& a, StreamOrDevice s /* = {} */) {
function array (line 2996) | array exp(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3002) | array expm1(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3009) | array sin(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3015) | array cos(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3021) | array tan(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3027) | array arcsin(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3034) | array arccos(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3041) | array arctan(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3048) | array arctan2(const array& a, const array& b, StreamOrDevice s /* = {}...
function array (line 3056) | array sinh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3062) | array cosh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3068) | array tanh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3074) | array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3081) | array arccosh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3088) | array arctanh(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3095) | array degrees(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3100) | array radians(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3105) | array log(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3115) | array log2(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3125) | array log10(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3135) | array log1p(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3142) | array logaddexp(const array& a, const array& b, StreamOrDevice s /* = ...
function array (line 3155) | array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3162) | array erf(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3171) | array erfinv(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3180) | array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3185) | array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
function array (line 3200) | array matmul(
function array (line 3277) | array gather(
function array (line 3365) | array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
function array (line 3393) | array take(
function array (line 3438) | array take(const array& a, const array& indices, StreamOrDevice s /* =...
function array (line 3442) | array take(const array& a, int index, int axis, StreamOrDevice s /* = ...
function array (line 3468) | array take(const array& a, int index, StreamOrDevice s /* = {} */) {
function array (line 3472) | array take_along_axis(
function array (line 3506) | array scatter_axis(
function array (line 3559) | array put_along_axis(
function array (line 3568) | array scatter_add_axis(
function array (line 3578) | array scatter(
function array (line 3672) | array scatter(
function array (line 3681) | array scatter_add(
function array (line 3690) | array scatter_prod(
function array (line 3699) | array scatter_max(
function array (line 3708) | array scatter_min(
function array (line 3717) | array masked_scatter(
function array (line 3800) | array sqrt(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3809) | array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
function array (line 3818) | array softmax(
function array (line 3862) | array softmax(
function array (line 3871) | array power(const array& a, const array& b, StreamOrDevice s /* = {} *...
function array (line 3881) | array cumsum(
function array (line 3904) | array cumsum(
function array (line 3912) | array cumprod(
function array (line 3934) | array cumprod(
function array (line 3942) | array cummax(
function array (line 3964) | array cummax(
function array (line 3972) | array cummin(
function array (line 3994) | array cummin(
function array (line 4002) | array logcumsumexp(
function array (line 4024) | array logcumsumexp(
function run_conv_checks (line 4037) | inline void
function array (line 4099) | array conv1d(
function array (line 4120) | array conv2d(
function array (line 4142) | array conv3d(
function array (line 4166) | array conv_transpose_general(
function array (line 4204) | array conv_transpose1d(
function array (line 4218) | array conv_transpose2d(
function array (line 4239) | array conv_transpose3d(
function array (line 4262) | array conv_general(
function quantization_params_from_mode (line 4374) | std::pair<int, int> quantization_params_from_mode(
function validate_mode_with_type (line 4403) | std::pair<Dtype, QuantizationMode> validate_mode_with_type(
function validate_global_scale (line 4455) | void validate_global_scale(
function array (line 4483) | array quantized_matmul(
function validate_qqmm_inputs (line 4535) | void validate_qqmm_inputs(
function extract_qqmm_dims (line 4595) | std::pair<int, int> extract_qqmm_dims(
function array (line 4625) | array qqmm(
function array (line 4694) | array pack_and_quantize(
function affine_quantize (line 4741) | std::vector<array>
function fp_quantize (line 4807) | std::vector<array> fp_quantize(
function quantize (line 4914) | std::vector<array> quantize(
function array (line 4960) | array affine_dequantize(
function array (line 5054) | array fp_dequantize(
function array (line 5177) | array dequantize(
function array (line 5238) | array from_fp8(array x, Dtype dtype, StreamOrDevice s) {
function array (line 5258) | array to_fp8(array x, StreamOrDevice s) {
function array (line 5272) | array gather_qmm(
function array (line 5369) | array tensordot(
function array (line 5392) | array tensordot(
function array (line 5454) | array outer(const array& a, const array& b, StreamOrDevice s /* = {} *...
function array (line 5459) | array inner(const array& a, const array& b, StreamOrDevice s /* = {} *...
function array (line 5472) | array addmm(
function array (line 5614) | array block_masked_mm(
function array (line 5791) | array gather_mm(
function array (line 5895) | array segmented_mm(
function array (line 5943) | array diagonal(
function array (line 5996) | array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} *...
function array (line 6019) | array trace(
function array (line 6061) | array trace(
function array (line 6070) | array trace(const array& a, StreamOrDevice s /* = {} */) {
function depends (line 6075) | std::vector<array> depends(
function array (line 6100) | array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
function atleast_1d (line 6107) | std::vector<array> atleast_1d(
function array (line 6118) | array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
function atleast_2d (line 6129) | std::vector<array> atleast_2d(
function array (line 6140) | array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
function atleast_3d (line 6153) | std::vector<array> atleast_3d(
function array (line 6164) | array number_of_elements(
function array (line 6196) | array conjugate(const array& a, StreamOrDevice s /* = {} */) {
function array (line 6205) | array bitwise_impl(
function array (line 6231) | array bitwise_and(const array& a, const array& b, StreamOrDevice s /* ...
function array (line 6234) | array operator&(const array& a, const array& b) {
function array (line 6238) | array bitwise_or(const array& a, const array& b, StreamOrDevice s /* =...
function array (line 6241) | array operator|(const array& a, const array& b) {
function array (line 6245) | array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* ...
function array (line 6248) | array operator^(const array& a, const array& b) {
function array (line 6252) | array left_shift(const array& a, const array& b, StreamOrDevice s /* =...
function array (line 6259) | array operator<<(const array& a, const array& b) {
function array (line 6263) | array right_shift(const array& a, const array& b, StreamOrDevice s /* ...
function array (line 6276) | array operator>>(const array& a, const array& b) {
function array (line 6280) | array bitwise_invert(const array& a, StreamOrDevice s /* = {} */) {
function array (line 6291) | array operator~(const array& a) {
function array (line 6295) | array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {...
function array (line 6323) | array roll(
function array (line 6367) | array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
function array (line 6375) | array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {...
function array (line 6383) | array roll(const array& a, int shift, int axis, StreamOrDevice s /* = ...
function array (line 6387) | array roll(
function array (line 6396) | array roll(
function array (line 6408) | array real(const array& a, StreamOrDevice s /* = {} */) {
function array (line 6415) | array imag(const array& a, StreamOrDevice s /* = {} */) {
function array (line 6422) | array contiguous(
FILE: mlx/ops.h
function namespace (line 13) | namespace mlx::core {
FILE: mlx/primitives.cpp
type mlx::core (line 20) | namespace mlx::core {
function vmap_binary_op (line 24) | std::tuple<array, array, int> vmap_binary_op(
function vmap_ternary_op (line 60) | std::tuple<array, array, array, int> vmap_ternary_op(
function array (line 121) | array gather_mm_grad(
function broadcast_vjp (line 805) | std::vector<array>
function Shape (line 859) | Shape Broadcast::output_shape(const std::vector<array>& inputs) {
function Shape (line 908) | Shape BroadcastAxes::output_shape(
function array (line 1175) | array conv_weight_backward_patches(
function conv_out_axis_size (line 1248) | inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int ...
function dilate_size (line 1253) | inline int dilate_size(int dim, int dil) {
function Shape (line 1259) | Shape Convolution::conv_out_shape(
function Shape (line 2021) | Shape ExpandDims::output_shape(
function Shape (line 2080) | Shape Flatten::output_shape(const array& input, int start_axis, int en...
function Shape (line 2134) | Shape Unflatten::output_shape(
function quantization_mode_to_string (line 3330) | std::string quantization_mode_to_string(QuantizationMode mode) {
function QuantizationMode (line 3344) | QuantizationMode string_to_quantization_mode(
function Shape (line 3810) | Shape Reshape::output_shape(const array& input, Shape shape) {
function Shape (line 5387) | Shape Squeeze::output_shape(const array& input, const std::vector<int>...
FILE: mlx/primitives.h
function class (line 49) | class MLX_API Primitive {
function eval_cpu (line 137) | inline void eval_cpu(
function eval_gpu (line 142) | inline void eval_gpu(
type class (line 155) | enum class
function DEFINE_VMAP (line 167) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function DEFINE_VMAP (line 308) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 380) | class ArgSort : public UnaryPrimitive {
type Op (line 450) | enum Op { And, Or, Xor, LeftShift, RightShift }
function explicit (line 452) | explicit BitwiseBinary(Stream stream, Op op)
function DEFINE_INPUT_OUTPUT_SHAPE (line 477) | bool is_equivalent(const Primitive& other) const override;
function DEFINE_NAME (line 559) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function DEFINE_VMAP (line 618) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 671) | class Concatenate : public UnaryPrimitive {
function class (line 692) | class Conjugate : public UnaryPrimitive {
function DEFINE_VMAP (line 823) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 884) | class Depends : public Primitive {
function DEFINE_VMAP (line 910) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 937) | class Select : public UnaryPrimitive {
function DEFINE_VMAP (line 1026) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 1100) | class Flatten : public UnaryPrimitive {
function DEFINE_VMAP (line 1130) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 1177) | class GatherAxis : public UnaryPrimitive {
function DEFINE_VMAP (line 1281) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 1313) | class Log : public UnaryPrimitive {
function DEFINE_VMAP (line 1408) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 1957) | class MaskedScatter : public UnaryPrimitive {
function DEFINE_VMAP (line 1976) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function state (line 2159) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 2221) | class Square : public UnaryPrimitive {
function DEFINE_VMAP (line 2284) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function DEFINE_VMAP (line 2337) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function DEFINE_VMAP (line 2377) | void eval_gpu(const std::vector<array>& inputs, array& out) override;
function class (line 2415) | class QRF : public Primitive {
FILE: mlx/random.cpp
type mlx::core::random (line 12) | namespace mlx::core::random {
function array (line 20) | array KeySequence::next() {
function seed (line 26) | void seed(uint64_t seed) {
function array (line 30) | array key(uint64_t seed) {
function array (line 36) | array bits(
function split (line 75) | std::pair<array, array> split(const array& key, StreamOrDevice s /* = ...
function array (line 81) | array split(const array& key, int num, StreamOrDevice s /* = {} */) {
function T (line 88) | T below_one() {
function array (line 95) | array uniform(
function array (line 143) | array uniform(
function array (line 152) | inline array complex_normal(
function array (line 174) | array normal(
function array (line 206) | array multivariate_normal(
function array (line 271) | array randint(
function array (line 286) | array bernoulli(
function array (line 311) | array bernoulli(
function array (line 318) | array bernoulli(
function array (line 324) | array truncated_normal(
function array (line 351) | array truncated_normal(
function array (line 361) | array gumbel(
function get_valid_axis (line 371) | int get_valid_axis(int axis, int ndim) {
function array (line 382) | array categorical_impl(
function array (line 395) | array categorical(
function array (line 418) | array categorical(
function array (line 432) | array categorical(
function array (line 443) | array laplace(
function array (line 477) | array permutation(
function array (line 485) | array permutation(
FILE: mlx/random.h
function namespace (line 13) | namespace mlx::core::random {
FILE: mlx/scheduler.cpp
type mlx::core (line 7) | namespace mlx::core {
function Stream (line 9) | Stream default_stream(Device d) {
function set_default_stream (line 17) | void set_default_stream(Stream s) {
function Stream (line 25) | Stream get_stream(int index) {
function get_streams (line 29) | std::vector<Stream> get_streams() {
function Stream (line 33) | Stream new_stream(Device d) {
function Stream (line 41) | Stream new_stream() {
function synchronize (line 45) | void synchronize(Stream s) {
function synchronize (line 56) | void synchronize() {
type scheduler (line 60) | namespace scheduler {
function Scheduler (line 63) | Scheduler& scheduler() {
FILE: mlx/scheduler.h
function namespace (line 16) | namespace mlx::core::scheduler {
function thread_fn (line 36) | void thread_fn() {
function class (line 67) | class Scheduler {
FILE: mlx/small_vector.h
function namespace (line 37) | namespace mlx::core {
function allocator_ (line 155) | allocator_(allocator) {
function T (line 233) | T* data() {
function T (line 236) | const T* data() const {
function iterator (line 240) | iterator begin() {
function iterator (line 247) | iterator end() {
function const (line 262) | auto rbegin() {
function const (line 269) | auto rend() {
function T (line 290) | const T& front() const {
function T (line 299) | const T& back() const {
function T (line 310) | const T& at(size_t index) const {
function T (line 318) | const T& operator[](size_t index) const {
function push_back (line 332) | void push_back(T x) {
function iterator (line 342) | iterator insert(iterator pos, T value) {
function iterator (line 346) | iterator insert(iterator pos, size_t count, T value) {
function reserve (line 431) | void reserve(size_t new_capacity) {
function clear (line 438) | void clear() {
function MLX_NOINLINE (line 473) | MLX_NOINLINE void free_storage() {
function reset_to_inline_storage (line 482) | void reset_to_inline_storage() {
function T (line 496) | T* inline_storage_begin() {
function T (line 499) | const T* inline_storage_begin() const {
type is_vector (line 526) | struct is_vector
FILE: mlx/stream.h
function namespace (line 10) | namespace mlx::core {
FILE: mlx/threadpool.h
function class (line 35) | class ThreadPool {
function ThreadPool (line 55) | inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
function task (line 64) | auto task = std::make_shared<std::packaged_task<return_type()>>(
function resize (line 82) | inline void ThreadPool::resize(size_t threads) {
function ThreadPool (line 93) | inline ThreadPool::~ThreadPool() {
function stop_and_wait (line 97) | inline void ThreadPool::stop_and_wait() {
FILE: mlx/transforms.cpp
type mlx::core (line 23) | namespace mlx::core {
class Synchronizer (line 29) | class Synchronizer : public Primitive {
method Synchronizer (line 31) | explicit Synchronizer(Stream stream) : Primitive(stream) {}
method eval_cpu (line 33) | void eval_cpu(const std::vector<array>&, std::vector<array>&) overri...
method eval_gpu (line 34) | void eval_gpu(const std::vector<array>&, std::vector<array>&) overri...
function array (line 52) | array eval_impl(std::vector<array> outputs, bool async) {
function async_eval (line 296) | void async_eval(std::vector<array> outputs) {
function eval (line 310) | void eval(std::vector<array> outputs) {
function vjp (line 327) | std::pair<std::vector<array>, std::vector<array>> vjp(
function vjp (line 506) | std::pair<std::vector<array>, std::vector<array>> vjp(
function vjp (line 515) | std::pair<array, array> vjp(
function jvp (line 526) | std::pair<std::vector<array>, std::vector<array>> jvp(
function jvp (line 641) | std::pair<array, array> jvp(
function ValueAndGradFn (line 652) | ValueAndGradFn value_and_grad(
type detail (line 692) | namespace detail {
function vmap_trace (line 694) | std::pair<std::vector<array>, std::vector<array>> vmap_trace(
function vmap_replace (line 754) | std::vector<array> vmap_replace(
function vmap (line 886) | std::function<std::vector<array>(const std::vector<array>&)> vmap(
function vmap (line 919) | std::function<array(const array&, const array&)> vmap(
function vmap (line 933) | std::function<array(const array&)> vmap(
function custom_function (line 946) | std::function<std::vector<array>(const std::vector<array>&)> custom_fu...
function custom_vjp (line 1043) | std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
function checkpoint (line 1052) | std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
FILE: mlx/transforms.h
function SimpleValueAndGradFn (line 102) | SimpleValueAndGradFn inline value_and_grad(
FILE: mlx/transforms_impl.h
function namespace (line 7) | namespace mlx::core::detail {
function retain_graph (line 51) | struct RetainGraph {
function in_tracing (line 69) | inline bool in_tracing() {
function in_dynamic_tracing (line 75) | inline bool in_dynamic_tracing() {
function in_grad_tracing (line 80) | inline bool in_grad_tracing() {
function retain_graph (line 84) | inline bool retain_graph() {
FILE: mlx/types/bf16.h
function namespace (line 13) | namespace mlx::core {
FILE: mlx/types/complex.h
function namespace (line 7) | namespace mlx::core {
FILE: mlx/types/fp16.h
function namespace (line 13) | namespace mlx::core {
FILE: mlx/types/half_types.h
function namespace (line 8) | namespace mlx::core {
function namespace (line 16) | namespace mlx::core {
function namespace (line 25) | namespace mlx::core {
function namespace (line 33) | namespace mlx::core {
FILE: mlx/types/limits.h
type numeric_limits (line 13) | struct numeric_limits
type numeric_limits (line 16) | struct numeric_limits
type numeric_limits (line 19) | struct numeric_limits
function float16_t (line 25) | constexpr static float16_t bits_to_half(uint16_t v) {
function bfloat16_t (line 45) | struct numeric_limits<bfloat16_t> {
FILE: mlx/utils.cpp
type mlx::core (line 12) | namespace mlx::core {
function Stream (line 14) | Stream to_stream(StreamOrDevice s) {
function Stream (line 24) | Stream to_stream(StreamOrDevice s, Device default_) {
function PrintFormatter (line 80) | PrintFormatter& get_global_formatter() {
function abort_with_exception (line 85) | void abort_with_exception(const std::exception& error) {
function Dtype (line 92) | Dtype result_type(const std::vector<array>& arrays) {
function Shape (line 100) | Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
function normalize_axis_index (line 133) | int normalize_axis_index(
function print_subarray (line 180) | void print_subarray(std::ostream& os, const array& a, size_t index, in...
function print_array (line 206) | void print_array(std::ostream& os, const array& a) {
type env (line 251) | namespace env {
function get_var (line 253) | int get_var(const char* name, int default_value) {
function get_var (line 261) | std::string get_var(const char* name, const char* default_value) {
function set_finfo_limits (line 272) | void set_finfo_limits(double& min, double& max, double& eps) {
function set_iinfo_limits (line 299) | void set_iinfo_limits(int64_t& min, uint64_t& max) {
FILE: mlx/utils.h
function namespace (line 14) | namespace mlx::core {
type PrintFormatter (line 41) | struct PrintFormatter {
function finfo (line 64) | struct MLX_API finfo {
function iinfo (line 73) | struct MLX_API iinfo {
function Dtype (line 81) | inline Dtype result_type(const array& a, const array& b) {
function Dtype (line 84) | inline Dtype result_type(const array& a, const array& b, const array& c) {
function is_power_of_2 (line 125) | inline bool is_power_of_2(int n) {
function next_power_of_2 (line 129) | inline int next_power_of_2(int n) {
function namespace (line 136) | namespace env {
FILE: mlx/version.cpp
type mlx::core (line 5) | namespace mlx::core {
FILE: mlx/version.h
function namespace (line 13) | namespace mlx::core {
FILE: python/mlx/__main__.py
function main (line 4) | def main() -> None:
FILE: python/mlx/_distributed_utils/common.py
class Host (line 13) | class Host:
class Hostfile (line 21) | class Hostfile:
method to_json (line 26) | def to_json(self):
method from_file (line 37) | def from_file(c
Condensed preview — 879 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (6,873K chars).
[
{
"path": ".clang-format",
"chars": 2552,
"preview": "---\nAccessModifierOffset: -1\nAlignAfterOpenBracket: AlwaysBreak\nAlignConsecutiveAssignments: false\nAlignConsecutiveDecla"
},
{
"path": ".github/ISSUE_TEMPLATE/bug_report.md",
"chars": 527,
"preview": "---\nname: Bug report\nabout: Create a report about an issue you've encountered\ntitle: \"[BUG] \"\nlabels: ''\nassignees: ''\n\n"
},
{
"path": ".github/actions/build-cuda-release/action.yml",
"chars": 768,
"preview": "name: 'Build CUDA wheel'\ndescription: 'Build CUDA wheel'\n\ninputs:\n arch:\n description: 'Platform architecture tag'\n "
},
{
"path": ".github/actions/build-docs/action.yml",
"chars": 947,
"preview": "name: 'Build Documentation'\ndescription: 'Build documentation'\n\nruns:\n using: \"composite\"\n steps:\n - name: Setup ma"
},
{
"path": ".github/actions/build-linux/action.yml",
"chars": 1314,
"preview": "name: 'Build and Test on Linux'\n\ninputs:\n toolkit:\n description: 'The toolkit to build with'\n required: false\n "
},
{
"path": ".github/actions/build-linux-release/action.yml",
"chars": 1049,
"preview": "name: 'Build Linux wheel'\ndescription: 'Build Linux wheel'\n\ninputs:\n build-backend:\n description: 'Build the backend"
},
{
"path": ".github/actions/build-macos/action.yml",
"chars": 2395,
"preview": "name: 'Build and Test on macOS'\ndescription: 'Build and test MLX on macOS'\n\nruns:\n using: \"composite\"\n steps:\n - na"
},
{
"path": ".github/actions/build-macos-release/action.yml",
"chars": 959,
"preview": "name: 'Build macOS release'\ndescription: 'Build MLX releases macOS'\n\ninputs:\n macos-target:\n description: 'macOS bui"
},
{
"path": ".github/actions/build-windows/action.yml",
"chars": 736,
"preview": "name: 'Build on Windows'\n\nruns:\n using: 'composite'\n steps:\n - name: Install Python package\n id: python-build\n"
},
{
"path": ".github/actions/setup-linux/action.yml",
"chars": 3360,
"preview": "name: 'Setup Linux Environment'\ndescription: 'Install dependencies for Linux builds'\n\ninputs:\n toolkit:\n description"
},
{
"path": ".github/actions/setup-macos/action.yml",
"chars": 602,
"preview": "name: 'Setup macOS Environment'\ndescription: 'Install dependencies for macOS builds'\n\ninputs:\n python-version:\n desc"
},
{
"path": ".github/actions/setup-windows/action.yml",
"chars": 1143,
"preview": "name: 'Setup Windows environment'\n\ninputs:\n python-version:\n description: 'Version of python to set up'\n required"
},
{
"path": ".github/actions/test-linux/action.yml",
"chars": 1806,
"preview": "name: 'Run Linux tests'\n\ninputs:\n has-gpu:\n description: 'Run GPU tests'\n required: false\n default: false\n\nrun"
},
{
"path": ".github/actions/test-windows/action.yml",
"chars": 478,
"preview": "name: 'Run tests on Windows'\n\nruns:\n using: 'composite'\n steps:\n - name: Run Python tests - CPU\n shell: bash\n "
},
{
"path": ".github/dependabot.yml",
"chars": 118,
"preview": "version: 2\nupdates:\n - package-ecosystem: \"github-actions\"\n directory: \"/\"\n schedule:\n interval: \"weekly\"\n"
},
{
"path": ".github/pull_request_template.md",
"chars": 571,
"preview": "## Proposed changes\n\nPlease include a description of the problem or feature this PR is addressing. If there is a corresp"
},
{
"path": ".github/scripts/build-sanitizer-tests.sh",
"chars": 1197,
"preview": "#!/bin/bash\nset -ex\n\nexport CMAKE_C_COMPILER=/usr/bin/clang\nexport CMAKE_CXX_COMPILER=/usr/bin/clang++\nBASE_CMAKE_ARGS=\""
},
{
"path": ".github/scripts/setup+build-cpp-linux-fedora-container.sh",
"chars": 554,
"preview": "#!/bin/bash\nset -ex\n\n# [Setup] Install dependencies inside the container.\ndnf update -y\ndnf install -y \\\n blas-devel \\\n"
},
{
"path": ".github/workflows/build_and_test.yml",
"chars": 4172,
"preview": "name: Build and Test\n\non:\n pull_request:\n push:\n branches:\n - main\n # For testing CI without starting a p"
},
{
"path": ".github/workflows/documentation.yml",
"chars": 529,
"preview": "name: Documentation\n\non:\n workflow_dispatch:\n\npermissions:\n contents: read\n\njobs:\n build:\n runs-on: ubuntu-22.04\n "
},
{
"path": ".github/workflows/nightly.yml",
"chars": 3002,
"preview": "name: Nightly Build\n\non:\n schedule:\n - cron: 33 6 * * 1-5\n workflow_dispatch:\n\npermissions:\n contents: read\n\njobs:"
},
{
"path": ".github/workflows/release.yml",
"chars": 7818,
"preview": "name: PyPI Release\n\non:\n push:\n tags:\n - 'v*'\n branches:\n - 'test-publish/*'\n workflow_dispatch:\n i"
},
{
"path": ".gitignore",
"chars": 757,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# tensor files\n*.safe\n*.safetensors\n\n# Metal "
},
{
"path": ".pre-commit-config.yaml",
"chars": 718,
"preview": "repos:\n- repo: https://github.com/pre-commit/pre-commit-hooks\n rev: v6.0.0\n hooks:\n - id: check-yaml\n # "
},
{
"path": "ACKNOWLEDGMENTS.md",
"chars": 14540,
"preview": "# Individual Contributors\n\nIf you wish to be acknowledged for your contributions, please list your name\nwith a short des"
},
{
"path": "CITATION.cff",
"chars": 583,
"preview": "cff-version: 1.2.0\ntitle: mlx\nmessage: >-\n If you use this software, please cite it using the\n metadata from this file"
},
{
"path": "CMakeLists.txt",
"chars": 14939,
"preview": "cmake_minimum_required(VERSION 3.25)\n\nif(NOT MLX_VERSION)\n file(STRINGS \"mlx/version.h\" _mlx_h_version REGEX \"^#define "
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 5544,
"preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
},
{
"path": "CONTRIBUTING.md",
"chars": 1284,
"preview": "# Contributing to MLX\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Reques"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright © 2023 Apple Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "MANIFEST.in",
"chars": 168,
"preview": "include CMakeLists.txt\ninclude mlx.pc.in\nrecursive-include mlx/ *\ninclude cmake/*\ninclude python/src/*\ninclude python/ml"
},
{
"path": "README.md",
"chars": 4597,
"preview": "# MLX\n\n[**Quickstart**](#quickstart) | [**Installation**](#installation) |\n[**Documentation**](https://ml-explore.github"
},
{
"path": "benchmarks/cpp/CMakeLists.txt",
"chars": 370,
"preview": "function(build_benchmark SRCFILE)\n get_filename_component(src_name ${SRCFILE} NAME_WE)\n set(target \"${src_name}\")\n ad"
},
{
"path": "benchmarks/cpp/autograd.cpp",
"chars": 891,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::c"
},
{
"path": "benchmarks/cpp/compare_devices.cpp",
"chars": 611,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <iostream>\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::co"
},
{
"path": "benchmarks/cpp/irregular_strides.cpp",
"chars": 5775,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <cstring>\n#include <iostream>\n#include <sstream>\n\n#include \"mlx/mlx.h\"\n#include"
},
{
"path": "benchmarks/cpp/single_ops.cpp",
"chars": 8345,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include \"mlx/mlx.h\"\n#include \"time_utils.h\"\n\nnamespace mx = mlx::core;\n\nvoid time_creat"
},
{
"path": "benchmarks/cpp/time_utils.h",
"chars": 1239,
"preview": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <chrono>\n#include <iomanip>\n#include <iostream>\n\n#include \"mlx/ml"
},
{
"path": "benchmarks/numpy/single_ops.py",
"chars": 829,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport numpy as np\nfrom time_utils import time_fn\n\n\ndef time_add():\n a = np.ones((100,"
},
{
"path": "benchmarks/numpy/time_utils.py",
"chars": 378,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport time\n\n\ndef time_fn(fn, *args):\n print(f\"Timing {fn.__name__} ...\", end=\" \")\n\n "
},
{
"path": "benchmarks/python/batch_matmul_bench.py",
"chars": 1374,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nB = 8\nT = 1024\nD ="
},
{
"path": "benchmarks/python/blas/bench_gemm.py",
"chars": 4677,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as m"
},
{
"path": "benchmarks/python/blas/bench_gemv.py",
"chars": 6331,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport os\nimport subprocess\nimport time\n\nimport matplotlib.pyplot as plt\nimport mlx.core "
},
{
"path": "benchmarks/python/comparative/README.md",
"chars": 639,
"preview": "Microbenchmarks comparing MLX to PyTorch\n========================================\n\nImplement the same microbenchmarks in"
},
{
"path": "benchmarks/python/comparative/bench_mlx.py",
"chars": 11914,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport time\nfrom functools import partial\n\nimport m"
},
{
"path": "benchmarks/python/comparative/bench_torch.py",
"chars": 10691,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport os\nimport time\n\nimport torch\nimport torch.cuda\nimport torch.mps\n\n\n"
},
{
"path": "benchmarks/python/comparative/compare.py",
"chars": 13380,
"preview": "# Copyright © 2023 Apple Inc.\n\n#!/usr/bin/env python\n\nimport argparse\nimport re\nfrom pathlib import Path\nfrom subprocess"
},
{
"path": "benchmarks/python/compile_bench.py",
"chars": 2682,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport math\nimport random\n\nimport mlx.core as mx\nfrom time_utils imp"
},
{
"path": "benchmarks/python/conv1d_bench.py",
"chars": 3653,
"preview": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport tor"
},
{
"path": "benchmarks/python/conv2d_bench_cpu.py",
"chars": 4227,
"preview": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_benc"
},
{
"path": "benchmarks/python/conv2d_train_bench_cpu.py",
"chars": 4139,
"preview": "import time\n\nimport mlx.core as mx\nimport mlx.nn\nimport mlx.optimizers as opt\nimport torch\n\n\ndef bench_mlx(steps: int = "
},
{
"path": "benchmarks/python/conv2d_transpose_bench_cpu.py",
"chars": 4245,
"preview": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_benc"
},
{
"path": "benchmarks/python/conv3d_bench.py",
"chars": 5639,
"preview": "import math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 2\nN_iter_bench = 10\nN_iter_fu"
},
{
"path": "benchmarks/python/conv3d_bench_cpu.py",
"chars": 3380,
"preview": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_benc"
},
{
"path": "benchmarks/python/conv3d_train_bench_cpu.py",
"chars": 4172,
"preview": "import time\n\nimport mlx.core as mx\nimport mlx.nn\nimport mlx.optimizers as opt\nimport torch\n\n\ndef bench_mlx(steps: int = "
},
{
"path": "benchmarks/python/conv3d_transpose_bench_cpu.py",
"chars": 3506,
"preview": "import argparse\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 1\nN_iter_benc"
},
{
"path": "benchmarks/python/conv_bench.py",
"chars": 4449,
"preview": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport tor"
},
{
"path": "benchmarks/python/conv_transpose_bench.py",
"chars": 4334,
"preview": "import argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport tor"
},
{
"path": "benchmarks/python/conv_unaligned_bench.py",
"chars": 3017,
"preview": "import math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_warmup = 10\nN_iter_bench = 100\nN_iter_"
},
{
"path": "benchmarks/python/distributed_bench.py",
"chars": 1310,
"preview": "# Copyright © 2024 Apple Inc.\n\n\"\"\"\nRun with:\n mpirun -n 2 python /path/to/distributed_bench.py\n\"\"\"\n\nimport time\n\nimpo"
},
{
"path": "benchmarks/python/einsum_bench.py",
"chars": 2508,
"preview": "# Copyright © 2024 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\nimport numpy as np\n\n\ndef timeit(fn, its=100, args=[]):"
},
{
"path": "benchmarks/python/fft_bench.py",
"chars": 3502,
"preview": "# Copyright © 2024 Apple Inc.\n\nimport matplotlib\nimport mlx.core as mx\nimport numpy as np\nimport sympy\nimport torch\nfrom"
},
{
"path": "benchmarks/python/gather_bench.py",
"chars": 1531,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_r"
},
{
"path": "benchmarks/python/gather_mm_bench.py",
"chars": 2080,
"preview": "# Copyright © 2025 Apple Inc.\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nN = 1024\nD = 1024\nM = 1024\nE = 32\nI"
},
{
"path": "benchmarks/python/gather_qmm_bench.py",
"chars": 2390,
"preview": "# Copyright © 2025 Apple Inc.\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nN = 1024\nD = 1024\nM = 1024\nE = 32\nI"
},
{
"path": "benchmarks/python/hadamard_bench.py",
"chars": 1916,
"preview": "import argparse\n\nimport matplotlib\nimport mlx.core as mx\nimport numpy as np\nfrom time_utils import measure_runtime\n\nmatp"
},
{
"path": "benchmarks/python/large_gemm_bench.py",
"chars": 3196,
"preview": "# Copyright © 2026 Apple Inc.\n\nimport math\nimport time\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\n\nN_WARMUP "
},
{
"path": "benchmarks/python/layer_norm_bench.py",
"chars": 2519,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_u"
},
{
"path": "benchmarks/python/masked_scatter.py",
"chars": 6973,
"preview": "import math\nimport os\nimport platform\nimport subprocess\nimport time\nfrom copy import copy\nfrom functools import partial\n"
},
{
"path": "benchmarks/python/rms_norm_bench.py",
"chars": 1812,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_utils import time_fn\n\n\ndef rms_n"
},
{
"path": "benchmarks/python/rope_bench.py",
"chars": 636,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom time_utils import time_fn\n\n\ndef time_"
},
{
"path": "benchmarks/python/scatter_bench.py",
"chars": 2677,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_r"
},
{
"path": "benchmarks/python/sdpa_bench.py",
"chars": 6954,
"preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport math\nimport os\nimport subprocess\nimport time\n\nimport mlx.core as m"
},
{
"path": "benchmarks/python/sdpa_vector_bench.py",
"chars": 2658,
"preview": "import argparse\nimport math\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\nL = 16384\nH = 32\nH_k = H // 4\nD = 128"
},
{
"path": "benchmarks/python/segmented_mm_bench.py",
"chars": 6141,
"preview": "# Copyright © 2026 Apple Inc.\n\nimport argparse\nimport time\n\nimport mlx.core as mx\nimport numpy as np\n\nMLX_DTYPES = {\n "
},
{
"path": "benchmarks/python/single_ops.py",
"chars": 2686,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nfrom time_utils import time_fn\n\n\ndef time_add():\n "
},
{
"path": "benchmarks/python/slice_update_bench.py",
"chars": 3585,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport torch\nfrom time_utils import measure_r"
},
{
"path": "benchmarks/python/synchronize_bench.py",
"chars": 1075,
"preview": "import time\n\nimport mlx.core as mx\n\nrank = mx.distributed.init().rank()\n\n\ndef timeit(fn, a):\n\n # warmup\n for _ in "
},
{
"path": "benchmarks/python/time_utils.py",
"chars": 801,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\n\ndef time_fn(fn, *args, **kwargs):\n msg = kwa"
},
{
"path": "cmake/FindCUDNN.cmake",
"chars": 5928,
"preview": "# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Permission is hereby granted, free of charge, to any "
},
{
"path": "cmake/FindNCCL.cmake",
"chars": 1712,
"preview": "# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include\n# directories.\n\nset(NCCL_ROOT_DIR\n $ENV{NC"
},
{
"path": "cmake/Findnvpl.cmake",
"chars": 211,
"preview": "# This file does nothing but to suppress the cmake warning: \"By not providing\n# Findnvpl.cmake in CMAKE_MODULE_PATH...\","
},
{
"path": "cmake/extension.cmake",
"chars": 1868,
"preview": "include(CMakeParseArguments)\n\n# clang format off\n#\n# ###################################################################"
},
{
"path": "docs/.clang-format",
"chars": 40,
"preview": "DisableFormat: true\nSortIncludes: Never\n"
},
{
"path": "docs/.gitignore",
"chars": 92,
"preview": "src/python/_autosummary*/\nsrc/python/nn/_autosummary*/\nsrc/python/optimizers/_autosummary*/\n"
},
{
"path": "docs/.nojekyll",
"chars": 0,
"preview": ""
},
{
"path": "docs/Doxyfile",
"chars": 2021,
"preview": "################################################################################\n# Primary project setup. "
},
{
"path": "docs/Makefile",
"chars": 580,
"preview": "# Minimal makefile for Sphinx documentation\n\n# You can set these variables from the command line.\nSPHINXOPTS =\nSPHINX"
},
{
"path": "docs/README.md",
"chars": 823,
"preview": "## Build the Docs\n\n### Setup (do once)\n\nInstall Doxygen:\n\n```\nbrew install doxygen\n```\n\nInstall Python packages:\n\n```\npi"
},
{
"path": "docs/index.html",
"chars": 71,
"preview": "<meta http-equiv=\"refresh\" content=\"0; url=./build/html/index.html\" />\n"
},
{
"path": "docs/requirements.txt",
"chars": 55,
"preview": "sphinx\nbreathe\nsphinx-book-theme\nsphinx-copybutton\nmlx\n"
},
{
"path": "docs/src/_templates/module-base-class.rst",
"chars": 693,
"preview": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. add toctree option to make autodoc generate the "
},
{
"path": "docs/src/_templates/nn-module-template.rst",
"chars": 398,
"preview": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n\n {% block methods %"
},
{
"path": "docs/src/_templates/optimizers-template.rst",
"chars": 375,
"preview": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n\n {% block methods %"
},
{
"path": "docs/src/conf.py",
"chars": 2510,
"preview": "# Copyright © 2023 Apple Inc.\n\n# -*- coding: utf-8 -*-\n\nimport os\nimport subprocess\n\nimport mlx.core as mx\n\n# -- Project"
},
{
"path": "docs/src/cpp/ops.rst",
"chars": 77,
"preview": ".. _cpp_ops:\n\nOperations\n==========\n\n.. doxygengroup:: ops\n :content-only:\n"
},
{
"path": "docs/src/dev/custom_metal_kernels.rst",
"chars": 15210,
"preview": ".. _custom_metal_kernels:\n\nCustom Metal Kernels\n====================\n\nMLX supports writing custom Metal kernels through "
},
{
"path": "docs/src/dev/extensions.rst",
"chars": 28102,
"preview": "Custom Extensions in MLX\n========================\n\nYou can extend MLX with custom operations on the CPU or GPU. This gui"
},
{
"path": "docs/src/dev/metal_debugger.rst",
"chars": 1873,
"preview": "Metal Debugger\n==============\n\n.. currentmodule:: mlx.core\n\nProfiling is a key step for performance optimization. You ca"
},
{
"path": "docs/src/dev/metal_logging.rst",
"chars": 1090,
"preview": "Metal Logging\n=============\n\nIn debug builds, MLX compiles Metal kernels with ``os_log`` enabled so shader\nwarnings and "
},
{
"path": "docs/src/dev/mlx_in_cpp.rst",
"chars": 2692,
"preview": ".. _mlx_in_cpp:\n\nUsing MLX in C++\n================\n\nYou can use MLX in a C++ project with CMake.\n\n.. note::\n\n This guid"
},
{
"path": "docs/src/examples/data_parallelism.rst",
"chars": 2545,
"preview": ".. _data_parallelism:\n\nData Parallelism\n================\n\nMLX enables efficient data parallel distributed training throu"
},
{
"path": "docs/src/examples/linear_regression.rst",
"chars": 2085,
"preview": ".. _linear_regression:\n\nLinear Regression\n-----------------\n\nLet's implement a basic linear regression model as a starti"
},
{
"path": "docs/src/examples/llama-inference.rst",
"chars": 19397,
"preview": "LLM inference\n==============\n\nMLX enables efficient inference of large-ish transformers on Apple silicon\nwithout comprom"
},
{
"path": "docs/src/examples/mlp.rst",
"chars": 4136,
"preview": ".. _mlp:\n\nMulti-Layer Perceptron\n----------------------\n\nIn this example we'll learn to use ``mlx.nn`` by implementing a"
},
{
"path": "docs/src/examples/tensor_parallelism.rst",
"chars": 10394,
"preview": ".. _tensor_parallelism:\n\nTensor Parallelism\n==================\n\nIn this example, we will explore how tensor parallelism "
},
{
"path": "docs/src/index.rst",
"chars": 2344,
"preview": "MLX\n===\n\nMLX is a NumPy-like array framework designed for efficient and flexible machine\nlearning on Apple silicon, brou"
},
{
"path": "docs/src/install.rst",
"chars": 8387,
"preview": ".. _build_and_install:\n\nBuild and Install\n=================\n\nPython Installation\n-------------------\n\nMLX is available o"
},
{
"path": "docs/src/python/array.rst",
"chars": 980,
"preview": ".. _array:\n\nArray\n=====\n\n.. currentmodule:: mlx.core\n\n.. autosummary:: \n :toctree: _autosummary \n\n array\n array."
},
{
"path": "docs/src/python/cuda.rst",
"chars": 104,
"preview": "CUDA\n=====\n\n.. currentmodule:: mlx.core.cuda\n\n.. autosummary::\n :toctree: _autosummary\n\n is_available\n"
},
{
"path": "docs/src/python/data_types.rst",
"chars": 1617,
"preview": ".. _data_types:\n\nData Types\n==========\n\n.. currentmodule:: mlx.core\n\nThe default floating point type is ``float32`` and "
},
{
"path": "docs/src/python/devices_and_streams.rst",
"chars": 309,
"preview": ".. _devices_and_streams:\n\nDevices and Streams\n===================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n :toct"
},
{
"path": "docs/src/python/distributed.rst",
"chars": 426,
"preview": ".. _distributed:\n\n.. currentmodule:: mlx.core.distributed\n\nDistributed Communication\n==========================\n\nMLX pro"
},
{
"path": "docs/src/python/export.rst",
"chars": 187,
"preview": ".. _export:\n\nExport Functions\n================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n :toctree: _autosummary\n\n"
},
{
"path": "docs/src/python/fast.rst",
"chars": 191,
"preview": ".. _fast:\n\nFast\n====\n\n.. currentmodule:: mlx.core.fast\n\n.. autosummary:: \n :toctree: _autosummary\n\n rms_norm\n layer_n"
},
{
"path": "docs/src/python/fft.rst",
"chars": 211,
"preview": ".. _fft:\n\nFFT\n===\n\n.. currentmodule:: mlx.core.fft\n\n.. autosummary:: \n :toctree: _autosummary\n\n fft\n ifft\n fft2\n if"
},
{
"path": "docs/src/python/linalg.rst",
"chars": 311,
"preview": ".. _linalg:\n\nLinear Algebra\n==============\n\n.. currentmodule:: mlx.core.linalg\n\n.. autosummary::\n :toctree: _autosumma"
},
{
"path": "docs/src/python/memory_management.rst",
"chars": 255,
"preview": "Memory Management\n=================\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n :toctree: _autosummary\n\n get_activ"
},
{
"path": "docs/src/python/metal.rst",
"chars": 151,
"preview": "Metal\n=====\n\n.. currentmodule:: mlx.core.metal\n\n.. autosummary::\n :toctree: _autosummary\n\n is_available\n device_info\n"
},
{
"path": "docs/src/python/nn/distributed.rst",
"chars": 577,
"preview": ".. _nn_distributed:\n\nDistributed\n-----------\n\nHelper Routines\n^^^^^^^^^^^^^^^\n\nThe :code:`mlx.nn.layers.distributed` pac"
},
{
"path": "docs/src/python/nn/functions.rst",
"chars": 536,
"preview": ".. _nn_functions:\n\n.. currentmodule:: mlx.nn\n\nFunctions\n---------\n\nLayers without parameters (e.g. activation functions)"
},
{
"path": "docs/src/python/nn/init.rst",
"chars": 929,
"preview": ".. _init:\n\n.. currentmodule:: mlx.nn.init\n\nInitializers\n------------\n\nThe ``mlx.nn.init`` package contains commonly used"
},
{
"path": "docs/src/python/nn/layers.rst",
"chars": 964,
"preview": ".. _layers:\n\n.. currentmodule:: mlx.nn\n\nLayers\n------\n\n.. autosummary::\n :toctree: _autosummary\n :template: nn-modul"
},
{
"path": "docs/src/python/nn/losses.rst",
"chars": 408,
"preview": ".. _losses:\n\n.. currentmodule:: mlx.nn.losses\n\nLoss Functions\n--------------\n\n.. autosummary::\n :toctree: _autosummary"
},
{
"path": "docs/src/python/nn/module.rst",
"chars": 700,
"preview": "Module\n======\n\n.. currentmodule:: mlx.nn\n\n.. autoclass:: Module\n\n .. rubric:: Attributes\n\n .. autosummary::\n :t"
},
{
"path": "docs/src/python/nn.rst",
"chars": 5869,
"preview": ".. _nn:\n\n.. currentmodule:: mlx.nn\n\nNeural Networks\n===============\n\nWriting arbitrarily complex neural networks in MLX "
},
{
"path": "docs/src/python/ops.rst",
"chars": 2069,
"preview": ".. _ops:\n\nOperations\n==========\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n :toctree: _autosummary\n\n abs\n add\n "
},
{
"path": "docs/src/python/optimizers/common_optimizers.rst",
"chars": 293,
"preview": ".. _common_optimizers:\n\nCommon Optimizers\n=================\n\n.. currentmodule:: mlx.optimizers\n\n.. autosummary::\n :toc"
},
{
"path": "docs/src/python/optimizers/optimizer.rst",
"chars": 340,
"preview": "Optimizer\n=========\n\n.. currentmodule:: mlx.optimizers\n\n.. autoclass:: Optimizer \n\n\n .. rubric:: Attributes\n\n .. aut"
},
{
"path": "docs/src/python/optimizers/schedulers.rst",
"chars": 219,
"preview": ".. _schedulers:\n\nSchedulers\n==========\n\n.. currentmodule:: mlx.optimizers\n\n.. autosummary::\n :toctree: _autosummary\n\n "
},
{
"path": "docs/src/python/optimizers.rst",
"chars": 2467,
"preview": ".. _optimizers:\n\n.. currentmodule:: mlx.optimizers\n\nOptimizers\n==========\n\nThe optimizers in MLX can be used both with :"
},
{
"path": "docs/src/python/random.rst",
"chars": 1064,
"preview": ".. _random:\n\nRandom\n======\n\nRandom sampling functions in MLX use an implicit global PRNG state by default.\nHowever, all "
},
{
"path": "docs/src/python/transforms.rst",
"chars": 263,
"preview": ".. _transforms:\n\nTransforms\n==========\n\n.. currentmodule:: mlx.core\n\n.. autosummary::\n :toctree: _autosummary\n\n eval\n"
},
{
"path": "docs/src/python/tree_utils.rst",
"chars": 579,
"preview": ".. _utils:\n\nTree Utils\n==========\n\nIn MLX we consider a python tree to be an arbitrarily nested collection of\ndictionari"
},
{
"path": "docs/src/usage/compile.rst",
"chars": 13358,
"preview": ".. _compile:\n\nCompilation\n===========\n\n.. currentmodule:: mlx.core\n\nMLX has a :func:`compile` function transformation wh"
},
{
"path": "docs/src/usage/distributed.rst",
"chars": 20583,
"preview": ".. _usage_distributed:\n\nDistributed Communication\n=========================\n\n.. currentmodule:: mlx.core.distributed\n\nML"
},
{
"path": "docs/src/usage/export.rst",
"chars": 9054,
"preview": ".. _export_usage:\n\nExporting Functions\n===================\n\n.. currentmodule:: mlx.core\n\nMLX has an API to export and im"
},
{
"path": "docs/src/usage/function_transforms.rst",
"chars": 6001,
"preview": ".. _function_transforms:\n\nFunction Transforms\n===================\n\n.. currentmodule:: mlx.core\n\nMLX uses composable func"
},
{
"path": "docs/src/usage/indexing.rst",
"chars": 5333,
"preview": ".. _indexing:\n\nIndexing Arrays\n===============\n\n.. currentmodule:: mlx.core\n\nFor the most part, indexing an MLX :obj:`ar"
},
{
"path": "docs/src/usage/launching_distributed.rst",
"chars": 8235,
"preview": ":orphan:\n\n.. _usage_launch_distributed:\n\nLaunching Distributed Programs\n==============================\n\n.. currentmodule"
},
{
"path": "docs/src/usage/lazy_evaluation.rst",
"chars": 4720,
"preview": ".. _lazy eval:\n\nLazy Evaluation\n===============\n\n.. currentmodule:: mlx.core\n\nWhy Lazy Evaluation\n-------------------\n\nW"
},
{
"path": "docs/src/usage/numpy.rst",
"chars": 3252,
"preview": ".. _numpy:\n\nConversion to NumPy and Other Frameworks\n========================================\n\nMLX array supports conver"
},
{
"path": "docs/src/usage/quick_start.rst",
"chars": 1806,
"preview": "Quick Start Guide\n=================\n\n\nBasics\n------\n\n.. currentmodule:: mlx.core\n\nImport ``mlx.core`` and make an :class"
},
{
"path": "docs/src/usage/saving_and_loading.rst",
"chars": 2096,
"preview": ".. _saving_and_loading:\n\nSaving and Loading Arrays\n=========================\n\n.. currentmodule:: mlx.core\n\nMLX supports "
},
{
"path": "docs/src/usage/unified_memory.rst",
"chars": 2609,
"preview": ".. _unified_memory:\n\nUnified Memory\n==============\n\n.. currentmodule:: mlx.core\n\nApple silicon has a unified memory arch"
},
{
"path": "docs/src/usage/using_streams.rst",
"chars": 651,
"preview": ".. _using_streams:\n\nUsing Streams\n=============\n\n.. currentmodule:: mlx.core\n\nSpecifying the :obj:`Stream`\n~~~~~~~~~~~~~"
},
{
"path": "examples/cmake_project/CMakeLists.txt",
"chars": 594,
"preview": "cmake_minimum_required(VERSION 3.27)\n\nproject(example LANGUAGES CXX)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_"
},
{
"path": "examples/cmake_project/README.md",
"chars": 286,
"preview": "## Build and Run \n\nInstall MLX with Python:\n\n```bash\npip install mlx>=0.22\n```\n\nBuild the C++ example:\n\n```bash\ncmake -B"
},
{
"path": "examples/cmake_project/example.cpp",
"chars": 230,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n aut"
},
{
"path": "examples/cpp/CMakeLists.txt",
"chars": 396,
"preview": "function(build_example SRCFILE)\n get_filename_component(src_name ${SRCFILE} NAME_WE)\n set(target \"${src_name}\")\n add_"
},
{
"path": "examples/cpp/distributed.cpp",
"chars": 498,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n\nint main() {\n if "
},
{
"path": "examples/cpp/linear_regression.cpp",
"chars": 1384,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <chrono>\n#include <cmath>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"t"
},
{
"path": "examples/cpp/logistic_regression.cpp",
"chars": 1342,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <chrono>\n#include <cmath>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n#include \"t"
},
{
"path": "examples/cpp/metal_capture.cpp",
"chars": 924,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include <cassert>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n"
},
{
"path": "examples/cpp/timer.h",
"chars": 331,
"preview": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <chrono>\n\nnamespace timer {\n\nusing namespace std::chrono;\n\ntempla"
},
{
"path": "examples/cpp/tutorial.cpp",
"chars": 2758,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <cassert>\n#include <iostream>\n\n#include \"mlx/mlx.h\"\n\nnamespace mx = mlx::core;\n"
},
{
"path": "examples/export/CMakeLists.txt",
"chars": 555,
"preview": "cmake_minimum_required(VERSION 3.27)\n\nproject(import_mlx LANGUAGES CXX)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDA"
},
{
"path": "examples/export/README.md",
"chars": 685,
"preview": "## Setup\n\nInstall MLX:\n\n```bash\npip install mlx>=0.22\n```\n\nBuild the C++ examples:\n\n```bash\ncmake -B build -DCMAKE_BUILD"
},
{
"path": "examples/export/eval_mlp.cpp",
"chars": 470,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include <mlx/mlx.h>\n#include <iostream>\n\nnamespace mx = mlx::core;\n\nint main() {\n int "
},
{
"path": "examples/export/eval_mlp.py",
"chars": 1363,
"preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.utils\n\n\nclass MLP(nn.Module):\n \"\""
},
{
"path": "examples/export/train_mlp.cpp",
"chars": 828,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include <mlx/mlx.h>\n#include <iostream>\n\nnamespace mx = mlx::core;\n\nint main() {\n int "
},
{
"path": "examples/export/train_mlp.py",
"chars": 2445,
"preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.optimizers as optim\nimport mlx.utils"
},
{
"path": "examples/extensions/CMakeLists.txt",
"chars": 1931,
"preview": "cmake_minimum_required(VERSION 3.27)\n\nproject(_ext LANGUAGES CXX)\n\n# ----------------------------- Setup ---------------"
},
{
"path": "examples/extensions/README.md",
"chars": 256,
"preview": "\n## Build\n\n```\npip install -e .\n```\n\nFor faster builds during development, you can also pre-install the requirements:\n\n`"
},
{
"path": "examples/extensions/axpby/axpby.cpp",
"chars": 10474,
"preview": "// Copyright © 2023-2025 Apple Inc.\n\n#include <dlfcn.h>\n#include <iostream>\n#include <sstream>\n\n#include \"mlx/backend/co"
},
{
"path": "examples/extensions/axpby/axpby.h",
"chars": 2720,
"preview": "// Copyright © 2023-2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/ops.h\"\n#include \"mlx/primitives.h\"\n\nnamespace mx = mlx:"
},
{
"path": "examples/extensions/axpby/axpby.metal",
"chars": 1696,
"preview": "// Copyright © 2023-2025 Apple Inc.\n\n#include <metal_stdlib>\n\n#include \"mlx/backend/metal/kernels/utils.h\"\n\ntemplate <ty"
},
{
"path": "examples/extensions/bindings.cpp",
"chars": 879,
"preview": "// Copyright © 2023-2024 Apple Inc.\n\n#include <nanobind/nanobind.h>\n#include <nanobind/stl/variant.h>\n\n#include \"axpby/a"
},
{
"path": "examples/extensions/mlx_sample_extensions/__init__.py",
"chars": 78,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport mlx.core as mx\n\nfrom ._ext import axpby\n"
},
{
"path": "examples/extensions/pyproject.toml",
"chars": 146,
"preview": "[build-system]\nrequires = [\n \"setuptools>=42\",\n \"cmake>=3.25\",\n \"mlx>=0.18.0\",\n \"nanobind==2.10.2\",\n]\nbuild-backend "
},
{
"path": "examples/extensions/requirements.txt",
"chars": 56,
"preview": "setuptools>=42\ncmake>=3.25\nmlx>=0.21.0\nnanobind==2.10.2\n"
},
{
"path": "examples/extensions/setup.py",
"chars": 591,
"preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom setuptools import setup\n\nfrom mlx import extension\n\nif __name__ == \"__main__\":\n"
},
{
"path": "examples/extensions/test.py",
"chars": 370,
"preview": "import mlx.core as mx\nfrom mlx_sample_extensions import axpby\n\na = mx.ones((3, 4))\nb = mx.ones((3, 4))\nc_cpu = axpby(a, "
},
{
"path": "examples/python/linear_regression.py",
"chars": 914,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\nnum_features = 100\nnum_examples = 1_000\nnum_iters = 1"
},
{
"path": "examples/python/logistic_regression.py",
"chars": 864,
"preview": "# Copyright © 2023 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\nnum_features = 100\nnum_examples = 1_000\nnum_iters = 1"
},
{
"path": "examples/python/qqmm.py",
"chars": 3830,
"preview": "from itertools import product\n\nimport mlx.core as mx\n\n\n# In mxfp8 mode, the results do not match exactly:\n# fewer than 1"
},
{
"path": "mlx/3rdparty/.clang-format",
"chars": 40,
"preview": "DisableFormat: true\nSortIncludes: Never\n"
},
{
"path": "mlx/3rdparty/pocketfft.h",
"chars": 110508,
"preview": "/*\nThis file is part of pocketfft.\n\nCopyright (C) 2010-2022 Max-Planck-Society\nCopyright (C) 2019-2020 Peter Bell\n\nFor t"
},
{
"path": "mlx/CMakeLists.txt",
"chars": 4113,
"preview": "target_sources(\n mlx\n PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp\n ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp\n"
},
{
"path": "mlx/allocator.h",
"chars": 1889,
"preview": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include <cstdlib>\n\n#include \"mlx/api.h\"\n\nnamespace mlx::core::allocator {"
},
{
"path": "mlx/api.h",
"chars": 624,
"preview": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n// MLX_API macro for controlling symbol visibility, must add for public AP"
},
{
"path": "mlx/array.cpp",
"chars": 9971,
"preview": "// Copyright © 2023-2024 Apple Inc.\n#include <functional>\n#include <unordered_map>\n\n#include \"mlx/array.h\"\n#include \"mlx"
},
{
"path": "mlx/array.h",
"chars": 17612,
"preview": "// Copyright © 2023 Apple Inc.\n#pragma once\n\n#include <algorithm>\n#include <cstdint>\n#include <functional>\n#include <mem"
},
{
"path": "mlx/backend/common/CMakeLists.txt",
"chars": 372,
"preview": "target_sources(\n mlx\n PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp\n ${CMAKE_CURRENT_SOURCE_DIR}/compi"
},
{
"path": "mlx/backend/common/binary.h",
"chars": 2626,
"preview": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/allocator.h\"\n#include \"mlx/array.h\"\n#include \"mlx/backend/co"
},
{
"path": "mlx/backend/common/broadcasting.cpp",
"chars": 619,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nvoid broadcast(const array"
},
{
"path": "mlx/backend/common/broadcasting.h",
"chars": 164,
"preview": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nvoid broadcast(const array&"
},
{
"path": "mlx/backend/common/buffer_cache.h",
"chars": 3737,
"preview": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include <algorithm>\n#include <cassert>\n#include <functional>\n#include <ma"
},
{
"path": "mlx/backend/common/common.cpp",
"chars": 9490,
"preview": "// Copyright © 2024 Apple Inc.\n#include <cassert>\n\n#include \"mlx/backend/common/broadcasting.h\"\n#include \"mlx/backend/co"
},
{
"path": "mlx/backend/common/compiled.cpp",
"chars": 6698,
"preview": "// Copyright © 2023-2024 Apple Inc.\n\n#include \"mlx/backend/common/compiled.h\"\n#include \"mlx/backend/common/utils.h\"\n#inc"
},
{
"path": "mlx/backend/common/compiled.h",
"chars": 2285,
"preview": "// Copyright © 2023-2024 Apple Inc.\n#pragma once\n\n#include <functional>\n#include <iomanip>\n\n#include \"mlx/array.h\"\n#incl"
},
{
"path": "mlx/backend/common/copy.h",
"chars": 1246,
"preview": "// Copyright © 2023-2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nenum cl"
},
{
"path": "mlx/backend/common/hadamard.h",
"chars": 2377,
"preview": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include <map>\n\n#include \"mlx/utils.h\"\n\nnamespace mlx::core {\n\n// From htt"
},
{
"path": "mlx/backend/common/load.cpp",
"chars": 1562,
"preview": "// Copyright © 2023 Apple Inc.\n\n#include <algorithm>\n#include <utility>\n\n#include \"mlx/primitives.h\"\n#include \"mlx/sched"
},
{
"path": "mlx/backend/common/matmul.h",
"chars": 1929,
"preview": "// Copyright © 2025 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n#include \"mlx/utils.h\"\n\n#include <ss"
},
{
"path": "mlx/backend/common/quantized.h",
"chars": 406,
"preview": "// Copyright © 2026 Apple Inc.\n\nnamespace mlx::core {\n\ninline constexpr short get_pack_factor(int bits, int wsize = 8) {"
},
{
"path": "mlx/backend/common/reduce.cpp",
"chars": 4938,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/reduce.h\"\n\nnamespace mlx::core {\n\nstd::pair<Shape, Strides>"
},
{
"path": "mlx/backend/common/reduce.h",
"chars": 1775,
"preview": "// Copyright © 2023 Apple Inc.\n\n#pragma once\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nenum Reducti"
},
{
"path": "mlx/backend/common/slicing.cpp",
"chars": 1972,
"preview": "// Copyright © 2024 Apple Inc.\n\n#include \"mlx/backend/common/utils.h\"\n\nnamespace mlx::core {\n\nstd::tuple<int64_t, Stride"
},
{
"path": "mlx/backend/common/slicing.h",
"chars": 352,
"preview": "// Copyright © 2024 Apple Inc.\n\n#pragma once\n\n#include \"mlx/array.h\"\n\nnamespace mlx::core {\n\nstd::tuple<int64_t, Strides"
}
]
// ... and 679 more files (download for full content)
About this extraction
This page contains the full source code of the ml-explore/mlx GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 879 files (6.3 MB), approximately 1.7M tokens, and a symbol index with 4174 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.